#include "PndKnnClassify.h" using namespace std; PndKnnClassify::PndKnnClassify(const string& inputFile, const vector& classNames, const vector& varNames) : PndGpidClassifier(inputFile, classNames, varNames) { // Initialize the class indices map. for(size_t cls = 0; cls < classNames.size(); cls++) { m_classIndices.insert(make_pair(classNames[cls], cls)); } //Initialize TMVA KNN module (everything is done by this module.) m_module = new TMVA::kNN::ModulekNN(); // Set the deafult value for scale factor m_ScaleFact = 0.8; m_knn = 0; } PndKnnClassify::~PndKnnClassify() { if(m_module) { m_module->Clear(); delete m_module; } } void PndKnnClassify::InitKNN() { cerr << " Initializing KNN classifier." << endl; const vector& vars = m_dataSets.GetVars(); const vector*> >& events = m_dataSets.GetData(); // Read the events and insert into the module for(size_t j = 0; j < events.size(); j++) { vector* evtVect = events[j].second; int cls = m_classIndices[events[j].first]; // Create and add the current event. TMVA::kNN::Event Evt(*evtVect, m_weight, cls); m_module->Add(Evt); } // Fill module and optimize m_module->Fill( static_cast (vars.size()), static_cast (100.0 * m_ScaleFact), "");//"metric" cout << " Done initializing." << endl; } const std::string& PndKnnClassify::Classify(std::vector EvtData)const { EvtData.clear(); std::string* re = new std::string(); std::cout << "Not implemented yet."< eventData, map& result) { if(m_knn == 0) { cerr << "\t Number neighbours can not be zero." << endl; return; } const vector & vars = m_dataSets.GetVars(); const vector & classes = m_dataSets.GetClasses(); // Initialize results result.clear(); for(size_t cls = 0; cls < classes.size(); cls++) { result.insert(make_pair(classes[cls].Name, 0.0)); } // Normalize current Event for(size_t k = 0; k < vars.size(); k++) { assert(vars[k].NormFactor != 0); eventData[k] -= vars[k].Mean; eventData[k] /= vars[k].NormFactor; } // Create event and fetch Neighbours events from the tree TMVA::kNN::Event evt(eventData, m_weight, 20); m_module->Find(evt, m_knn + 2); vector countsPerClass(classes.size(), 0); // Fetch the results list ResList lst = m_module->GetkNNList(); ResList::iterator iter; for(iter = lst.begin(); iter != lst.end(); ++iter) { // Fetch the node const TMVA::kNN::Node *node = (*iter).first; // Fetch the event from the Node const TMVA::kNN::Event& event = node->GetEvent(); // Type corresponds with the class name int type = event.GetType(); // Store per class counts in the search results countsPerClass[type] += 1; } // Fill result map with per class counts for(size_t cls = 0; cls < classes.size(); cls++) { result[classes[cls].Name] = countsPerClass[cls]; } // Normalizing the results float probSum = 0.0; for(size_t cls = 0; cls < classes.size(); cls++) { const string& clName = classes[cls].Name; size_t num = classes[cls].NExamples; result[clName] /= static_cast(num); probSum += result[clName]; } for(size_t cls = 0; cls < classes.size(); cls++) { const string& clName = classes[cls].Name; result[clName] /= probSum; } }