参考在这里插入图片描述


```python
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns
#import plot_svc_decision_function from plot_svm_boundary
from plot_svm_boundary import plot_svc_decision_function

##用sklearn 工具包中的datasers 生成数据集

from sklearn.datasets.samples_generator import make_blobs
# X:the generated samples  y:the labels for cluster
X,y=make_blobs(n_samples=50,centers=2,random_state=0,cluster_std=0.6)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap='autumn')


#进行二分类任务

#随便画几条分割线,判断哪个好?

xfit=np.linspace(-1,3.5)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap='autumn')

for m,b in [(1,0.65),(0.5,1.6),(-0.2,2.9)]:
    plt.plot(xfit,m*xfit+b,'-k')
#限制一下X的取值范围
plt.xlim(-1,3.5)
plt.show()

# 找最宽的边界,并将其画出来

#看支持向量机如何决定最宽边界
from sklearn.svm import SVC
#线性核函数 相当于不对数据进行变换

model=SVC(kernel='linear')
model.fit(X,y)

#绘图函数
# def plot_svc_decision_function(model, ax=None, plt_support=True, plot_support=None):
#     """Plot the decision function for a 2D SVC"""
#     if ax is None:
#         ax=plt.gca() # get current axes 获取当前坐标轴
#     xlim=ax.get_xlim()
#     ylim=ax.get_ylim()
#
#     #用SVM自带的decision_function函数来绘制
#     x=np.linspace(xlim[0],xlim[1],30)
#     y=np.linspace(ylim[0],ylim[1],30)
#     Y,X=np.meshgrid(y,x) # 绘制 x,y 构成的网格点
#
#     xy=np.vstack([X.ravel(),Y.ravel()]).T
#
#     P=model.decision_function(xy).reshape(X.shape)
#
#     #绘制决策边界
#     # plot decision boundary and margins
#     ax.contour(X, Y, P, colors='k',
#                levels=[-1, 0, 1], alpha=0.5,
#                linestyles=['--', '-', '--'])
#
#     # plot support vectors
#     if plot_support:
#         ax.scatter(model.support_vectors_[:, 0],
#                    model.support_vectors_[:, 1],
#                    s=300, linewidth=1, facecolors='none');
#     ax.set_xlim(xlim)
#     ax.set_ylim(ylim)

#把数据点和决策边界一起绘制出来

plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn')
plot_svc_decision_function(model);

#SVM 圈出 支持向量机

print("支持向量机为:%s"%model.support_vectors_)

# 只要支持向量不变,决策边界就不变

# 读取数据集
import matplotlib.pyplot as plt
from sklearn.datases import fetch_lfw_people
import seaborn as sns
faces = fetch_lfw_people(min_faces_per_person=40)
print(faces.target_names)
print(faces.images.shape)

#Let's plot a few of these faces to see what we're working with:
fig, ax = plt.subplots(3, 5)
for i, axi in enumerate(ax.flat):
    axi.imshow(faces.images[i], cmap='bone')
    axi.set(xticks=[], yticks=[],
            xlabel=faces.target_names[faces.target[i]])

from sklearn.svm import SVC
#from sklearn.decomposition import RandomizedPCA
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline

pca = PCA(n_components=150, whiten=True, random_state=42)
svc = SVC(kernel='rbf', class_weight='balanced')
model = make_pipeline(pca, svc)

from sklearn.model_selection import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(faces.data, faces.target,
                                                random_state=40)

# 使用grid search cross-validation来选择我们的参数
from sklearn.model_selection import GridSearchCV
param_grid={'svc_C':[1,5,10],
            'svc_gamma':[0.0001,0.0005,0.001]}
#字典

#交叉验证
grid= GridSearchCV(model, param_grid)

#训练模型
grid.fit(Xtrain, ytrain)
print(grid.best_params_)

#预测
model = grid.best_estimator_
yfit = model.predict(Xtest)

# 看效果
fig,ax=plt.subplot(4,6)
for i, axi in enumerate(ax.flat):
    axi.imshow(Xtest[i].reshape(62,47),cmap='bone')
    axi.set(xticks=[], yticks=[])
    axi.set_ylabel(faces.target_names[yfit[i]].split()[-1],
                   color='black' if yfit[i] == ytest[i] else 'red')

fig.suptitle('Predicted Names; Incorrect Labels in Red', size=14);


#查看评估指标
from sklearn.metrics import classification_report
print(classification_report(ytest, yfit,
                            target_names=faces.target_names))

#混淆矩阵
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(ytest, yfit)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
            xticklabels=faces.target_names,
            yticklabels=faces.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label');
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐