[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_decisionTree.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35
36#ifndef VIGRA_RANDOM_FOREST_DT_HXX
37#define VIGRA_RANDOM_FOREST_DT_HXX
38
39#include <algorithm>
40#include <map>
41#include <numeric>
42#include "vigra/multi_array.hxx"
43#include "vigra/mathutil.hxx"
44#include "vigra/metaprogramming.hxx"
45#include "vigra/array_vector.hxx"
46#include "vigra/sized_int.hxx"
47#include "vigra/matrix.hxx"
48#include "vigra/random.hxx"
49#include "vigra/functorexpression.hxx"
50#include <vector>
51
52#include "rf_common.hxx"
53#include "rf_visitors.hxx"
54#include "rf_nodeproxy.hxx"
55namespace vigra
56{
57
58namespace detail
59{
60 // todo FINALLY DECIDE TO USE CAMEL CASE OR UNDERSCORES !!!!!!
61/* decisiontree classifier.
62 *
63 * This class is actually meant to be used in conjunction with the
64 * Random Forest Classifier
65 * - My suggestion would be to use the RandomForest classifier with
66 * following parameters instead of directly using this
67 * class (Preprocessing default values etc is handled in there):
68 *
69 * \code
70 * RandomForest decisionTree(RF_Traits::Options_t()
71 * .features_per_node(RF_ALL)
72 * .tree_count(1) );
73 * \endcode
74 *
75 * \todo remove the classCount and featurecount from the topology
76 * array. Pass ext_param_ to the nodes!
77 * \todo Use relative addressing of nodes?
78 */
79class DecisionTree
80{
81 /* \todo make private?*/
82 public:
83
84 /* value type of container array. use whenever referencing it
85 */
86 typedef Int32 TreeInt;
87
88 ArrayVector<TreeInt> topology_;
89 ArrayVector<double> parameters_;
90
91 ProblemSpec<> ext_param_;
92 unsigned int classCount_;
93
94
95 public:
96 /* \brief Create tree with parameters */
97 template<class T>
98 DecisionTree(ProblemSpec<T> ext_param)
99 :
100 ext_param_(ext_param),
101 classCount_(ext_param.class_count_)
102 {}
103
104 /* clears all memory used.
105 */
106 void reset(unsigned int classCount = 0)
107 {
108 if(classCount)
109 classCount_ = classCount;
110 topology_.clear();
111 parameters_.clear();
112 }
113
114
115 /* learn a Tree
116 *
117 * \tparam StackEntry_t The Stackentry containing Node/StackEntry_t
118 * Information used during learning. Each Split functor has a
119 * Stack entry associated with it (Split_t::StackEntry_t)
120 * \sa RandomForest::learn()
121 */
122 template < class U, class C,
123 class U2, class C2,
124 class StackEntry_t,
125 class Stop_t,
126 class Split_t,
127 class Visitor_t,
128 class Random_t >
129 void learn( MultiArrayView<2, U, C> const & features,
130 MultiArrayView<2, U2, C2> const & labels,
131 StackEntry_t const & stack_entry,
132 Split_t split,
133 Stop_t stop,
134 Visitor_t & visitor,
135 Random_t & randint);
136 template < class U, class C,
137 class U2, class C2,
138 class StackEntry_t,
139 class Stop_t,
140 class Split_t,
141 class Visitor_t,
142 class Random_t>
143 void continueLearn( MultiArrayView<2, U, C> const & features,
144 MultiArrayView<2, U2, C2> const & labels,
145 StackEntry_t const & stack_entry,
146 Split_t split,
147 Stop_t stop,
148 Visitor_t & visitor,
149 Random_t & randint,
150 //an index to which the last created exterior node will be moved (because it is not used anymore)
151 int garbaged_child=-1);
152
153 /* is a node a Leaf Node? */
154 inline bool isLeafNode(TreeInt in) const
155 {
156 return (in & LeafNodeTag) == LeafNodeTag;
157 }
158
159 /* data driven traversal from root to leaf
160 *
161 * traverse through tree with data given in features. Use Visitors to
162 * collect statistics along the way.
163 */
164 template<class U, class C, class Visitor_t>
165 TreeInt getToLeaf(MultiArrayView<2, U, C> const & features,
166 Visitor_t & visitor) const
167 {
168 TreeInt index = 2;
169 while(!isLeafNode(topology_[index]))
170 {
171 visitor.visit_internal_node(*this, index, topology_[index],features);
172 switch(topology_[index])
173 {
174 case i_ThresholdNode:
175 {
176 Node<i_ThresholdNode>
177 node(topology_, parameters_, index);
178 index = node.next(features);
179 break;
180 }
181 case i_HyperplaneNode:
182 {
183 Node<i_HyperplaneNode>
184 node(topology_, parameters_, index);
185 index = node.next(features);
186 break;
187 }
188 case i_HypersphereNode:
189 {
190 Node<i_HypersphereNode>
191 node(topology_, parameters_, index);
192 index = node.next(features);
193 break;
194 }
195#if 0
196 // for quick prototyping! has to be implemented.
197 case i_VirtualNode:
198 {
199 Node<i_VirtualNode>
200 node(topology_, parameters, index);
201 index = node.next(features);
202 }
203#endif
204 default:
205 vigra_fail("DecisionTree::getToLeaf():"
206 "encountered unknown internal Node Type");
207 }
208 }
209 visitor.visit_external_node(*this, index, topology_[index],features);
210 return index;
211 }
212 /* traverse tree to get statistics
213 *
214 * Tree is traversed in order the Nodes are in memory (i.e. if no
215 * relearning//pruning scheme is utilized this will be pre order)
216 */
217 template<class Visitor_t>
218 void traverse_mem_order(Visitor_t visitor) const
219 {
220 UInt32 index = 2;
221 while(index < topology_.size())
222 {
223 if(isLeafNode(topology_[index]))
224 {
225 visitor
226 .visit_external_node(*this, index, topology_[index]);
227 }
228 else
229 {
230 visitor
231 ._internal_node(*this, index, topology_[index]);
232 }
233 }
234 }
235
236 template<class Visitor_t>
237 void traverse_post_order(Visitor_t visitor, TreeInt /*start*/ = 2) const
238 {
239 typedef TinyVector<double, 2> Entry;
240 std::vector<Entry > stack;
241 std::vector<double> result_stack;
242 stack.push_back(Entry(2, 0));
243 int addr;
244 while(!stack.empty())
245 {
246 addr = stack.back()[0];
247 NodeBase node(topology_, parameters_, stack.back()[0]);
248 if(stack.back()[1] == 1)
249 {
250 stack.pop_back();
251 double leftRes = result_stack.back();
252 double rightRes = result_stack.back();
253 result_stack.pop_back();
254 result_stack.pop_back();
255 result_stack.push_back(rightRes+ leftRes);
256 visitor.visit_internal_node(*this,
257 addr,
258 node.typeID(),
259 rightRes+leftRes);
260 }
261 else
262 {
263 if(isLeafNode(node.typeID()))
264 {
265 visitor.visit_external_node(*this,
266 addr,
267 node.typeID(),
268 node.weights());
269 stack.pop_back();
270 result_stack.push_back(node.weights());
271 }
272 else
273 {
274 stack.back()[1] = 1;
275 stack.push_back(Entry(node.child(0), 0));
276 stack.push_back(Entry(node.child(1), 0));
277 }
278
279 }
280 }
281 }
282
283 /* same thing as above, without any visitors */
284 template<class U, class C>
285 TreeInt getToLeaf(MultiArrayView<2, U, C> const & features) const
286 {
287 ::vigra::rf::visitors::StopVisiting stop;
288 return getToLeaf(features, stop);
289 }
290
291
292 template <class U, class C>
293 ArrayVector<double>::iterator
294 predict(MultiArrayView<2, U, C> const & features) const
295 {
296 TreeInt nodeindex = getToLeaf(features);
297 switch(topology_[nodeindex])
298 {
299 case e_ConstProbNode:
300 return Node<e_ConstProbNode>(topology_,
301 parameters_,
302 nodeindex).prob_begin();
303 break;
304#if 0
305 //first make the Logistic regression stuff...
306 case e_LogRegProbNode:
307 return Node<e_LogRegProbNode>(topology_,
308 parameters_,
309 nodeindex).prob_begin();
310#endif
311 default:
312 vigra_fail("DecisionTree::predict() :"
313 " encountered unknown external Node Type");
314 }
315 return ArrayVector<double>::iterator();
316 }
317
318
319
320 template <class U, class C>
321 Int32 predictLabel(MultiArrayView<2, U, C> const & features) const
322 {
323 ArrayVector<double>::const_iterator weights = predict(features);
324 return argMax(weights, weights+classCount_) - weights;
325 }
326
327};
328
329
330template < class U, class C,
331 class U2, class C2,
332 class StackEntry_t,
333 class Stop_t,
334 class Split_t,
335 class Visitor_t,
336 class Random_t>
337void DecisionTree::learn( MultiArrayView<2, U, C> const & features,
338 MultiArrayView<2, U2, C2> const & labels,
339 StackEntry_t const & stack_entry,
340 Split_t split,
341 Stop_t stop,
342 Visitor_t & visitor,
343 Random_t & randint)
344{
345 this->reset();
346 topology_.reserve(256);
347 parameters_.reserve(256);
348 topology_.push_back(features.shape(1));
349 topology_.push_back(classCount_);
350 continueLearn(features,labels,stack_entry,split,stop,visitor,randint);
351}
352
353template < class U, class C,
354 class U2, class C2,
355 class StackEntry_t,
356 class Stop_t,
357 class Split_t,
358 class Visitor_t,
359 class Random_t>
360void DecisionTree::continueLearn( MultiArrayView<2, U, C> const & features,
361 MultiArrayView<2, U2, C2> const & labels,
362 StackEntry_t const & stack_entry,
363 Split_t split,
364 Stop_t stop,
365 Visitor_t & visitor,
366 Random_t & randint,
367 //an index to which the last created exterior node will be moved (because it is not used anymore)
368 int garbaged_child)
369{
370 std::vector<StackEntry_t> stack;
371 stack.reserve(128);
372 ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry);
373 stack.push_back(stack_entry);
374 size_t last_node_pos = 0;
375 StackEntry_t top=stack.back();
376
377 while(!stack.empty())
378 {
379
380 // Take an element of the stack. Obvious ain't it?
381 top = stack.back();
382 stack.pop_back();
383
384 // Make sure no data from the last round has remained in Pipeline;
385 child_stack_entry[0].reset();
386 child_stack_entry[1].reset();
387 split.reset();
388
389
390 //Either the Stopping criterion decides that the split should
391 //produce a Terminal Node or the Split itself decides what
392 //kind of node to make
393 TreeInt NodeID;
394
395 if(stop(top))
396 NodeID = split.makeTerminalNode(features,
397 labels,
398 top,
399 randint);
400 else
401 {
402 //TIC;
403 NodeID = split.findBestSplit(features,
404 labels,
405 top,
406 child_stack_entry,
407 randint);
408 //std::cerr << TOC <<" " << NodeID << ";" <<std::endl;
409 }
410
411 // do some visiting yawn - just added this comment as eye candy
412 // (looks odd otherwise with my syntax highlighting....
413 visitor.visit_after_split(*this, split, top,
414 child_stack_entry[0],
415 child_stack_entry[1],
416 features,
417 labels);
418
419
420 // Update the Child entries of the parent
421 // Using InteriorNodeBase because exact parameter form not needed.
422 // look at the Node base before getting scared.
423 last_node_pos = topology_.size();
424 if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
425 {
426 NodeBase(topology_,
427 parameters_,
428 top.leftParent).child(0) = last_node_pos;
429 }
430 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
431 {
432 NodeBase(topology_,
433 parameters_,
434 top.rightParent).child(1) = last_node_pos;
435 }
436
437
438 // Supply the split functor with the Node type it requires.
439 // set the address to which the children of this node should point
440 // to and push back children onto stack
441 if(!isLeafNode(NodeID))
442 {
443 child_stack_entry[0].leftParent = topology_.size();
444 child_stack_entry[1].rightParent = topology_.size();
445 child_stack_entry[0].rightParent = -1;
446 child_stack_entry[1].leftParent = -1;
447 stack.push_back(child_stack_entry[0]);
448 stack.push_back(child_stack_entry[1]);
449 }
450
451 //copy the newly created node form the split functor to the
452 //decision tree.
453 NodeBase node(split.createNode(), topology_, parameters_ );
454 ignore_argument(node);
455 }
456 if(garbaged_child!=-1)
457 {
458 Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos));
459
460 int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size();
461 topology_.resize(last_node_pos);
462 parameters_.resize(parameters_.size() - last_parameter_size);
463
464 if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
465 NodeBase(topology_,
466 parameters_,
467 top.leftParent).child(0) = garbaged_child;
468 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
469 NodeBase(topology_,
470 parameters_,
471 top.rightParent).child(1) = garbaged_child;
472 }
473}
474
475} //namespace detail
476
477} //namespace vigra
478
479#endif //VIGRA_RANDOM_FOREST_DT_HXX
detail::SelectIntegerType< 32, detail::UnsignedIntTypes >::type UInt32
32-bit unsigned int
Definition sized_int.hxx:183
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.2 (Mon Apr 14 2025)