近邻算法API介绍

sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')

  • n_neighbors:
    • int,可选(默认= 5),k_neighbors查询默认使用的邻居数
  • algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’}
    • 快速k近邻搜索算法,默认参数为auto,可以理解为算法自己决定合适的搜索算法。除此之外,用户也可以自己指定搜索算法ball_tree、kd_tree、brute方法进行搜索,
      • brute是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。
      • kd_tree,构造kd树存储数据以便对其进行快速检索的树形数据结构,kd树也就是数据结构中的二叉树。以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高。
      • ball tree是为了克服kd树高纬失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体。

 

案例:鸢尾花种类预测

数据集介绍

Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。关于数据集的具体介绍:

鸢尾花种类预测

步骤分析

  • 1.获取数据集
  • 2.数据基本处理
  • 3.特征工程
  • 4.机器学习(模型训练)
  • 5.模型评估
'''
1.获取数据集
2.数据基本处理
3.特征工程
4.机器学习(模型训练)
5.模型评估
'''
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

#1.获取数据集
iris=load_iris()

#2.数据基础处理
#2.1数据分割
# 训练集的特征值x_train 测试集的特征值x_test 训练集的目标值y_train 测试集的目标值y_test
x_train, x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=22,test_size=0.2)

#3.特征工程
#3.1实例化一个转换器
transfer=StandardScaler()
#3.2调用fit_trainform方法
x_train=transfer.fit_transform(x_train)
x_test=transfer.fit_transform(x_test)

#4.机器学习(模拟训练)
#4.1实例化一个估计器
estimator=KNeighborsClassifier(n_neighbors=5)
#4.2模型训练
estimator.fit(x_train,y_train)

#5,模型训练
#5.1输出预测值
y_pre=estimator.predict(x_test)
print("预测值是:\n",y_pre)
print("预测值和真实值对比:\n",y_pre==y_test)

#5.2输出准确率
ret=estimator.score(x_test,y_test)
print("准确率是:\n",ret)

'''
预测值是:
 [0 2 1 1 1 1 1 1 1 0 2 1 2 2 0 2 1 1 1 1 0 2 0 1 1 0 1 1 2 1]
预测值和真实值对比:
 [ True  True  True False  True  True  True False  True  True  True  True
  True  True  True  True  True  True False  True  True  True  True  True
 False  True False False  True False]
准确率是:
 0.7666666666666667

进程已结束,退出代码0

'''

 

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐