/usr/include/root/TMVA/DecisionTree.h is in libroot-tmva-dev 5.34.19+dfsg-1.2.
This file is owned by root:root, with mode 0o644.
The actual contents of the file can be viewed below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 | // @(#)root/tmva $Id$
// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Jan Therhaag, Eckhard von Toerne
/**********************************************************************************
* Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
* Package: TMVA *
* Class : DecisionTree *
* Web : http://tmva.sourceforge.net *
* *
* Description: *
* Implementation of a Decision Tree *
* *
* Authors (alphabetical): *
* Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
* Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
* Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
* Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
* Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
* *
* Copyright (c) 2005-2011: *
* CERN, Switzerland *
* U. of Victoria, Canada *
* MPI-K Heidelberg, Germany *
* U. of Bonn, Germany *
* *
* Redistribution and use in source and binary forms, with or without *
* modification, are permitted according to the terms listed in LICENSE *
* (http://mva.sourceforge.net/license.txt) *
* *
**********************************************************************************/
#ifndef ROOT_TMVA_DecisionTree
#define ROOT_TMVA_DecisionTree
//////////////////////////////////////////////////////////////////////////
// //
// DecisionTree //
// //
// Implementation of a Decision Tree //
// //
//////////////////////////////////////////////////////////////////////////
#ifndef ROOT_TH2
#include "TH2.h"
#endif
#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif
#ifndef ROOT_TMVA_DecisionTreeNode
#include "TMVA/DecisionTreeNode.h"
#endif
#ifndef ROOT_TMVA_BinaryTree
#include "TMVA/BinaryTree.h"
#endif
#ifndef ROOT_TMVA_BinarySearchTree
#include "TMVA/BinarySearchTree.h"
#endif
#ifndef ROOT_TMVA_SeparationBase
#include "TMVA/SeparationBase.h"
#endif
#ifndef ROOT_TMVA_RegressionVariance
#include "TMVA/RegressionVariance.h"
#endif
class TRandom3;
namespace TMVA {
class Event;
class DecisionTree : public BinaryTree {
private:
static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
public:
typedef std::vector<TMVA::Event*> EventList;
typedef std::vector<const TMVA::Event*> EventConstList;
// the constructur needed for the "reading" of the decision tree from weight files
DecisionTree( void );
// the constructur needed for constructing the decision tree via training with events
DecisionTree( SeparationBase *sepType, Float_t minSize,
Int_t nCuts,
UInt_t cls =0,
Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
UInt_t nMaxDepth=9999999,
Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
Int_t treeID = 0);
// copy constructor
DecisionTree (const DecisionTree &d);
virtual ~DecisionTree( void );
// Retrieves the address of the root node
virtual DecisionTreeNode* GetRoot() const { return dynamic_cast<TMVA::DecisionTreeNode*>(fRoot); }
virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
virtual const char* ClassName() const { return "DecisionTree"; }
// building of a tree by recursivly splitting the nodes
// UInt_t BuildTree( const EventList & eventSample,
// DecisionTreeNode *node = NULL);
UInt_t BuildTree( const EventConstList & eventSample,
DecisionTreeNode *node = NULL);
// determine the way how a node is split (which variable, which cut value)
Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
// fill at tree with a given structure already (just see how many signa/bkgr
// events end up in each node
void FillTree( const EventList & eventSample);
// fill the existing the decision tree structure by filling event
// in from the top node and see where they happen to end up
void FillEvent( const TMVA::Event & event,
TMVA::DecisionTreeNode *node );
// returns: 1 = Signal (right), -1 = Bkg (left)
Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
// return the individual relative variable importance
std::vector< Double_t > GetVariableImportance();
Double_t GetVariableImportance(UInt_t ivar);
// clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
void ClearTree();
// set pruning method
enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
// recursive pruning of the tree, validation sample required for automatic pruning
Double_t PruneTree( const EventConstList* validationSample = NULL );
// manage the pruning strength parameter (iff < 0 -> automate the pruning process)
void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
Double_t GetPruneStrength( ) const { return fPruneStrength; }
// apply pruning validation sample to a decision tree
void ApplyValidationSample( const EventConstList* validationSample ) const;
// return the misclassification rate of a pruned tree
Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
// pass a single validation event throught a pruned decision tree
void CheckEventWithPrunedTree( const TMVA::Event* ) const;
// calculate the normalization factor for a pruning validation sample
Double_t GetSumWeights( const EventConstList* validationSample ) const;
void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
Double_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
void DescendTree( Node *n = NULL );
void SetParentTreeInNodes( Node *n = NULL );
// retrieve node from the tree. Its position (up to a maximal tree depth of 64)
// is coded as a sequence of left-right moves starting from the root, coded as
// 0-1 bit patterns stored in the "long-integer" together with the depth
Node* GetNode( ULong_t sequence, UInt_t depth );
UInt_t CleanTree(DecisionTreeNode *node=NULL);
void PruneNode(TMVA::DecisionTreeNode *node);
// prune a node from the tree without deleting its descendants; allows one to
// effectively prune a tree many times without making deep copies
void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
Int_t GetNNodesBeforePruning(){return (fNNodesBeforePruning)?fNNodesBeforePruning:fNNodesBeforePruning=GetNNodes();}
UInt_t CountLeafNodes(TMVA::Node *n = NULL);
void SetTreeID(Int_t treeID){fTreeID = treeID;};
Int_t GetTreeID(){return fTreeID;};
Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
inline void SetUseFisherCuts(Bool_t t=kTRUE) { fUseFisherCuts = t;}
inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
inline void SetNVars(Int_t n){fNvars = n;}
private:
// utility functions
// calculate the Purity out of the number of sig and bkg events collected
// from individual samples.
// calculates the purity S/(S+B) of a given event sample
Double_t SamplePurity(EventList eventSample);
UInt_t fNvars; // number of variables used to separate S and B
Int_t fNCuts; // number of grid point in variable cut scans
Bool_t fUseFisherCuts; // use multivariate splits using the Fisher criterium
Double_t fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
Bool_t fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
SeparationBase *fSepType; // the separation crition
RegressionVariance *fRegType; // the separation crition used in Regression
Double_t fMinSize; // min number of events in node
Double_t fMinNodeSize; // min fraction of training events in node
Double_t fMinSepGain; // min number of separation gain to perform node splitting
Bool_t fUseSearchTree; // cut scan done with binary trees or simple event loop.
Double_t fPruneStrength; // a parameter to set the "amount" of pruning..needs to be adjusted
EPruneMethod fPruneMethod; // method used for prunig
Int_t fNNodesBeforePruning; //remember this one (in case of pruning, it allows to monitor the before/after
Double_t fNodePurityLimit;// purity limit to decide whether a node is signal
Bool_t fRandomisedTree; // choose at each node splitting a random set of variables
Int_t fUseNvars; // the number of variables used in randomised trees;
Bool_t fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
TRandom3 *fMyTrandom; // random number generator for randomised trees
std::vector< Double_t > fVariableImportance; // the relative importance of the different variables
UInt_t fMaxDepth; // max depth
UInt_t fSigClass; // class which is treated as signal when building the tree
static const Int_t fgDebugLevel = 0; // debug level determining some printout/control plots etc.
Int_t fTreeID; // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
Types::EAnalysisType fAnalysisType; // kClassification(=0=false) or kRegression(=1=true)
ClassDef(DecisionTree,0) // implementation of a Decision Tree
};
} // namespace TMVA
#endif
|