数学推导+纯Python实现机器学习算法4:决策树之ID3算法

浏览: 519

作者:鲁伟,热爱数据,坚信数据技术和代码改变世界。R语言和Python的忠实拥趸,为成为一名未来的数据科学家而奋斗终生。

个人公众号:机器学习实验室  (微信ID:louwill12)

作为机器学习中的一大类模型,树模型一直以来都颇受学界和业界的重视。目前无论是各大比赛各种大杀器的XGBoost、lightgbm还是像随机森林、Adaboost等典型集成学习模型,都是以决策树模型为基础的。传统的经典决策树算法包括ID3算法、C4.5算法以及GBDT的基分类器CART算法。

     三大经典决策树算法最主要的区别在于其特征选择准则的不同。ID3算法选择特征的依据是信息增益、C4.5是信息增益比,而CART则是Gini指数。作为一种基础的分类和回归方法,决策树可以有如下两种理解方式。一种是我们可以将决策树看作是一组if-then规则的集合,另一种则是给定特征条件下类的条件概率分布。关于这两种理解方式,读者朋友可深入阅读相关教材进行理解,笔者这里补详细展开。

     根据上述两种理解方式,我们既可以将决策树的本质视作从训练数据集中归纳出一组分类规则,也可以将其看作是根据训练数据集估计条件概率模型。整个决策树的学习过程就是一个递归地选择最优特征,并根据该特征对数据集进行划分,使得各个样本都得到一个最好的分类的过程。

ID3算法理论

     所以这里的关键在于如何选择最优特征对数据集进行划分。答案就是前面提到的信息增益、信息增益比和Gini指数。因为本篇针对的是ID3算法,所以这里笔者仅对信息增益进行详细的表述。

     在讲信息增益之前,这里我们必须先介绍下熵的概念。在信息论里面,熵是一种表示随机变量不确定性的度量方式。若离散随机变量X的概率分布为:


     则随机变量X的熵定义为:

     同理,对于连续型随机变量Y,其熵可定义为:

     当给定随机变量X的条件下随机变量Y的熵可定义为条件熵H(Y|X):

     所谓信息增益就是数据在得到特征X的信息时使得类Y的信息不确定性减少的程度。假设数据集D的信息熵为H(D),给定特征A之后的条件熵为H(D|A),则特征A对于数据集的信息增益g(D,A)可表示为:

g(D,A) = H(D) - H(D|A)

     信息增益越大,则该特征对数据集确定性贡献越大,表示该特征对数据有较强的分类能力。信息增益的计算示例如下:
1)计算目标特征的信息熵

2)计算加入某个特征之后的条件熵

3)计算信息增益

     以上就是ID3算法的核心理论部分,至于如何基于ID3构造决策树,我们在代码实例中来看。

ID3算法实现

     先读入示例数据集:

1import numpy as np
2import pandas as pd
3from math import log
4
5df = pd.read_csv('./example_data.csv')
6df

定义熵的计算函数:

1def entropy(ele):    
2    '''
3    function: Calculating entropy value.
4    input: A list contain categorical value.
5    output: Entropy value.
6    entropy = - sum(p * log(p)), p is a prob value.
7    '''

8    # Calculating the probability distribution of list value
9    probs = [ele.count(i)/len(ele) for i in set(ele)]    
10    # Calculating entropy value
11    entropy = -sum([prob*log(prob, 2for prob in probs])    
12    return entropy

计算示例:


然后我们需要定义根据特征和特征值进行数据划分的方法:

1def split_dataframe(data, col):    
2    '''
3    function: split pandas dataframe to sub-df based on data and column.
4    input: dataframe, column name.
5    output: a dict of splited dataframe.
6    '''
7    # unique value of column
8    unique_values = data[col].unique()    
9    # empty dict of dataframe
10    result_dict = {elem : pd.DataFrame for elem in unique_values}    
11    # split dataframe based on column value
12    for key in result_dict.keys():
13        result_dict[key] = data[:][data[col] == key]    
14    return result_dict

根据temp和其三个特征值的数据集划分示例:

     然后就是根据熵计算公式和数据集划分方法计算信息增益来选择最佳特征的过程:

1def choose_best_col(df, label):    
2    '''
3    funtion: choose the best column based on infomation gain.
4    input: datafram, label
5    output: max infomation gain, best column, 
6            splited dataframe dict based on best column.
7    '''

8    # Calculating label's entropy
9    entropy_D = entropy(df[label].tolist())    
10    # columns list except label
11    cols = [col for col in df.columns if col not in [label]]    
12    # initialize the max infomation gain, best column and best splited dict
13    max_value, best_col = -999None
14    max_splited = None
15    # split data based on different column
16    for col in cols:
17        splited_set = split_dataframe(df, col)
18        entropy_DA = 0
19        for subset_col, subset in splited_set.items():            
20            # calculating splited dataframe label's entropy
21            entropy_Di = entropy(subset[label].tolist())            
22            # calculating entropy of current feature
23            entropy_DA += len(subset)/len(df) * entropy_Di        
24        # calculating infomation gain of current feature
25        info_gain = entropy_D - entropy_DA        
26        if info_gain > max_value:
27            max_value, best_col = info_gain, col
28            max_splited = splited_set    
29        return max_value, best_col, max_splited

最先选到的信息增益最大的特征是outlook:

     决策树基本要素定义好后,我们即可根据以上函数来定义一个ID3算法类,在类里面定义构造ID3决策树的方法:

1class ID3Tree:    
2    # define a Node class
3    class Node:        
4        def __init__(self, name):
5            self.name = name
6            self.connections = {}    
7
8        def connect(self, label, node):
9            self.connections[label] = node    
10
11    def __init__(self, data, label):
12        self.columns = data.columns
13        self.data = data
14        self.label = label
15        self.root = self.Node("Root")    
16
17    # print tree method
18    def print_tree(self, node, tabs):
19        print(tabs + node.name)        
20        for connection, child_node in node.connections.items():
21            print(tabs + "\t" + "(" + connection + ")")
22            self.print_tree(child_node, tabs + "\t\t")    
23
24    def construct_tree(self):
25        self.construct(self.root, ""self.data, self.columns)    
26
27    # construct tree
28    def construct(self, parent_node, parent_connection_label, input_data, columns):
29        max_value, best_col, max_splited = choose_best_col(input_data[columns], self.label)        
30        if not best_col:
31            node = self.Node(input_data[self.label].iloc[0])
32            parent_node.connect(parent_connection_label, node)            
33        return
34
35        node = self.Node(best_col)
36        parent_node.connect(parent_connection_label, node)
37
38        new_columns = [col for col in columns if col != best_col]        
39        # Recursively constructing decision trees
40        for splited_value, splited_data in max_splited.items():
41            self.construct(node, splited_value, splited_data, new_columns)

根据上述代码和示例数据集构造一个ID3决策树:

     以上便是ID3算法的手写过程。sklearn中tree模块为我们提供了决策树的实现方式,参考代码如下:

1from sklearn.datasets import load_iris
2from sklearn import tree
3import graphviz
4
5iris = load_iris()
6# criterion选择entropy,这里表示选择ID3算法
7clf = tree.DecisionTreeClassifier(criterion='entropy', splitter='best')
8clf = clf.fit(iris.data, iris.target)
9
10dot_data = tree.export_graphviz(clf, out_file=None,
11                               feature_names=iris.feature_names,
12                               class_names=iris.target_names,
13                               filled=True
14                               rounded=True,
15                               special_characters=True)
16graph = graphviz.Source(dot_data)
17graph

    

 以上便是本篇的全部内容,完整版代码和数据请移步本人github:

https://github.com/luwill/machine-learning-code-writing

参考资料:

往期精选

数学推导+纯Python实现机器学习算法2:逻辑回归

数据分析到底对企业有什么用?

总结 | 基于代码的数学符号释义(一)

2018年终精心整理|人工智能爱好者社区历史文章合集(作者篇)

2018年终精心整理 | 人工智能爱好者社区历史文章合集(类型篇)

公众号后台回复关键词学习

回复 免费                获取免费课程

回复 直播                获取系列直播课

回复 Python           1小时破冰入门Python

回复 人工智能         从零入门人工智能

回复 深度学习         手把手教你用Python深度学习

回复 机器学习         小白学数据挖掘与机器学习

回复 贝叶斯算法      贝叶斯与新闻分类实战

回复 数据分析师      数据分析师八大能力培养

回复 自然语言处理  自然语言处理之AI深度学习

推荐 0
本文由 人工智能爱好者社区 创作,采用 知识共享署名-相同方式共享 3.0 中国大陆许可协议 进行许可。
转载、引用前需联系作者,并署名作者且注明文章出处。
本站文章版权归原作者及原出处所有 。内容为作者个人观点, 并不代表本站赞同其观点和对其真实性负责。本站是一个个人学习交流的平台,并不用于任何商业目的,如果有任何问题,请及时联系我们,我们将根据著作权人的要求,立即更正或者删除有关内容。本站拥有对此声明的最终解释权。

0 个评论

要回复文章请先登录注册