00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055 #if !defined(ALIZE_TrainWorld_cpp)
00056 #define ALIZE_TrainWorld_cpp
00057
00058 #include <iostream>
00059 #include <cmath>
00060 #include "liatools.h"
00061 #include "TrainWorld.h"
00062
00063 using namespace alize;
00064 using namespace std;
00065
00066 void featureStream(Config &config,String filename,FeatureServer *&fs,SegServer *&segServ,SegCluster *&segCluster,String labelSelectedFrames){
00067 fs=new FeatureServer(config,filename);
00068 try{
00069
00070
00071 segServ=new SegServer;
00072 LabelServer labelServer;
00073 initializeClusters(filename,*segServ,labelServer,config);
00074 verifyClusterFile(*segServ,*fs,config);
00075 unsigned long codeSelectedFrame=labelServer.getLabelIndexByString(labelSelectedFrames);
00076 segCluster=&(segServ->getCluster(codeSelectedFrame));
00077 }
00078 catch (Exception& e){
00079 cout << e.toString() << endl;
00080 }
00081 }
00082 void reserveMem(FeatureServer** &fsTab,SegServer** &segServTab,SegCluster** &segTab,double *&weightTab,unsigned long nbStream){
00083 fsTab=new FeatureServer*[nbStream];
00084 segServTab=new SegServer*[nbStream];
00085 segTab=new SegCluster*[nbStream];
00086 weightTab=new double [nbStream];
00087 for (unsigned long i=0;i<nbStream;i++)weightTab[i]=1/(double) nbStream;
00088 }
00089 void freeMem(FeatureServer** &fsTab,SegServer** &segServTab,SegCluster** &segTab,double *&weightTab,unsigned long nbStream){
00090 for (unsigned long i=0;i<nbStream;i++){
00091 delete fsTab[i];
00092 delete segServTab[i];
00093 }
00094 delete [] fsTab;
00095 delete [] segServTab;
00096 delete [] segTab;
00097 delete [] weightTab;
00098 }
00099
00100
00101 int trainWorld(Config& config){
00102 if (verbose) cout << "Begin world model training"<<endl;
00103 try{
00104
00105 unsigned long nbStream=0;
00106 FeatureServer **fsTab=NULL;
00107 SegServer **segServTab=NULL;
00108 SegCluster **segTab=NULL;
00109 double *weightTab=NULL;
00110 String outputWorldFilename = config.getParam("outputWorldFilename");
00111 bool fileInit=config.existsParam("inputWorldFilename");
00112 bool saveInitModel=true;
00113 if (config.existsParam("saveInitModel")) saveInitModel=config.getParam("saveInitModel").toBool();
00114 String inputWorldFilename="";
00115 if (fileInit) inputWorldFilename=config.getParam("inputWorldFilename");
00116 String labelSelectedFrames =config.getParam("labelSelectedFrames");
00117 TrainCfg trainCfg(config);
00118
00119
00120 if(config.existsParam("inputStreamList")){
00121 XList tmp(config.getParam("inputStreamList"),config);
00122 XLine & listInputFilename=tmp.getAllElements();
00123 nbStream=listInputFilename.getElementCount();
00124 if (nbStream==0) throw Exception("TrainWorld error:no input stream" , __FILE__, __LINE__);
00125 reserveMem(fsTab,segServTab,segTab,weightTab,nbStream);
00126 for (unsigned i=0;i<nbStream;i++)
00127 featureStream(config,listInputFilename.getElement(i),fsTab[i],segServTab[i],segTab[i],labelSelectedFrames);
00128 if (config.existsParam("weightStreamList")){
00129 XList tmpW(config.getParam("weightStreamList"),config);
00130 XLine & listW=tmpW.getAllElements();
00131 if (listW.getElementCount()!=nbStream) throw Exception("TrainWorld error: number of weigths differs than number of input streams" , __FILE__, __LINE__);
00132 for (unsigned i=0;i<nbStream;i++) weightTab[i]=listW.getElement(i).toDouble();
00133 }
00134 }
00135 else{
00136 nbStream=1;
00137 reserveMem(fsTab,segServTab,segTab,weightTab,nbStream);
00138 featureStream(config,config.getParam("inputFeatureFilename"),fsTab[0],segServTab[0],segTab[0],labelSelectedFrames);
00139 }
00140 unsigned long vectSize=fsTab[0]->getVectSize();
00141
00142 MixtureServer ms(config);
00143 StatServer ss(config, ms);
00144 if (debug || verbose) cout << "Stream mode, nb Stream="<<nbStream<<endl;
00145 if (debug|| (verboseLevel>2)){
00146 for (unsigned long i=0;i<nbStream;i++){
00147 cout <<"Stream["<<i<<"]"<<endl;
00148 segTab[i]->rewind();
00149 Seg *seg;
00150 while((seg=segTab[i]->getSeg())!=NULL)
00151 cout << "File["<<seg->sourceName()<<"] Segment begin["<<
00152 seg->begin()<<"] length["<<seg->length()<<"] index in the feature server["<<fsTab[i]->getFirstFeatureIndexOfASource(seg->sourceName())<<"]"<<endl;
00153 }
00154 }
00155
00156 bool use01=false;
00157 if (config.existsParam("use01")) use01=config.getParam("use01").toBool();
00158 if (verbose){ if (use01) cout<<"Use 0 mean, 1 cov "<<endl; else cout << "Compute global mean and cov"<<endl;}
00159 DoubleVector globalMean;
00160 DoubleVector globalCov;
00161 if (!use01){
00162 FrameAccGD globalFrameAcc;
00163 unsigned long nbFrame=computeMeanCov(config,fsTab,segTab,nbStream,globalMean,globalCov);
00164 if (verboseLevel>1){
00165 cout <<"global mean and cov of training data, number of frame= ["<<nbFrame<<"]"<<endl;
00166 for (unsigned i=0; i < vectSize; i++)cout << "mean[" << i << "=" << globalMean[i] << "]\tcov[" << globalCov[i] << "]" << endl;
00167 }
00168 }
00169 else initialize01(vectSize,globalMean,globalCov);
00170 MixtureGD &world=ms.createMixtureGD();
00171 if (fileInit){
00172 if (verbose) cout << "Load initial world model ["<<inputWorldFilename<<"]" << endl;
00173 world=ms.loadMixtureGD(inputWorldFilename);
00174 }
00175 else{
00176 if (verbose) cout <<"World model init from scratch"<<endl;
00177 mixtureInit(ms,fsTab,segTab,weightTab,nbStream,world,globalCov,config,trainCfg);
00178 if (saveInitModel) world.save(outputWorldFilename+"init", config);
00179 }
00180 MixtureGD *newWorld=&world;
00181 trainModelStream(config,ms,ss,fsTab,segTab,weightTab,nbStream,globalMean,globalCov,newWorld,trainCfg);
00182 if (verbose) cout << "Save world model ["<<outputWorldFilename<<"]" << endl;
00183 newWorld->save(outputWorldFilename, config);
00184
00185 freeMem(fsTab,segServTab,segTab,weightTab,nbStream);
00186 }
00187 catch (Exception& e){
00188 cout << e.toString() << endl;
00189 }
00190 return 0;
00191 }
00192
00193
00194 #endif // !defined(ALIZE_TrainWorld_cpp)