[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35#ifndef RF_VISITORS_HXX
36#define RF_VISITORS_HXX
37
38#ifdef HasHDF5
39# include "vigra/hdf5impex.hxx"
40#endif // HasHDF5
41#include <vigra/windows.h>
42#include <iostream>
43#include <iomanip>
44#include <random>
45
46#include <vigra/metaprogramming.hxx>
47#include <vigra/multi_pointoperators.hxx>
48#include <vigra/timing.hxx>
49
50namespace vigra
51{
52namespace rf
53{
54/** \brief Visitors to extract information during training of \ref vigra::RandomForest version 2.
55
56 \ingroup MachineLearning
57
58 This namespace contains all classes and methods related to extracting information during
59 learning of the random forest. All Visitors share the same interface defined in
60 visitors::VisitorBase. The member methods are invoked at certain points of the main code in
61 the order they were supplied.
62
63 For the Random Forest the Visitor concept is implemented as a statically linked list
64 (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
65 VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
66
67 To simplify usage create_visitor() factory methods are supplied.
68 Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
69 It is possible to supply more than one visitor. They will then be invoked in serial order.
70
71 The calculated information are stored as public data members of the class. - see documentation
72 of the individual visitors
73
74 While creating a new visitor the new class should therefore publicly inherit from this class
75 (i.e.: see visitors::OOB_Error).
76
77 \code
78
79 typedef xxx feature_t \\ replace xxx with whichever type
80 typedef yyy label_t \\ meme chose.
81 MultiArrayView<2, feature_t> f = get_some_features();
82 MultiArrayView<2, label_t> l = get_some_labels();
83 RandomForest<> rf()
84
85 //calculate OOB Error
86 visitors::OOB_Error oob_v;
87 //calculate Variable Importance
88 visitors::VariableImportanceVisitor varimp_v;
89
90 double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
91 //the data can be found in the attributes of oob_v and varimp_v now
92
93 \endcode
94*/
95namespace visitors
96{
97
98
99/** Base Class from which all Visitors derive. Can be used as a template to create new
100 * Visitors.
101 */
103{
104 public:
105 bool active_;
106 bool is_active()
107 {
108 return active_;
109 }
110
111 bool has_value()
112 {
113 return false;
114 }
115
117 : active_(true)
118 {}
119
120 void deactivate()
121 {
122 active_ = false;
123 }
124 void activate()
125 {
126 active_ = true;
127 }
128
129 /** do something after the the Split has decided how to process the Region
130 * (Stack entry)
131 *
132 * \param tree reference to the tree that is currently being learned
133 * \param split reference to the split object
134 * \param parent current stack entry which was used to decide the split
135 * \param leftChild left stack entry that will be pushed
136 * \param rightChild
137 * right stack entry that will be pushed.
138 * \param features features matrix
139 * \param labels label matrix
140 * \sa RF_Traits::StackEntry_t
141 */
142 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
144 Split & split,
145 Region & parent,
148 Feature_t & features,
149 Label_t & labels)
150 {
151 ignore_argument(tree,split,parent,leftChild,rightChild,features,labels);
152 }
153
154 /** do something after each tree has been learned
155 *
156 * \param rf reference to the random forest object that called this
157 * visitor
158 * \param pr reference to the preprocessor that processed the input
159 * \param sm reference to the sampler object
160 * \param st reference to the first stack entry
161 * \param index index of current tree
162 */
163 template<class RF, class PR, class SM, class ST>
164 void visit_after_tree(RF & rf, PR & pr, SM & sm, ST & st, int index)
165 {
166 ignore_argument(rf,pr,sm,st,index);
167 }
168
169 /** do something after all trees have been learned
170 *
171 * \param rf reference to the random forest object that called this
172 * visitor
173 * \param pr reference to the preprocessor that processed the input
174 */
175 template<class RF, class PR>
176 void visit_at_end(RF const & rf, PR const & pr)
177 {
178 ignore_argument(rf,pr);
179 }
180
181 /** do something before learning starts
182 *
183 * \param rf reference to the random forest object that called this
184 * visitor
185 * \param pr reference to the Processor class used.
186 */
187 template<class RF, class PR>
188 void visit_at_beginning(RF const & rf, PR const & pr)
189 {
190 ignore_argument(rf,pr);
191 }
192 /** do some thing while traversing tree after it has been learned
193 * (external nodes)
194 *
195 * \param tr reference to the tree object that called this visitor
196 * \param index index in the topology_ array we currently are at
197 * \param node_t type of node we have (will be e_.... - )
198 * \param features feature matrix
199 * \sa NodeTags;
200 *
201 * you can create the node by using a switch on node_tag and using the
202 * corresponding Node objects. Or - if you do not care about the type
203 * use the NodeBase class.
204 */
205 template<class TR, class IntT, class TopT,class Feat>
206 void visit_external_node(TR & tr, IntT index, TopT node_t, Feat & features)
207 {
208 ignore_argument(tr,index,node_t,features);
209 }
210
211 /** do something when visiting a internal node after it has been learned
212 *
213 * \sa visit_external_node
214 */
215 template<class TR, class IntT, class TopT,class Feat>
216 void visit_internal_node(TR & /* tr */, IntT /* index */, TopT /* node_t */, Feat & /* features */)
217 {}
218
219 /** return a double value. The value of the first
220 * visitor encountered that has a return value is returned with the
221 * RandomForest::learn() method - or -1.0 if no return value visitor
222 * existed. This functionality basically only exists so that the
223 * OOB - visitor can return the oob error rate like in the old version
224 * of the random forest.
225 */
226 double return_val()
227 {
228 return -1.0;
229 }
230};
231
232
233/** Last Visitor that should be called to stop the recursion.
234 */
236{
237 public:
238 bool has_value()
239 {
240 return true;
241 }
242 double return_val()
243 {
244 return -1.0;
245 }
246};
247namespace detail
248{
249/** Container elements of the statically linked Visitor list.
250 *
251 * use the create_visitor() factory functions to create visitors up to size 10;
252 *
253 */
254template <class Visitor, class Next = StopVisiting>
256{
257 public:
258
259 StopVisiting stop_;
260 Next next_;
261 Visitor & visitor_;
262 VisitorNode(Visitor & visitor, Next & next)
263 :
264 next_(next), visitor_(visitor)
265 {}
266
267 VisitorNode(Visitor & visitor)
268 :
269 next_(stop_), visitor_(visitor)
270 {}
271
272 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
273 void visit_after_split( Tree & tree,
274 Split & split,
275 Region & parent,
278 Feature_t & features,
279 Label_t & labels)
280 {
281 if(visitor_.is_active())
282 visitor_.visit_after_split(tree, split,
283 parent, leftChild, rightChild,
284 features, labels);
285 next_.visit_after_split(tree, split, parent, leftChild, rightChild,
286 features, labels);
287 }
288
289 template<class RF, class PR, class SM, class ST>
290 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
291 {
292 if(visitor_.is_active())
293 visitor_.visit_after_tree(rf, pr, sm, st, index);
294 next_.visit_after_tree(rf, pr, sm, st, index);
295 }
296
297 template<class RF, class PR>
298 void visit_at_beginning(RF & rf, PR & pr)
299 {
300 if(visitor_.is_active())
301 visitor_.visit_at_beginning(rf, pr);
302 next_.visit_at_beginning(rf, pr);
303 }
304 template<class RF, class PR>
305 void visit_at_end(RF & rf, PR & pr)
306 {
307 if(visitor_.is_active())
308 visitor_.visit_at_end(rf, pr);
309 next_.visit_at_end(rf, pr);
310 }
311
312 template<class TR, class IntT, class TopT,class Feat>
313 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
314 {
315 if(visitor_.is_active())
316 visitor_.visit_external_node(tr, index, node_t,features);
317 next_.visit_external_node(tr, index, node_t,features);
318 }
319 template<class TR, class IntT, class TopT,class Feat>
320 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
321 {
322 if(visitor_.is_active())
323 visitor_.visit_internal_node(tr, index, node_t,features);
324 next_.visit_internal_node(tr, index, node_t,features);
325 }
326
327 double return_val()
328 {
329 if(visitor_.is_active() && visitor_.has_value())
330 return visitor_.return_val();
331 return next_.return_val();
332 }
333};
334
335} //namespace detail
336
337//////////////////////////////////////////////////////////////////////////////
338// Visitor Factory function up to 10 visitors //
339//////////////////////////////////////////////////////////////////////////////
340
341/** factory method to to be used with RandomForest::learn()
342 */
343template<class A>
346{
348 _0_t _0(a);
349 return _0;
350}
351
352
353/** factory method to to be used with RandomForest::learn()
354 */
355template<class A, class B>
356detail::VisitorNode<A, detail::VisitorNode<B> >
357create_visitor(A & a, B & b)
358{
360 _1_t _1(b);
362 _0_t _0(a, _1);
363 return _0;
364}
365
366
367/** factory method to to be used with RandomForest::learn()
368 */
369template<class A, class B, class C>
370detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
371create_visitor(A & a, B & b, C & c)
372{
374 _2_t _2(c);
376 _1_t _1(b, _2);
378 _0_t _0(a, _1);
379 return _0;
380}
381
382
383/** factory method to to be used with RandomForest::learn()
384 */
385template<class A, class B, class C, class D>
386detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
387 detail::VisitorNode<D> > > >
388create_visitor(A & a, B & b, C & c, D & d)
389{
391 _3_t _3(d);
393 _2_t _2(c, _3);
395 _1_t _1(b, _2);
397 _0_t _0(a, _1);
398 return _0;
399}
400
401
402/** factory method to to be used with RandomForest::learn()
403 */
404template<class A, class B, class C, class D, class E>
405detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
406 detail::VisitorNode<D, detail::VisitorNode<E> > > > >
407create_visitor(A & a, B & b, C & c,
408 D & d, E & e)
409{
411 _4_t _4(e);
413 _3_t _3(d, _4);
415 _2_t _2(c, _3);
417 _1_t _1(b, _2);
419 _0_t _0(a, _1);
420 return _0;
421}
422
423
424/** factory method to to be used with RandomForest::learn()
425 */
426template<class A, class B, class C, class D, class E,
427 class F>
428detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
429 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
430create_visitor(A & a, B & b, C & c,
431 D & d, E & e, F & f)
432{
434 _5_t _5(f);
436 _4_t _4(e, _5);
438 _3_t _3(d, _4);
440 _2_t _2(c, _3);
442 _1_t _1(b, _2);
444 _0_t _0(a, _1);
445 return _0;
446}
447
448
449/** factory method to to be used with RandomForest::learn()
450 */
451template<class A, class B, class C, class D, class E,
452 class F, class G>
453detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
454 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
455 detail::VisitorNode<G> > > > > > >
456create_visitor(A & a, B & b, C & c,
457 D & d, E & e, F & f, G & g)
458{
460 _6_t _6(g);
462 _5_t _5(f, _6);
464 _4_t _4(e, _5);
466 _3_t _3(d, _4);
468 _2_t _2(c, _3);
470 _1_t _1(b, _2);
472 _0_t _0(a, _1);
473 return _0;
474}
475
476
477/** factory method to to be used with RandomForest::learn()
478 */
479template<class A, class B, class C, class D, class E,
480 class F, class G, class H>
481detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
482 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
483 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
484create_visitor(A & a, B & b, C & c,
485 D & d, E & e, F & f,
486 G & g, H & h)
487{
489 _7_t _7(h);
491 _6_t _6(g, _7);
493 _5_t _5(f, _6);
495 _4_t _4(e, _5);
497 _3_t _3(d, _4);
499 _2_t _2(c, _3);
501 _1_t _1(b, _2);
503 _0_t _0(a, _1);
504 return _0;
505}
506
507
508/** factory method to to be used with RandomForest::learn()
509 */
510template<class A, class B, class C, class D, class E,
511 class F, class G, class H, class I>
512detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
513 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
514 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
515create_visitor(A & a, B & b, C & c,
516 D & d, E & e, F & f,
517 G & g, H & h, I & i)
518{
520 _8_t _8(i);
522 _7_t _7(h, _8);
524 _6_t _6(g, _7);
526 _5_t _5(f, _6);
528 _4_t _4(e, _5);
530 _3_t _3(d, _4);
532 _2_t _2(c, _3);
534 _1_t _1(b, _2);
536 _0_t _0(a, _1);
537 return _0;
538}
539
540/** factory method to to be used with RandomForest::learn()
541 */
542template<class A, class B, class C, class D, class E,
543 class F, class G, class H, class I, class J>
544detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
545 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
546 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
547 detail::VisitorNode<J> > > > > > > > > >
548create_visitor(A & a, B & b, C & c,
549 D & d, E & e, F & f,
550 G & g, H & h, I & i,
551 J & j)
552{
554 _9_t _9(j);
556 _8_t _8(i, _9);
558 _7_t _7(h, _8);
560 _6_t _6(g, _7);
562 _5_t _5(f, _6);
564 _4_t _4(e, _5);
566 _3_t _3(d, _4);
568 _2_t _2(c, _3);
570 _1_t _1(b, _2);
572 _0_t _0(a, _1);
573 return _0;
574}
575
576//////////////////////////////////////////////////////////////////////////////
577// Visitors of communal interest. //
578//////////////////////////////////////////////////////////////////////////////
579
580
581/** Visitor to gain information, later needed for online learning.
582 */
583
585{
586public:
587 //Set if we adjust thresholds
588 bool adjust_thresholds;
589 //Current tree id
590 int tree_id;
591 //Last node id for finding parent
592 int last_node_id;
593 //Need to now the label for interior node visiting
594 vigra::Int32 current_label;
595 //marginal distribution for interior nodes
596 //
598 adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0)
599 {}
600 struct MarginalDistribution
601 {
602 ArrayVector<Int32> leftCounts;
603 Int32 leftTotalCounts;
604 ArrayVector<Int32> rightCounts;
605 Int32 rightTotalCounts;
606 double gap_left;
607 double gap_right;
608 };
610
611 //All information for one tree
612 struct TreeOnlineInformation
613 {
614 std::vector<MarginalDistribution> mag_distributions;
615 std::vector<IndexList> index_lists;
616 //map for linear index of mag_distributions
617 std::map<int,int> interior_to_index;
618 //map for linear index of index_lists
619 std::map<int,int> exterior_to_index;
620 };
621
622 //All trees
623 std::vector<TreeOnlineInformation> trees_online_information;
624
625 /** Initialize, set the number of trees
626 */
627 template<class RF,class PR>
628 void visit_at_beginning(RF & rf,const PR & /* pr */)
629 {
630 tree_id=0;
631 trees_online_information.resize(rf.options_.tree_count_);
632 }
633
634 /** Reset a tree
635 */
636 void reset_tree(int tree_id)
637 {
638 trees_online_information[tree_id].mag_distributions.clear();
639 trees_online_information[tree_id].index_lists.clear();
640 trees_online_information[tree_id].interior_to_index.clear();
641 trees_online_information[tree_id].exterior_to_index.clear();
642 }
643
644 /** simply increase the tree count
645 */
646 template<class RF, class PR, class SM, class ST>
647 void visit_after_tree(RF & /* rf */, PR & /* pr */, SM & /* sm */, ST & /* st */, int /* index */)
648 {
649 tree_id++;
650 }
651
652 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
653 void visit_after_split( Tree & tree,
654 Split & split,
655 Region & parent,
658 Feature_t & features,
659 Label_t & /* labels */)
660 {
661 int linear_index;
662 int addr=tree.topology_.size();
663 if(split.createNode().typeID() == i_ThresholdNode)
664 {
665 if(adjust_thresholds)
666 {
667 //Store marginal distribution
668 linear_index=trees_online_information[tree_id].mag_distributions.size();
669 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
670 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
671
672 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
673 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
674
675 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
676 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
677 //Store the gap
678 double gap_left,gap_right;
679 int i;
680 gap_left=features(leftChild[0],split.bestSplitColumn());
681 for(i=1;i<leftChild.size();++i)
682 if(features(leftChild[i],split.bestSplitColumn())>gap_left)
683 gap_left=features(leftChild[i],split.bestSplitColumn());
684 gap_right=features(rightChild[0],split.bestSplitColumn());
685 for(i=1;i<rightChild.size();++i)
686 if(features(rightChild[i],split.bestSplitColumn())<gap_right)
687 gap_right=features(rightChild[i],split.bestSplitColumn());
688 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
689 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
690 }
691 }
692 else
693 {
694 //Store index list
695 linear_index=trees_online_information[tree_id].index_lists.size();
696 trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
697
698 trees_online_information[tree_id].index_lists.push_back(IndexList());
699
700 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
701 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
702 }
703 }
704 void add_to_index_list(int tree,int node,int index)
705 {
706 if(!this->active_)
707 return;
708 TreeOnlineInformation &ti=trees_online_information[tree];
709 ti.index_lists[ti.exterior_to_index[node]].push_back(index);
710 }
711 void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
712 {
713 if(!this->active_)
714 return;
715 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
716 trees_online_information[src_tree].exterior_to_index.erase(src_index);
717 }
718 /** do something when visiting a internal node during getToLeaf
719 *
720 * remember as last node id, for finding the parent of the last external node
721 * also: adjust class counts and borders
722 */
723 template<class TR, class IntT, class TopT,class Feat>
724 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
725 {
726 last_node_id=index;
727 if(adjust_thresholds)
728 {
729 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
730 //Check if we are in the gap
731 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
732 TreeOnlineInformation &ti=trees_online_information[tree_id];
733 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
734 if(value>m.gap_left && value<m.gap_right)
735 {
736 //Check which site we want to go
737 if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
738 {
739 //We want to go left
740 m.gap_left=value;
741 }
742 else
743 {
744 //We want to go right
745 m.gap_right=value;
746 }
747 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
748 }
749 //Adjust class counts
750 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
751 {
752 ++m.rightTotalCounts;
753 ++m.rightCounts[current_label];
754 }
755 else
756 {
757 ++m.leftTotalCounts;
758 ++m.rightCounts[current_label];
759 }
760 }
761 }
762 /** do something when visiting a extern node during getToLeaf
763 *
764 * Store the new index!
765 */
766};
767
768//////////////////////////////////////////////////////////////////////////////
769// Out of Bag Error estimates //
770//////////////////////////////////////////////////////////////////////////////
771
772
773/** Visitor that calculates the oob error of each individual randomized
774 * decision tree.
775 *
776 * After training a tree, all those samples that are OOB for this particular tree
777 * are put down the tree and the error estimated.
778 * the per tree oob error is the average of the individual error estimates.
779 * (oobError = average error of one randomized tree)
780 * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
781 * visitor)
782 */
784{
785public:
786 /** Average error of one randomized decision tree
787 */
788 double oobError;
789
790 int totalOobCount;
791 ArrayVector<int> oobCount,oobErrorCount;
792
794 : oobError(0.0),
795 totalOobCount(0)
796 {}
797
798
799 bool has_value()
800 {
801 return true;
802 }
803
804
805 /** does the basic calculation per tree*/
806 template<class RF, class PR, class SM, class ST>
807 void visit_after_tree(RF & rf, PR & pr, SM & sm, ST &, int index)
808 {
809 //do the first time called.
810 if(int(oobCount.size()) != rf.ext_param_.row_count_)
811 {
812 oobCount.resize(rf.ext_param_.row_count_, 0);
813 oobErrorCount.resize(rf.ext_param_.row_count_, 0);
814 }
815 // go through the samples
816 for(int l = 0; l < rf.ext_param_.row_count_; ++l)
817 {
818 // if the lth sample is oob...
819 if(!sm.is_used()[l])
820 {
821 ++oobCount[l];
822 if( rf.tree(index)
823 .predictLabel(rowVector(pr.features(), l))
824 != pr.response()(l,0))
825 {
826 ++oobErrorCount[l];
827 }
828 }
829
830 }
831 }
832
833 /** Does the normalisation
834 */
835 template<class RF, class PR>
836 void visit_at_end(RF & rf, PR &)
837 {
838 // do some normalisation
839 for(int l=0; l < static_cast<int>(rf.ext_param_.row_count_); ++l)
840 {
841 if(oobCount[l])
842 {
843 oobError += double(oobErrorCount[l]) / oobCount[l];
844 ++totalOobCount;
845 }
846 }
847 oobError/=totalOobCount;
848 }
849
850};
851
852/** Visitor that calculates the oob error of the ensemble
853 *
854 * This rate serves as a quick estimate for the crossvalidation
855 * error rate.
856 * Here, each sample is put down the trees for which this sample
857 * is OOB, i.e., if sample #1 is OOB for trees 1, 3 and 5, we calculate
858 * the output using the ensemble consisting only of trees 1 3 and 5.
859 *
860 * Using normal bagged sampling each sample is OOB for approx. 33% of trees.
861 * The error rate obtained as such therefore corresponds to a crossvalidation
862 * rate obtained using a ensemble containing 33% of the trees.
863 */
864class OOB_Error : public VisitorBase
865{
867 int class_count;
868 bool is_weighted;
869 MultiArray<2,double> tmp_prob;
870 public:
871
872 MultiArray<2, double> prob_oob;
873 /** Ensemble oob error rate
874 */
876
877 MultiArray<2, double> oobCount;
878 ArrayVector< int> indices;
879 OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
880#ifdef HasHDF5
881 void save(std::string filen, std::string pathn)
882 {
883 if(*(pathn.end()-1) != '/')
884 pathn += "/";
885 const char* filename = filen.c_str();
886 MultiArray<2, double> temp(Shp(1,1), 0.0);
887 temp[0] = oob_breiman;
888 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
889 }
890#endif
891 // negative value if sample was ib, number indicates how often.
892 // value >=0 if sample was oob, 0 means fail 1, correct
893
894 template<class RF, class PR>
895 void visit_at_beginning(RF & rf, PR &)
896 {
897 class_count = rf.class_count();
898 tmp_prob.reshape(Shp(1, class_count), 0);
899 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
900 is_weighted = rf.options().predict_weighted_;
901 indices.resize(rf.ext_param().row_count_);
902 if(int(oobCount.size()) != rf.ext_param_.row_count_)
903 {
904 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
905 }
906 for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
907 {
908 indices[ii] = ii;
909 }
910 }
911
912 template<class RF, class PR, class SM, class ST>
913 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
914 {
915 // go through the samples
916 int total_oob =0;
917 // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
918 // (i.e. the OOB sample ist very large)
919 // 40000: use at most 40000 OOB samples per class for OOB error estimate
920 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
921 {
922 ArrayVector<int> oob_indices;
923 ArrayVector<int> cts(class_count, 0);
924 std::random_device rd;
925 std::mt19937 g(rd());
926 std::shuffle(indices.begin(), indices.end(), g);
927 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
928 {
929 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
930 {
931 oob_indices.push_back(indices[ii]);
932 ++cts[pr.response()(indices[ii], 0)];
933 }
934 }
935 for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
936 {
937 // update number of trees in which current sample is oob
938 ++oobCount[oob_indices[ll]];
939
940 // update number of oob samples in this tree.
941 ++total_oob;
942 // get the predicted votes ---> tmp_prob;
943 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
944 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
945 rf.tree(index).parameters_,
946 pos);
947 tmp_prob.init(0);
948 for(int ii = 0; ii < class_count; ++ii)
949 {
950 tmp_prob[ii] = node.prob_begin()[ii];
951 }
952 if(is_weighted)
953 {
954 for(int ii = 0; ii < class_count; ++ii)
955 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
956 }
957 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
958
959 }
960 }else
961 {
962 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
963 {
964 // if the lth sample is oob...
965 if(!sm.is_used()[ll])
966 {
967 // update number of trees in which current sample is oob
968 ++oobCount[ll];
969
970 // update number of oob samples in this tree.
971 ++total_oob;
972 // get the predicted votes ---> tmp_prob;
973 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
974 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
975 rf.tree(index).parameters_,
976 pos);
977 tmp_prob.init(0);
978 for(int ii = 0; ii < class_count; ++ii)
979 {
980 tmp_prob[ii] = node.prob_begin()[ii];
981 }
982 if(is_weighted)
983 {
984 for(int ii = 0; ii < class_count; ++ii)
985 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
986 }
987 rowVector(prob_oob, ll) += tmp_prob;
988 }
989 }
990 }
991 // go through the ib samples;
992 }
993
994 /** Normalise variable importance after the number of trees is known.
995 */
996 template<class RF, class PR>
997 void visit_at_end(RF & rf, PR & pr)
998 {
999 // ullis original metric and breiman style stuff
1000 int totalOobCount =0;
1001 int breimanstyle = 0;
1002 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1003 {
1004 if(oobCount[ll])
1005 {
1006 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1007 ++breimanstyle;
1008 ++totalOobCount;
1009 }
1010 }
1011 oob_breiman = double(breimanstyle)/totalOobCount;
1012 }
1013};
1014
1015
1016/** Visitor that calculates different OOB error statistics
1017 */
1019{
1021 int class_count;
1022 bool is_weighted;
1023 MultiArray<2,double> tmp_prob;
1024 public:
1025
1026 /** OOB Error rate of each individual tree
1027 */
1029 /** Mean of oob_per_tree
1030 */
1031 double oob_mean;
1032 /**Standard deviation of oob_per_tree
1033 */
1034 double oob_std;
1035
1036 MultiArray<2, double> prob_oob;
1037 /** Ensemble OOB error
1038 *
1039 * \sa OOB_Error
1040 */
1042
1043 MultiArray<2, double> oobCount;
1044 MultiArray<2, double> oobErrorCount;
1045 /** Per Tree OOB error calculated as in OOB_PerTreeError
1046 * (Ulli's version)
1047 */
1049
1050 /**Column containing the development of the Ensemble
1051 * error rate with increasing number of trees
1052 */
1054 /** 4 dimensional array containing the development of confusion matrices
1055 * with number of trees - can be used to estimate ROC curves etc.
1056 *
1057 * oobroc_per_tree(ii,jj,kk,ll)
1058 * corresponds true label = ii
1059 * predicted label = jj
1060 * confusion matrix after ll trees
1061 *
1062 * explanation of third index:
1063 *
1064 * Two class case:
1065 * kk = 0 - (treeCount-1)
1066 * Threshold is on Probability for class 0 is kk/(treeCount-1);
1067 * More classes:
1068 * kk = 0. Threshold on probability set by argMax of the probability array.
1069 */
1071
1073
1074#ifdef HasHDF5
1075 /** save to HDF5 file
1076 */
1077 void save(std::string filen, std::string pathn)
1078 {
1079 if(*(pathn.end()-1) != '/')
1080 pathn += "/";
1081 const char* filename = filen.c_str();
1082 MultiArray<2, double> temp(Shp(1,1), 0.0);
1083 writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1084 writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1085 writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1086 temp[0] = oob_mean;
1087 writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1088 temp[0] = oob_std;
1089 writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1090 temp[0] = oob_breiman;
1091 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1092 temp[0] = oob_per_tree2;
1093 writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1094 }
1095#endif
1096 // negative value if sample was ib, number indicates how often.
1097 // value >=0 if sample was oob, 0 means fail 1, correct
1098
1099 template<class RF, class PR>
1100 void visit_at_beginning(RF & rf, PR &)
1101 {
1102 class_count = rf.class_count();
1103 if(class_count == 2)
1104 oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1105 else
1106 oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1107 tmp_prob.reshape(Shp(1, class_count), 0);
1108 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1109 is_weighted = rf.options().predict_weighted_;
1110 oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1111 breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1112 //do the first time called.
1113 if(int(oobCount.size()) != rf.ext_param_.row_count_)
1114 {
1115 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1116 oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1117 }
1118 }
1119
1120 template<class RF, class PR, class SM, class ST>
1121 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
1122 {
1123 // go through the samples
1124 int total_oob =0;
1125 int wrong_oob =0;
1126 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1127 {
1128 // if the lth sample is oob...
1129 if(!sm.is_used()[ll])
1130 {
1131 // update number of trees in which current sample is oob
1132 ++oobCount[ll];
1133
1134 // update number of oob samples in this tree.
1135 ++total_oob;
1136 // get the predicted votes ---> tmp_prob;
1137 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1138 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1139 rf.tree(index).parameters_,
1140 pos);
1141 tmp_prob.init(0);
1142 for(int ii = 0; ii < class_count; ++ii)
1143 {
1144 tmp_prob[ii] = node.prob_begin()[ii];
1145 }
1146 if(is_weighted)
1147 {
1148 for(int ii = 0; ii < class_count; ++ii)
1149 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1150 }
1151 rowVector(prob_oob, ll) += tmp_prob;
1152 int label = argMax(tmp_prob);
1153
1154 if(label != pr.response()(ll, 0))
1155 {
1156 // update number of wrong oob samples in this tree.
1157 ++wrong_oob;
1158 // update number of trees in which current sample is wrong oob
1159 ++oobErrorCount[ll];
1160 }
1161 }
1162 }
1163 int breimanstyle = 0;
1164 int totalOobCount = 0;
1165 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1166 {
1167 if(oobCount[ll])
1168 {
1169 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1170 ++breimanstyle;
1171 ++totalOobCount;
1172 if(oobroc_per_tree.shape(2) == 1)
1173 {
1174 oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1175 }
1176 }
1177 }
1178 if(oobroc_per_tree.shape(2) == 1)
1179 oobroc_per_tree.bindOuter(index)/=totalOobCount;
1180 if(oobroc_per_tree.shape(2) > 1)
1181 {
1182 MultiArrayView<3, double> current_roc
1183 = oobroc_per_tree.bindOuter(index);
1184 for(int gg = 0; gg < current_roc.shape(2); ++gg)
1185 {
1186 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1187 {
1188 if(oobCount[ll])
1189 {
1190 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1191 1 : 0;
1192 current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1193 }
1194 }
1195 current_roc.bindOuter(gg)/= totalOobCount;
1196 }
1197 }
1198 breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1199 oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1200 // go through the ib samples;
1201 }
1202
1203 /** Normalise variable importance after the number of trees is known.
1204 */
1205 template<class RF, class PR>
1206 void visit_at_end(RF & rf, PR & pr)
1207 {
1208 // ullis original metric and breiman style stuff
1209 oob_per_tree2 = 0;
1210 int totalOobCount =0;
1211 int breimanstyle = 0;
1212 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1213 {
1214 if(oobCount[ll])
1215 {
1216 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1217 ++breimanstyle;
1218 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1219 ++totalOobCount;
1220 }
1221 }
1222 oob_per_tree2 /= totalOobCount;
1223 oob_breiman = double(breimanstyle)/totalOobCount;
1224 // mean error of each tree
1226 MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1227 rowStatistics(oob_per_tree, mean, stdDev);
1228 }
1229};
1230
1231/** calculate variable importance while learning.
1232 */
1234{
1235 public:
1236
1237 /** This Array has the same entries as the R - random forest variable
1238 * importance.
1239 * Matrix is featureCount by (classCount +2)
1240 * variable_importance_(ii,jj) is the variable importance measure of
1241 * the ii-th variable according to:
1242 * jj = 0 - (classCount-1)
1243 * classwise permutation importance
1244 * jj = rowCount(variable_importance_) -2
1245 * permutation importance
1246 * jj = rowCount(variable_importance_) -1
1247 * gini decrease importance.
1248 *
1249 * permutation importance:
1250 * The difference between the fraction of OOB samples classified correctly
1251 * before and after permuting (randomizing) the ii-th column is calculated.
1252 * The ii-th column is permuted rep_cnt times.
1253 *
1254 * class wise permutation importance:
1255 * same as permutation importance. We only look at those OOB samples whose
1256 * response corresponds to class jj.
1257 *
1258 * gini decrease importance:
1259 * row ii corresponds to the sum of all gini decreases induced by variable ii
1260 * in each node of the random forest.
1261 */
1263 int repetition_count_;
1264 bool in_place_;
1265
1266#ifdef HasHDF5
1267 void save(std::string filename, std::string prefix)
1268 {
1269 prefix = "variable_importance_" + prefix;
1270 writeHDF5(filename.c_str(),
1271 prefix.c_str(),
1273 }
1274#endif
1275
1276 /* Constructor
1277 * \param rep_cnt (defautl: 10) how often should
1278 * the permutation take place. Set to 1 to make calculation faster (but
1279 * possibly more instable)
1280 */
1282 : repetition_count_(rep_cnt)
1283
1284 {}
1285
1286 /** calculates impurity decrease based variable importance after every
1287 * split.
1288 */
1289 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1291 Split & split,
1292 Region & /* parent */,
1293 Region & /* leftChild */,
1294 Region & /* rightChild */,
1295 Feature_t & /* features */,
1296 Label_t & /* labels */)
1297 {
1298 //resize to right size when called the first time
1299
1300 Int32 const class_count = tree.ext_param_.class_count_;
1301 Int32 const column_count = tree.ext_param_.column_count_;
1302 if(variable_importance_.size() == 0)
1303 {
1304
1306 .reshape(MultiArrayShape<2>::type(column_count,
1307 class_count+2));
1308 }
1309
1310 if(split.createNode().typeID() == i_ThresholdNode)
1311 {
1312 Node<i_ThresholdNode> node(split.createNode());
1313 variable_importance_(node.column(),class_count+1)
1314 += split.region_gini_ - split.minGini();
1315 }
1316 }
1317
1318 /**compute permutation based var imp.
1319 * (Only an Array of size oob_sample_count x 1 is created.
1320 * - apposed to oob_sample_count x feature_count in the other method.
1321 *
1322 * \sa FieldProxy
1323 */
1324 template<class RF, class PR, class SM, class ST>
1325 void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & /* st */, int index)
1326 {
1328 Int32 column_count = rf.ext_param_.column_count_;
1329 Int32 class_count = rf.ext_param_.class_count_;
1330
1331 /* This solution saves memory uptake but not multithreading
1332 * compatible
1333 */
1334 // remove the const cast on the features (yep , I know what I am
1335 // doing here.) data is not destroyed.
1336 //typename PR::Feature_t & features
1337 // = const_cast<typename PR::Feature_t &>(pr.features());
1338
1339 typedef typename PR::FeatureWithMemory_t FeatureArray;
1340 typedef typename FeatureArray::value_type FeatureValue;
1341
1342 FeatureArray features = pr.features();
1343
1344 //find the oob indices of current tree.
1346 ArrayVector<Int32>::iterator
1347 iter;
1348 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1349 if(!sm.is_used()[ii])
1350 oob_indices.push_back(ii);
1351
1352 //create space to back up a column
1354
1355 // Random foo
1356#ifdef CLASSIFIER_TEST
1357 RandomMT19937 random(1);
1358#else
1359 RandomMT19937 random(RandomSeed);
1360#endif
1362 randint(random);
1363
1364
1365 //make some space for the results
1367 oob_right(Shp_t(1, class_count + 1));
1369 perm_oob_right (Shp_t(1, class_count + 1));
1370
1371
1372 // get the oob success rate with the original samples
1373 for(iter = oob_indices.begin();
1374 iter != oob_indices.end();
1375 ++iter)
1376 {
1377 if(rf.tree(index)
1378 .predictLabel(rowVector(features, *iter))
1379 == pr.response()(*iter, 0))
1380 {
1381 //per class
1382 ++oob_right[pr.response()(*iter,0)];
1383 //total
1384 ++oob_right[class_count];
1385 }
1386 }
1387 //get the oob rate after permuting the ii'th dimension.
1388 for(int ii = 0; ii < column_count; ++ii)
1389 {
1390 perm_oob_right.init(0.0);
1391 //make backup of original column
1392 backup_column.clear();
1393 for(iter = oob_indices.begin();
1394 iter != oob_indices.end();
1395 ++iter)
1396 {
1397 backup_column.push_back(features(*iter,ii));
1398 }
1399
1400 //get the oob rate after permuting the ii'th dimension.
1401 for(int rr = 0; rr < repetition_count_; ++rr)
1402 {
1403 //permute dimension.
1404 int n = oob_indices.size();
1405 for(int jj = n-1; jj >= 1; --jj)
1406 std::swap(features(oob_indices[jj], ii),
1407 features(oob_indices[randint(jj+1)], ii));
1408
1409 //get the oob success rate after permuting
1410 for(iter = oob_indices.begin();
1411 iter != oob_indices.end();
1412 ++iter)
1413 {
1414 if(rf.tree(index)
1415 .predictLabel(rowVector(features, *iter))
1416 == pr.response()(*iter, 0))
1417 {
1418 //per class
1419 ++perm_oob_right[pr.response()(*iter, 0)];
1420 //total
1421 ++perm_oob_right[class_count];
1422 }
1423 }
1424 }
1425
1426
1427 //normalise and add to the variable_importance array.
1428 perm_oob_right /= repetition_count_;
1430 perm_oob_right *= -1;
1433 .subarray(Shp_t(ii,0),
1434 Shp_t(ii+1,class_count+1)) += perm_oob_right;
1435 //copy back permuted dimension
1436 for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1437 features(oob_indices[jj], ii) = backup_column[jj];
1438 }
1439 }
1440
1441 /** calculate permutation based impurity after every tree has been
1442 * learned default behaviour is that this happens out of place.
1443 * If you have very big data sets and want to avoid copying of data
1444 * set the in_place_ flag to true.
1445 */
1446 template<class RF, class PR, class SM, class ST>
1447 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1448 {
1449 after_tree_ip_impl(rf, pr, sm, st, index);
1450 }
1451
1452 /** Normalise variable importance after the number of trees is known.
1453 */
1454 template<class RF, class PR>
1455 void visit_at_end(RF & rf, PR & /* pr */)
1456 {
1457 variable_importance_ /= rf.trees_.size();
1458 }
1459};
1460
1461/** Verbose output
1462 */
1464 public:
1466
1467 template<class RF, class PR, class SM, class ST>
1468 void visit_after_tree(RF& rf, PR &, SM &, ST &, int index){
1469 if(index != rf.options().tree_count_-1) {
1470 std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1471 << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1472 }
1473 else {
1474 std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1475 }
1476 }
1477
1478 template<class RF, class PR>
1479 void visit_at_end(RF const & rf, PR const &) {
1480 std::string a = TOCS;
1481 std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1482 }
1483
1484 template<class RF, class PR>
1485 void visit_at_beginning(RF const & rf, PR const &) {
1486 TIC;
1487 std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1488 }
1489
1490 private:
1491 USETICTOC;
1492};
1493
1494
1495/** Computes Correlation/Similarity Matrix of features while learning
1496 * random forest.
1497 */
1499{
1500 public:
1501 /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1502 * created on variable ii(when variable ii was chosen)
1503 */
1505 MultiArray<2, int> tmp_labels;
1506 /** additional noise features.
1507 */
1509 MultiArray<2, double> noise_l;
1510 /** how well can a noise column describe a partition created on variable ii.
1511 */
1513 MultiArray<2, double> corr_l;
1514
1515 /** Similarity Matrix
1516 *
1517 * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1518 * gini_missc
1519 * - row normalized by the number of times the column was chosen
1520 * - mean of corr_noise subtracted
1521 * - and symmetrised.
1522 *
1523 */
1525 /** Distance Matrix 1-similarity
1526 */
1528 ArrayVector<int> tmp_cc;
1529
1530 /** How often was variable ii chosen
1531 */
1535 void save(std::string, std::string)
1536 {
1537 /*
1538 std::string tmp;
1539#define VAR_WRITE(NAME) \
1540 tmp = #NAME;\
1541 tmp += "_";\
1542 tmp += prefix;\
1543 vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1544 VAR_WRITE(gini_missc);
1545 VAR_WRITE(corr_noise);
1546 VAR_WRITE(distance);
1547 VAR_WRITE(similarity);
1548 vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1549#undef VAR_WRITE
1550*/
1551 }
1552
1553 template<class RF, class PR>
1554 void visit_at_beginning(RF const & rf, PR & pr)
1555 {
1556 typedef MultiArrayShape<2>::type Shp;
1557 int n = rf.ext_param_.column_count_;
1558 gini_missc.reshape(Shp(n +1,n+ 1));
1559 corr_noise.reshape(Shp(n + 1, 10));
1560 corr_l.reshape(Shp(n +1, 10));
1561
1562 noise.reshape(Shp(pr.features().shape(0), 10));
1563 noise_l.reshape(Shp(pr.features().shape(0), 10));
1564 RandomMT19937 random(RandomSeed);
1565 for(int ii = 0; ii < noise.size(); ++ii)
1566 {
1567 noise[ii] = random.uniform53();
1568 noise_l[ii] = random.uniform53() > 0.5;
1569 }
1570 bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1571 tmp_labels.reshape(pr.response().shape());
1572 tmp_cc.resize(2);
1573 numChoices.resize(n+1);
1574 // look at all axes
1575 }
1576 template<class RF, class PR>
1577 void visit_at_end(RF const &, PR const &)
1578 {
1579 typedef MultiArrayShape<2>::type Shp;
1583 rowStatistics(corr_noise, mean_noise);
1585 int rC = similarity.shape(0);
1586 for(int jj = 0; jj < rC-1; ++jj)
1587 {
1588 rowVector(similarity, jj) /= numChoices[jj];
1589 rowVector(similarity, jj) -= mean_noise(jj, 0);
1590 }
1591 for(int jj = 0; jj < rC; ++jj)
1592 {
1593 similarity(rC -1, jj) /= numChoices[jj];
1594 }
1595 rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1596 similarity = abs(similarity);
1597 FindMinMax<double> minmax;
1598 inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1599
1600 for(int jj = 0; jj < rC; ++jj)
1601 similarity(jj, jj) = minmax.max;
1602
1603 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1604 += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1605 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1606 columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1607 for(int jj = 0; jj < rC; ++jj)
1608 similarity(jj, jj) = 0;
1609
1610 FindMinMax<double> minmax2;
1611 inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1612 for(int jj = 0; jj < rC; ++jj)
1613 similarity(jj, jj) = minmax2.max;
1614 distance.reshape(gini_missc.shape(), minmax2.max);
1616 }
1617
1618 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1619 void visit_after_split( Tree &,
1620 Split & split,
1621 Region & parent,
1622 Region &,
1623 Region &,
1624 Feature_t & features,
1625 Label_t & labels)
1626 {
1627 if(split.createNode().typeID() == i_ThresholdNode)
1628 {
1629 double wgini;
1630 tmp_cc.init(0);
1631 for(int ii = 0; ii < parent.size(); ++ii)
1632 {
1633 tmp_labels[parent[ii]]
1634 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1635 ++tmp_cc[tmp_labels[parent[ii]]];
1636 }
1637 double region_gini = bgfunc.loss_of_region(tmp_labels,
1638 parent.begin(),
1639 parent.end(),
1640 tmp_cc);
1641
1642 int n = split.bestSplitColumn();
1643 ++numChoices[n];
1644 ++(*(numChoices.end()-1));
1645 //this functor does all the work
1646 for(int k = 0; k < features.shape(1); ++k)
1647 {
1648 bgfunc(columnVector(features, k),
1649 tmp_labels,
1650 parent.begin(), parent.end(),
1651 tmp_cc);
1652 wgini = (region_gini - bgfunc.min_gini_);
1653 gini_missc(n, k)
1654 += wgini;
1655 }
1656 for(int k = 0; k < 10; ++k)
1657 {
1658 bgfunc(columnVector(noise, k),
1659 tmp_labels,
1660 parent.begin(), parent.end(),
1661 tmp_cc);
1662 wgini = (region_gini - bgfunc.min_gini_);
1663 corr_noise(n, k)
1664 += wgini;
1665 }
1666
1667 for(int k = 0; k < 10; ++k)
1668 {
1669 bgfunc(columnVector(noise_l, k),
1670 tmp_labels,
1671 parent.begin(), parent.end(),
1672 tmp_cc);
1673 wgini = (region_gini - bgfunc.min_gini_);
1674 corr_l(n, k)
1675 += wgini;
1676 }
1677 bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1678 wgini = (region_gini - bgfunc.min_gini_);
1680 += wgini;
1681
1682 region_gini = split.region_gini_;
1683#if 1
1684 Node<i_ThresholdNode> node(split.createNode());
1686 node.column())
1687 +=split.region_gini_ - split.minGini();
1688#endif
1689 for(int k = 0; k < 10; ++k)
1690 {
1691 split.bgfunc(columnVector(noise, k),
1692 labels,
1693 parent.begin(), parent.end(),
1694 parent.classCounts());
1696 k)
1697 += wgini;
1698 }
1699#if 0
1700 for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1701 {
1702 wgini = region_gini - split.min_gini_[k];
1703
1705 split.splitColumns[k])
1706 += wgini;
1707 }
1708
1709 for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1710 {
1711 split.bgfunc(columnVector(features, split.splitColumns[k]),
1712 labels,
1713 parent.begin(), parent.end(),
1714 parent.classCounts());
1715 wgini = region_gini - split.bgfunc.min_gini_;
1717 split.splitColumns[k]) += wgini;
1718 }
1719#endif
1720 // remember to partition the data according to the best.
1723 += region_gini;
1724 SortSamplesByDimensions<Feature_t>
1725 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1726 std::partition(parent.begin(), parent.end(), sorter);
1727 }
1728 }
1729};
1730
1731
1732} // namespace visitors
1733} // namespace rf
1734} // namespace vigra
1735
1736#endif // RF_VISITORS_HXX
const_pointer data() const
Definition array_vector.hxx:209
const_iterator end() const
Definition array_vector.hxx:237
MultiArrayView subarray(difference_type p, difference_type q) const
Definition multi_array.hxx:1530
const difference_type & shape() const
Definition multi_array.hxx:1650
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition multi_array.hxx:2186
difference_type_1 size() const
Definition multi_array.hxx:1643
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition multi_array.hxx:1569
void reshape(const difference_type &shape)
Definition multi_array.hxx:2863
Class for a single RGB value.
Definition rgbvalue.hxx:128
void init(Iterator i, Iterator end)
Definition tinyvector.hxx:708
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Class for fixed size vectors.
Definition tinyvector.hxx:1008
Definition rf_visitors.hxx:1019
double oob_per_tree2
Definition rf_visitors.hxx:1048
MultiArray< 2, double > breiman_per_tree
Definition rf_visitors.hxx:1053
double oob_mean
Definition rf_visitors.hxx:1031
double oob_breiman
Definition rf_visitors.hxx:1041
MultiArray< 2, double > oob_per_tree
Definition rf_visitors.hxx:1028
void visit_at_end(RF &rf, PR &pr)
Definition rf_visitors.hxx:1206
MultiArray< 4, double > oobroc_per_tree
Definition rf_visitors.hxx:1070
double oob_std
Definition rf_visitors.hxx:1034
Definition rf_visitors.hxx:1499
MultiArray< 2, double > distance
Definition rf_visitors.hxx:1527
MultiArray< 2, double > corr_noise
Definition rf_visitors.hxx:1512
MultiArray< 2, double > gini_missc
Definition rf_visitors.hxx:1504
MultiArray< 2, double > similarity
Definition rf_visitors.hxx:1524
ArrayVector< int > numChoices
Definition rf_visitors.hxx:1532
MultiArray< 2, double > noise
Definition rf_visitors.hxx:1508
Definition rf_visitors.hxx:865
double oob_breiman
Definition rf_visitors.hxx:875
void visit_at_end(RF &rf, PR &pr)
Definition rf_visitors.hxx:997
Definition rf_visitors.hxx:784
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_visitors.hxx:807
double oobError
Definition rf_visitors.hxx:788
void visit_at_end(RF &rf, PR &)
Definition rf_visitors.hxx:836
Definition rf_visitors.hxx:585
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition rf_visitors.hxx:724
void reset_tree(int tree_id)
Definition rf_visitors.hxx:636
void visit_after_tree(RF &, PR &, SM &, ST &, int)
Definition rf_visitors.hxx:647
void visit_at_beginning(RF &rf, const PR &)
Definition rf_visitors.hxx:628
Definition rf_visitors.hxx:236
Definition rf_visitors.hxx:1234
void visit_after_split(Tree &tree, Split &split, Region &, Region &, Region &, Feature_t &, Label_t &)
Definition rf_visitors.hxx:1290
void visit_at_end(RF &rf, PR &)
Definition rf_visitors.hxx:1455
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_visitors.hxx:1447
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_visitors.hxx:1325
MultiArray< 2, double > variable_importance_
Definition rf_visitors.hxx:1262
Definition rf_visitors.hxx:103
void visit_at_beginning(RF const &rf, PR const &pr)
Definition rf_visitors.hxx:188
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition rf_visitors.hxx:206
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition rf_visitors.hxx:143
void visit_internal_node(TR &, IntT, TopT, Feat &)
Definition rf_visitors.hxx:216
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_visitors.hxx:164
void visit_at_end(RF const &rf, PR const &pr)
Definition rf_visitors.hxx:176
double return_val()
Definition rf_visitors.hxx:226
Definition rf_visitors.hxx:256
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:684
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:671
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:727
detail::VisitorNode< A > create_visitor(A &a)
Definition rf_visitors.hxx:345
void writeHDF5(...)
Store array data in an HDF5 file.
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
#define TIC
Definition timing.hxx:322
#define TOCS
Definition timing.hxx:325

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.1