[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_algorithm.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by Rahul Nair                             */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #define VIGRA_RF_ALGORTIHM_HXX
00036 
00037 #include <vector>
00038 #include "splices.hxx"
00039 #include <queue>
00040 namespace vigra
00041 {
00042  
00043 namespace rf
00044 {
00045 /** This namespace contains all algorithms developed for feature 
00046  * selection
00047  *
00048  */
00049 namespace algorithms
00050 {
00051 
00052 namespace detail
00053 {
00054     /** create a MultiArray containing only columns supplied between iterators
00055         b and e
00056     */
00057     template<class OrigMultiArray,
00058              class Iter,
00059              class DestMultiArray>
00060     void choose(OrigMultiArray     const & in,
00061                 Iter               const & b,
00062                 Iter               const & e,
00063                 DestMultiArray        & out)
00064     {
00065         int columnCount = std::distance(b, e);
00066         int rowCount     = in.shape(0);
00067         out.reshape(MultiArrayShape<2>::type(rowCount, columnCount));
00068         int ii = 0;
00069         for(Iter iter = b; iter != e; ++iter, ++ii)
00070         {
00071             columnVector(out, ii) = columnVector(in, *iter);
00072         }
00073     }
00074 }
00075 
00076 
00077 
00078 /** Standard random forest Errorrate callback functor
00079  *
00080  * returns the random forest error estimate when invoked. 
00081  */
00082 class RFErrorCallback
00083 {
00084     RandomForestOptions options;
00085     
00086     public:
00087     /** Default constructor
00088      *
00089      * optionally supply options to the random forest classifier
00090      * \sa RandomForestOptions
00091      */
00092     RFErrorCallback(RandomForestOptions opt = RandomForestOptions())
00093     : options(opt)
00094     {}
00095 
00096     /** returns the RF OOB error estimate given features and 
00097      * labels
00098      */
00099     template<class Feature_t, class Response_t>
00100     double operator() (Feature_t const & features,
00101                        Response_t const & response)
00102     {
00103         RandomForest<>             rf(options);
00104         visitors::OOB_Error        oob;
00105         rf.learn(features, 
00106                  response, 
00107                  visitors::create_visitor(oob ));
00108         return oob.oob_breiman;
00109     }
00110 };
00111 
00112 
00113 /** Structure to hold Variable Selection results
00114  */
00115 class VariableSelectionResult
00116 {
00117     bool initialized;
00118 
00119   public:
00120     VariableSelectionResult()
00121     : initialized(false)
00122     {}
00123 
00124     typedef std::vector<int> FeatureList_t;
00125     typedef std::vector<double> ErrorList_t;
00126     typedef FeatureList_t::iterator Pivot_t;
00127 
00128     Pivot_t pivot;
00129 
00130     /** list of features. 
00131      */
00132     FeatureList_t selected;
00133     
00134     /** vector of size (number of features)
00135      *
00136      * the i-th entry encodes the error rate obtained
00137      * while using features [0 - i](including i) 
00138      *
00139      * if the i-th entry is -1 then no error rate was obtained
00140      * this may happen if more than one feature is added to the
00141      * selected list in one step of the algorithm.
00142      *
00143      * during initialisation error[m+n-1] is always filled
00144      */
00145     ErrorList_t errors;
00146     
00147 
00148     /** errorrate using no features
00149      */
00150     double no_features;
00151 
00152     template<class FeatureT, 
00153              class ResponseT, 
00154              class Iter,
00155              class ErrorRateCallBack>
00156     bool init(FeatureT const & all_features,
00157               ResponseT const & response,
00158               Iter b,
00159               Iter e,
00160               ErrorRateCallBack errorcallback)
00161     {
00162         bool ret_ = init(all_features, response, errorcallback); 
00163         if(!ret_)
00164             return false;
00165         vigra_precondition(std::distance(b, e) == selected.size(),
00166                            "Number of features in ranking != number of features matrix");
00167         std::copy(b, e, selected.begin());
00168         return true;
00169     }
00170     
00171     template<class FeatureT, 
00172              class ResponseT, 
00173              class Iter>
00174     bool init(FeatureT const & all_features,
00175               ResponseT const & response,
00176               Iter b,
00177               Iter e)
00178     {
00179         RFErrorCallback ecallback;
00180         return init(all_features, response, b, e, ecallback);
00181     }
00182 
00183 
00184     template<class FeatureT, 
00185              class ResponseT>
00186     bool init(FeatureT const & all_features,
00187               ResponseT const & response)
00188     {
00189         return init(all_features, response, RFErrorCallback());
00190     }
00191     /**initialization routine. Will be called only once in the lifetime
00192      * of a VariableSelectionResult. Subsequent calls will not reinitialize
00193      * member variables.
00194      *
00195      * This is intended, to allow continuing variable selection at a point 
00196      * stopped in an earlier iteration. 
00197      *
00198      * returns true if initialization was successful and false if 
00199      * the object was already initialized before.
00200      */
00201     template<class FeatureT, 
00202              class ResponseT,
00203              class ErrorRateCallBack>
00204     bool init(FeatureT const & all_features,
00205               ResponseT const & response,
00206               ErrorRateCallBack errorcallback)
00207     {
00208         if(initialized)
00209         {
00210             return false;
00211         }
00212         // calculate error with all features
00213         selected.resize(all_features.shape(1), 0);
00214         for(unsigned int ii = 0; ii < selected.size(); ++ii)
00215             selected[ii] = ii;
00216         errors.resize(all_features.shape(1), -1);
00217         errors.back() = errorcallback(all_features, response);
00218 
00219         // calculate error rate if no features are chosen 
00220         // corresponds to max(prior probability) of the classes
00221         std::map<typename ResponseT::value_type, int>     res_map;
00222         std::vector<int>                                 cts;
00223         int                                             counter = 0;
00224         for(int ii = 0; ii < response.shape(0); ++ii)
00225         {
00226             if(res_map.find(response(ii, 0)) == res_map.end())
00227             {
00228                 res_map[response(ii, 0)] = counter;
00229                 ++counter;
00230                 cts.push_back(0);
00231             }
00232             cts[res_map[response(ii,0)]] +=1;
00233         }
00234         no_features = double(*(std::max_element(cts.begin(),
00235                                                  cts.end())))
00236                     / double(response.shape(0));
00237 
00238         /*init not_selected vector;
00239         not_selected.resize(all_features.shape(1), 0);
00240         for(int ii = 0; ii < not_selected.size(); ++ii)
00241         {
00242             not_selected[ii] = ii;
00243         }
00244         initialized = true;
00245         */
00246         pivot = selected.begin();
00247         return true;
00248     }
00249 };
00250 
00251 
00252     
00253 /** Perform forward selection
00254  *
00255  * \param features    IN:     n x p matrix containing n instances with p attributes/features
00256  *                             used in the variable selection algorithm
00257  * \param response  IN:     n x 1 matrix containing the corresponding response
00258  * \param result    IN/OUT: VariableSelectionResult struct which will contain the results
00259  *                             of the algorithm. 
00260  *                             Features between result.selected.begin() and result.pivot will
00261  *                             be left untouched.
00262  *                             \sa VariableSelectionResult
00263  * \param errorcallback
00264  *                     IN, OPTIONAL: 
00265  *                             Functor that returns the error rate given a set of 
00266  *                             features and labels. Default is the RandomForest OOB Error.
00267  *
00268  * Forward selection subsequently chooses the next feature that decreases the Error rate most.
00269  *
00270  * usage:
00271  * \code
00272  *         MultiArray<2, double>     features = createSomeFeatures();
00273  *         MultiArray<2, int>        labels   = createCorrespondingLabels();
00274  *         VariableSelectionResult  result;
00275  *         forward_selection(features, labels, result);
00276  * \endcode
00277  * To use forward selection but ensure that a specific feature e.g. feature 5 is always 
00278  * included one would do the following
00279  *
00280  * \code
00281  *         VariableSelectionResult result;
00282  *         result.init(features, labels);
00283  *         std::swap(result.selected[0], result.selected[5]);
00284  *         result.setPivot(1);
00285  *         forward_selection(features, labels, result);
00286  * \endcode
00287  *
00288  * \sa VariableSelectionResult
00289  *
00290  */                    
00291 template<class FeatureT, class ResponseT, class ErrorRateCallBack>
00292 void forward_selection(FeatureT          const & features,
00293                        ResponseT          const & response,
00294                        VariableSelectionResult & result,
00295                        ErrorRateCallBack          errorcallback)
00296 {
00297     VariableSelectionResult::FeatureList_t & selected         = result.selected;
00298     VariableSelectionResult::ErrorList_t &     errors            = result.errors;
00299     VariableSelectionResult::Pivot_t       & pivot            = result.pivot;    
00300     int featureCount = features.shape(1);
00301     // initialize result struct if in use for the first time
00302     if(!result.init(features, response, errorcallback))
00303     {
00304         //result is being reused just ensure that the number of features is
00305         //the same.
00306         vigra_precondition(selected.size() == featureCount,
00307                            "forward_selection(): Number of features in Feature "
00308                            "matrix and number of features in previously used "
00309                            "result struct mismatch!");
00310     }
00311     
00312 
00313     int not_selected_size = std::distance(pivot, selected.end());
00314     while(not_selected_size > 1)
00315     {
00316         std::vector<int> current_errors;
00317         VariableSelectionResult::Pivot_t next = pivot;
00318         for(int ii = 0; ii < not_selected_size; ++ii, ++next)
00319         {
00320             std::swap(*pivot, *next);
00321             MultiArray<2, double> cur_feats;
00322             detail::choose( features, 
00323                             selected.begin(), 
00324                             pivot+1, 
00325                             cur_feats);
00326             double error = errorcallback(cur_feats, response);
00327             current_errors.push_back(error);
00328             std::swap(*pivot, *next);
00329         }
00330         int pos = std::distance(current_errors.begin(),
00331                                 std::min_element(current_errors.begin(),
00332                                                    current_errors.end()));
00333         next = pivot;
00334         std::advance(next, pos);
00335         std::swap(*pivot, *next);
00336         errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
00337         ++pivot;
00338         not_selected_size = std::distance(pivot, selected.end());
00339     }
00340 }
00341 template<class FeatureT, class ResponseT>
00342 void forward_selection(FeatureT          const & features,
00343                        ResponseT          const & response,
00344                        VariableSelectionResult & result)
00345 {
00346     forward_selection(features, response, result, RFErrorCallback());
00347 }
00348 
00349 
00350 /** Perform backward elimination
00351  *
00352  * \param features    IN:     n x p matrix containing n instances with p attributes/features
00353  *                             used in the variable selection algorithm
00354  * \param response  IN:     n x 1 matrix containing the corresponding response
00355  * \param result    IN/OUT: VariableSelectionResult struct which will contain the results
00356  *                             of the algorithm. 
00357  *                             Features between result.pivot and result.selected.end() will
00358  *                             be left untouched.
00359  *                             \sa VariableSelectionResult
00360  * \param errorcallback
00361  *                     IN, OPTIONAL: 
00362  *                             Functor that returns the error rate given a set of 
00363  *                             features and labels. Default is the RandomForest OOB Error.
00364  *
00365  * Backward elimination subsequently eliminates features that have the least influence
00366  * on the error rate
00367  *
00368  * usage:
00369  * \code
00370  *         MultiArray<2, double>     features = createSomeFeatures();
00371  *         MultiArray<2, int>        labels   = createCorrespondingLabels();
00372  *         VariableSelectionResult  result;
00373  *         backward_elimination(features, labels, result);
00374  * \endcode
00375  * To use backward elimination but ensure that a specific feature e.g. feature 5 is always 
00376  * excluded one would do the following:
00377  *
00378  * \code
00379  *         VariableSelectionResult result;
00380  *         result.init(features, labels);
00381  *         std::swap(result.selected[result.selected.size()-1], result.selected[5]);
00382  *         result.setPivot(result.selected.size()-1);
00383  *         backward_elimination(features, labels, result);
00384  * \endcode
00385  *
00386  * \sa VariableSelectionResult
00387  *
00388  */                    
00389 template<class FeatureT, class ResponseT, class ErrorRateCallBack>
00390 void backward_elimination(FeatureT              const & features,
00391                              ResponseT         const & response,
00392                           VariableSelectionResult & result,
00393                           ErrorRateCallBack         errorcallback)
00394 {
00395     int featureCount = features.shape(1);
00396     VariableSelectionResult::FeatureList_t & selected         = result.selected;
00397     VariableSelectionResult::ErrorList_t &     errors            = result.errors;
00398     VariableSelectionResult::Pivot_t       & pivot            = result.pivot;    
00399     
00400     // initialize result struct if in use for the first time
00401     if(!result.init(features, response, errorcallback))
00402     {
00403         //result is being reused just ensure that the number of features is
00404         //the same.
00405         vigra_precondition(selected.size() == featureCount,
00406                            "backward_elimination(): Number of features in Feature "
00407                            "matrix and number of features in previously used "
00408                            "result struct mismatch!");
00409     }
00410     pivot = selected.end() - 1;    
00411 
00412     int selected_size = std::distance(selected.begin(), pivot);
00413     while(selected_size > 1)
00414     {
00415         VariableSelectionResult::Pivot_t next = selected.begin();
00416         std::vector<int> current_errors;
00417         for(int ii = 0; ii < selected_size; ++ii, ++next)
00418         {
00419             std::swap(*pivot, *next);
00420             MultiArray<2, double> cur_feats;
00421             detail::choose( features, 
00422                             selected.begin(), 
00423                             pivot, 
00424                             cur_feats);
00425             double error = errorcallback(cur_feats, response);
00426             current_errors.push_back(error);
00427             std::swap(*pivot, *next);
00428         }
00429         int pos = std::distance(current_errors.begin(),
00430                                 std::max_element(current_errors.begin(),
00431                                                    current_errors.end()));
00432         next = selected.begin();
00433         std::advance(next, pos);
00434         std::swap(*pivot, *next);
00435 //        std::cerr << std::distance(selected.begin(), pivot) << " " << pos << " " << current_errors.size() << " " << errors.size() << std::endl;
00436         errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
00437         selected_size = std::distance(selected.begin(), pivot);
00438         --pivot;
00439     }
00440 }
00441 
00442 template<class FeatureT, class ResponseT>
00443 void backward_elimination(FeatureT              const & features,
00444                              ResponseT         const & response,
00445                           VariableSelectionResult & result)
00446 {
00447     backward_elimination(features, response, result, RFErrorCallback());
00448 }
00449 
00450 /** Perform rank selection using a predefined ranking
00451  *
00452  * \param features    IN:     n x p matrix containing n instances with p attributes/features
00453  *                             used in the variable selection algorithm
00454  * \param response  IN:     n x 1 matrix containing the corresponding response
00455  * \param result    IN/OUT: VariableSelectionResult struct which will contain the results
00456  *                             of the algorithm. The struct should be initialized with the
00457  *                             predefined ranking.
00458  *                         
00459  *                             \sa VariableSelectionResult
00460  * \param errorcallback
00461  *                     IN, OPTIONAL: 
00462  *                             Functor that returns the error rate given a set of 
00463  *                             features and labels. Default is the RandomForest OOB Error.
00464  *
00465  * Often some variable importance, score measure is used to create the ordering in which
00466  * variables have to be selected. This method takes such a ranking and calculates the 
00467  * corresponding error rates. 
00468  *
00469  * usage:
00470  * \code
00471  *         MultiArray<2, double>     features = createSomeFeatures();
00472  *         MultiArray<2, int>        labels   = createCorrespondingLabels();
00473  *         std::vector<int>        ranking  = createRanking(features);
00474  *         VariableSelectionResult  result;
00475  *         result.init(features, labels, ranking.begin(), ranking.end());
00476  *         backward_elimination(features, labels, result);
00477  * \endcode
00478  *
00479  * \sa VariableSelectionResult
00480  *
00481  */                    
00482 template<class FeatureT, class ResponseT, class ErrorRateCallBack>
00483 void rank_selection      (FeatureT              const & features,
00484                              ResponseT         const & response,
00485                           VariableSelectionResult & result,
00486                           ErrorRateCallBack         errorcallback)
00487 {
00488     VariableSelectionResult::FeatureList_t & selected         = result.selected;
00489     VariableSelectionResult::ErrorList_t &     errors            = result.errors;
00490     VariableSelectionResult::Pivot_t       & iter            = result.pivot;
00491     int featureCount = features.shape(1);
00492     // initialize result struct if in use for the first time
00493     if(!result.init(features, response, errorcallback))
00494     {
00495         //result is being reused just ensure that the number of features is
00496         //the same.
00497         vigra_precondition(selected.size() == featureCount,
00498                            "forward_selection(): Number of features in Feature "
00499                            "matrix and number of features in previously used "
00500                            "result struct mismatch!");
00501     }
00502     
00503     int ii = 0;
00504     for(; iter != selected.end(); ++iter)
00505     {
00506 //        std::cerr << ii<< std::endl;
00507         ++ii;
00508         MultiArray<2, double> cur_feats;
00509         detail::choose( features, 
00510                         selected.begin(), 
00511                         iter, 
00512                         cur_feats);
00513         double error = errorcallback(cur_feats, response);
00514         errors[std::distance(selected.begin(), iter)] = error;
00515 
00516     }
00517 }
00518 
00519 template<class FeatureT, class ResponseT>
00520 void rank_selection      (FeatureT              const & features,
00521                              ResponseT         const & response,
00522                           VariableSelectionResult & result)
00523 {
00524     rank_selection(features, response, result, RFErrorCallback());
00525 }
00526 
00527 
00528 
00529 enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
00530 
00531 /* View of a Node in the hierarchical clustering 
00532  * class 
00533  * For internal use only - 
00534  * \sa NodeBase
00535  */
00536 class ClusterNode
00537 : public NodeBase
00538 {
00539     public:
00540 
00541     typedef NodeBase BT;
00542 
00543         /**constructors **/
00544     ClusterNode():NodeBase(){}
00545     ClusterNode(    int                      nCol,
00546                     BT::T_Container_type    &   topology,
00547                     BT::P_Container_type    &   split_param)
00548                 :   BT(nCol + 5, 5,topology, split_param)
00549     {
00550         status() = 0; 
00551         BT::column_data()[0] = nCol;
00552         if(nCol == 1)
00553             BT::typeID() = c_Leaf;
00554         else
00555             BT::typeID() = c_Node;
00556     }
00557 
00558     ClusterNode(           BT::T_Container_type  const  &   topology,
00559                     BT::P_Container_type  const  &   split_param,
00560                     int                  n             )
00561                 :   NodeBase(5 , 5,topology, split_param, n)
00562     {
00563         //TODO : is there a more elegant way to do this?
00564         BT::topology_size_ += BT::column_data()[0];
00565     }
00566 
00567     ClusterNode( BT & node_)
00568         :   BT(5, 5, node_) 
00569     {
00570         //TODO : is there a more elegant way to do this?
00571         BT::topology_size_ += BT::column_data()[0];
00572         BT::parameter_size_ += 0;
00573     }
00574     int index()
00575     {
00576         return static_cast<int>(BT::parameters_begin()[1]);
00577     }
00578     void set_index(int in)
00579     {
00580         BT::parameters_begin()[1] = in;
00581     }
00582     double& mean()
00583     {
00584         return BT::parameters_begin()[2];
00585     }
00586     double& stdev()
00587     {
00588         return BT::parameters_begin()[3];
00589     }
00590     double& status()
00591     {
00592         return BT::parameters_begin()[4];
00593     }
00594 };
00595 
00596 /** Stackentry class for HClustering class
00597  */
00598 struct HC_Entry
00599 {
00600     int parent;
00601     int level;
00602     int addr; 
00603     bool infm;
00604     HC_Entry(int p, int l, int a, bool in)
00605         : parent(p), level(l), addr(a), infm(in)
00606     {}
00607 };
00608 
00609 
00610 /** Hierarchical Clustering class. 
00611  * Performs single linkage clustering
00612  * \code
00613  *         Matrix<double> distance = get_distance_matrix();
00614  *      linkage.cluster(distance);
00615  *      // Draw clustering tree.
00616  *      Draw<double, int> draw(features, labels, "linkagetree.graph");
00617  *      linkage.breadth_first_traversal(draw);
00618  * \endcode
00619  * \sa ClusterImportanceVisitor
00620  *
00621  * once the clustering has taken place. Information queries can be made
00622  * using the breadth_first_traversal() method and iterate() method
00623  *
00624  */
00625 class HClustering
00626 {
00627 public:
00628     typedef MultiArrayShape<2>::type Shp;
00629     ArrayVector<int>         topology_;
00630     ArrayVector<double>     parameters_;
00631     int                     begin_addr;
00632 
00633     // Calculates the distance between two 
00634     double dist_func(double a, double b)
00635     {
00636         return std::min(a, b); 
00637     }
00638 
00639     /** Visit each node with a Functor 
00640      * in creation order (should be depth first)
00641      */
00642     template<class Functor>
00643     void iterate(Functor & tester)
00644     {
00645 
00646         std::vector<int> stack; 
00647         stack.push_back(begin_addr); 
00648         while(!stack.empty())
00649         {
00650             ClusterNode node(topology_, parameters_, stack.back());
00651             stack.pop_back();
00652             if(!tester(node))
00653             {
00654                 if(node.columns_size() != 1)
00655                 {
00656                     stack.push_back(node.child(0));
00657                     stack.push_back(node.child(1));
00658                 }
00659             }
00660         }
00661     }
00662 
00663     /** Perform breadth first traversal of hierarchical cluster tree
00664      */
00665     template<class Functor>
00666     void breadth_first_traversal(Functor & tester)
00667     {
00668 
00669         std::queue<HC_Entry> queue; 
00670         int level = 0;
00671         int parent = -1;
00672         int addr   = -1;
00673         bool infm  = false;
00674         queue.push(HC_Entry(parent,level,begin_addr, infm)); 
00675         while(!queue.empty())
00676         {
00677             level  = queue.front().level;
00678             parent = queue.front().parent;
00679             addr   = queue.front().addr;
00680             infm   = queue.front().infm;
00681             ClusterNode node(topology_, parameters_, queue.front().addr);
00682             ClusterNode parnt;
00683             if(parent != -1)
00684             {
00685                 parnt = ClusterNode(topology_, parameters_, parent); 
00686             }
00687             queue.pop();
00688             bool istrue = tester(node, level, parnt, infm);
00689             if(node.columns_size() != 1)
00690             {
00691                 queue.push(HC_Entry(addr, level +1,node.child(0),istrue));
00692                 queue.push(HC_Entry(addr, level +1,node.child(1),istrue));
00693             }
00694         }
00695     }
00696     /**save to HDF5 - defunct - has to be updated to new HDF5 interface
00697      */
00698     void save(std::string file, std::string prefix)
00699     {
00700         
00701         vigra::writeHDF5(file.c_str(), (prefix + "topology").c_str(), 
00702                                MultiArrayView<2, int>(
00703                                     Shp(topology_.size(),1),
00704                                     topology_.data()));
00705         vigra::writeHDF5(file.c_str(), (prefix + "parameters").c_str(), 
00706                                MultiArrayView<2, double>(
00707                                     Shp(parameters_.size(), 1),
00708                                     parameters_.data()));
00709         vigra::writeHDF5(file.c_str(), (prefix + "begin_addr").c_str(), 
00710                                MultiArrayView<2, int>(Shp(1,1), &begin_addr));
00711                                
00712     }
00713 
00714     /**Perform single linkage clustering
00715      * \param distance distance matrix used. \sa CorrelationVisitor
00716      */
00717     template<class T, class C>
00718     void cluster(MultiArrayView<2, T, C> distance)
00719     {
00720         MultiArray<2, T> dist(distance); 
00721         std::vector<std::pair<int, int> > addr; 
00722         typedef std::pair<int, int>  Entry;
00723         int index = 0;
00724         for(int ii = 0; ii < distance.shape(0); ++ii)
00725         {
00726             addr.push_back(std::make_pair(topology_.size(), ii));
00727             ClusterNode leaf(1, topology_, parameters_);
00728             leaf.set_index(index);
00729             ++index;
00730             leaf.columns_begin()[0] = ii;
00731         }
00732 
00733         while(addr.size() != 1)
00734         {
00735             //find the two nodes with the smallest distance
00736             int ii_min = 0;
00737             int jj_min = 1;
00738             double min_dist = dist((addr.begin()+ii_min)->second, 
00739                               (addr.begin()+jj_min)->second);
00740             for(unsigned int ii = 0; ii < addr.size(); ++ii)
00741             {
00742                 for(unsigned int jj = ii+1; jj < addr.size(); ++jj)
00743                 {
00744                     if(  dist((addr.begin()+ii_min)->second, 
00745                               (addr.begin()+jj_min)->second)
00746                        > dist((addr.begin()+ii)->second, 
00747                               (addr.begin()+jj)->second))
00748                     {
00749                         min_dist = dist((addr.begin()+ii)->second, 
00750                               (addr.begin()+jj)->second);
00751                         ii_min = ii; 
00752                         jj_min = jj;
00753                     }
00754                 }
00755             }
00756 
00757             //merge two nodes
00758             int col_size = 0;
00759             // The problem is that creating a new node invalidates the iterators stored
00760             // in firstChild and secondChild.
00761             {
00762                 ClusterNode firstChild(topology_, 
00763                                        parameters_, 
00764                                        (addr.begin() +ii_min)->first);
00765                 ClusterNode secondChild(topology_, 
00766                                        parameters_, 
00767                                        (addr.begin() +jj_min)->first);
00768                 col_size = firstChild.columns_size() + secondChild.columns_size();
00769             }
00770             int cur_addr = topology_.size();
00771             begin_addr = cur_addr;
00772 //            std::cerr << col_size << std::endl;
00773             ClusterNode parent(col_size,
00774                                topology_,
00775                                parameters_); 
00776             ClusterNode firstChild(topology_, 
00777                                    parameters_, 
00778                                    (addr.begin() +ii_min)->first);
00779             ClusterNode secondChild(topology_, 
00780                                    parameters_, 
00781                                    (addr.begin() +jj_min)->first);
00782             parent.parameters_begin()[0] = min_dist;
00783             parent.set_index(index);
00784             ++index;
00785             std::merge(firstChild.columns_begin(), firstChild.columns_end(),
00786                        secondChild.columns_begin(),secondChild.columns_end(),
00787                        parent.columns_begin());
00788             //merge nodes in addr
00789             int to_keep;
00790             int to_desc;
00791             int ii_keep;
00792             if(*parent.columns_begin() ==  *firstChild.columns_begin())
00793             {
00794                 parent.child(0) = (addr.begin()+ii_min)->first;
00795                 parent.child(1) = (addr.begin()+jj_min)->first;
00796                 (addr.begin()+ii_min)->first = cur_addr;
00797                 ii_keep = ii_min;
00798                 to_keep = (addr.begin()+ii_min)->second;
00799                 to_desc = (addr.begin()+jj_min)->second;
00800                 addr.erase(addr.begin()+jj_min);
00801             }
00802             else
00803             {
00804                 parent.child(1) = (addr.begin()+ii_min)->first;
00805                 parent.child(0) = (addr.begin()+jj_min)->first;
00806                 (addr.begin()+jj_min)->first = cur_addr;
00807                 ii_keep = jj_min;
00808                 to_keep = (addr.begin()+jj_min)->second;
00809                 to_desc = (addr.begin()+ii_min)->second;
00810                 addr.erase(addr.begin()+ii_min);
00811             }
00812             //update distances;
00813             
00814             for(unsigned int jj = 0 ; jj < addr.size(); ++jj)
00815             {
00816                 if(jj == ii_keep)
00817                     continue;
00818                 double bla = dist_func(
00819                                   dist(to_desc, (addr.begin()+jj)->second),
00820                                   dist((addr.begin()+ii_keep)->second,
00821                                         (addr.begin()+jj)->second));
00822 
00823                 dist((addr.begin()+ii_keep)->second,
00824                      (addr.begin()+jj)->second) = bla;
00825                 dist((addr.begin()+jj)->second,
00826                      (addr.begin()+ii_keep)->second) = bla;
00827             }
00828         }
00829     }
00830 
00831 };
00832 
00833 
00834 /** Normalize the status value in the HClustering tree (HClustering Visitor)
00835  */
00836 class NormalizeStatus
00837 {
00838 public:
00839     double n;
00840     /** Constructor
00841      * \param m normalize status() by m
00842      */
00843     NormalizeStatus(double m)
00844         :n(m)
00845     {}
00846     template<class Node>
00847     bool operator()(Node& node)
00848     {
00849         node.status()/=n;
00850         return false;
00851     }
00852 };
00853 
00854 
00855 /** Perform Permutation importance on HClustering clusters
00856  * (See visit_after_tree() method of visitors::VariableImportance to 
00857  * see the basic idea. (Just that we apply the permutation not only to
00858  * variables but also to clusters))
00859  */
00860 template<class Iter, class DT>
00861 class PermuteCluster
00862 {
00863 public:
00864     typedef MultiArrayShape<2>::type Shp;
00865     Matrix<double> tmp_mem_;
00866     MultiArrayView<2, double> perm_imp;
00867     MultiArrayView<2, double> orig_imp;
00868     Matrix<double> feats_;
00869     Matrix<int>    labels_;
00870     const int      nPerm;
00871     DT const &           dt;
00872     int index;
00873     int oob_size;
00874 
00875     template<class Feat_T, class Label_T>
00876     PermuteCluster(Iter  a, 
00877                    Iter  b,
00878                    Feat_T const & feats,
00879                    Label_T const & labls, 
00880                    MultiArrayView<2, double> p_imp, 
00881                    MultiArrayView<2, double> o_imp, 
00882                    int np,
00883                    DT const  & dt_)
00884         :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
00885          perm_imp(p_imp),
00886          orig_imp(o_imp),
00887          feats_(_spl(a,b).size(), feats.shape(1)),
00888          labels_(_spl(a,b).size(),1),
00889          nPerm(np),
00890          dt(dt_),
00891          index(0),
00892          oob_size(b-a)
00893     {
00894         copy_splice(_spl(a,b),
00895                     _spl(feats.shape(1)),
00896                     feats,
00897                     feats_);
00898         copy_splice(_spl(a,b),
00899                     _spl(labls.shape(1)),
00900                     labls,
00901                     labels_);
00902     }
00903 
00904     template<class Node>
00905     bool operator()(Node& node)
00906     {
00907         tmp_mem_ = feats_;
00908         RandomMT19937 random;
00909         int class_count = perm_imp.shape(1) - 1;
00910         //permute columns together
00911         for(int kk = 0; kk < nPerm; ++kk)
00912         {
00913             tmp_mem_ = feats_;
00914             for(int ii = 0; ii < rowCount(feats_); ++ii)
00915             {
00916                 int index = random.uniformInt(rowCount(feats_) - ii) +ii;
00917                 for(int jj = 0; jj < node.columns_size(); ++jj)
00918                 {
00919                     if(node.columns_begin()[jj] != feats_.shape(1))
00920                         tmp_mem_(ii, node.columns_begin()[jj]) 
00921                             = tmp_mem_(index, node.columns_begin()[jj]);
00922                 }
00923             }
00924             
00925             for(int ii = 0; ii < rowCount(tmp_mem_); ++ii)
00926             {
00927                 if(dt
00928                         .predictLabel(rowVector(tmp_mem_, ii)) 
00929                     ==  labels_(ii, 0))
00930                 {
00931                     //per class
00932                     ++perm_imp(index,labels_(ii, 0));
00933                     //total
00934                     ++perm_imp(index, class_count);
00935                 }
00936             }
00937         }
00938         double node_status  = perm_imp(index, class_count);
00939         node_status /= nPerm;
00940         node_status -= orig_imp(0, class_count);
00941         node_status *= -1;
00942         node_status /= oob_size;
00943         node.status() += node_status;
00944         ++index;
00945          
00946         return false;
00947     }
00948 };
00949 
00950 /** Convert ClusteringTree into a list (HClustering visitor)
00951  */
00952 class GetClusterVariables
00953 {
00954 public:
00955     /** NumberOfClusters x NumberOfVariables MultiArrayView containing
00956      * in each row the variable belonging to a cluster
00957      */
00958     MultiArrayView<2, int>    variables;
00959     int index;
00960     GetClusterVariables(MultiArrayView<2, int> vars)
00961         :variables(vars), index(0)
00962     {}
00963     void save(std::string file, std::string prefix)
00964     {
00965         vigra::writeHDF5(file.c_str(), (prefix + "_variables").c_str(), 
00966                                variables);
00967     }
00968 
00969     template<class Node>
00970     bool operator()(Node& node)
00971     {
00972         for(int ii = 0; ii < node.columns_size(); ++ii)
00973             variables(index, ii) = node.columns_begin()[ii];
00974         ++index;
00975         return false;
00976     }
00977 };
00978 /** corrects the status fields of a linkage Clustering (HClustering Visitor)
00979  *  
00980  *  such that status(currentNode) = min(status(parent), status(currentNode))
00981  *  \sa cluster_permutation_importance()
00982  */
00983 class CorrectStatus
00984 {
00985 public:
00986     template<class Nde>
00987     bool operator()(Nde & cur, int level, Nde parent, bool infm)
00988     {
00989         if(parent.hasData_)
00990             cur.status() = std::min(parent.status(), cur.status());
00991         return true;
00992     }
00993 };
00994 
00995 
00996 /** draw current linkage Clustering (HClustering Visitor)
00997  *
00998  * create a graphviz .dot file
00999  * usage:
01000  * \code
01001  *         Matrix<double> distance = get_distance_matrix();
01002  *      linkage.cluster(distance);
01003  *      Draw<double, int> draw(features, labels, "linkagetree.graph");
01004  *      linkage.breadth_first_traversal(draw);
01005  * \endcode 
01006  */
01007 template<class T1,
01008          class T2, 
01009          class C1 = UnstridedArrayTag,
01010          class C2 = UnstridedArrayTag> 
01011 class Draw
01012 {
01013 public:
01014     typedef MultiArrayShape<2>::type Shp;
01015     MultiArrayView<2, T1, C1> const &   features_;
01016     MultiArrayView<2, T2, C2> const &   labels_;
01017     std::ofstream graphviz;
01018 
01019 
01020     Draw(MultiArrayView<2, T1, C1> const & features, 
01021          MultiArrayView<2, T2, C2> const& labels,
01022          std::string const  gz)
01023         :features_(features), labels_(labels), 
01024         graphviz(gz.c_str(), std::ios::out)
01025     {
01026         graphviz << "digraph G\n{\n node [shape=\"record\"]";
01027     }
01028     ~Draw()
01029     {
01030         graphviz << "\n}\n";
01031         graphviz.close();
01032     }
01033 
01034     template<class Nde>
01035     bool operator()(Nde & cur, int level, Nde parent, bool infm)
01036     {
01037         graphviz << "node" << cur.index() << " [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() << "\\n";
01038         graphviz << " status: " << cur.status() << "\\n";
01039         for(int kk = 0; kk < cur.columns_size(); ++kk)
01040         {
01041                 graphviz  << cur.columns_begin()[kk] << " ";
01042                 if(kk % 15 == 14)
01043                     graphviz << "\\n";
01044         }
01045         graphviz << "\"] [color = \"" <<cur.status() << " 1.000 1.000\"];\n";
01046         if(parent.hasData_)
01047         graphviz << "\"node" << parent.index() << "\" -> \"node" << cur.index() <<"\";\n";
01048         return true;
01049     }
01050 };
01051 
01052 /** calculate Cluster based permutation importance while learning. (RandomForestVisitor)
01053  */
01054 class ClusterImportanceVisitor : public visitors::VisitorBase
01055 {
01056     public:
01057 
01058     /** List of variables as produced by GetClusterVariables
01059      */
01060     MultiArray<2, int>          variables;
01061     /** Corresponding importance measures
01062      */
01063     MultiArray<2, double>       cluster_importance_;
01064     /** Corresponding error
01065      */
01066     MultiArray<2, double>       cluster_stdev_;
01067     int                         repetition_count_;
01068     bool                        in_place_;
01069     HClustering            &    clustering;
01070 
01071 
01072 #ifdef HasHDF5
01073     void save(std::string filename, std::string prefix)
01074     {
01075         std::string prefix1 = "cluster_importance_" + prefix;
01076         writeHDF5(filename.c_str(), 
01077                         prefix1.c_str(), 
01078                         cluster_importance_);
01079         prefix1 = "vars_" + prefix;
01080         writeHDF5(filename.c_str(), 
01081                         prefix1.c_str(), 
01082                         variables);
01083     }
01084 #endif
01085 
01086     ClusterImportanceVisitor(HClustering & clst, int rep_cnt = 10) 
01087     :   repetition_count_(rep_cnt), clustering(clst)
01088 
01089     {}
01090 
01091     /** Allocate enough memory 
01092      */
01093     template<class RF, class PR>
01094     void visit_at_beginning(RF const & rf, PR const & pr)
01095     {
01096         Int32 const  class_count = rf.ext_param_.class_count_;
01097         Int32 const  column_count = rf.ext_param_.column_count_+1;
01098         cluster_importance_
01099             .reshape(MultiArrayShape<2>::type(2*column_count-1, 
01100                                                 class_count+1));
01101         cluster_stdev_
01102             .reshape(MultiArrayShape<2>::type(2*column_count-1, 
01103                                                 class_count+1));
01104         variables
01105             .reshape(MultiArrayShape<2>::type(2*column_count-1, 
01106                                                 column_count), -1);
01107         GetClusterVariables gcv(variables);
01108         clustering.iterate(gcv);
01109         
01110     }
01111 
01112     /**compute permutation based var imp. 
01113      * (Only an Array of size oob_sample_count x 1 is created.
01114      *  - apposed to oob_sample_count x feature_count in the other method.
01115      * 
01116      * \sa FieldProxy
01117      */
01118     template<class RF, class PR, class SM, class ST>
01119     void after_tree_ip_impl(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01120     {
01121         typedef MultiArrayShape<2>::type Shp_t;
01122         Int32                   column_count = rf.ext_param_.column_count_ +1;
01123         Int32                   class_count  = rf.ext_param_.class_count_;  
01124         
01125         // remove the const cast on the features (yep , I know what I am 
01126         // doing here.) data is not destroyed.
01127         typename PR::Feature_t & features 
01128             = const_cast<typename PR::Feature_t &>(pr.features());
01129 
01130         //find the oob indices of current tree. 
01131         ArrayVector<Int32>      oob_indices;
01132         ArrayVector<Int32>::iterator
01133                                 iter;
01134         
01135         if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
01136         {
01137             ArrayVector<int> cts(2, 0);
01138             ArrayVector<Int32> indices(pr.features().shape(0));
01139             for(int ii = 0; ii < pr.features().shape(0); ++ii)
01140                indices.push_back(ii); 
01141             std::random_shuffle(indices.begin(), indices.end());
01142             for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
01143             {
01144                 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
01145                 {
01146                     oob_indices.push_back(indices[ii]);
01147                     ++cts[pr.response()(indices[ii], 0)];
01148                 }
01149             }
01150         }
01151         else
01152         {
01153             for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
01154                 if(!sm.is_used()[ii])
01155                     oob_indices.push_back(ii);
01156         }
01157 
01158         // Random foo
01159         RandomMT19937           random(RandomSeed);
01160         UniformIntRandomFunctor<RandomMT19937>  
01161                                 randint(random);
01162 
01163         //make some space for the results
01164         MultiArray<2, double>
01165                     oob_right(Shp_t(1, class_count + 1)); 
01166         
01167         // get the oob success rate with the original samples
01168         for(iter = oob_indices.begin(); 
01169             iter != oob_indices.end(); 
01170             ++iter)
01171         {
01172             if(rf.tree(index)
01173                     .predictLabel(rowVector(features, *iter)) 
01174                 ==  pr.response()(*iter, 0))
01175             {
01176                 //per class
01177                 ++oob_right[pr.response()(*iter,0)];
01178                 //total
01179                 ++oob_right[class_count];
01180             }
01181         }
01182         
01183         MultiArray<2, double>
01184                     perm_oob_right (Shp_t(2* column_count-1, class_count + 1)); 
01185         
01186         PermuteCluster<ArrayVector<Int32>::iterator,typename RF::DecisionTree_t>
01187             pc(oob_indices.begin(), oob_indices.end(), 
01188                             pr.features(),
01189                             pr.response(),
01190                             perm_oob_right,
01191                             oob_right,
01192                             repetition_count_,
01193                             rf.tree(index));
01194         clustering.iterate(pc);
01195 
01196         perm_oob_right  /=  repetition_count_;
01197         for(int ii = 0; ii < rowCount(perm_oob_right); ++ii)
01198             rowVector(perm_oob_right, ii) -= oob_right;
01199 
01200         perm_oob_right       *= -1;
01201         perm_oob_right       /= oob_indices.size();
01202         cluster_importance_  += perm_oob_right;
01203     }
01204 
01205     /** calculate permutation based impurity after every tree has been 
01206      * learned  default behaviour is that this happens out of place.
01207      * If you have very big data sets and want to avoid copying of data 
01208      * set the in_place_ flag to true. 
01209      */
01210     template<class RF, class PR, class SM, class ST>
01211     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01212     {    
01213             after_tree_ip_impl(rf, pr, sm, st, index);
01214     }
01215 
01216     /** Normalise variable importance after the number of trees is known.
01217      */
01218     template<class RF, class PR>
01219     void visit_at_end(RF & rf, PR & pr)
01220     {
01221         NormalizeStatus nrm(rf.tree_count());
01222         clustering.iterate(nrm);
01223         cluster_importance_ /= rf.trees_.size();
01224     }
01225 };
01226 
01227 /** Perform hierarchical clustering of variables and assess importance of clusters
01228  *
01229  * \param features    IN:     n x p matrix containing n instances with p attributes/features
01230  *                             used in the variable selection algorithm
01231  * \param response  IN:     n x 1 matrix containing the corresponding response
01232  * \param linkage    OUT:    Hierarchical grouping of variables.
01233  * \param distance  OUT:    distance matrix used for creating the linkage
01234  *
01235  * Performs Hierarchical clustering of variables. And calculates the permutation importance 
01236  * measures of each of the clusters. Use the Draw functor to create human readable output
01237  * The cluster-permutation importance measure corresponds to the normal permutation importance
01238  * measure with all columns corresponding to a cluster permuted. 
01239  * The importance measure for each cluster is stored as the status() field of each clusternode
01240  * \sa HClustering
01241  *
01242  * usage:
01243  * \code
01244  *         MultiArray<2, double>     features = createSomeFeatures();
01245  *         MultiArray<2, int>        labels   = createCorrespondingLabels();
01246  *         HClustering                linkage;
01247  *         MultiArray<2, double>    distance;
01248  *         cluster_permutation_importance(features, labels, linkage, distance)
01249  *        // create graphviz output
01250  *
01251  *      Draw<double, int> draw(features, labels, "linkagetree.graph");
01252  *      linkage.breadth_first_traversal(draw);
01253  *
01254  * \endcode
01255  *
01256  *
01257  */                    
01258 template<class FeatureT, class ResponseT>
01259 void cluster_permutation_importance(FeatureT              const & features,
01260                                          ResponseT         const &     response,
01261                                     HClustering               & linkage,
01262                                     MultiArray<2, double>      & distance)
01263 {
01264 
01265         RandomForestOptions opt;
01266         opt.tree_count(100);
01267         if(features.shape(0) > 40000)
01268             opt.samples_per_tree(20000).use_stratification(RF_EQUAL);
01269 
01270 
01271         vigra::RandomForest<int> RF(opt); 
01272         visitors::RandomForestProgressVisitor             progress;
01273         visitors::CorrelationVisitor                     missc;
01274         RF.learn(features, response,
01275                  create_visitor(missc, progress));
01276         distance = missc.distance;
01277         /*
01278            missc.save(exp_dir + dset.name() + "_result.h5", dset.name()+"MACH");
01279            */
01280 
01281 
01282         // Produce linkage
01283         linkage.cluster(distance);
01284         
01285         //linkage.save(exp_dir + dset.name() + "_result.h5", "_linkage_CC/");
01286         vigra::RandomForest<int> RF2(opt); 
01287         ClusterImportanceVisitor          ci(linkage);
01288         RF2.learn(features, 
01289                   response,
01290                   create_visitor(progress, ci));
01291         
01292         
01293         CorrectStatus cs;
01294         linkage.breadth_first_traversal(cs);
01295 
01296         //ci.save(exp_dir + dset.name() + "_result.h5", dset.name());
01297         //Draw<double, int> draw(dset.features(), dset.response(), exp_dir+ dset.name() + ".graph");
01298         //linkage.breadth_first_traversal(draw);
01299 
01300 }
01301 
01302     
01303 template<class FeatureT, class ResponseT>
01304 void cluster_permutation_importance(FeatureT              const & features,
01305                                          ResponseT         const &     response,
01306                                     HClustering               & linkage)
01307 {
01308     MultiArray<2, double> distance;
01309     cluster_permutation_importance(features, response, linkage, distance);
01310 }
01311 }//namespace algorithms
01312 }//namespace rf
01313 }//namespace vigra

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (3 Dec 2010)