2013年6月12日水曜日

決定木のサンプル

決定木のサンプルプログラム

数値パラメータ入力に対する決定木のサンプル.

#include <opencv2/opencv.hpp> //OpenCV関連ヘッダ全部インクルード
#ifdef _DEBUG  //Debugモードの場合
#pragma comment(lib, "opencv_core245d.lib")
#pragma comment(lib, "opencv_ml245d.lib")
#else  //Releaseモードの場合
#pragma comment(lib, "opencv_core245.lib")
#pragma comment(lib, "opencv_ml245.lib")
#endif

using namespace cv;
using std::cout;
using std::endl;

void showSplits(CvDTreeSplit *split)
{
    if(split==NULL) return;

    cout << "Split : vidx=" << split->var_idx;
    cout << ", quality=" << split->quality;
    cout << ", ord.c=" << split->ord.c << endl;
    showSplits(split->next);
}

void showTrees(const CvDTreeNode *root, string text)
{
    if(root==NULL) return;
    for(int d=0; d<root->depth; d++) cout << "-- ";
    cout << text << endl;
    cout << "sample count" << root->sample_count << endl;
    cout << "value:" << root->value << endl;
    showSplits(root->split);
    cout << endl;

    showTrees(root->left, "Left");
    showTrees(root->right, "Right");
}

int main()
{
    //http://www.sakurai.comp.ae.keio.ac.jp/classes/IntInfProc-class/2010/06DecisionTree.pdf より「破産の予測」
    // 1年あたりの支払い遅延回数,支出/収入
    Mat data = (Mat_<float>(14,2) <<
    3, 0.2,
    1, 0.3,
    4, 0.5,
    2, 0.7,
    0, 1.0,
    1, 1.2,
    1, 1.7,
    6, 0.2,
    7, 0.3,
    6, 0.7,
    3, 1.1,
    2, 1.5,
    4, 1.7,
    2, 1.9);
    Mat label = (Mat_<int>(14, 1) << 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1);

    CvDTreeParams param = CvDTreeParams( 
        3, // 最大の深さ
        2, // あるノードに対するサンプル数がこの値よりも少ない場合,分岐しない.
        0, // 別の終了条件 - 回帰木の場合のみ. 
        false, // trueの場合代理分岐が構築される.データの欠損や,変数の重要度の推定に必要
        2, // max_categories 
        0, // このパラメータが >1 の場合,木は cv_folds 分割交差検証法により刈り込まれる.
        false, // true の場合,木は刈り込み手続きによって切り捨てられる. 
        false, // true の場合,カットオフノードが,木から物理的に削除される.
        NULL // クラスラベル値によって保存されたクラス事前確率の配列.
    );

    DecisionTree dtree = DecisionTree();
    dtree.train(data, CV_ROW_SAMPLE, label, Mat(), Mat(), Mat(), Mat(), param);

    //学習データをテストに回してみる
    for( int i = 0; i < data.rows; i++ )
    {
        double r = (dtree.predict(data.row(i)))->value;
        cout << i << ":" << r << endl;
    }

    //ツリーの表示
    showTrees(dtree.get_root(), "Root");

    return 0;
}