00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055 #if !defined(ALIZE_TrainTarget_cpp)
00056 #define ALIZE_TrainTarget_cpp
00057
00058 #include <iostream>
00059 #include <fstream>
00060 #include <cstdio>
00061 #include <cassert>
00062 #include <cmath>
00063 #include <liatools.h>
00064 #include "TrainTarget.h"
00065 #include "AccumulateJFAStat.h"
00066
00067 using namespace alize;
00068 using namespace std;
00069
00070
00071
00072
00073 int InfoTarget(Config& config)
00074 {
00075 String inputClientListFileName = config.getParam("targetIdList");
00076 bool fixedLabelSelectedFrame;
00077 String labelSelectedFrames;
00078 if (config.existsParam("useIdForSelectedFrame"))
00079 fixedLabelSelectedFrame=false;
00080 else{
00081 labelSelectedFrames=config.getParam("labelSelectedFrames");
00082 if (verbose) cout << "Computing on" << labelSelectedFrames << " label" << endl;
00083 fixedLabelSelectedFrame=true;
00084 }
00085 unsigned long maxFrame=config.getParam("maxFrame").toLong();
00086 String outputFilename=config.getParam("outputFilename");
00087
00088
00089 ofstream outputFile(outputFilename.c_str(),ios::out| ios::trunc);
00090 try{
00091 XList inputClientList(inputClientListFileName,config);
00092 XLine * linep;
00093 if (verbose) cout << "InfoTarget" << endl;
00094
00095 while ((linep=inputClientList.getLine()) != NULL){
00096 String *id=linep->getElement();
00097 outputFile<<*id;
00098 String currentFile="";
00099 XLine featureFileListp=linep->getElements();
00100 if (verbose) cout << "Info model ["<<*id<<"]"<<endl;
00101 if (!fixedLabelSelectedFrame){
00102 labelSelectedFrames=*id;
00103 if (debug) cout <<*id<<" is used for label selected frames"<<endl;
00104 }
00105
00106 SegServer segmentsServer;
00107 LabelServer labelServer;
00108 initializeClusters(featureFileListp,segmentsServer,labelServer,config);
00109 unsigned long codeSelectedFrame=labelServer.getLabelIndexByString(labelSelectedFrames);
00110 SegCluster& selectedSegments=segmentsServer.getCluster(codeSelectedFrame);
00111 Seg *seg;
00112 unsigned long frameCount=0;
00113 selectedSegments.rewind();
00114 while(((seg=selectedSegments.getSeg())!=NULL) && (frameCount<maxFrame)){
00115 frameCount+=seg->length();
00116 cout << seg->sourceName()<<" "<<seg->begin()<<" "<<seg->length()<<" Total time="<<frameCount<<endl;
00117 if (seg->sourceName()!=currentFile){
00118 outputFile<<" "<<seg->sourceName();
00119 currentFile=seg->sourceName();
00120 }
00121 }
00122 outputFile<<endl;
00123 if (verbose) cout << "Save info client ["<<*id<<"]" << endl;
00124 }
00125 }
00126
00127 catch (Exception& e)
00128 {
00129 cout << e.toString().c_str() << endl;
00130 }
00131 return 0;
00132 }
00133
00134
00135
00136
00137
00138 int TrainTarget(Config& config)
00139 {
00140 String inputClientListFileName = config.getParam("targetIdList");
00141 String inputWorldFilename = config.getParam("inputWorldFilename");
00142 String outputSERVERFilename = "";
00143 if (config.existsParam("mixtureServer")) outputSERVERFilename =config.getParam("mixtureServer");
00144 bool initByClient=false;
00145 if (config.existsParam("initByClient")) initByClient=config.getParam("initByClient").toBool();
00146 bool saveEmptyModel=false;
00147 if (config.existsParam("saveEmptyModel")) saveEmptyModel=config.getParam("saveEmptyModel").toBool();
00148
00149 bool fixedLabelSelectedFrame=true;
00150 String labelSelectedFrames;
00151 if (config.existsParam("useIdForSelectedFrame"))
00152 fixedLabelSelectedFrame=(config.getParam("useIdForSelectedFrame").toBool()==false);
00153 if (fixedLabelSelectedFrame)
00154 labelSelectedFrames=config.getParam("labelSelectedFrames");
00155 bool modelData=false;
00156 if (config.existsParam("useModelData")) modelData=config.getParam("useModelData").toBool();
00157 String initModelS=inputWorldFilename;
00158 if (modelData) if (config.existsParam("initModel")) initModelS=config.getParam("initModel");
00159 bool outputAdaptParam=false;
00160 if (config.existsParam("superVector")) outputAdaptParam=true;
00161 bool NAP=false;
00162 Matrix <double> ChannelMatrix;
00163 if (config.existsParam("NAP")) {
00164 if (verbose) cout<< "Removing channel effect with NAP from " << config.getParam("NAP") << " of size: [";
00165 NAP=true;
00166 ChannelMatrix.load(config.getParam("NAP"),config);
00167 if (verbose) cout << ChannelMatrix.rows() << "," <<ChannelMatrix.cols() << "]" << endl;
00168 }
00169
00170
00171 bool saveCompleteServer=false;
00172
00173 try{
00174 XList inputClientList(inputClientListFileName,config);
00175 XLine * linep;
00176 inputClientList.getLine(0);
00177 MixtureServer ms(config);
00178 StatServer ss(config, ms);
00179 if (verbose) cout << "TrainTarget - Load world model [" << inputWorldFilename<<"]"<<endl;
00180 MixtureGD& world = ms.loadMixtureGD(inputWorldFilename);
00181 MixtureGD& initModel =ms.loadMixtureGD(initModelS);
00182
00183 if (verbose) cout <<"Use["<<initModelS<<"] for initializing EM"<<endl;
00184
00185
00186 while ((linep=inputClientList.getLine()) != NULL){
00187 String *id=linep->getElement();
00188 XLine featureFileListp=linep->getElements();
00189 if (verbose) cout << "Train model ["<<*id<<"]"<<endl;
00190 if (!fixedLabelSelectedFrame){
00191 labelSelectedFrames=*id;
00192 if (verbose) cout <<*id<<" is used for label selected frames"<<endl;
00193 }
00194 FeatureServer fs(config,featureFileListp);
00195 SegServer segmentsServer;
00196 LabelServer labelServer;
00197 initializeClusters(featureFileListp,segmentsServer,labelServer,config);
00198 verifyClusterFile(segmentsServer,fs,config);
00199 MixtureGD & adaptedMixture = ms.duplicateMixture(world,DUPL_DISTRIB);
00200 MixtureGD & clientMixture= ms.duplicateMixture(world,DUPL_DISTRIB);
00201 if (initByClient){
00202 clientMixture= ms.loadMixtureGD(*id);
00203 adaptedMixture=clientMixture;
00204 }
00205 long codeSelectedFrame=labelServer.getLabelIndexByString(labelSelectedFrames);
00206 if (codeSelectedFrame==-1){
00207 cout << " WARNING - NO DATA FOR TRAINING ["<<*id<<"]";
00208 if (saveEmptyModel){
00209 cout <<" World model is returned"<<endl;
00210 if (verbose) cout << "Save client model ["<<*id<<"]" << endl;
00211 adaptedMixture.save(*id, config);
00212 }
00213 }
00214 else{
00215 SegCluster& selectedSegments=segmentsServer.getCluster(codeSelectedFrame);
00216 if (!initByClient) ms.setMixtureId(clientMixture,*id);
00217 if (modelData) modelBasedadaptModel(config,ss,ms,fs,selectedSegments,world,clientMixture,initModel);
00218 else adaptModel(config,ss,ms,fs,selectedSegments,world,clientMixture);
00219 if (NAP) {
00220 if (verbose) cout << "NAP on SVs" << endl;
00221 computeNap(clientMixture,ChannelMatrix);
00222 }
00223 if (outputAdaptParam) {
00224 RealVector<double> v;
00225 getSuperVector(v,world,clientMixture,config);
00226 String out=config.getParam("vectorFilesPath")+*id+config.getParam("vectorFilesExtension");
00227 Matrix <double> vv=(Matrix<double>)v;
00228 vv.save(out,config);
00229 }
00230 if (!outputAdaptParam) {
00231 if (verbose) cout << "Save client model ["<<*id<<"]" << endl;
00232 clientMixture.save(*id, config);
00233 }
00234 if (!saveCompleteServer){
00235 long tid=ms.getMixtureIndex(*id);
00236 ms.deleteMixtures(tid,tid);
00237 ms.deleteUnusedDistribs();
00238 }
00239 }
00240 }
00241
00242
00243
00244 }
00245
00246
00247
00248 catch (Exception& e)
00249 {
00250 cout << e.toString().c_str() << endl;
00251 }
00252 return 0;
00253 }
00254
00255
00256
00257
00258
00259
00260
00261 int TrainTargetByLabel(Config& config)
00262 {
00263 String inputClientListFileName = config.getParam("targetIdList");
00264 String inputWorldFilename = config.getParam("inputWorldFilename");
00265 String outputSERVERFilename = config.getParam("mixtureServer");
00266
00267
00268 bool initByClient=false;
00269 bool aprioriWorld=true;
00270 if (config.existsParam("initByClient")) initByClient=true;
00271 if (config.existsParam("aprioriClient")){
00272 aprioriWorld=false;
00273 initByClient=true;
00274 }
00275 bool saveCompleteServer=false;
00276 bool outputAdaptParam=false;
00277 if (config.existsParam("outputAdaptParam")) outputAdaptParam=config.getParam("outputAdaptParam").toBool();
00278
00279 try{
00280 XList inputClientList(inputClientListFileName,config);
00281 XLine *linep;
00282 inputClientList.getLine(0);
00283 MixtureServer ms(config);
00284 StatServer ss(config, ms);
00285 if (verbose) cout << "TrainTarget - by label opption - Load world model [" << inputWorldFilename<<"]"<<endl;
00286 MixtureGD& world = ms.loadMixtureGD(inputWorldFilename);
00287
00288 while ((linep=inputClientList.getLine()) != NULL){
00289 String clientId=(*linep->getElement());
00290 XLine featureFileListp=linep->getElements();
00291 FeatureServer fs(config,featureFileListp);
00292 if (verbose) cout << "Train label models for client ["<<clientId<<"]"<<endl;
00293 MixtureGD &clientGModel=ms.createMixtureGD();
00294 if (initByClient) {
00295 if (verbose) cout << "Load client model [" << clientId <<"]"<<endl;
00296 clientGModel = ms.loadMixtureGD(clientId);
00297 }
00298 SegServer segmentsServer;
00299 LabelServer labelServer;
00300 initializeClusters(featureFileListp,segmentsServer,labelServer,config);
00301 verifyClusterFile(segmentsServer,fs,config);
00302 for (unsigned long codeSelectedFrame=0;codeSelectedFrame<segmentsServer.getClusterCount();codeSelectedFrame++){
00303 String clientIdByLabel=clientId+"_"+labelServer.getLabel(codeSelectedFrame).getString();
00304 if (verbose) cout << "Train labeldependent model ["<<clientIdByLabel<<"]"<<endl;
00305 SegCluster& selectedSegments=segmentsServer.getCluster(codeSelectedFrame);
00306 MixtureGD & clientMixture = ms.duplicateMixture(world,DUPL_DISTRIB);
00307 ms.setMixtureId(clientMixture,clientIdByLabel);
00308 if (initByClient)
00309 clientMixture=clientGModel;
00310 if (aprioriWorld)
00311 adaptModel(config,ss,ms,fs,selectedSegments,world,clientMixture);
00312 else adaptModel(config,ss,ms,fs,selectedSegments,clientGModel,clientMixture);
00313 if (!outputAdaptParam) {
00314 if (verbose) cout << "Save client model ["<<clientIdByLabel<<"]" << endl;
00315 clientMixture.save(clientIdByLabel, config);
00316 }
00317 if (!saveCompleteServer){
00318 long tid=ms.getMixtureIndex(clientIdByLabel);
00319 ms.deleteMixtures(tid,tid);
00320 ms.deleteUnusedDistribs();
00321 }
00322 }
00323 if (!saveCompleteServer){
00324 long tid=ms.getMixtureIndex(clientId);
00325 ms.deleteMixtures(tid,tid);
00326 ms.deleteUnusedDistribs();
00327 }
00328 }
00329
00330
00331
00332 }
00333
00334
00335
00336 catch (Exception& e)
00337 {
00338 cout << e.toString().c_str() << endl;
00339 }
00340 return 0;
00341 }
00342
00343 int TrainTargetFA(Config& config)
00344 {
00345 String inputClientListFileName = config.getParam("targetIdList");
00346 String inputWorldFilename = config.getParam("inputWorldFilename");
00347 String outputSERVERFilename = "";
00348 if (config.existsParam("mixtureServer")) outputSERVERFilename =config.getParam("mixtureServer");
00349 bool initByClient=false;
00350 if (config.existsParam("initByClient")) initByClient=config.getParam("initByClient").toBool();
00351 bool saveEmptyModel=false;
00352 if (config.existsParam("saveEmptyModel")) saveEmptyModel=config.getParam("saveEmptyModel").toBool();
00353
00354 bool fixedLabelSelectedFrame=true;
00355 String labelSelectedFrames;
00356 if (config.existsParam("useIdForSelectedFrame"))
00357 fixedLabelSelectedFrame=(config.getParam("useIdForSelectedFrame").toBool()==false);
00358 if (fixedLabelSelectedFrame)
00359 labelSelectedFrames=config.getParam("labelSelectedFrames");
00360 bool modelData=false;
00361 if (config.existsParam("useModelData")) modelData=config.getParam("useModelData").toBool();
00362 String initModelS=inputWorldFilename;
00363 if (modelData) if (config.existsParam("initModel")) initModelS=config.getParam("initModel");
00364 bool outputAdaptParam=false;
00365 if (config.existsParam("superVectors")) outputAdaptParam=true;
00366 Matrix <double> ChannelMatrix;
00367 if (verbose) cout<< "EigenMAP and Eigenchannel with [" << config.getParam("initChannelMatrix") << "] of size: [";
00368 ChannelMatrix.load(config.getParam("initChannelMatrix"),config);
00369 if (verbose) cout << ChannelMatrix.rows() << "," <<ChannelMatrix.cols() << "]" << endl;
00370 bool varAdapt=false;
00371 if (config.existsParam("FAVarAdapt")) varAdapt=true;
00372 bool saveCompleteServer=false;
00373
00374 try{
00375 XList inputClientList(inputClientListFileName,config);
00376 XLine * linep;
00377 inputClientList.getLine(0);
00378 MixtureServer ms(config);
00379 StatServer ss(config, ms);
00380 if (verbose) cout << "(TrainTarget) Factor Analysis - Load world model [" << inputWorldFilename<<"]"<<endl;
00381 MixtureGD& world = ms.loadMixtureGD(inputWorldFilename);
00382 if (verbose) cout <<"(TrainTarget) Use["<<initModelS<<"] for initializing EM"<<endl;
00383
00384
00385 while ((linep=inputClientList.getLine()) != NULL){
00386
00387 String *id=linep->getElement();
00388 XLine featureFileListp=linep->getElements();
00389 if (verbose) cout << "(TrainTarget) Train model ["<<*id<<"]"<<endl;
00390 FeatureServer fs(config,featureFileListp);
00391 SegServer segmentsServer;
00392 LabelServer labelServer;
00393 initializeClusters(featureFileListp,segmentsServer,labelServer,config);
00394 verifyClusterFile(segmentsServer,fs,config);
00395 MixtureGD & adaptedMixture = ms.duplicateMixture(world,DUPL_DISTRIB);
00396 MixtureGD & clientMixture= ms.duplicateMixture(world,DUPL_DISTRIB);
00397 long codeSelectedFrame=labelServer.getLabelIndexByString(labelSelectedFrames);
00398 if (codeSelectedFrame==-1){
00399 cout << " WARNING - NO DATA FOR TRAINING ["<<*id<<"]";
00400 if (saveEmptyModel){
00401 cout <<" World model is returned"<<endl;
00402 if (verbose) cout << "Save client model ["<<*id<<"]" << endl;
00403 adaptedMixture.save(*id, config);
00404 }
00405 }
00406 else{
00407 SegCluster& selectedSegments=segmentsServer.getCluster(codeSelectedFrame);
00409 XList faNdx;
00410 faNdx.addLine()=featureFileListp;
00411 FactorAnalysisStat FA(faNdx,fs,config);
00412
00413
00414 for(int i=0;i<config.getParam("nbTrainIt").toLong();i++){
00415 if (verbose) cout << "------ Iteration ["<<i<<"] ------"<<endl;
00416 FA.computeAndAccumulateGeneralFAStats(selectedSegments,fs,config);
00417
00418
00419
00420 FA.estimateAndInverseL(config);
00421 FA.substractSpeakerStats();
00422 FA.getXEstimate();
00423 FA.substractChannelStats();
00424 FA.getYEstimate();
00425 }
00426 MixtureGD & sessionMixture= ms.duplicateMixture(world,DUPL_DISTRIB);
00427 bool saveSessionModel=false;
00428 if (config.existsParam("saveSessionModel")) saveSessionModel=true;
00429 if (saveSessionModel) FA.getSessionModel(sessionMixture,linep->getElement(1));
00430 if (!varAdapt) FA.getTrueSpeakerModel(clientMixture,linep->getElement(1));
00431 else FA.getFactorAnalysisModel(clientMixture,linep->getElement(1));
00432 if (verbose) cout << "Final LLK for model["<<*id<<"]="<<FA.getLLK(selectedSegments,clientMixture,fs,config) << endl;
00433
00435 if (!outputAdaptParam) {
00436 if (verbose) cout << "Save client model ["<<*id<<"]" << endl;
00437 clientMixture.save(*id, config);
00438 if (saveSessionModel) {
00439 String sessionfile=*id+".session";
00440 if (verbose) cout << "Save session model ["<<sessionfile<<"]" << endl;
00441 sessionMixture.save(sessionfile,config);
00442 }
00443 }
00444 if (!saveCompleteServer){
00445 long tid=ms.getMixtureIndex(*id);
00446 ms.deleteMixtures(tid,tid);
00447 ms.deleteUnusedDistribs();
00448 }
00449 }
00450 }
00451 }
00452 catch (Exception& e) {cout << e.toString().c_str() << endl;}
00453 return 0;
00454 }
00455
00456
00457 int TrainTargetJFA(Config& config)
00458 {
00459 String inputClientListFileName = config.getParam("targetIdList");
00460 String inputWorldFilename = config.getParam("inputWorldFilename");
00461 String outputSERVERFilename = "";
00462 if (config.existsParam("mixtureServer")) outputSERVERFilename =config.getParam("mixtureServer");
00463 bool initByClient=false;
00464 if (config.existsParam("initByClient")) initByClient=config.getParam("initByClient").toBool();
00465 bool saveEmptyModel=false;
00466 if (config.existsParam("saveEmptyModel")) saveEmptyModel=config.getParam("saveEmptyModel").toBool();
00467
00468 bool fixedLabelSelectedFrame=true;
00469 String labelSelectedFrames;
00470 if (config.existsParam("useIdForSelectedFrame"))
00471 fixedLabelSelectedFrame=(config.getParam("useIdForSelectedFrame").toBool()==false);
00472 if (fixedLabelSelectedFrame)
00473 labelSelectedFrames=config.getParam("labelSelectedFrames");
00474 bool modelData=false;
00475 if (config.existsParam("useModelData")) modelData=config.getParam("useModelData").toBool();
00476 String initModelS=inputWorldFilename;
00477 if (modelData) if (config.existsParam("initModel")) initModelS=config.getParam("initModel");
00478 bool outputAdaptParam=false;
00479 if (config.existsParam("superVectors")) outputAdaptParam=true;
00480
00481 try{
00482 XList inputClientList(inputClientListFileName,config);
00483 XLine * linep;
00484 inputClientList.getLine(0);
00485 MixtureServer ms(config);
00486 StatServer ss(config, ms);
00487 if (verbose) cout << "(TrainTarget) Joint Factor Analysis - Load world model [" << inputWorldFilename<<"]"<<endl;
00488 MixtureGD& world = ms.loadMixtureGD(inputWorldFilename);
00489 if (verbose) cout <<"(TrainTarget) Use["<<initModelS<<"] for initializing EM"<<endl;
00490
00491
00492 Matrix<double> U, V;
00493 DoubleVector D;
00494
00495
00496 if(config.existsParam("eigenChannelMatrix")){
00497 String uName = config.getParam("matrixFilesPath") + config.getParam("eigenChannelMatrix") + config.getParam("loadMatrixFilesExtension");
00498 U.load (uName, config);
00499 if (verboseLevel >=1) cout << "(TrainTargetJFA) Init EC matrix from "<< config.getParam("eigenChannelMatrix") <<" from EigenChannel Matrix: "<<", rank: ["<<U.rows() << "] sv size: [" << U.cols() <<"]"<<endl;
00500 }
00501 else{
00502 unsigned long sS = world.getVectSize() * world.getDistribCount();
00503 U.setDimensions(1,sS);
00504 U.setAllValues(0.0);
00505 if (verboseLevel >1) cout << "(TrainTargetJFA) Init EC matrix to 0"<<endl;
00506 }
00507
00508
00509 if(config.existsParam("eigenVoiceMatrix")){
00510 String vName = config.getParam("matrixFilesPath") + config.getParam("eigenVoiceMatrix") + config.getParam("loadMatrixFilesExtension");
00511 V.load (vName, config);
00512 if (verboseLevel >=1) cout << "(TrainTargetJFA) Init EV matrix from "<< config.getParam("eigenVoiceMatrix") <<" from EigenVoice Matrix: "<<", rank: ["<<V.rows() << "] sv size: [" << V.cols() <<"]"<<endl;
00513 }
00514 else{
00515 unsigned long sS = world.getVectSize() * world.getDistribCount();
00516 V.setDimensions(1,sS);
00517 V.setAllValues(0.0);
00518 if (verboseLevel >=1) cout << "(TrainTargetJFA) Init EV matrix to 0"<<endl;
00519 }
00520
00521
00522 if(config.existsParam("DMatrix")){
00523 String dName = config.getParam("matrixFilesPath") + config.getParam("DMatrix") + config.getParam("loadMatrixFilesExtension");
00524 Matrix<double> tmpD(dName, config);
00525
00526 if( (tmpD.rows() != 1) || ( tmpD.cols() != world.getVectSize()*world.getDistribCount() ) ){
00527 throw Exception("Incorrect dimension of D Matrix",__FILE__,__LINE__);
00528 }
00529 else{
00530 D.setSize(world.getVectSize()*world.getDistribCount());
00531 D.setAllValues(0.0);
00532 for(unsigned long i=0; i<world.getVectSize()*world.getDistribCount(); i++){
00533 D[i] = tmpD(0,i);
00534 }
00535 if (verboseLevel >1) cout << "(TrainTargetJFA) Init D matrix from "<<config.getParam("DMatrix")<<endl;
00536 }
00537 }
00538 else{
00539 unsigned long sS = world.getVectSize() * world.getDistribCount();
00540 D.setSize(sS,sS);
00541 D.setAllValues(0.0);
00542 if (verboseLevel >1) cout << "(TrainTargetJFA) Init D matrix to 0"<<endl;
00543 }
00544
00545
00546 while ((linep=inputClientList.getLine()) != NULL){
00547
00548 String *id=linep->getElement();
00549 XLine featureFileListp=linep->getElements();
00550 if (verbose) cout << "(TrainTarget) Train model ["<<*id<<"]"<<endl;
00551
00552 XList ndx; ndx.addLine() = featureFileListp;
00553 JFAAcc jfaAcc(ndx,config);
00554
00555
00556 jfaAcc.loadEV(V, config); jfaAcc.loadEC(U, config); jfaAcc.loadD(D);
00557
00558
00559 jfaAcc.initVU();
00560
00561 FeatureServer fs(config,featureFileListp);
00562 SegServer segmentsServer;
00563 LabelServer labelServer;
00564 initializeClusters(featureFileListp,segmentsServer,labelServer,config);
00565 verifyClusterFile(segmentsServer,fs,config);
00566
00567 MixtureGD & adaptedMixture = ms.duplicateMixture(world,DUPL_DISTRIB);
00568 MixtureGD & clientMixture= ms.duplicateMixture(world,DUPL_DISTRIB);
00569 long codeSelectedFrame=labelServer.getLabelIndexByString(labelSelectedFrames);
00570 if (codeSelectedFrame==-1){
00571 cout << " WARNING - NO DATA FOR TRAINING ["<<*id<<"]";
00572 if (saveEmptyModel){
00573 cout <<" World model is returned"<<endl;
00574 if (verbose) cout << "Save client model ["<<*id<<"]" << endl;
00575 adaptedMixture.save(*id, config);
00576 }
00577 }
00578
00579 else{
00580 SegCluster& selectedSegments=segmentsServer.getCluster(codeSelectedFrame);
00581
00582
00583 jfaAcc.computeAndAccumulateJFAStat(selectedSegments,fs,config);
00584
00585
00586 jfaAcc.storeAccs();
00587 jfaAcc.estimateVUEVUT(config);
00588 jfaAcc.estimateAndInverseL_VU(config);
00589 jfaAcc.substractMplusDZ(config);
00590 jfaAcc.estimateYX();
00591
00592 jfaAcc.resetTmpAcc();
00593 jfaAcc.restoreAccs();
00594
00595
00596 jfaAcc.splitYX();
00597
00598
00599 jfaAcc.substractMplusVUYX();
00600
00601 jfaAcc.estimateZ();
00602
00603 jfaAcc.resetTmpAcc();
00604 jfaAcc.restoreAccs();
00605
00606 bool varAdapt = false;
00607 if((config.existsParam("varAdapt")) && ( config.getParam("varAdapt").toBool() )){
00608 varAdapt = true;
00609 }
00610
00611 DoubleVector clientSV(jfaAcc.getSvSize(), jfaAcc.getSvSize());
00612 clientSV.setSize(jfaAcc.getSvSize());
00613 DoubleVector clientModel(jfaAcc.getSvSize(), jfaAcc.getSvSize());
00614 clientModel.setSize(jfaAcc.getSvSize());
00615
00616 bool saveMixture = true;
00617 if((config.existsParam("saveMixture")) && !( config.getParam("saveMixture").toBool() )) saveMixture = false;
00618 bool saveSuperVector = true;
00619 if((config.existsParam("saveSuperVector")) && !( config.getParam("saveSuperVector").toBool() )) saveSuperVector = false;
00620
00621 jfaAcc.getVYplusDZ(clientSV, 0);
00622 jfaAcc.getMplusVYplusDZ(clientModel, 0);
00623
00624
00625 for(unsigned long i=0; i<jfaAcc.getSvSize(); i++){
00626 clientSV[i] *= jfaAcc.getUbmInvVar()[i];
00627 }
00628
00629
00630 if(saveMixture){
00631 svToModel(clientModel, clientMixture);
00632 clientMixture.save(*id, config);
00633 }
00634
00635 if(saveSuperVector){
00636 String svPath=config.getParam("vectorFilesPath");
00637 String svExt=config.getParam("vectorFilesExtension");
00638 String svFile=svPath+*id+svExt;
00639 ((Matrix<double>)clientSV).save(svFile,config);
00640 }
00641
00642 long tid=ms.getMixtureIndex(*id);
00643 ms.deleteMixtures(tid,tid);
00644 ms.deleteUnusedDistribs();
00645 }
00646 }
00647 }
00648 catch (Exception& e) {cout << e.toString().c_str() << endl;}
00649 return 0;
00650 }
00651
00652
00653 int TrainTargetLFA(Config& config)
00654 {
00655 String inputClientListFileName = config.getParam("targetIdList");
00656 String inputWorldFilename = config.getParam("inputWorldFilename");
00657 String outputSERVERFilename = "";
00658 if (config.existsParam("mixtureServer")) outputSERVERFilename =config.getParam("mixtureServer");
00659 bool initByClient=false;
00660 if (config.existsParam("initByClient")) initByClient=config.getParam("initByClient").toBool();
00661 bool saveEmptyModel=false;
00662 if (config.existsParam("saveEmptyModel")) saveEmptyModel=config.getParam("saveEmptyModel").toBool();
00663
00664 bool fixedLabelSelectedFrame=true;
00665 String labelSelectedFrames;
00666 if (config.existsParam("useIdForSelectedFrame"))
00667 fixedLabelSelectedFrame=(config.getParam("useIdForSelectedFrame").toBool()==false);
00668 if (fixedLabelSelectedFrame)
00669 labelSelectedFrames=config.getParam("labelSelectedFrames");
00670 bool modelData=false;
00671 if (config.existsParam("useModelData")) modelData=config.getParam("useModelData").toBool();
00672 String initModelS=inputWorldFilename;
00673 if (modelData) if (config.existsParam("initModel")) initModelS=config.getParam("initModel");
00674 bool outputAdaptParam=false;
00675 if (config.existsParam("superVectors")) outputAdaptParam=true;
00676
00677 try{
00678 XList inputClientList(inputClientListFileName,config);
00679 XLine * linep;
00680 inputClientList.getLine(0);
00681 MixtureServer ms(config);
00682 StatServer ss(config, ms);
00683 if (verbose) cout << "(TrainTarget) Joint Factor Analysis - Load world model [" << inputWorldFilename<<"]"<<endl;
00684 MixtureGD& world = ms.loadMixtureGD(inputWorldFilename);
00685 if (verbose) cout <<"(TrainTarget) Use["<<initModelS<<"] for initializing EM"<<endl;
00686
00687
00688 unsigned long svsize=world.getDistribCount()*world.getVectSize();
00689 Matrix<double> U, V;
00690 DoubleVector D(svsize,svsize);
00691
00692
00693 if(config.existsParam("eigenChannelMatrix")){
00694 String uName = config.getParam("matrixFilesPath") + config.getParam("eigenChannelMatrix") + config.getParam("loadMatrixFilesExtension");
00695 U.load (uName, config);
00696 if (verboseLevel >=1) cout << "(TrainTargetLFA) Init EC matrix from "<< config.getParam("eigenChannelMatrix") <<" from EigenChannel Matrix: "<<", rank: ["<<U.rows() << "] sv size: [" << U.cols() <<"]"<<endl;
00697 }
00698 else{
00699 U.setDimensions(1,svsize);
00700 U.setAllValues(0.0);
00701 if (verboseLevel >1) cout << "(TrainTargetLFA) Init EC matrix to 0"<<endl;
00702 }
00703
00704 V.setDimensions(1,svsize);
00705 V.setAllValues(0.0);
00706 if (verboseLevel >=1) cout << "(TrainTargetLFA) Init EV matrix to 0"<<endl;
00707
00708
00709 for(unsigned long i=0; i<world.getDistribCount(); i++){
00710 for(unsigned long j = 0; j<world.getVectSize(); j++){
00711 D[i*world.getVectSize()+j] = sqrt(1.0/(world.getDistrib(i).getCovInv(j)*config.getParam("regulationFactor").toDouble()));
00712 }
00713 }
00714
00715
00716 while ((linep=inputClientList.getLine()) != NULL){
00717
00718 String *id=linep->getElement();
00719 XLine featureFileListp=linep->getElements();
00720 if (verbose) cout << "(TrainTargetLFA) Train model ["<<*id<<"]"<<endl;
00721
00722 XList ndx; ndx.addLine() = featureFileListp;
00723 JFAAcc jfaAcc(ndx,config);
00724
00725
00726 jfaAcc.loadEV(V, config); jfaAcc.loadEC(U, config); jfaAcc.loadD(D);
00727
00728
00729 jfaAcc.initVU();
00730
00731 FeatureServer fs(config,featureFileListp);
00732 SegServer segmentsServer;
00733 LabelServer labelServer;
00734 initializeClusters(featureFileListp,segmentsServer,labelServer,config);
00735 verifyClusterFile(segmentsServer,fs,config);
00736
00737 MixtureGD & adaptedMixture = ms.duplicateMixture(world,DUPL_DISTRIB);
00738 MixtureGD & clientMixture= ms.duplicateMixture(world,DUPL_DISTRIB);
00739 long codeSelectedFrame=labelServer.getLabelIndexByString(labelSelectedFrames);
00740 if (codeSelectedFrame==-1){
00741 cout << " WARNING - NO DATA FOR TRAINING ["<<*id<<"]";
00742 if (saveEmptyModel){
00743 cout <<" World model is returned"<<endl;
00744 if (verbose) cout << "Save client model ["<<*id<<"]" << endl;
00745 adaptedMixture.save(*id, config);
00746 }
00747 }
00748
00749 else{
00750 SegCluster& selectedSegments=segmentsServer.getCluster(codeSelectedFrame);
00751
00752
00753 jfaAcc.computeAndAccumulateJFAStat(selectedSegments,fs,config);
00754
00755
00756 jfaAcc.storeAccs();
00757 jfaAcc.estimateVUEVUT(config);
00758 jfaAcc.estimateAndInverseL_VU(config);
00759 jfaAcc.substractMplusDZ(config);
00760 jfaAcc.estimateYX();
00761
00762 jfaAcc.resetTmpAcc();
00763 jfaAcc.restoreAccs();
00764
00765
00766 jfaAcc.splitYX();
00767
00768
00769 jfaAcc.substractMplusVUYX();
00770
00771 double tau = config.getParam("regulationFactor").toLong();
00772 jfaAcc.estimateZMAP(tau);
00773
00774 jfaAcc.resetTmpAcc();
00775 jfaAcc.restoreAccs();
00776
00777 bool varAdapt = false;
00778 if((config.existsParam("varAdapt")) && ( config.getParam("varAdapt").toBool() )){
00779 varAdapt = true;
00780 }
00781
00782 DoubleVector clientModel(jfaAcc.getSvSize(), jfaAcc.getSvSize());
00783 clientModel.setSize(jfaAcc.getSvSize());
00784
00785 jfaAcc.getMplusVYplusDZ(clientModel, 0);
00786
00787
00788 svToModel(clientModel, clientMixture);
00789 clientMixture.save(*id, config);
00790
00791 long tid=ms.getMixtureIndex(*id);
00792 ms.deleteMixtures(tid,tid);
00793 ms.deleteUnusedDistribs();
00794 }
00795 }
00796 }
00797 catch (Exception& e) {cout << e.toString().c_str() << endl;}
00798 return 0;
00799 }
00800
00801 #endif //!defined(ALIZE_TrainTarget_cpp)