/* ************************************ * Author: M. Babai (M.Babai@rug.nl) * * * * pid classifier * * * * Created: 23-03-2010 * * Modified: * * * * ************************************/ //=================== #include "PndPidMvaAssociatorTask.h" // Standard C++ includes #include // Root includes. #include "TClonesArray.h" // PANDA and Fair includes. #include "FairTask.h" #include "FairRootManager.h" #include "PndPidCandidate.h" #include "PndPidProbability.h" // MVA Headers. #include "PndMvaClassifier.h" #include "PndKnnClassify.h" #include "PndLVQClassify.h" #include "PndMultiClassMlpClassify.h" #include "PndMultiClassBdtClassify.h" ClassImp(PndPidMvaAssociatorTask) //==================== #define PIDMVA_ASSOCIATORT_DEBUG 0 //========================================================== #if (PIDMVA_ASSOCIATORT_DEBUG != 0) // Function to use for debugging void printResult(std::map& res) { std::cout << "\n\t================================== \n"; for( std::map::iterator ii=res.begin(); ii != res.end(); ++ii) { std::cout << "\t" << (*ii).first << "\t=> " << (*ii).second << '\n'; } std::cout << "\n\t================================== \n"; } #endif //========================================================== /** * Default Constructor. */ PndPidMvaAssociatorTask::PndPidMvaAssociatorTask() : FairTask("PndPidMvaAssociatorTaskSTD"), fManager(0), fVarNames (std::vector()), fClassNames(std::vector()), fWeightsFileName(std::string(getenv("VMCWORKDIR")) + std::string("/PndTools/MVA/PndMVAWeights/")), fNumNeigh(200), fScFact(0.8), fWeight(1.00), fClassifier(0), fMethodType(UNKNOWN_METHOD), fPidChargedCand(0), fPidChargedProb(new TClonesArray("PndPidProbability")), fMCTrack(0), fMethodName("UNKNOWN_METHOD") { // Init neutral probab. containers. // fPidNeutralProb = new TClonesArray("PndPidProbability"); // Set Default path to the weight file // SetDefaultWeightsPath(); } //___________________________________________________________ /** * Constructor. */ PndPidMvaAssociatorTask::PndPidMvaAssociatorTask(char const* name) : FairTask(name), fManager(0), fVarNames (std::vector()), fClassNames(std::vector()), fWeightsFileName(std::string(getenv("VMCWORKDIR")) + std::string("/PndTools/MVA/PndMVAWeights/")), fNumNeigh(200), fScFact(0.8), fWeight(1.00), fClassifier(0), fMethodType(UNKNOWN_METHOD), fPidChargedCand(0), fPidChargedProb(new TClonesArray("PndPidProbability")), fMCTrack(0), fMethodName("UNKNOWN_METHOD") { // Init neutral probab. containers. // fPidNeutralProb = new TClonesArray("PndPidProbability"); // Set Default path to the weight file //SetDefaultWeightsPath(); } /* * Set the default path where the weights are stored. */ void PndPidMvaAssociatorTask::SetDefaultWeightsPath() { fWeightsFileName = std::string(getenv("VMCWORKDIR")); fWeightsFileName += std::string("/PndTools/MVA/PndMVAWeights/"); std::cout<<" Default Weights path is set to " << fWeightsFileName << '\n'; } //___________________________________________________________ /** * Destructor. */ PndPidMvaAssociatorTask::~PndPidMvaAssociatorTask() { // Clean-up allocated stuff. if(fManager) { fManager->Write(); delete fManager; } if(fPidChargedCand) { delete fPidChargedCand; } if(fPidChargedProb) { delete fPidChargedProb; } //if(fPidNeutralCand) //delete fPidNeutralCand; //if(fPidNeutralProb) //delete fPidNeutralProb; if(fMCTrack) { delete fMCTrack; } if(fClassifier) { delete fClassifier; } } //___________________________________________________________ InitStatus PndPidMvaAssociatorTask::Init() { std::cout << "<-I-> InitStatus PndPidMvaAssociatorTask::Init()\n"; fManager = FairRootManager::Instance(); if( !fManager ) { std::cerr << " PndPidMvaAssociatorTask::Init:\n" << "\t Could not init FairRootManager." << std::endl; return kERROR; } // Get charged candidates. fPidChargedCand = (TClonesArray *)fManager->GetObject("PidChargedCand"); if ( !fPidChargedCand) { std::cerr << " PndPidMvaAssociatorTask::Init: No PidChargedCand there!" << std::endl; return kERROR; } // Get Neutral candidates. /* fPidNeutralCand = (TClonesArray *)fManager->GetObject("PidNeutralCand"); if ( ! fPidNeutralCand) { std::cerr << " PndPidMvaAssociatorTask::Init: No PidNeutralCand there!" << std::endl; return kERROR; } */ std::cout << " Using weight file " << fWeightsFileName << '\n'; // Init Classifier object switch(fMethodType) { case TMVA_MLP:// Multi label MLP classifier from TMVA. { PndMultiClassMlpClassify* TmvaMlpCls = new PndMultiClassMlpClassify(fWeightsFileName, fClassNames, fVarNames); if(!TmvaMlpCls) { std::cerr << " Failed to initialize TMVA_MLP classifier." << std::endl; return kERROR; } // Init TmvaMlpCls->Initialize(); //fClassifier = dynamic_cast(TmvaMlpCls); fClassifier = TmvaMlpCls; std::cout << " TMVA_MLP initialized using " << fWeightsFileName << '\n'; fMethodName = "TMVAMLP"; } break; case TMVA_BDT:// Multi label BDT classifier from TMVA. { PndMultiClassBdtClassify* TmvaBdtCls = new PndMultiClassBdtClassify(fWeightsFileName, fClassNames, fVarNames); if(!TmvaBdtCls) { std::cerr << " Failed to initialize TMVA_BDT classifier." << std::endl; return kERROR; } // INIT TmvaBdtCls->Initialize(); //fClassifier = dynamic_cast(TmvaBdtCls); fClassifier = TmvaBdtCls; std::cout << " TMVA_BDT initialized using " << fWeightsFileName << '\n'; fMethodName = "TMVABDT"; } break; case LVQ: { PndLVQClassify* LvqCls = new PndLVQClassify(fWeightsFileName, fClassNames, fVarNames); if(!LvqCls) { std::cerr << " Failed to initialize LVQ classifier." << std::endl; return kERROR; } // Init LvqCls->Initialize(); //fClassifier = dynamic_cast(LvqCls); fClassifier = LvqCls; std::cout << " LVQ initialized using " << fWeightsFileName << '\n'; fMethodName = "LVQ"; } break; case KNN: default: { PndKnnClassify* KnnCls = new PndKnnClassify(fWeightsFileName, fClassNames, fVarNames); if(!KnnCls) { std::cerr << " Failed to initialize KNN classifier." << std::endl; return kERROR; } // Set parameters. KnnCls->SetEvtParam(fScFact, fWeight); KnnCls->SetKnn(fNumNeigh); KnnCls->Initialize(); //fClassifier = dynamic_cast(KnnCls); fClassifier = KnnCls; std::cout << " KNN initialized using " << fWeightsFileName << '\n'; fMethodName = "KNN"; } break; }// End of switch(fMethodType) // Register objects in the output chain Register(); std::cout << " PndPidMvaAssociatorTask::Init: Success!\n"; return kSUCCESS; } //______________________________________________________ void PndPidMvaAssociatorTask::SetParContainers() {} void PndPidMvaAssociatorTask::SetClassifier(std::string const& methodNameStr) { if(methodNameStr == "KNN") { fMethodType = KNN; fMethodName = "KNN"; } else if(methodNameStr == "LVQ") { fMethodType = LVQ; fMethodName = "LVQ"; } else if(methodNameStr == "TMVA_MLP") { fMethodType = TMVA_MLP; fMethodName = "TMVAMLP"; } else if(methodNameStr == "TMVA_BDT") { fMethodType = TMVA_BDT; fMethodName = "TMVABDT"; } else { std::cerr << " Unknown Method." << std::endl; } }; //______________________________________________________ void PndPidMvaAssociatorTask::Exec(Option_t* option) { std::cout << option << '\n'; if (fPidChargedProb->GetEntriesFast() != 0) { fPidChargedProb->Delete(); } #if ( PIDMVA_ASSOCIATORT_DEBUG != 0 ) std::cout << " Call to Exec with options = " << option << "___\n"; #endif if(fVerbose > 1) { std::cout << "-I- Start PndPidMvaAssociatorTask.\n"; } // Charged Candidates Loop for(int i = 0; i < fPidChargedCand->GetEntriesFast(); i++) { PndPidCandidate* pidcand = (PndPidCandidate*)fPidChargedCand->At(i); TClonesArray& pidRef = *fPidChargedProb; // initializes with zeros PndPidProbability* prob = new(pidRef[i]) PndPidProbability(); if(fVerbose > 1) { std::cout << "-I- PndPidMVAAssociatorTask Ch BEFORE " << pidcand->GetLorentzVector().M() << '\n'; } // Classify DoPidMatch(*pidcand, *prob); if(fVerbose > 1) { std::cout << "-I- PndPidMVAAssociatorTask Ch AFTER " << pidcand->GetLorentzVector().M() << '\n'; } } // Get the Neutral Candidates /* for(int i = 0; i < fPidNeutralCand->GetEntriesFast(); i++) { PndPidCandidate* pidcand = (PndPidCandidate*)fPidNeutralCand->At(i); TClonesArray& pidRef = *fPidNeutralProb; // initializes with zeros PndPidProbability* prob = new(pidRef[i]) PndPidProbability(); // Classify DoPidMatch(*pidcand, *prob); } */ } /** * Performs the actual classification. *@param pidcand Current pid candidate to be classified. *@param prob Output probabilities. */ void PndPidMvaAssociatorTask::DoPidMatch(PndPidCandidate& pidcand, PndPidProbability& prob) { std::map out; std::vector const* evtPidData = PrepareEvtVect(pidcand); // Perform Recognition. if( evtPidData ) { fClassifier->GetMvaValues( *evtPidData, out); delete evtPidData; } else { // Feature vector is empty or damaged. delete evtPidData; evtPidData = 0; return; } #if ( PIDMVA_ASSOCIATORT_DEBUG != 0 ) std::cout << "****************************************************\n" << "Momentum = " << (pidcand.GetMomentum()).Mag() << "\nGetEnergy = " << pidcand.GetEnergy() << "\nEMC = " << pidcand.GetEmcCalEnergy() << "\nEMC/P = " << (pidcand.GetEmcCalEnergy())/((pidcand.GetMomentum()).Mag()) << "\nEMCZ20 = " << pidcand.GetEmcClusterZ20() << "\nEMCZ53 = " << pidcand.GetEmcClusterZ53() << "\nEMCLAT = " << pidcand.GetEmcClusterLat() << "\nEmcE1 = " << pidcand.GetEmcClusterE1() << "\nEmcE9 = " << pidcand.GetEmcClusterE9() << "\nEmcE25 = " << pidcand.GetEmcClusterE25() << "\nSTT = " << pidcand.GetSttMeanDEDX() << "\nMVD = " << pidcand.GetMvdDEDX() << "\nDRC_TC = " << pidcand.GetDrcThetaC() << '\n'; printResult(out); std::cout << "====================================================\n"; #endif // Set probs. for(size_t i = 0; i < fClassNames.size(); i++) { std::string name = fClassNames[i]; if(name == "electron") { prob.SetElectronPdf(out[name]); } else if(name == "muon") { prob.SetMuonPdf(out[name]); } else if(name == "pion") { prob.SetPionPdf(out[name]); } else if(name == "kaon") { prob.SetKaonPdf(out[name]); } else if(name == "proton") { prob.SetProtonPdf(out[name]); } else { std::cerr << " Unknown label (class Name).\n" << std::flush; } } } std::vector const* PndPidMvaAssociatorTask::PrepareEvtVect(PndPidCandidate const& pidcand) const { std::vector* vect = new std::vector(); float mom = (pidcand.GetMomentum()).Mag(); for(size_t i = 0; i < fVarNames.size(); i++) { if( fVarNames[i] == "p" ) { vect->push_back((pidcand.GetMomentum()).Mag()); } else if(fVarNames[i] == "emc") { if(mom > 0.00) { // E/p vect->push_back( (pidcand.GetEmcCalEnergy())/mom); } else { std::cerr << " (p > 0) failed. The event is skipped.\n" << " p = " << mom << std::endl; delete vect; vect = 0; return 0;// Can not proceed. Break the procedure } } // Cluster Ex parameters. else if( (fVarNames[i] == "e1") || (fVarNames[i] == "E1") ) { vect->push_back(pidcand.GetEmcClusterE1()); } else if( (fVarNames[i] == "e9") || (fVarNames[i] == "E9") ) { vect->push_back(pidcand.GetEmcClusterE9()); } else if( (fVarNames[i] == "e25") || (fVarNames[i] == "E25") ) { vect->push_back(pidcand.GetEmcClusterE25()); } else if( (fVarNames[i] == "e1e9") || (fVarNames[i] == "E1E9") ) { if( pidcand.GetEmcClusterE9() > 0 ) { vect->push_back(pidcand.GetEmcClusterE1()/pidcand.GetEmcClusterE9()); } else { std::cerr << " (EmcClusterE9 > 0) failed. The event is skipped.\n" << std::flush; delete vect; vect = 0; return 0;// Can not proceed. Break the procedure } } else if( (fVarNames[i] == "e9e25") || (fVarNames[i] == "E9E25") ) { if( pidcand.GetEmcClusterE25() > 0 ) { vect->push_back(pidcand.GetEmcClusterE9()/pidcand.GetEmcClusterE25()); } else { std::cerr << " (EmcClusterE25 > 0) failed. The event is skipped.\n" << std::flush; delete vect; vect = 0; return 0;// Can not proceed. Break the procedure } } //======== Zernike & moments else if(fVarNames[i] == "z20") { vect->push_back(pidcand.GetEmcClusterZ20()); } else if(fVarNames[i] == "z53") { vect->push_back(pidcand.GetEmcClusterZ53()); } // Cluster Second lat. moment else if(fVarNames[i] == "lat") { vect->push_back(pidcand.GetEmcClusterLat()); } // ========== other detectors else if(fVarNames[i] == "stt") { vect->push_back(pidcand.GetSttMeanDEDX()); } else if(fVarNames[i] == "mvd") { vect->push_back(pidcand.GetMvdDEDX()); } else if(fVarNames[i] == "thetaC") { vect->push_back(pidcand.GetDrcThetaC()); } } return vect; } //_________________________________________________________________ void PndPidMvaAssociatorTask::Register() { std::string tcaName = fMethodName + "MvaProb"; //--- FairRootManager::Instance()->Register(tcaName.c_str(),"Pid", fPidChargedProb, kTRUE); // FairRootManager::Instance()->Register("MvaNeutralProb","Pid", fPidNeutralProb, kTRUE); } //_________________________________________________________________ void PndPidMvaAssociatorTask::Finish() {} //_________________________________________________________________ void PndPidMvaAssociatorTask::Reset() {}