[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35#ifndef RF_VISITORS_HXX
36#define RF_VISITORS_HXX
37
38#ifdef HasHDF5
39# include "vigra/hdf5impex.hxx"
40#endif // HasHDF5
41#include <vigra/windows.h>
42#include <iostream>
43#include <iomanip>
44
45#include <vigra/metaprogramming.hxx>
46#include <vigra/multi_pointoperators.hxx>
47#include <vigra/timing.hxx>
48
49namespace vigra
50{
51namespace rf
52{
53/** \brief Visitors to extract information during training of \ref vigra::RandomForest version 2.
54
55 \ingroup MachineLearning
56
57 This namespace contains all classes and methods related to extracting information during
58 learning of the random forest. All Visitors share the same interface defined in
59 visitors::VisitorBase. The member methods are invoked at certain points of the main code in
60 the order they were supplied.
61
62 For the Random Forest the Visitor concept is implemented as a statically linked list
63 (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
64 VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
65
66 To simplify usage create_visitor() factory methods are supplied.
67 Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
68 It is possible to supply more than one visitor. They will then be invoked in serial order.
69
70 The calculated information are stored as public data members of the class. - see documentation
71 of the individual visitors
72
73 While creating a new visitor the new class should therefore publicly inherit from this class
74 (i.e.: see visitors::OOB_Error).
75
76 \code
77
78 typedef xxx feature_t \\ replace xxx with whichever type
79 typedef yyy label_t \\ meme chose.
80 MultiArrayView<2, feature_t> f = get_some_features();
81 MultiArrayView<2, label_t> l = get_some_labels();
82 RandomForest<> rf()
83
84 //calculate OOB Error
85 visitors::OOB_Error oob_v;
86 //calculate Variable Importance
87 visitors::VariableImportanceVisitor varimp_v;
88
89 double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
90 //the data can be found in the attributes of oob_v and varimp_v now
91
92 \endcode
93*/
94namespace visitors
95{
96
97
98/** Base Class from which all Visitors derive. Can be used as a template to create new
99 * Visitors.
100 */
102{
103 public:
104 bool active_;
105 bool is_active()
106 {
107 return active_;
108 }
109
110 bool has_value()
111 {
112 return false;
113 }
114
116 : active_(true)
117 {}
118
119 void deactivate()
120 {
121 active_ = false;
122 }
123 void activate()
124 {
125 active_ = true;
126 }
127
128 /** do something after the the Split has decided how to process the Region
129 * (Stack entry)
130 *
131 * \param tree reference to the tree that is currently being learned
132 * \param split reference to the split object
133 * \param parent current stack entry which was used to decide the split
134 * \param leftChild left stack entry that will be pushed
135 * \param rightChild
136 * right stack entry that will be pushed.
137 * \param features features matrix
138 * \param labels label matrix
139 * \sa RF_Traits::StackEntry_t
140 */
141 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
143 Split & split,
144 Region & parent,
147 Feature_t & features,
148 Label_t & labels)
149 {
150 ignore_argument(tree,split,parent,leftChild,rightChild,features,labels);
151 }
152
153 /** do something after each tree has been learned
154 *
155 * \param rf reference to the random forest object that called this
156 * visitor
157 * \param pr reference to the preprocessor that processed the input
158 * \param sm reference to the sampler object
159 * \param st reference to the first stack entry
160 * \param index index of current tree
161 */
162 template<class RF, class PR, class SM, class ST>
163 void visit_after_tree(RF & rf, PR & pr, SM & sm, ST & st, int index)
164 {
165 ignore_argument(rf,pr,sm,st,index);
166 }
167
168 /** do something after all trees have been learned
169 *
170 * \param rf reference to the random forest object that called this
171 * visitor
172 * \param pr reference to the preprocessor that processed the input
173 */
174 template<class RF, class PR>
175 void visit_at_end(RF const & rf, PR const & pr)
176 {
177 ignore_argument(rf,pr);
178 }
179
180 /** do something before learning starts
181 *
182 * \param rf reference to the random forest object that called this
183 * visitor
184 * \param pr reference to the Processor class used.
185 */
186 template<class RF, class PR>
187 void visit_at_beginning(RF const & rf, PR const & pr)
188 {
189 ignore_argument(rf,pr);
190 }
191 /** do some thing while traversing tree after it has been learned
192 * (external nodes)
193 *
194 * \param tr reference to the tree object that called this visitor
195 * \param index index in the topology_ array we currently are at
196 * \param node_t type of node we have (will be e_.... - )
197 * \param features feature matrix
198 * \sa NodeTags;
199 *
200 * you can create the node by using a switch on node_tag and using the
201 * corresponding Node objects. Or - if you do not care about the type
202 * use the NodeBase class.
203 */
204 template<class TR, class IntT, class TopT,class Feat>
205 void visit_external_node(TR & tr, IntT index, TopT node_t, Feat & features)
206 {
207 ignore_argument(tr,index,node_t,features);
208 }
209
210 /** do something when visiting a internal node after it has been learned
211 *
212 * \sa visit_external_node
213 */
214 template<class TR, class IntT, class TopT,class Feat>
215 void visit_internal_node(TR & /* tr */, IntT /* index */, TopT /* node_t */, Feat & /* features */)
216 {}
217
218 /** return a double value. The value of the first
219 * visitor encountered that has a return value is returned with the
220 * RandomForest::learn() method - or -1.0 if no return value visitor
221 * existed. This functionality basically only exists so that the
222 * OOB - visitor can return the oob error rate like in the old version
223 * of the random forest.
224 */
225 double return_val()
226 {
227 return -1.0;
228 }
229};
230
231
232/** Last Visitor that should be called to stop the recursion.
233 */
235{
236 public:
237 bool has_value()
238 {
239 return true;
240 }
241 double return_val()
242 {
243 return -1.0;
244 }
245};
246namespace detail
247{
248/** Container elements of the statically linked Visitor list.
249 *
250 * use the create_visitor() factory functions to create visitors up to size 10;
251 *
252 */
253template <class Visitor, class Next = StopVisiting>
255{
256 public:
257
258 StopVisiting stop_;
259 Next next_;
260 Visitor & visitor_;
261 VisitorNode(Visitor & visitor, Next & next)
262 :
263 next_(next), visitor_(visitor)
264 {}
265
266 VisitorNode(Visitor & visitor)
267 :
268 next_(stop_), visitor_(visitor)
269 {}
270
271 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
272 void visit_after_split( Tree & tree,
273 Split & split,
274 Region & parent,
277 Feature_t & features,
278 Label_t & labels)
279 {
280 if(visitor_.is_active())
281 visitor_.visit_after_split(tree, split,
282 parent, leftChild, rightChild,
283 features, labels);
284 next_.visit_after_split(tree, split, parent, leftChild, rightChild,
285 features, labels);
286 }
287
288 template<class RF, class PR, class SM, class ST>
289 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
290 {
291 if(visitor_.is_active())
292 visitor_.visit_after_tree(rf, pr, sm, st, index);
293 next_.visit_after_tree(rf, pr, sm, st, index);
294 }
295
296 template<class RF, class PR>
297 void visit_at_beginning(RF & rf, PR & pr)
298 {
299 if(visitor_.is_active())
300 visitor_.visit_at_beginning(rf, pr);
301 next_.visit_at_beginning(rf, pr);
302 }
303 template<class RF, class PR>
304 void visit_at_end(RF & rf, PR & pr)
305 {
306 if(visitor_.is_active())
307 visitor_.visit_at_end(rf, pr);
308 next_.visit_at_end(rf, pr);
309 }
310
311 template<class TR, class IntT, class TopT,class Feat>
312 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
313 {
314 if(visitor_.is_active())
315 visitor_.visit_external_node(tr, index, node_t,features);
316 next_.visit_external_node(tr, index, node_t,features);
317 }
318 template<class TR, class IntT, class TopT,class Feat>
319 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
320 {
321 if(visitor_.is_active())
322 visitor_.visit_internal_node(tr, index, node_t,features);
323 next_.visit_internal_node(tr, index, node_t,features);
324 }
325
326 double return_val()
327 {
328 if(visitor_.is_active() && visitor_.has_value())
329 return visitor_.return_val();
330 return next_.return_val();
331 }
332};
333
334} //namespace detail
335
336//////////////////////////////////////////////////////////////////////////////
337// Visitor Factory function up to 10 visitors //
338//////////////////////////////////////////////////////////////////////////////
339
340/** factory method to to be used with RandomForest::learn()
341 */
342template<class A>
345{
347 _0_t _0(a);
348 return _0;
349}
350
351
352/** factory method to to be used with RandomForest::learn()
353 */
354template<class A, class B>
355detail::VisitorNode<A, detail::VisitorNode<B> >
356create_visitor(A & a, B & b)
357{
359 _1_t _1(b);
361 _0_t _0(a, _1);
362 return _0;
363}
364
365
366/** factory method to to be used with RandomForest::learn()
367 */
368template<class A, class B, class C>
369detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
370create_visitor(A & a, B & b, C & c)
371{
373 _2_t _2(c);
375 _1_t _1(b, _2);
377 _0_t _0(a, _1);
378 return _0;
379}
380
381
382/** factory method to to be used with RandomForest::learn()
383 */
384template<class A, class B, class C, class D>
385detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
386 detail::VisitorNode<D> > > >
387create_visitor(A & a, B & b, C & c, D & d)
388{
390 _3_t _3(d);
392 _2_t _2(c, _3);
394 _1_t _1(b, _2);
396 _0_t _0(a, _1);
397 return _0;
398}
399
400
401/** factory method to to be used with RandomForest::learn()
402 */
403template<class A, class B, class C, class D, class E>
404detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
405 detail::VisitorNode<D, detail::VisitorNode<E> > > > >
406create_visitor(A & a, B & b, C & c,
407 D & d, E & e)
408{
410 _4_t _4(e);
412 _3_t _3(d, _4);
414 _2_t _2(c, _3);
416 _1_t _1(b, _2);
418 _0_t _0(a, _1);
419 return _0;
420}
421
422
423/** factory method to to be used with RandomForest::learn()
424 */
425template<class A, class B, class C, class D, class E,
426 class F>
427detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
428 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
429create_visitor(A & a, B & b, C & c,
430 D & d, E & e, F & f)
431{
433 _5_t _5(f);
435 _4_t _4(e, _5);
437 _3_t _3(d, _4);
439 _2_t _2(c, _3);
441 _1_t _1(b, _2);
443 _0_t _0(a, _1);
444 return _0;
445}
446
447
448/** factory method to to be used with RandomForest::learn()
449 */
450template<class A, class B, class C, class D, class E,
451 class F, class G>
452detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
453 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
454 detail::VisitorNode<G> > > > > > >
455create_visitor(A & a, B & b, C & c,
456 D & d, E & e, F & f, G & g)
457{
459 _6_t _6(g);
461 _5_t _5(f, _6);
463 _4_t _4(e, _5);
465 _3_t _3(d, _4);
467 _2_t _2(c, _3);
469 _1_t _1(b, _2);
471 _0_t _0(a, _1);
472 return _0;
473}
474
475
476/** factory method to to be used with RandomForest::learn()
477 */
478template<class A, class B, class C, class D, class E,
479 class F, class G, class H>
480detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
481 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
482 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
483create_visitor(A & a, B & b, C & c,
484 D & d, E & e, F & f,
485 G & g, H & h)
486{
488 _7_t _7(h);
490 _6_t _6(g, _7);
492 _5_t _5(f, _6);
494 _4_t _4(e, _5);
496 _3_t _3(d, _4);
498 _2_t _2(c, _3);
500 _1_t _1(b, _2);
502 _0_t _0(a, _1);
503 return _0;
504}
505
506
507/** factory method to to be used with RandomForest::learn()
508 */
509template<class A, class B, class C, class D, class E,
510 class F, class G, class H, class I>
511detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
512 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
513 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
514create_visitor(A & a, B & b, C & c,
515 D & d, E & e, F & f,
516 G & g, H & h, I & i)
517{
519 _8_t _8(i);
521 _7_t _7(h, _8);
523 _6_t _6(g, _7);
525 _5_t _5(f, _6);
527 _4_t _4(e, _5);
529 _3_t _3(d, _4);
531 _2_t _2(c, _3);
533 _1_t _1(b, _2);
535 _0_t _0(a, _1);
536 return _0;
537}
538
539/** factory method to to be used with RandomForest::learn()
540 */
541template<class A, class B, class C, class D, class E,
542 class F, class G, class H, class I, class J>
543detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
544 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
545 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
546 detail::VisitorNode<J> > > > > > > > > >
547create_visitor(A & a, B & b, C & c,
548 D & d, E & e, F & f,
549 G & g, H & h, I & i,
550 J & j)
551{
553 _9_t _9(j);
555 _8_t _8(i, _9);
557 _7_t _7(h, _8);
559 _6_t _6(g, _7);
561 _5_t _5(f, _6);
563 _4_t _4(e, _5);
565 _3_t _3(d, _4);
567 _2_t _2(c, _3);
569 _1_t _1(b, _2);
571 _0_t _0(a, _1);
572 return _0;
573}
574
575//////////////////////////////////////////////////////////////////////////////
576// Visitors of communal interest. //
577//////////////////////////////////////////////////////////////////////////////
578
579
580/** Visitor to gain information, later needed for online learning.
581 */
582
584{
585public:
586 //Set if we adjust thresholds
587 bool adjust_thresholds;
588 //Current tree id
589 int tree_id;
590 //Last node id for finding parent
591 int last_node_id;
592 //Need to now the label for interior node visiting
593 vigra::Int32 current_label;
594 //marginal distribution for interior nodes
595 //
597 adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0)
598 {}
599 struct MarginalDistribution
600 {
601 ArrayVector<Int32> leftCounts;
602 Int32 leftTotalCounts;
603 ArrayVector<Int32> rightCounts;
604 Int32 rightTotalCounts;
605 double gap_left;
606 double gap_right;
607 };
609
610 //All information for one tree
611 struct TreeOnlineInformation
612 {
613 std::vector<MarginalDistribution> mag_distributions;
614 std::vector<IndexList> index_lists;
615 //map for linear index of mag_distributions
616 std::map<int,int> interior_to_index;
617 //map for linear index of index_lists
618 std::map<int,int> exterior_to_index;
619 };
620
621 //All trees
622 std::vector<TreeOnlineInformation> trees_online_information;
623
624 /** Initialize, set the number of trees
625 */
626 template<class RF,class PR>
627 void visit_at_beginning(RF & rf,const PR & /* pr */)
628 {
629 tree_id=0;
630 trees_online_information.resize(rf.options_.tree_count_);
631 }
632
633 /** Reset a tree
634 */
635 void reset_tree(int tree_id)
636 {
637 trees_online_information[tree_id].mag_distributions.clear();
638 trees_online_information[tree_id].index_lists.clear();
639 trees_online_information[tree_id].interior_to_index.clear();
640 trees_online_information[tree_id].exterior_to_index.clear();
641 }
642
643 /** simply increase the tree count
644 */
645 template<class RF, class PR, class SM, class ST>
646 void visit_after_tree(RF & /* rf */, PR & /* pr */, SM & /* sm */, ST & /* st */, int /* index */)
647 {
648 tree_id++;
649 }
650
651 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
652 void visit_after_split( Tree & tree,
653 Split & split,
654 Region & parent,
657 Feature_t & features,
658 Label_t & /* labels */)
659 {
660 int linear_index;
661 int addr=tree.topology_.size();
662 if(split.createNode().typeID() == i_ThresholdNode)
663 {
664 if(adjust_thresholds)
665 {
666 //Store marginal distribution
667 linear_index=trees_online_information[tree_id].mag_distributions.size();
668 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
669 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
670
671 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
672 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
673
674 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
675 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
676 //Store the gap
677 double gap_left,gap_right;
678 int i;
679 gap_left=features(leftChild[0],split.bestSplitColumn());
680 for(i=1;i<leftChild.size();++i)
681 if(features(leftChild[i],split.bestSplitColumn())>gap_left)
682 gap_left=features(leftChild[i],split.bestSplitColumn());
683 gap_right=features(rightChild[0],split.bestSplitColumn());
684 for(i=1;i<rightChild.size();++i)
685 if(features(rightChild[i],split.bestSplitColumn())<gap_right)
686 gap_right=features(rightChild[i],split.bestSplitColumn());
687 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
688 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
689 }
690 }
691 else
692 {
693 //Store index list
694 linear_index=trees_online_information[tree_id].index_lists.size();
695 trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
696
697 trees_online_information[tree_id].index_lists.push_back(IndexList());
698
699 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
700 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
701 }
702 }
703 void add_to_index_list(int tree,int node,int index)
704 {
705 if(!this->active_)
706 return;
707 TreeOnlineInformation &ti=trees_online_information[tree];
708 ti.index_lists[ti.exterior_to_index[node]].push_back(index);
709 }
710 void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
711 {
712 if(!this->active_)
713 return;
714 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
715 trees_online_information[src_tree].exterior_to_index.erase(src_index);
716 }
717 /** do something when visiting a internal node during getToLeaf
718 *
719 * remember as last node id, for finding the parent of the last external node
720 * also: adjust class counts and borders
721 */
722 template<class TR, class IntT, class TopT,class Feat>
723 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
724 {
725 last_node_id=index;
726 if(adjust_thresholds)
727 {
728 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
729 //Check if we are in the gap
730 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
731 TreeOnlineInformation &ti=trees_online_information[tree_id];
732 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
733 if(value>m.gap_left && value<m.gap_right)
734 {
735 //Check which site we want to go
736 if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
737 {
738 //We want to go left
739 m.gap_left=value;
740 }
741 else
742 {
743 //We want to go right
744 m.gap_right=value;
745 }
746 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
747 }
748 //Adjust class counts
749 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
750 {
751 ++m.rightTotalCounts;
752 ++m.rightCounts[current_label];
753 }
754 else
755 {
756 ++m.leftTotalCounts;
757 ++m.rightCounts[current_label];
758 }
759 }
760 }
761 /** do something when visiting a extern node during getToLeaf
762 *
763 * Store the new index!
764 */
765};
766
767//////////////////////////////////////////////////////////////////////////////
768// Out of Bag Error estimates //
769//////////////////////////////////////////////////////////////////////////////
770
771
772/** Visitor that calculates the oob error of each individual randomized
773 * decision tree.
774 *
775 * After training a tree, all those samples that are OOB for this particular tree
776 * are put down the tree and the error estimated.
777 * the per tree oob error is the average of the individual error estimates.
778 * (oobError = average error of one randomized tree)
779 * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
780 * visitor)
781 */
783{
784public:
785 /** Average error of one randomized decision tree
786 */
787 double oobError;
788
789 int totalOobCount;
790 ArrayVector<int> oobCount,oobErrorCount;
791
793 : oobError(0.0),
794 totalOobCount(0)
795 {}
796
797
798 bool has_value()
799 {
800 return true;
801 }
802
803
804 /** does the basic calculation per tree*/
805 template<class RF, class PR, class SM, class ST>
806 void visit_after_tree(RF & rf, PR & pr, SM & sm, ST &, int index)
807 {
808 //do the first time called.
809 if(int(oobCount.size()) != rf.ext_param_.row_count_)
810 {
811 oobCount.resize(rf.ext_param_.row_count_, 0);
812 oobErrorCount.resize(rf.ext_param_.row_count_, 0);
813 }
814 // go through the samples
815 for(int l = 0; l < rf.ext_param_.row_count_; ++l)
816 {
817 // if the lth sample is oob...
818 if(!sm.is_used()[l])
819 {
820 ++oobCount[l];
821 if( rf.tree(index)
822 .predictLabel(rowVector(pr.features(), l))
823 != pr.response()(l,0))
824 {
825 ++oobErrorCount[l];
826 }
827 }
828
829 }
830 }
831
832 /** Does the normalisation
833 */
834 template<class RF, class PR>
835 void visit_at_end(RF & rf, PR &)
836 {
837 // do some normalisation
838 for(int l=0; l < static_cast<int>(rf.ext_param_.row_count_); ++l)
839 {
840 if(oobCount[l])
841 {
842 oobError += double(oobErrorCount[l]) / oobCount[l];
843 ++totalOobCount;
844 }
845 }
846 oobError/=totalOobCount;
847 }
848
849};
850
851/** Visitor that calculates the oob error of the ensemble
852 *
853 * This rate serves as a quick estimate for the crossvalidation
854 * error rate.
855 * Here, each sample is put down the trees for which this sample
856 * is OOB, i.e., if sample #1 is OOB for trees 1, 3 and 5, we calculate
857 * the output using the ensemble consisting only of trees 1 3 and 5.
858 *
859 * Using normal bagged sampling each sample is OOB for approx. 33% of trees.
860 * The error rate obtained as such therefore corresponds to a crossvalidation
861 * rate obtained using a ensemble containing 33% of the trees.
862 */
863class OOB_Error : public VisitorBase
864{
866 int class_count;
867 bool is_weighted;
868 MultiArray<2,double> tmp_prob;
869 public:
870
871 MultiArray<2, double> prob_oob;
872 /** Ensemble oob error rate
873 */
875
876 MultiArray<2, double> oobCount;
877 ArrayVector< int> indices;
878 OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
879#ifdef HasHDF5
880 void save(std::string filen, std::string pathn)
881 {
882 if(*(pathn.end()-1) != '/')
883 pathn += "/";
884 const char* filename = filen.c_str();
885 MultiArray<2, double> temp(Shp(1,1), 0.0);
886 temp[0] = oob_breiman;
887 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
888 }
889#endif
890 // negative value if sample was ib, number indicates how often.
891 // value >=0 if sample was oob, 0 means fail 1, correct
892
893 template<class RF, class PR>
894 void visit_at_beginning(RF & rf, PR &)
895 {
896 class_count = rf.class_count();
897 tmp_prob.reshape(Shp(1, class_count), 0);
898 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
899 is_weighted = rf.options().predict_weighted_;
900 indices.resize(rf.ext_param().row_count_);
901 if(int(oobCount.size()) != rf.ext_param_.row_count_)
902 {
903 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
904 }
905 for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
906 {
907 indices[ii] = ii;
908 }
909 }
910
911 template<class RF, class PR, class SM, class ST>
912 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
913 {
914 // go through the samples
915 int total_oob =0;
916 // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
917 // (i.e. the OOB sample ist very large)
918 // 40000: use at most 40000 OOB samples per class for OOB error estimate
919 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
920 {
921 ArrayVector<int> oob_indices;
922 ArrayVector<int> cts(class_count, 0);
923 std::random_shuffle(indices.begin(), indices.end());
924 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
925 {
926 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
927 {
928 oob_indices.push_back(indices[ii]);
929 ++cts[pr.response()(indices[ii], 0)];
930 }
931 }
932 for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
933 {
934 // update number of trees in which current sample is oob
935 ++oobCount[oob_indices[ll]];
936
937 // update number of oob samples in this tree.
938 ++total_oob;
939 // get the predicted votes ---> tmp_prob;
940 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
941 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
942 rf.tree(index).parameters_,
943 pos);
944 tmp_prob.init(0);
945 for(int ii = 0; ii < class_count; ++ii)
946 {
947 tmp_prob[ii] = node.prob_begin()[ii];
948 }
949 if(is_weighted)
950 {
951 for(int ii = 0; ii < class_count; ++ii)
952 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
953 }
954 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
955
956 }
957 }else
958 {
959 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
960 {
961 // if the lth sample is oob...
962 if(!sm.is_used()[ll])
963 {
964 // update number of trees in which current sample is oob
965 ++oobCount[ll];
966
967 // update number of oob samples in this tree.
968 ++total_oob;
969 // get the predicted votes ---> tmp_prob;
970 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
971 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
972 rf.tree(index).parameters_,
973 pos);
974 tmp_prob.init(0);
975 for(int ii = 0; ii < class_count; ++ii)
976 {
977 tmp_prob[ii] = node.prob_begin()[ii];
978 }
979 if(is_weighted)
980 {
981 for(int ii = 0; ii < class_count; ++ii)
982 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
983 }
984 rowVector(prob_oob, ll) += tmp_prob;
985 }
986 }
987 }
988 // go through the ib samples;
989 }
990
991 /** Normalise variable importance after the number of trees is known.
992 */
993 template<class RF, class PR>
994 void visit_at_end(RF & rf, PR & pr)
995 {
996 // ullis original metric and breiman style stuff
997 int totalOobCount =0;
998 int breimanstyle = 0;
999 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1000 {
1001 if(oobCount[ll])
1002 {
1003 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1004 ++breimanstyle;
1005 ++totalOobCount;
1006 }
1007 }
1008 oob_breiman = double(breimanstyle)/totalOobCount;
1009 }
1010};
1011
1012
1013/** Visitor that calculates different OOB error statistics
1014 */
1016{
1018 int class_count;
1019 bool is_weighted;
1020 MultiArray<2,double> tmp_prob;
1021 public:
1022
1023 /** OOB Error rate of each individual tree
1024 */
1026 /** Mean of oob_per_tree
1027 */
1028 double oob_mean;
1029 /**Standard deviation of oob_per_tree
1030 */
1031 double oob_std;
1032
1033 MultiArray<2, double> prob_oob;
1034 /** Ensemble OOB error
1035 *
1036 * \sa OOB_Error
1037 */
1039
1040 MultiArray<2, double> oobCount;
1041 MultiArray<2, double> oobErrorCount;
1042 /** Per Tree OOB error calculated as in OOB_PerTreeError
1043 * (Ulli's version)
1044 */
1046
1047 /**Column containing the development of the Ensemble
1048 * error rate with increasing number of trees
1049 */
1051 /** 4 dimensional array containing the development of confusion matrices
1052 * with number of trees - can be used to estimate ROC curves etc.
1053 *
1054 * oobroc_per_tree(ii,jj,kk,ll)
1055 * corresponds true label = ii
1056 * predicted label = jj
1057 * confusion matrix after ll trees
1058 *
1059 * explanation of third index:
1060 *
1061 * Two class case:
1062 * kk = 0 - (treeCount-1)
1063 * Threshold is on Probability for class 0 is kk/(treeCount-1);
1064 * More classes:
1065 * kk = 0. Threshold on probability set by argMax of the probability array.
1066 */
1068
1070
1071#ifdef HasHDF5
1072 /** save to HDF5 file
1073 */
1074 void save(std::string filen, std::string pathn)
1075 {
1076 if(*(pathn.end()-1) != '/')
1077 pathn += "/";
1078 const char* filename = filen.c_str();
1079 MultiArray<2, double> temp(Shp(1,1), 0.0);
1080 writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1081 writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1082 writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1083 temp[0] = oob_mean;
1084 writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1085 temp[0] = oob_std;
1086 writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1087 temp[0] = oob_breiman;
1088 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1089 temp[0] = oob_per_tree2;
1090 writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1091 }
1092#endif
1093 // negative value if sample was ib, number indicates how often.
1094 // value >=0 if sample was oob, 0 means fail 1, correct
1095
1096 template<class RF, class PR>
1097 void visit_at_beginning(RF & rf, PR &)
1098 {
1099 class_count = rf.class_count();
1100 if(class_count == 2)
1101 oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1102 else
1103 oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1104 tmp_prob.reshape(Shp(1, class_count), 0);
1105 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1106 is_weighted = rf.options().predict_weighted_;
1107 oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1108 breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1109 //do the first time called.
1110 if(int(oobCount.size()) != rf.ext_param_.row_count_)
1111 {
1112 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1113 oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1114 }
1115 }
1116
1117 template<class RF, class PR, class SM, class ST>
1118 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST &, int index)
1119 {
1120 // go through the samples
1121 int total_oob =0;
1122 int wrong_oob =0;
1123 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1124 {
1125 // if the lth sample is oob...
1126 if(!sm.is_used()[ll])
1127 {
1128 // update number of trees in which current sample is oob
1129 ++oobCount[ll];
1130
1131 // update number of oob samples in this tree.
1132 ++total_oob;
1133 // get the predicted votes ---> tmp_prob;
1134 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1135 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1136 rf.tree(index).parameters_,
1137 pos);
1138 tmp_prob.init(0);
1139 for(int ii = 0; ii < class_count; ++ii)
1140 {
1141 tmp_prob[ii] = node.prob_begin()[ii];
1142 }
1143 if(is_weighted)
1144 {
1145 for(int ii = 0; ii < class_count; ++ii)
1146 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1147 }
1148 rowVector(prob_oob, ll) += tmp_prob;
1149 int label = argMax(tmp_prob);
1150
1151 if(label != pr.response()(ll, 0))
1152 {
1153 // update number of wrong oob samples in this tree.
1154 ++wrong_oob;
1155 // update number of trees in which current sample is wrong oob
1156 ++oobErrorCount[ll];
1157 }
1158 }
1159 }
1160 int breimanstyle = 0;
1161 int totalOobCount = 0;
1162 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1163 {
1164 if(oobCount[ll])
1165 {
1166 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1167 ++breimanstyle;
1168 ++totalOobCount;
1169 if(oobroc_per_tree.shape(2) == 1)
1170 {
1171 oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1172 }
1173 }
1174 }
1175 if(oobroc_per_tree.shape(2) == 1)
1176 oobroc_per_tree.bindOuter(index)/=totalOobCount;
1177 if(oobroc_per_tree.shape(2) > 1)
1178 {
1179 MultiArrayView<3, double> current_roc
1180 = oobroc_per_tree.bindOuter(index);
1181 for(int gg = 0; gg < current_roc.shape(2); ++gg)
1182 {
1183 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1184 {
1185 if(oobCount[ll])
1186 {
1187 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1188 1 : 0;
1189 current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1190 }
1191 }
1192 current_roc.bindOuter(gg)/= totalOobCount;
1193 }
1194 }
1195 breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1196 oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1197 // go through the ib samples;
1198 }
1199
1200 /** Normalise variable importance after the number of trees is known.
1201 */
1202 template<class RF, class PR>
1203 void visit_at_end(RF & rf, PR & pr)
1204 {
1205 // ullis original metric and breiman style stuff
1206 oob_per_tree2 = 0;
1207 int totalOobCount =0;
1208 int breimanstyle = 0;
1209 for(int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
1210 {
1211 if(oobCount[ll])
1212 {
1213 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1214 ++breimanstyle;
1215 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1216 ++totalOobCount;
1217 }
1218 }
1219 oob_per_tree2 /= totalOobCount;
1220 oob_breiman = double(breimanstyle)/totalOobCount;
1221 // mean error of each tree
1223 MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1224 rowStatistics(oob_per_tree, mean, stdDev);
1225 }
1226};
1227
1228/** calculate variable importance while learning.
1229 */
1231{
1232 public:
1233
1234 /** This Array has the same entries as the R - random forest variable
1235 * importance.
1236 * Matrix is featureCount by (classCount +2)
1237 * variable_importance_(ii,jj) is the variable importance measure of
1238 * the ii-th variable according to:
1239 * jj = 0 - (classCount-1)
1240 * classwise permutation importance
1241 * jj = rowCount(variable_importance_) -2
1242 * permutation importance
1243 * jj = rowCount(variable_importance_) -1
1244 * gini decrease importance.
1245 *
1246 * permutation importance:
1247 * The difference between the fraction of OOB samples classified correctly
1248 * before and after permuting (randomizing) the ii-th column is calculated.
1249 * The ii-th column is permuted rep_cnt times.
1250 *
1251 * class wise permutation importance:
1252 * same as permutation importance. We only look at those OOB samples whose
1253 * response corresponds to class jj.
1254 *
1255 * gini decrease importance:
1256 * row ii corresponds to the sum of all gini decreases induced by variable ii
1257 * in each node of the random forest.
1258 */
1260 int repetition_count_;
1261 bool in_place_;
1262
1263#ifdef HasHDF5
1264 void save(std::string filename, std::string prefix)
1265 {
1266 prefix = "variable_importance_" + prefix;
1267 writeHDF5(filename.c_str(),
1268 prefix.c_str(),
1270 }
1271#endif
1272
1273 /* Constructor
1274 * \param rep_cnt (defautl: 10) how often should
1275 * the permutation take place. Set to 1 to make calculation faster (but
1276 * possibly more instable)
1277 */
1279 : repetition_count_(rep_cnt)
1280
1281 {}
1282
1283 /** calculates impurity decrease based variable importance after every
1284 * split.
1285 */
1286 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1288 Split & split,
1289 Region & /* parent */,
1290 Region & /* leftChild */,
1291 Region & /* rightChild */,
1292 Feature_t & /* features */,
1293 Label_t & /* labels */)
1294 {
1295 //resize to right size when called the first time
1296
1297 Int32 const class_count = tree.ext_param_.class_count_;
1298 Int32 const column_count = tree.ext_param_.column_count_;
1299 if(variable_importance_.size() == 0)
1300 {
1301
1303 .reshape(MultiArrayShape<2>::type(column_count,
1304 class_count+2));
1305 }
1306
1307 if(split.createNode().typeID() == i_ThresholdNode)
1308 {
1309 Node<i_ThresholdNode> node(split.createNode());
1310 variable_importance_(node.column(),class_count+1)
1311 += split.region_gini_ - split.minGini();
1312 }
1313 }
1314
1315 /**compute permutation based var imp.
1316 * (Only an Array of size oob_sample_count x 1 is created.
1317 * - apposed to oob_sample_count x feature_count in the other method.
1318 *
1319 * \sa FieldProxy
1320 */
1321 template<class RF, class PR, class SM, class ST>
1322 void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & /* st */, int index)
1323 {
1325 Int32 column_count = rf.ext_param_.column_count_;
1326 Int32 class_count = rf.ext_param_.class_count_;
1327
1328 /* This solution saves memory uptake but not multithreading
1329 * compatible
1330 */
1331 // remove the const cast on the features (yep , I know what I am
1332 // doing here.) data is not destroyed.
1333 //typename PR::Feature_t & features
1334 // = const_cast<typename PR::Feature_t &>(pr.features());
1335
1336 typedef typename PR::FeatureWithMemory_t FeatureArray;
1337 typedef typename FeatureArray::value_type FeatureValue;
1338
1339 FeatureArray features = pr.features();
1340
1341 //find the oob indices of current tree.
1343 ArrayVector<Int32>::iterator
1344 iter;
1345 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1346 if(!sm.is_used()[ii])
1347 oob_indices.push_back(ii);
1348
1349 //create space to back up a column
1351
1352 // Random foo
1353#ifdef CLASSIFIER_TEST
1354 RandomMT19937 random(1);
1355#else
1356 RandomMT19937 random(RandomSeed);
1357#endif
1359 randint(random);
1360
1361
1362 //make some space for the results
1364 oob_right(Shp_t(1, class_count + 1));
1366 perm_oob_right (Shp_t(1, class_count + 1));
1367
1368
1369 // get the oob success rate with the original samples
1370 for(iter = oob_indices.begin();
1371 iter != oob_indices.end();
1372 ++iter)
1373 {
1374 if(rf.tree(index)
1375 .predictLabel(rowVector(features, *iter))
1376 == pr.response()(*iter, 0))
1377 {
1378 //per class
1379 ++oob_right[pr.response()(*iter,0)];
1380 //total
1381 ++oob_right[class_count];
1382 }
1383 }
1384 //get the oob rate after permuting the ii'th dimension.
1385 for(int ii = 0; ii < column_count; ++ii)
1386 {
1387 perm_oob_right.init(0.0);
1388 //make backup of original column
1389 backup_column.clear();
1390 for(iter = oob_indices.begin();
1391 iter != oob_indices.end();
1392 ++iter)
1393 {
1394 backup_column.push_back(features(*iter,ii));
1395 }
1396
1397 //get the oob rate after permuting the ii'th dimension.
1398 for(int rr = 0; rr < repetition_count_; ++rr)
1399 {
1400 //permute dimension.
1401 int n = oob_indices.size();
1402 for(int jj = n-1; jj >= 1; --jj)
1403 std::swap(features(oob_indices[jj], ii),
1404 features(oob_indices[randint(jj+1)], ii));
1405
1406 //get the oob success rate after permuting
1407 for(iter = oob_indices.begin();
1408 iter != oob_indices.end();
1409 ++iter)
1410 {
1411 if(rf.tree(index)
1412 .predictLabel(rowVector(features, *iter))
1413 == pr.response()(*iter, 0))
1414 {
1415 //per class
1416 ++perm_oob_right[pr.response()(*iter, 0)];
1417 //total
1418 ++perm_oob_right[class_count];
1419 }
1420 }
1421 }
1422
1423
1424 //normalise and add to the variable_importance array.
1425 perm_oob_right /= repetition_count_;
1427 perm_oob_right *= -1;
1430 .subarray(Shp_t(ii,0),
1431 Shp_t(ii+1,class_count+1)) += perm_oob_right;
1432 //copy back permuted dimension
1433 for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1434 features(oob_indices[jj], ii) = backup_column[jj];
1435 }
1436 }
1437
1438 /** calculate permutation based impurity after every tree has been
1439 * learned default behaviour is that this happens out of place.
1440 * If you have very big data sets and want to avoid copying of data
1441 * set the in_place_ flag to true.
1442 */
1443 template<class RF, class PR, class SM, class ST>
1444 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1445 {
1446 after_tree_ip_impl(rf, pr, sm, st, index);
1447 }
1448
1449 /** Normalise variable importance after the number of trees is known.
1450 */
1451 template<class RF, class PR>
1452 void visit_at_end(RF & rf, PR & /* pr */)
1453 {
1454 variable_importance_ /= rf.trees_.size();
1455 }
1456};
1457
1458/** Verbose output
1459 */
1461 public:
1463
1464 template<class RF, class PR, class SM, class ST>
1465 void visit_after_tree(RF& rf, PR &, SM &, ST &, int index){
1466 if(index != rf.options().tree_count_-1) {
1467 std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1468 << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1469 }
1470 else {
1471 std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1472 }
1473 }
1474
1475 template<class RF, class PR>
1476 void visit_at_end(RF const & rf, PR const &) {
1477 std::string a = TOCS;
1478 std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1479 }
1480
1481 template<class RF, class PR>
1482 void visit_at_beginning(RF const & rf, PR const &) {
1483 TIC;
1484 std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1485 }
1486
1487 private:
1488 USETICTOC;
1489};
1490
1491
1492/** Computes Correlation/Similarity Matrix of features while learning
1493 * random forest.
1494 */
1496{
1497 public:
1498 /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1499 * created on variable ii(when variable ii was chosen)
1500 */
1502 MultiArray<2, int> tmp_labels;
1503 /** additional noise features.
1504 */
1506 MultiArray<2, double> noise_l;
1507 /** how well can a noise column describe a partition created on variable ii.
1508 */
1510 MultiArray<2, double> corr_l;
1511
1512 /** Similarity Matrix
1513 *
1514 * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1515 * gini_missc
1516 * - row normalized by the number of times the column was chosen
1517 * - mean of corr_noise subtracted
1518 * - and symmetrised.
1519 *
1520 */
1522 /** Distance Matrix 1-similarity
1523 */
1525 ArrayVector<int> tmp_cc;
1526
1527 /** How often was variable ii chosen
1528 */
1532 void save(std::string, std::string)
1533 {
1534 /*
1535 std::string tmp;
1536#define VAR_WRITE(NAME) \
1537 tmp = #NAME;\
1538 tmp += "_";\
1539 tmp += prefix;\
1540 vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1541 VAR_WRITE(gini_missc);
1542 VAR_WRITE(corr_noise);
1543 VAR_WRITE(distance);
1544 VAR_WRITE(similarity);
1545 vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1546#undef VAR_WRITE
1547*/
1548 }
1549
1550 template<class RF, class PR>
1551 void visit_at_beginning(RF const & rf, PR & pr)
1552 {
1553 typedef MultiArrayShape<2>::type Shp;
1554 int n = rf.ext_param_.column_count_;
1555 gini_missc.reshape(Shp(n +1,n+ 1));
1556 corr_noise.reshape(Shp(n + 1, 10));
1557 corr_l.reshape(Shp(n +1, 10));
1558
1559 noise.reshape(Shp(pr.features().shape(0), 10));
1560 noise_l.reshape(Shp(pr.features().shape(0), 10));
1561 RandomMT19937 random(RandomSeed);
1562 for(int ii = 0; ii < noise.size(); ++ii)
1563 {
1564 noise[ii] = random.uniform53();
1565 noise_l[ii] = random.uniform53() > 0.5;
1566 }
1567 bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1568 tmp_labels.reshape(pr.response().shape());
1569 tmp_cc.resize(2);
1570 numChoices.resize(n+1);
1571 // look at all axes
1572 }
1573 template<class RF, class PR>
1574 void visit_at_end(RF const &, PR const &)
1575 {
1576 typedef MultiArrayShape<2>::type Shp;
1580 rowStatistics(corr_noise, mean_noise);
1582 int rC = similarity.shape(0);
1583 for(int jj = 0; jj < rC-1; ++jj)
1584 {
1585 rowVector(similarity, jj) /= numChoices[jj];
1586 rowVector(similarity, jj) -= mean_noise(jj, 0);
1587 }
1588 for(int jj = 0; jj < rC; ++jj)
1589 {
1590 similarity(rC -1, jj) /= numChoices[jj];
1591 }
1592 rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1593 similarity = abs(similarity);
1594 FindMinMax<double> minmax;
1595 inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1596
1597 for(int jj = 0; jj < rC; ++jj)
1598 similarity(jj, jj) = minmax.max;
1599
1600 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1601 += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1602 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1603 columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1604 for(int jj = 0; jj < rC; ++jj)
1605 similarity(jj, jj) = 0;
1606
1607 FindMinMax<double> minmax2;
1608 inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1609 for(int jj = 0; jj < rC; ++jj)
1610 similarity(jj, jj) = minmax2.max;
1611 distance.reshape(gini_missc.shape(), minmax2.max);
1613 }
1614
1615 template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1616 void visit_after_split( Tree &,
1617 Split & split,
1618 Region & parent,
1619 Region &,
1620 Region &,
1621 Feature_t & features,
1622 Label_t & labels)
1623 {
1624 if(split.createNode().typeID() == i_ThresholdNode)
1625 {
1626 double wgini;
1627 tmp_cc.init(0);
1628 for(int ii = 0; ii < parent.size(); ++ii)
1629 {
1630 tmp_labels[parent[ii]]
1631 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1632 ++tmp_cc[tmp_labels[parent[ii]]];
1633 }
1634 double region_gini = bgfunc.loss_of_region(tmp_labels,
1635 parent.begin(),
1636 parent.end(),
1637 tmp_cc);
1638
1639 int n = split.bestSplitColumn();
1640 ++numChoices[n];
1641 ++(*(numChoices.end()-1));
1642 //this functor does all the work
1643 for(int k = 0; k < features.shape(1); ++k)
1644 {
1645 bgfunc(columnVector(features, k),
1646 tmp_labels,
1647 parent.begin(), parent.end(),
1648 tmp_cc);
1649 wgini = (region_gini - bgfunc.min_gini_);
1650 gini_missc(n, k)
1651 += wgini;
1652 }
1653 for(int k = 0; k < 10; ++k)
1654 {
1655 bgfunc(columnVector(noise, k),
1656 tmp_labels,
1657 parent.begin(), parent.end(),
1658 tmp_cc);
1659 wgini = (region_gini - bgfunc.min_gini_);
1660 corr_noise(n, k)
1661 += wgini;
1662 }
1663
1664 for(int k = 0; k < 10; ++k)
1665 {
1666 bgfunc(columnVector(noise_l, k),
1667 tmp_labels,
1668 parent.begin(), parent.end(),
1669 tmp_cc);
1670 wgini = (region_gini - bgfunc.min_gini_);
1671 corr_l(n, k)
1672 += wgini;
1673 }
1674 bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1675 wgini = (region_gini - bgfunc.min_gini_);
1677 += wgini;
1678
1679 region_gini = split.region_gini_;
1680#if 1
1681 Node<i_ThresholdNode> node(split.createNode());
1683 node.column())
1684 +=split.region_gini_ - split.minGini();
1685#endif
1686 for(int k = 0; k < 10; ++k)
1687 {
1688 split.bgfunc(columnVector(noise, k),
1689 labels,
1690 parent.begin(), parent.end(),
1691 parent.classCounts());
1693 k)
1694 += wgini;
1695 }
1696#if 0
1697 for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1698 {
1699 wgini = region_gini - split.min_gini_[k];
1700
1702 split.splitColumns[k])
1703 += wgini;
1704 }
1705
1706 for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1707 {
1708 split.bgfunc(columnVector(features, split.splitColumns[k]),
1709 labels,
1710 parent.begin(), parent.end(),
1711 parent.classCounts());
1712 wgini = region_gini - split.bgfunc.min_gini_;
1714 split.splitColumns[k]) += wgini;
1715 }
1716#endif
1717 // remember to partition the data according to the best.
1720 += region_gini;
1721 SortSamplesByDimensions<Feature_t>
1722 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1723 std::partition(parent.begin(), parent.end(), sorter);
1724 }
1725 }
1726};
1727
1728
1729} // namespace visitors
1730} // namespace rf
1731} // namespace vigra
1732
1733#endif // RF_VISITORS_HXX
const_pointer data() const
Definition array_vector.hxx:209
const_iterator end() const
Definition array_vector.hxx:237
MultiArrayView subarray(difference_type p, difference_type q) const
Definition multi_array.hxx:1528
const difference_type & shape() const
Definition multi_array.hxx:1648
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition multi_array.hxx:2184
difference_type_1 size() const
Definition multi_array.hxx:1641
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition multi_array.hxx:1567
void reshape(const difference_type &shape)
Definition multi_array.hxx:2861
Class for a single RGB value.
Definition rgbvalue.hxx:128
void init(Iterator i, Iterator end)
Definition tinyvector.hxx:708
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Class for fixed size vectors.
Definition tinyvector.hxx:1008
Definition rf_visitors.hxx:1016
double oob_per_tree2
Definition rf_visitors.hxx:1045
MultiArray< 2, double > breiman_per_tree
Definition rf_visitors.hxx:1050
double oob_mean
Definition rf_visitors.hxx:1028
double oob_breiman
Definition rf_visitors.hxx:1038
MultiArray< 2, double > oob_per_tree
Definition rf_visitors.hxx:1025
void visit_at_end(RF &rf, PR &pr)
Definition rf_visitors.hxx:1203
MultiArray< 4, double > oobroc_per_tree
Definition rf_visitors.hxx:1067
double oob_std
Definition rf_visitors.hxx:1031
Definition rf_visitors.hxx:1496
MultiArray< 2, double > distance
Definition rf_visitors.hxx:1524
MultiArray< 2, double > corr_noise
Definition rf_visitors.hxx:1509
MultiArray< 2, double > gini_missc
Definition rf_visitors.hxx:1501
MultiArray< 2, double > similarity
Definition rf_visitors.hxx:1521
ArrayVector< int > numChoices
Definition rf_visitors.hxx:1529
MultiArray< 2, double > noise
Definition rf_visitors.hxx:1505
Definition rf_visitors.hxx:864
double oob_breiman
Definition rf_visitors.hxx:874
void visit_at_end(RF &rf, PR &pr)
Definition rf_visitors.hxx:994
Definition rf_visitors.hxx:783
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_visitors.hxx:806
double oobError
Definition rf_visitors.hxx:787
void visit_at_end(RF &rf, PR &)
Definition rf_visitors.hxx:835
Definition rf_visitors.hxx:584
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition rf_visitors.hxx:723
void reset_tree(int tree_id)
Definition rf_visitors.hxx:635
void visit_after_tree(RF &, PR &, SM &, ST &, int)
Definition rf_visitors.hxx:646
void visit_at_beginning(RF &rf, const PR &)
Definition rf_visitors.hxx:627
Definition rf_visitors.hxx:235
Definition rf_visitors.hxx:1231
void visit_after_split(Tree &tree, Split &split, Region &, Region &, Region &, Feature_t &, Label_t &)
Definition rf_visitors.hxx:1287
void visit_at_end(RF &rf, PR &)
Definition rf_visitors.hxx:1452
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_visitors.hxx:1444
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_visitors.hxx:1322
MultiArray< 2, double > variable_importance_
Definition rf_visitors.hxx:1259
Definition rf_visitors.hxx:102
void visit_at_beginning(RF const &rf, PR const &pr)
Definition rf_visitors.hxx:187
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition rf_visitors.hxx:205
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition rf_visitors.hxx:142
void visit_internal_node(TR &, IntT, TopT, Feat &)
Definition rf_visitors.hxx:215
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_visitors.hxx:163
void visit_at_end(RF const &rf, PR const &pr)
Definition rf_visitors.hxx:175
double return_val()
Definition rf_visitors.hxx:225
Definition rf_visitors.hxx:255
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:684
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:671
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:727
detail::VisitorNode< A > create_visitor(A &a)
Definition rf_visitors.hxx:344
void writeHDF5(...)
Store array data in an HDF5 file.
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
#define TIC
Definition timing.hxx:322
#define TOCS
Definition timing.hxx:325

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1