00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055 #if !defined(ALIZE_StatServer_cpp)
00056 #define ALIZE_StatServer_cpp
00057
00058 #include <new>
00059 #include <cmath>
00060 #if defined(_WIN32)
00061 #include <cfloat>
00062 #define ISNAN(x) _isnan(x)
00063
00064
00065 #elif defined(linux) || defined(__linux) || defined(__CYGWIN__) || defined(__APPLE__)
00066 #define ISNAN(x) isnan(x)
00067 #else
00068 #error "Unsupported OS\n"
00069 #endif
00070
00071 #include "MixtureServer.h"
00072 #include "StatServer.h"
00073 #include "MixtureStat.h"
00074 #include "MixtureGDStat.h"
00075 #include "MixtureGFStat.h"
00076 #include "Mixture.h"
00077 #include "Exception.h"
00078 #include "Config.h"
00079 #include "RealVector.h"
00080 #include "ULongVector.h"
00081 #include "ViterbiAccum.h"
00082 #include "FrameAccGD.h"
00083 #include "FrameAccGF.h"
00084
00085 using namespace alize;
00086 using namespace std;
00087
00088 #include <cstdio>
00089 #include <iostream>
00090 using namespace std;
00091
00092
00093 typedef StatServer S;
00094
00095 S::StatServer(const Config& c)
00096 :Object(), _config(c), _pMixtureServer(NULL),
00097 _topDistribsVect(0, 0), _minLLK(c.getParam_minLLK()),
00098 _maxLLK(c.getParam_maxLLK()){
00099 reset();
00100 }
00101
00102 S::StatServer(const Config& c, MixtureServer& ms)
00103 :Object(), _config(c), _pMixtureServer(&ms),
00104 _topDistribsVect(0, 0), _minLLK(c.getParam_minLLK()),
00105 _maxLLK(c.getParam_maxLLK())
00106
00107 { reset(); }
00108
00109 void S::reset()
00110 {
00111 _mixtureStatVect.deleteAllObjects();
00112 _viterbiAccumVect.deleteAllObjects();
00113 _pLastMixture = NULL;
00114 _pLastMixtureStat = NULL;
00115 _topDistribsVect.clear();
00116 }
00117
00118 real_t S::getAccumulatedOccFeatureCount(const Mixture& m)
00119 { return getMixtureStat(m).getAccumulatedOccFeatureCount(); }
00120
00121 void S::resetLLK(const Mixture& m) { getMixtureStat(m).resetLLK(); }
00122
00123 lk_t S::computeAndAccumulateLLK(const Mixture& m, const Feature& f,
00124 const TopDistribsAction& a)
00125 { return computeAndAccumulateLLK(m, f, (double)1.0f, a); }
00126
00127 lk_t S::computeAndAccumulateLLK(const Mixture& m, const Feature& f,
00128 double w, const TopDistribsAction& a)
00129 { return getMixtureStat(m).computeAndAccumulateLLK(f, w, a); }
00130
00131 void S::accumulateLLK(const Mixture& m, lk_t v, double w)
00132 { getMixtureStat(m).accumulateLLK(v, w); }
00133
00134 const LKVector& S::getTopDistribIndexVector() const
00135 { return _topDistribsVect; }
00136
00137 const DoubleVector& S::getDistribLKVector(const K&) const
00138 { return _distribLKVect; }
00139
00140 LKVector& S::getTopDistribIndexVector(const K&) { return _topDistribsVect; }
00141
00142 lk_t S::computeAndAccumulateLLK(const Mixture& m)
00143 { return getMixtureStat(m).computeAndAccumulateLLK(); }
00144
00145 lk_t S::getLLK(const Mixture& m) { return getMixtureStat(m).getLLK(); }
00146
00147 lk_t S::getMeanLLK(const Mixture& m) { return getMixtureStat(m).getMeanLLK(); }
00148
00149 lk_t S::computeLLK(const Mixture& m, const Feature& f) const
00150 {
00151 lk_t lk = 0.0;
00152 weight_t* w = m.getTabWeight().getArray();
00153 Distrib** d = m.getTabDistrib();
00154 unsigned long distribCount = m.getDistribCount();
00155 for (unsigned long c=0; c<distribCount; c++) {
00156 lk += w[c] * d[c]->computeLK(f);
00157 }
00158 return computeLLK(lk);
00159 }
00160
00161 lk_t S::computeLLK(const Mixture& m, const Feature& f, unsigned long idx) const
00162 {
00163 lk_t lk = 0.0;
00164 weight_t* w = m.getTabWeight().getArray();
00165 Distrib** d = m.getTabDistrib();
00166 unsigned long distribCount = m.getDistribCount();
00167 for (unsigned long c=0; c<distribCount; c++)
00168 lk += w[c] * d[c]->computeLK(f, idx);
00169 return computeLLK(lk);
00170 }
00171
00172 lk_t S::computeLLK(const K&, const Mixture& m) const
00173 {
00174 const weight_t* weightVect = m.getTabWeight().getArray();
00175 Distrib** distribVect = m.getTabDistrib();
00176 const lk_t* lkVect = _distribLKVect.getArray();
00177 unsigned long distribCount = m.getDistribCount();
00178
00179 lk_t lk = 0.0;
00180 for (unsigned long c=0; c<distribCount; c++)
00181 lk += lkVect[distribVect[c]->dictIndex(K::k)] * weightVect[c];
00182 return computeLLK(lk);
00183 }
00184
00185 lk_t S::computeLLK(lk_t lk) const
00186 {
00187 if ( ISNAN(lk) || lk == 0 || lk<_minLLK )
00188 lk = _minLLK;
00189 else
00190 {
00191 lk = log(lk);
00192 if (lk > _maxLLK)
00193 lk = _maxLLK;
00194 }
00195 return lk;
00196 }
00197
00198 lk_t S::computeLLK(const K&, const Mixture& m, const Feature& f,
00199 const TopDistribsAction& a)
00200 {
00201 if (a == TOP_DISTRIBS_NO_ACTION)
00202 return computeLLK(m, f);
00203
00204 LKVector& lkVect = _topDistribsVect;
00205 lk_t lk = 0.0;
00206 weight_t* w = m.getTabWeight().getArray();
00207 Distrib** d = m.getTabDistrib();
00208 unsigned long distribCount = m.getDistribCount();
00209 unsigned long c, i, nTop = _config.getParam_topDistribsCount();
00210
00211 if (a == USE_TOP_DISTRIBS)
00212 {
00213 if (nTop >= distribCount)
00214 nTop = distribCount;
00215 if (distribCount != lkVect.size())
00216 throw Exception("", __FILE__, __LINE__);
00217 LKVector::type* v = lkVect.getArray();
00218 real_t sumTopDistribWeights = 0.0;
00219
00220 for (i=0; i<nTop; i++)
00221 {
00222 c = v[i].idx;
00223 sumTopDistribWeights += w[c];
00224
00225 lk +=(v[c].lk =(w[c] * d[c]->computeLK(f)));
00226 }
00227 if (_config.getParam_computeLLKWithTopDistribs())
00228 lk += lkVect.sumNonTopDistribLK *
00229 (1.0 - sumTopDistribWeights) / lkVect.sumNonTopDistribWeights;
00230 else
00231 if (nTop != 0)
00232 lk /= sumTopDistribWeights;
00233 return computeLLK(lk);
00234 }
00235
00236 lkVect.setSize(distribCount);
00237 LKVector::type* v = lkVect.getArray();
00238 lkVect.topDistribsCount = nTop;
00239
00240 for (c=0; c<distribCount; c++)
00241 {
00242 v[c].idx = c;
00243 lk += (v[c].lk = w[c] * d[c]->computeLK(f));
00244 }
00245 lkVect.descendingSort();
00246
00247 if (_config.getParam_computeLLKWithTopDistribs() == true)
00248 {
00249 real_t sumTopDistribWeights = 0.0;
00250 real_t sumTopDistribLK = 0.0;
00251 for (i=0; i<nTop; i++)
00252 {
00253 sumTopDistribWeights += w[v[i].idx];
00254 sumTopDistribLK += v[i].lk;
00255 }
00256 lkVect.sumNonTopDistribWeights = 1.0 - sumTopDistribWeights;
00257 lkVect.sumNonTopDistribLK = lk - sumTopDistribLK;
00258 if (lkVect.sumNonTopDistribLK < EPS_LK)
00259 lkVect.sumNonTopDistribLK = EPS_LK;
00260 }
00261 return computeLLK(lk);
00262 }
00263
00264 lk_t S::computeLLK(const K&, const Mixture& m, const Feature& f,
00265 const LKVector& lkVect)
00266 {
00267 lk_t lk = 0.0;
00268 weight_t* w = m.getTabWeight().getArray();
00269 Distrib** d = m.getTabDistrib();
00270 unsigned long distribCount = m.getDistribCount();
00271 unsigned long c, i, nTop = _config.getParam_topDistribsCount();
00272
00273
00274 if (nTop >= distribCount)
00275 nTop = distribCount;
00276 LKVector::type* v = lkVect.getArray();
00277 real_t sumTopDistribWeights = 0.0;
00278
00279 for (i=0; i<nTop; i++)
00280 {
00281 c = v[i].idx;
00282 sumTopDistribWeights += w[c];
00283 lk += w[c] * d[c]->computeLK(f);
00284 }
00285 if (_config.getParam_computeLLKWithTopDistribs()) {
00286 lk += lkVect.sumNonTopDistribLK *
00287 (1.0 - sumTopDistribWeights) / lkVect.sumNonTopDistribWeights;
00288 }
00289
00290
00291
00292
00293
00294
00295 return computeLLK(lk);
00296 }
00297
00298 void S::computeAllDistribLK(const Feature& f)
00299 {
00300 if (_pMixtureServer == NULL)
00301 throw Exception("No mixture server connected to this stat server"
00302 , __FILE__, __LINE__);
00303 _distribLKVect.clear();
00304 unsigned long n = _pMixtureServer->getDistribCount();
00305 for (unsigned long i=0; i<n; i++)
00306 _distribLKVect.addValue(_pMixtureServer->getDistrib(i).computeLK(f));
00307 }
00308
00309 void S::resetOcc(const Mixture& m) { getMixtureStat(m).resetOcc(); }
00310
00311 real_t S::computeAndAccumulateOcc(const Mixture& m, const Feature& f)
00312 { return getMixtureStat(m).computeAndAccumulateOcc(f); }
00313
00314 occ_t* S::getMeanOccVect(const Mixture& m)
00315 { return getMixtureStat(m).getMeanOccVect().getArray(); }
00316
00317 occ_t* S::getAccumulatedOccVect(const Mixture& m)
00318 { return getMixtureStat(m).getAccumulatedOccVect().getArray(); }
00319
00320 occ_t S::getAccumulatedOcc(const Mixture& m)
00321 { return getMixtureStat(m).getAccumulatedOcc(); }
00322
00323 void S::setTopDistribIndexVector(const ULongVector& indexVect,
00324 real_t w, real_t l)
00325 {
00326 unsigned long topDistribCount = indexVect.size();
00327 if (topDistribCount>_topDistribsVect.size())
00328 throw Exception("", __FILE__, __LINE__);
00329 LKVector::type* v = _topDistribsVect.getArray();
00330 for (unsigned long i=0; i<topDistribCount; i++)
00331 v[i].idx = indexVect[i];
00332 _topDistribsVect.sumNonTopDistribWeights = w;
00333 _topDistribsVect.sumNonTopDistribLK = l;
00334 }
00335
00336 MixtureStat& S::getMixtureStat(const Mixture& m)
00337 {
00338
00339 if (&m == _pLastMixture)
00340 {
00341 assert(_pLastMixtureStat != NULL);
00342 return *_pLastMixtureStat;
00343 }
00344 const unsigned long n = _mixtureStatVect.size();
00345 for (unsigned long i=0; i<n; i++)
00346 {
00347 MixtureStat& ms = _mixtureStatVect.getObject(i);
00348 if (ms.getMixture().isSameObject(m))
00349 {
00350 _pLastMixture = &m;
00351 _pLastMixtureStat = &ms;
00352 return ms;
00353 }
00354 }
00355 return createAndStoreMixtureStat(m);
00356 }
00357
00358 MixtureStat& S::createAndStoreMixtureStat(const Mixture& m)
00359 {
00360 MixtureStat& ms = m.createNewMixtureStatObject(K::k,*this, _config);
00361 _mixtureStatVect.addObject(ms);
00362 _pLastMixture = &m;
00363 _pLastMixtureStat = &ms;
00364 return ms;
00365 }
00366
00367 MixtureGDStat& S::createAndStoreMixtureStat(MixtureGD& m)
00368 {
00369 return static_cast<MixtureGDStat&>(
00370 createAndStoreMixtureStat(static_cast<Mixture&>(m)));
00371 }
00372
00373 MixtureGFStat& S::createAndStoreMixtureStat(MixtureGF& m)
00374 {
00375 return static_cast<MixtureGFStat&>(
00376 createAndStoreMixtureStat(static_cast<Mixture&>(m)));
00377 }
00378
00379 MixtureGDStat& S::createAndStoreMixtureGDStat(Mixture& m)
00380 {
00381 MixtureGD* p = dynamic_cast<MixtureGD*>(&m);
00382 if (p == NULL)
00383 throw Exception("Wrong mixture type", __FILE__, __LINE__);
00384 return createAndStoreMixtureStat(*p);
00385 }
00386
00387 MixtureGFStat& S::createAndStoreMixtureGFStat(Mixture& m)
00388 {
00389 MixtureGF* p = dynamic_cast<MixtureGF*>(&m);
00390 if (p == NULL)
00391 throw Exception("Wrong mixture type", __FILE__, __LINE__);
00392 return createAndStoreMixtureStat(*p);
00393 }
00394
00395 unsigned long S::getMixtureStatCount() const
00396 { return _mixtureStatVect.size(); }
00397
00398 MixtureStat& S::getMixtureStat(unsigned long idx)
00399 { return _mixtureStatVect.getObject(idx); }
00400
00401 MixtureGDStat& S::getMixtureGDStat(unsigned long idx)
00402 {
00403 MixtureGDStat* p = dynamic_cast<MixtureGDStat*>(
00404 &_mixtureStatVect.getObject(idx));
00405 if (p == NULL)
00406 throw Exception("No mixtureGDStat object for index "+String::valueOf(idx),
00407 __FILE__, __LINE__);
00408 return *p;
00409 }
00410
00411 MixtureGFStat& S::getMixtureGFStat(unsigned long idx)
00412 {
00413 MixtureGFStat* p = dynamic_cast<MixtureGFStat*>(
00414 &_mixtureStatVect.getObject(idx));
00415 if (p == NULL)
00416 throw Exception("No mixtureGFStat object for index "+String::valueOf(idx),
00417 __FILE__, __LINE__);
00418 return *p;
00419 }
00420
00421 void S::deleteMixtureStat(MixtureStat& m)
00422 {
00423 delete &_mixtureStatVect.removeObject(m);
00424 _pLastMixtureStat = NULL;
00425 }
00426
00427 void S::deleteMixtureStat(unsigned long b, unsigned long e)
00428 {
00429 _mixtureStatVect.removeObjects(b, e, DELETE);
00430 _pLastMixtureStat = NULL;
00431 }
00432
00433 void S::deleteAllMixtureStat()
00434 {
00435 _mixtureStatVect.deleteAllObjects();
00436 _pLastMixtureStat = NULL;
00437 }
00438
00439 unsigned long S::getMixtureStatIndex(MixtureStat& m) const
00440 { return _mixtureStatVect.getObjectIndex(m); }
00441
00442
00443
00444
00445
00446
00447
00448 void S::resetEM(const Mixture& m)
00449 { getMixtureStat(m).resetEM(); }
00450
00451 occ_t S::computeAndAccumulateEM(const Mixture& m, const Feature& f)
00452 { return getMixtureStat(m).computeAndAccumulateEM(f); }
00453
00454 const Mixture& S::getEM(const Mixture& m)
00455 { return getMixtureStat(m).getEM(); }
00456
00457
00458
00459
00460
00461
00462 ViterbiAccum& S::createViterbiAccum()
00463 {
00464 ViterbiAccum& va = ViterbiAccum::create(*this, _config, K::k);
00465 _viterbiAccumVect.addObject(va);
00466 return va;
00467 }
00468
00469
00470
00471
00472 FrameAccGD S::createFrameAccGD()
00473 { return FrameAccGD(); }
00474
00475 FrameAccGF S::createFrameAccGF()
00476 { return FrameAccGF(); }
00477
00478
00479
00480
00481
00482
00483
00484 const String& S::getServerName() const { return _serverName; }
00485
00486 void S::setServerName(const String& s) { _serverName = s; }
00487
00488 String S::getClassName() const { return "StatServer"; }
00489
00490 String S::toString() const
00491 {
00492 String s(Object::toString()
00493 + "\n serverName = '" + _serverName + "'");
00494 if (_pMixtureServer != NULL)
00495 return s
00496 + "\n mixture server = [ " + _pMixtureServer->getClassName()
00497 + " " + _pMixtureServer->getAddress() + " ] name = '"
00498 + _pMixtureServer->getServerName() + "'";
00499 else
00500 return s + "\n mixture server = [NULL]";
00501 }
00502
00503 S::~StatServer()
00504 {
00505
00506 _viterbiAccumVect.deleteAllObjects();
00507 }
00508
00509
00510 #endif // !defined(ALIZE_StatServer_cpp)