GBDT 用于分类和回归的spark示例

浏览: 5939




GBDT是推荐系统中应用非常广泛的算法。

GBDT 是由决策树集成而来的,这种算法不断地迭代式训练决策树算法,目标是最小化损失函数。跟决策树类似,GBDT能够捕捉到非线性特征,也能发掘特征交互作用。

spark.mllib 支持GBDT的二分类问题和回归问题,spark.mllib中的GBDT是基于决策树来实现的,能够处理连续性和离散性的特征。

基本算法

GBDT 迭代式训练一系列决策树,每次迭代中,算法利用当前的集成算法来预测每个训练样本的标签,然后跟真实的标签相比较。为了重点关注预测效果较差的训练样本,数据集会被重打标签。所以,在下次迭代中,决策树能够有助于纠正之前的分错样本相应的预测。

重打标签的具体机制是由误差函数定义的。每次迭代中,GBDT 都能够减小训练集上的误差函数。

误差函数

下表给出了spark.mllib中GBDT支持的目标函数(损失函数),需要注意的是,每个目标函数要么适用于分类问题,要么适用于回归问题,并不是都适用的。


使用建议

下面给出了几个使用GBDT的建议,

  • loss: 目标函数。不同的目标函数适用不同的问题,比如分类或回归。不同的误差函数可能给出差别很大的最终结果,这依赖于数据集。

  • numIterations: 迭代次数。这个参数设置了集成算法中决策树的个数。每次迭代都会生成一棵决策树,增加迭代次数可以增加模型的表达能力,提高训练集上的准确率,然而,如果迭代次数过多,测试集上的准确率可能会降低 。

  • learningRate: 学习率。这个参数不需要调节,如果算法不稳定,减小学习率可能会提高稳定性。

  • algo: 算法。算法或任务(分类或回归)。

训练时的验证

GBDT 中树的个数较多时,可能会过拟合。为了防止过拟合,通常需要训练时加以验证。GBDT 中提供的的 runWithValidation 方法可以用来验证效果。需要一对 RDD 作为参数,第一个是训练集,第二个是验证集。

训练结束的条件是验证误差的提升量不超过某个限定值,这个限定值即为BoostingStrategy中的validationTol。实际问题中,验证误差会先减小后增大。验证误差不是单调变化的,用户可以设置一个足够大的负限定值,然后利用 evaluationEachIteration 来检测验证曲线,进而调节训练次数,其中evalutionnEachIteration 表示每次迭代中的误差或损失。

示例

分类问题

示例中给出了如何加载  LIBSVM data file,然后将其解析成LabeledPoint RDD,然后利用GBDT来分类,其中目标函数是对数误差,测试误差用来衡量算法的准确率。

API细节可以参考 GradientBoostedTrees  docs 和 GradientBoostedTreesModel docs 。.

Scala 代码如下

import org.apache.spark.mllib.tree.GradientBoostedTrees

import org.apache.spark.mllib.tree.configuration.BoostingStrategy

import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel

import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.

val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

// Split the data into training and test sets (30% held out for testing)

val splits = data.randomSplit(Array(0.7, 0.3))

val (trainingData, testData) = (splits(0), splits(1))

// Train a GradientBoostedTrees model.

// The defaultParams for Classification use LogLoss by default.

val boostingStrategy = BoostingStrategy.defaultParams("Classification")

boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.

boostingStrategy.treeStrategy.numClasses = 2

boostingStrategy.treeStrategy.maxDepth = 5

// Empty categoricalFeaturesInfo indicates all features are continuous.

boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

val model = GradientBoostedTrees.train(trainingData, boostingStrategy)

// Evaluate model on test instances and compute test error

val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)}

val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()

println("Test Error = " + testErr)

println("Learned classification GBT model:\n" + model.toDebugString)

// Save and load model

model.save(sc, "target/tmp/myGradientBoostingClassificationModel")

val sameModel = GradientBoostedTreesModel.load(sc,
"target/tmp/myGradientBoostingClassificationModel")

完整代码可以参考spark repo中的"examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala" 。

python 代码如下

from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel

from pyspark.mllib.util import MLUtils

# Load and parse the data file.

data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

# Split the data into training and test sets (30% held out for testing)

(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a GradientBoostedTrees model.

#  Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous.

#         (b) Use more iterations in practice.

model = GradientBoostedTrees.trainClassifier(trainingData,
categoricalFeaturesInfo={}, numIterations=3)

# Evaluate model on test instances and compute test error

predictions = model.predict(testData.map(lambda x: x.features))l
abelsAndPredictions
= testData.map(lambda lp: lp.label).zip(predictions)

testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())

print('Test Error = ' + str(testErr))

print('Learned classification GBT model:')

print(model.toDebugString())

# Save and load model

model.save(sc, "target/tmp/myGradientBoostingClassificationModel")

sameModel = GradientBoostedTreesModel.load(sc,
"target/tmp/myGradientBoostingClassificationModel")

完整代码可以参考spark repo中的 "examples/src/main/python/mllib/gradient_boosting_classification_example.py" 。

完整代码

回归问题

下面给出了如何加载,将其解析成 LabeledPoint RDD,然后用GBDT来实现回归,其中目标函数是均方误差,均方误差用来衡量拟合的好坏。

API细节可以参考 GradientBoostedTrees docs 和 GradientBoostedTreesModel docs

Scala 代码如下

import org.apache.spark.mllib.tree.GradientBoostedTrees

import org.apache.spark.mllib.tree.configuration.BoostingStrategy

import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel

import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.

val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

// Split the data into training and test sets (30% held out for testing)

val splits = data.randomSplit(Array(0.7, 0.3))

val (trainingData, testData) = (splits(0), splits(1))

// Train a GradientBoostedTrees model.

// The defaultParams for Regression use SquaredError by default.

val boostingStrategy = BoostingStrategy.defaultParams("Regression")

boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.

boostingStrategy.treeStrategy.maxDepth = 5

// Empty categoricalFeaturesInfo indicates all features are continuous.

boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()

val model = GradientBoostedTrees.train(trainingData, boostingStrategy)

// Evaluate model on test instances and compute test error

val labelsAndPredictions = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)}

val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()

println("Test Mean Squared Error = " + testMSE)

println("Learned regression GBT model:\n" + model.toDebugString)

// Save and load model

model.save(sc, "target/tmp/myGradientBoostingRegressionModel")

val sameModel = GradientBoostedTreesModel.load(sc,
"target/tmp/myGradientBoostingRegressionModel")

完整代码参加spark repo中的 "examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala" 。

python 代码如下:

from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel

from pyspark.mllib.util import MLUtils# Load and parse the data file.

data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

# Split the data into training and test sets (30% held out for testing)

(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a GradientBoostedTrees model.

#  Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous.

#         (b) Use more iterations in practice.

model = GradientBoostedTrees.trainRegressor(trainingData,
categoricalFeaturesInfo={}, numIterations=3)

# Evaluate model on test instances and compute test error

predictions = model.predict(testData.map(lambda x: x.features))

labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ float(testData.count())

print('Test Mean Squared Error = ' + str(testMSE))

print('Learned regression GBT model:')

print(model.toDebugString())

# Save and load model

model.save(sc, "target/tmp/myGradientBoostingRegressionModel")

sameModel = GradientBoostedTreesModel.load(sc, "target/tmp/myGradientBoostingRegressionModel")


完整代码参见spark repo中的 "examples/src/main/python/mllib/gradient_boosting_regression_example.py" 。

优缺点

优点:

非线性变换能力强,表达能力强,

不需要复杂的特征工程和特征变换

缺点

串行过程,不易并行化

计算复杂度高,不适合高维稀疏特征

参考资料

http://journal.frontiersin.org/article/10.3389/fnbot.2013.00021/full

http://www.slideshare.net/Hadoop_Summit/surge-rise-of-scalable-machine-learning-at-yahoo

http://spark.apache.org/docs/latest/mllib-ensembles.html#gradient-boosted-trees-gbts

推荐 0
本文由 深度学习 创作,采用 知识共享署名-相同方式共享 3.0 中国大陆许可协议 进行许可。
转载、引用前需联系作者,并署名作者且注明文章出处。
本站文章版权归原作者及原出处所有 。内容为作者个人观点, 并不代表本站赞同其观点和对其真实性负责。本站是一个个人学习交流的平台,并不用于任何商业目的,如果有任何问题,请及时联系我们,我们将根据著作权人的要求,立即更正或者删除有关内容。本站拥有对此声明的最终解释权。

0 个评论

要回复文章请先登录注册