/* ********************************************** * MVA classifiers interface. * * Author: M.Babai@rug.nl * * LICENSE: * * Version: * * License: * * ********************************************** */ #ifndef PND_MVA_CLASSIFIER_H #define PND_MVA_CLASSIFIER_H // C++ includes #include #include #include #include #include #include #include // PND PID includes. #include "PndMvaDataSet.h" #include "PndMvaUtil.h" // PANDA And ROOT includes /** * Main interface definition. */ class PndMvaClassifier { public: //! Constructor /** *@param InPut Input file name containing weights. *@ClassNames Names of available (signal) classes. *@VarNames Names of the used variables(features). */ PndMvaClassifier(std::string const& InPut, std::vector const& ClassNames, std::vector const& VarNames); //! Destructor virtual ~PndMvaClassifier(); /** * @param EvtData: Event data to be classified. * @param result: Classification results. Currently the shortest * distance for each class is stored in result. */ virtual void GetMvaValues( std::vector EvtData, std::map& result ) = 0; /** *@param EvtData: Event, to be classified. *@return Name of the class with the best MVA value. */ virtual std::string *Classify( std::vector EvtData ) = 0; //virtual std::string *Classify( std::vector EvtData, // std::map& result ) = 0; virtual void Initialize(); //! Get the list of available classes (labels). inline std::vector const& GetClasses() const; //! Get the list of available variables inline std::vector const& GetVariables() const; protected: //! Normalize the given event vector. virtual void NormalizeEvent(std::vector& EvtVector) const; inline void SetAppType(AppType t); //! Data set. Holds event Weights PndMvaDataSet m_dataSets; private: // To avoid mistakes. PndMvaClassifier (PndMvaClassifier const& other); PndMvaClassifier& operator=(PndMvaClassifier const& other); }; inline void PndMvaClassifier::SetAppType(AppType t) { m_dataSets.SetAppType(t); }; //End of class interface //___________________ Inline implement ___________ //! Get the list of available classes (labels). inline std::vector const& PndMvaClassifier::GetClasses() const { return m_dataSets.GetClasses(); }; //! Get the list of available variables inline std::vector const& PndMvaClassifier::GetVariables() const { return m_dataSets.GetVars(); }; #endif