/* *************************************** * KNN based Classifier using kd-tree * * data structure for better recognition * * performance. * * Author: M.Babai@rug.nl * * LICENSE: * * Version 1.0 beta1. * * *************************************** */ #pragma once #ifndef PndKnnClassify_H #define PndKnnClassify_H // Standard C++ libraries #include #include #include #include // Root includes #include "TFile.h" #include "TTree.h" // kd-Tree and boost includes #include "kdtree2.hpp" #include /** * Two dimensional array of floats to sotore input data before * generating the search tree. */ typedef multi_array array2dfloat; /** * Dummy object to hold the distances computed during classification */ class DistObject{ public: //! Constructors DistObject(): m_cls(""), m_dist(0.0){}; DistObject(float d, std::string c): m_cls(c), m_dist(d){}; //! Destructor ~DistObject(){}; //! Operators inline bool operator < (const DistObject &other)const { return (m_dist < other.m_dist); } inline bool operator>(const DistObject &other)const { return (m_dist > other.m_dist); } //! Local members. std::string m_cls;/**!< Class name. */ float m_dist;/**!< Computed distance. */ }; /** * KNN based classification alg. implementation. */ class PndKnnClassify{ public: /** * Constructor. * @param InputPutFile: File that holds the weights * @param ClassNames: Class names. * @param VarNames: Variable names from which the feature vector is * built. */ PndKnnClassify(const char *InputPutFile, const std::vector& ClassNames, const std::vector& VarNames); //! Destructor virtual ~PndKnnClassify(); /** * Classification function. * @param EvtData: Feature vector of the current event. * @param Neighbours: Number of Neighbours. * @param result: Holds the normalized results of classification */ void Classify(std::vector &EvtData, unsigned int Neighbours, std::map& result); protected: /** * Euclidean distance between two given vectors of event features. */ float ComputeDist(std::vector &EvtData, std::vector &Example); private: //! Class names container std::vector m_ClassNames; //! Variable names container std::vector m_VarNames; //! Contains distances labled for every class std::vector m_dists; //! Container to store the optimized per class trees std::vector< std::pair > m_EventTreeCont; //! Pairs to hold the number of available examples per calss std::map< std::string, int > m_perClassExamples; }; /** * Function used for sorting the distances container */ inline bool LessFunct(const DistObject& p1, const DistObject& p2) { return (p1.m_dist < p2.m_dist); } #endif