聊聊GBDT和随机森林

浏览: 2146

                                 

GBDT 和 随机森林都是基于决策树而得到的。决策树比较容易理解,它具有比较直观的图示。利用决策树可以发现比较重要的变量,也可以挖掘变量之间的关系。决策树也比较不易受到离群点和缺失值的影响。由于决策树不考虑空间分布,也不考虑分类器的结构,它是一种无参算法。但是决策树比较容易过拟合,另外,决策树不易处理连续型变量。

Gradient Boosting 是一种提升的框架,可以用于决策树算法,即GBDT。通常Boosting基于弱学习器,弱分类器具有高偏差,低方差。决策树的深度较浅时,通常是弱学习器,比如一种比较极端的例子,只有一个根节点和两个叶子节点。Boosting这种策略主要通过减小偏差来降低总体误差,一般来讲,通过集成多个模型的结果也会减小方差。这里的总体误差可以看作由偏差和方差构成。由于GBDT是基于Boosting策略的,所以这种算法具有序贯性,不容易并行实现。

关于偏差和方差随模型复杂度变化,可以参见下图。

                                          


随机森林主要通过减小方差来降低总体误差。随机森林是由多个决策树构成的,因此需要基于原始数据集随机生成多个数据集,用于生成多个决策树。这些决策树之间的相关性越小,方差降低得越多。虽然随机森林可以减小方差,但是这种组合策略不能降低偏差,它会使得总体偏差大于森林中单个决策树的偏差。

随机森林利用bagging来组合多个决策树,容易过拟合。由于这种方法基于bagging思想,因此这种算法比较容易并行实现。随机森林能够较好地应对缺失值和非平衡集的情形。

下面给出基于scikit-learn的随机森林示例:

from sklearn.ensemble import RandomForestClassifier

X
= [[0, 0], [1, 1]]

Y = [0, 1]

clf = RandomForestClassifier(n_estimators=10)

clf = clf.fit(X, Y)

spark也内嵌了随机森林算法,示例如下:

from pyspark.mllib.tree import RandomForest, RandomForestModel

from pyspark.mllib.util import MLUtils

# Load and parse the data file into an RDD of LabeledPoint.

data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')

# Split the data into training and test sets (30% held out for testing)

(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a RandomForest model.

#  Empty categoricalFeaturesInfo indicates all features are continuous.

#  Note: Use larger numTrees in practice.

#  Setting featureSubsetStrategy="auto" lets the algorithm choose.

model = RandomForest.trainClassifier(

trainingData, numClasses=2, categoricalFeaturesInfo={},

numTrees=3, featureSubsetStrategy="auto",

impurity='gini', maxDepth=4, maxBins=32)


# Evaluate model on test instances and compute test error

predictions = model.predict(testData.map(lambda x: x.features))

labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())

print('Test Error = ' + str(testErr))

print('Learned classification forest model:')

print(model.toDebugString())

# Save and load model

model.save(sc, "target/tmp/myRandomForestClassificationModel")

sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel")

这两种算法也可以用于客户管理或营销领域,比如客户流失预测(Bagging and boosting classification trees to predict churn  

https://pure.uvt.nl/portal/files/1425373/lemmens_bagging.pdf) 和点击率预估(

Feature Selection in Click-Through Rate Prediction Based on Gradient Boosting

https://link.springer.com/chapter/10.1007/978-3-319-46257-8_15)等。

推荐 0
本文由 深度学习 创作,采用 知识共享署名-相同方式共享 3.0 中国大陆许可协议 进行许可。
转载、引用前需联系作者,并署名作者且注明文章出处。
本站文章版权归原作者及原出处所有 。内容为作者个人观点, 并不代表本站赞同其观点和对其真实性负责。本站是一个个人学习交流的平台,并不用于任何商业目的,如果有任何问题,请及时联系我们,我们将根据著作权人的要求,立即更正或者删除有关内容。本站拥有对此声明的最终解释权。

0 个评论

要回复文章请先登录注册