返回 登录
0

比较R语言机器学习算法的性能

你如何有效地计算出不同机器学习算法的估计准确性?在这篇文章中,你将会学到8种技术,用来比较R语言机器学习算法。你可以使用这些技术来选择最精准的模型,并能够给出统计意义方面的评价,以及相比其它算法的绝对优势。

选择最好的机器学习模型

你如何根据需求选择最好的模型?

在你进行机器学习项目的时候,往往会有许多良好模型可供选择。每个模型都有不同的性能特点。

使用重采样方法,如交叉验证,就可以得到每个模型在未知数据上精准度的估计。你需要利用这些估计从你创建的一系列模型中选择一到两个最好的模型。

仔细比较机器学习模型

当你有了新数据集,使用多种不同的图形技术可视化数据是个好主意,你可以从不同角度来观察数据。

这种想法也可以用于模型选择。你应该使用不同的方法来进行估计机器学习算法的准确率,依此来选择一到两个模型。

你可以使用不同的可视化方法来显示平均准确率、方差和模型精度分布的其他性质。

比较并选择R语言的机器学习模型

在本节中,你将会学到如何客观地比较R语言机器学习模型。

通过本节中的案例研究,你将为皮马印第安人糖尿病数据集创建一些机器学习模型。然后你将会使用一系列不同的可视化技术来比较这些模型的估计准确率。

本案例研究分为三个部分:

  1. 准备数据集:加载库文件和数据集,准备训练模型。
  2. 训练模型:在数据集上训练标准机器学习模型,准备进行评估。
  3. 比较模型:使用8种不同的技术比较训练得到的模型。

准备数据集

本研究案例中使用的数据集是皮马印第安人糖尿病数据集,可在UCI机器学习库中获取。也可在R中的mlbench包中获取。

这是一个二元分类问题,预测患者在五年之内糖尿病是否会发作。入参是数值型,描述了女性患者的医疗信息。

现在来加载库文件和数据集。

# load libraries
library(mlbench)
library(caret)
# load the dataset
data(PimaIndiansDiabetes)

训练模型

在本节中,我们将会训练在下一节中将要比较的5个机器学习模型。

我们将使用重复交叉验证,folds为10,repeats为3,这是比较模型的常用标准配置。评估指标是精度和kappa,因为它们很容易解释。

根据算法的代表性和学习风格方式进行半随机选择。它们有:

  • 分类和回归树
  • 线性判别分析
  • 使用径向基函数的支持向量机
  • K-近邻
  • 随机森林

训练完模型之后,将它们添加到一个list中,然后调用resamples()函数。此函数可以检查模型是可比较的,并且模型都使用同样的训练方案(训练控制配置)。这个对象包含每个待评估算法每次折叠和重复的评估指标。

下一节中我们使用到的函数都需要包含这种数据的对象。

# prepare training scheme
control <- trainControl(method="repeatedcv", number=10, repeats=3)
# CART
set.seed(7)
fit.cart <- train(diabetes~., data=PimaIndiansDiabetes, method="rpart", trControl=control)
# LDA
set.seed(7)
fit.lda <- train(diabetes~., data=PimaIndiansDiabetes, method="lda", trControl=control)
# SVM
set.seed(7)
fit.svm <- train(diabetes~., data=PimaIndiansDiabetes, method="svmRadial", trControl=control)
# kNN
set.seed(7)
fit.knn <- train(diabetes~., data=PimaIndiansDiabetes, method="knn", trControl=control)
# Random Forest
set.seed(7)
fit.rf <- train(diabetes~., data=PimaIndiansDiabetes, method="rf", trControl=control)
# collect resamples
results <- resamples(list(CART=fit.cart, LDA=fit.lda, SVM=fit.svm, KNN=fit.knn, RF=fit.rf))

比较模型

在本节中,我们将看到8种不同的技术用来比较构建模型的估计精度。

汇总表(Table Summary)

这是你可以做的最简单的比较,只需要调用summary()函数,并传入resamples()函数值。它会创建一个表格,每行是一种算法,每列是评估指标。在这里我们已经整理好了结果。

Accuracy 
       Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
CART 0.6234  0.7115 0.7403 0.7382  0.7760 0.8442    0
LDA  0.6711  0.7532 0.7662 0.7759  0.8052 0.8701    0
SVM  0.6711  0.7403 0.7582 0.7651  0.7890 0.8961    0
KNN  0.6184  0.6984 0.7321 0.7299  0.7532 0.8182    0
RF   0.6711  0.7273 0.7516 0.7617  0.7890 0.8571    0

Kappa 
       Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
CART 0.1585  0.3296 0.3765 0.3934  0.4685 0.6393    0
LDA  0.2484  0.4196 0.4516 0.4801  0.5512 0.7048    0
SVM  0.2187  0.3889 0.4167 0.4520  0.5003 0.7638    0
KNN  0.1113  0.3228 0.3867 0.3819  0.4382 0.5867    0
RF   0.2624  0.3787 0.4516 0.4588  0.5193 0.6781    0

箱线图(Box and Whisker Plots)

这是查看不同模型评估精度伸展和联系的有效方式。

# box and whisker plots to compare models
scales <- list(x=list(relation="free"), y=list(relation="free"))
bwplot(results, scales=scales)

注意到箱线图以平均精度降序排序。我发现观察平均值(点)和箱线图的重叠(中间50%)很有用。

图片描述

用箱线图比较R语言机器学习算法

密度图(Density Plots)

你可以将模型精度分布显示成密度图。这是种评估算法估计行为重叠的有效方式。

# density plots of accuracy
scales <- list(x=list(relation="free"), y=list(relation="free"))
densityplot(results, scales=scales, pch = "|")

我喜欢观察波峰以及分布伸展或分布底部的差异。

图片描述

比较R语言机器学习算法的密度图

点图(Dot Plots)

这些点非常有用,它显示了平均估计精度以及95%的置信区间(例如,95%观测点所落入的范围)。

# dot plots of accuracy
scales <- list(x=list(relation="free"), y=list(relation="free"))
dotplot(results, scales=scales)

我发现比较均值和目测算法间的重叠伸展很有用。

图片描述

比较R语言机器学习算法的点图

平行线图(Parallel Plots)

这是另一种查看数据的方式。它显示了每个被测算法每次交叉验证折叠试验的行为。它可以帮助你查看一个算法中子集相对其他算法的线性走势。

# parallel plots to compare models
parallelplot(results)

要对此进行解释需要一些技巧。我认为这在以后对分析不同方法如何在组合预测中结合很有帮助(例如堆叠),尤其当你在相反方向看到有相关运动时。

图片描述

比较R语言机器学习算法的平行线图

散点图矩阵(Scatterplot Matrix)

这创建了一个算法的所有折叠试验结果与其他算法相同折叠试验结果比较的散点图矩阵。每一对都进行了比较。

# pair-wise scatterplots of predictions to compare models
splom(results)

这种做法对于考虑两个不同算法的预测是否相关时非常宝贵。如果弱相关,它们可以很好地用于组合预测。

比如,目测图表,好像LDA和SVM呈强相关性,SVM和RF也一样。SVM与CART似乎呈弱相关性。

图片描述

比较R语言机器学习算法的散点图矩阵

成对XY图(Pairwise xyPlots)

你可以使用xy图,对两种机器学习算法的折叠试验精度进行成对比较。

# xyplot plots to compare models
xyplot(results, models=c("LDA", "SVM"))

在这种情况下,我们可以看到LDA和SVM模型看似相关的精度。

图片描述

比较R语言机器学习算法的成对散点图

统计意义检测(Statistical Significance Tests)

你可以计算不同机器学习算法间指标分布差异的意义。我们可以直接调用summary()函数汇总结果。

# difference in model predictions
diffs <- diff(results)
# summarize p-values for pair-wise comparisons
summary(diffs)

我们可以得到一个表格,记录了每对算法的统计意义分数。表格对角线下方显示的是零假设的p值(分布是相同的),值越小越好。我们可以看到CART和kNN之间没有区别,同样能看出LDA和SVM分布相差不大。

表格对角线上方显示的是不同分布的估计差异。观察前面的图表,如果我们认为LDA是最精准的模型,我们可以得出它比其他模型要具体精准多少的估计。

这些分数可以帮助你计算具体算法之间任何精度。

p-value adjustment: bonferroni 
Upper diagonal: estimates of the difference
Lower diagonal: p-value for H0: difference = 0

Accuracy 
     CART      LDA       SVM       KNN       RF       
CART           -0.037759 -0.026908  0.008248 -0.023473
LDA  0.0050068            0.010851  0.046007  0.014286
SVM  0.0919580 0.3390336            0.035156  0.003435
KNN  1.0000000 1.218e-05 0.0007092           -0.031721
RF   0.1722106 0.1349151 1.0000000  0.0034441

一个好技巧是增加试验次数,来增加种群,获取可以得到更精准p值。你也可以画出它们之间的差异,但是我发现与上面的汇总表相比并没多大用处。

总结

在这篇文章中你学会了8种不同的技术,可以用来比较R语言机器学习算法模型的估计精度。

这8种技术是:

  • 表汇总
  • 箱线图
  • 密度图
  • 点图
  • 平行线图
  • 散点图矩阵
  • 成对XY图
  • 统计意义检测

我漏掉了你在比较R语言机器学习算法估计精度时喜欢使用的方法吗?请留下评论,我很乐意倾听!

原文链接:Compare The Performance of Machine Learning Algorithms in R
作者: Jason Brownlee
译者:刘翔宇 审校:赵屹华
责编:周建丁


2016年3月18日-19日,由CSDN重磅打造的数据库核心技术与实战应用峰会、互联网应用架构实战峰会将在上海举行。这两场峰会将邀请业内顶尖的架构师和技术专家,共同探讨高可用/高并发系统架构设计、新技术应用、移动应用架构、微服务、智能硬件架构、云数据库实战、新一代数据库平台、产品选型、性能调优、大数据应用实战等领域的热点话题与技术。

2月29日24点前仍处于最低六折优惠票价阶段,单场峰会(含餐)门票只需799元,5人以上团购或者购买两场峰会通票更有特惠,限量供应,预购从速。(票务详情链接)。

评论