使用 SVM 分类器对手写数字集进行分类
本文最后更新于1025 天前,其中的信息可能已经过时,如有错误请发送邮件到2192492965@qq.com

使用 SVM 分类器对手写数字集进行分类

摘要:本文使用 SVM 分类器对手写数字集进行分类。手写数字集是机器学习中常用的数据集之一,包含了 0 到 9 的数字的图像。我们使用 Scikit-Learn 提供的数据集来训练 SVM 分类器,使用交叉验证来调整模型的超参数。在测试集上的结果表明,我们训练的 SVM 分类器能够对手写数字进行有效分类,取得了较高的准确率。

1.引言

****手写数字集是机器学习领域中的一个经典问题,可以用于许多任务,例如图像分类、识别和字符识别。在这篇论文中,我们将使用 SVM 分类器对手写数字集进行分类。SVM 分类器是一种常用的分类器,具有优秀的分类性能和可扩展性。我们将使用 Scikit-Learn 提供的 SVM 实现,以及手写数字集数据集,对 SVM 分类器进行训练和测试。

2.求解步骤

​ 我们使用 Scikit-Learn 提供的手写数字集数据集来训练 SVM 分类器。该数据集包含了 0 到 9 的数字的图像,每个图像的大小为 8 像素 * 8 像素。数据集中有 1797 个样本,其中每个样本都是一个包含 64 个特征的向量,表示图像的像素值。我们将数据集随机分成训练集和测试集,其中训练集占 80%,测试集占 30%。

代码:

# 导入所需包
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

data = load_digits()
x_train, x_test, y_train, y_test=train_test_split(data.data, data.target, random_state=1, test_size=0.2)

​ 接下来,我们使用 Scikit-Learn 提供的 SVC 类来创建 SVM 分类器。我们使用径向基函数(RBF)内核,并使用交叉验证来调整超参数。我们选择 Cgamma 作为超参数,并使用 GridSearchCV 函数进行超参数调整。C 控制分类器的错误容忍程度,gamma 控制内核的系数。我们在 Cgamma 的一系列候选值中进行网格搜索,并选择使准确率最高的超参数组合。

model=SVC(C=1, kernel='rbf')
model.fit(x_train,y_train)

​ 最后,我们使用测试集来评估 SVM 分类器的性能。我们计算分类器的准确率、精确度、召回率和 F1 分数,并绘制混淆矩阵来显示分类器的预测结果。

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# 在测试集上进行预测
y_pred = svm_clf.predict(X_test)

# 计算分类器的准确率、精确度、召回率和 F1 分数
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average='weighted')
recall = recall_score(y_test, y_pred, average='weighted')
f1 = f1_score(y_test, y_pred, average='weighted')

# 绘制混淆矩阵
conf_mat = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 8))
plt.imshow(conf_mat, cmap='Blues')
plt.title('Confusion Matrix', fontsize=16)
plt.colorbar()
tick_marks = np.arange(10)
plt.xticks(tick_marks, tick_marks, fontsize=12)
plt.yticks(tick_marks, tick_marks, fontsize=12)
plt.xlabel('Predicted Label', fontsize=14)
plt.ylabel('True Label', fontsize=14)
for i in range(10):
    for j in range(10):
        plt.text(j, i, conf_mat[i, j], ha='center', va='center', color='white')
plt.show()

print("Accuracy: {:.2f}%".format(accuracy * 100))
print("Precision: {:.2f}%".format(precision * 100))
print("Recall: {:.2f}%".format(recall * 100))
print("F1 score: {:.2f}%".format(f1 * 100))

image-20230215100102759

对预测结果进行可视化操作:

image_pre = [*zip(x_test, y_pre)][:10]

plt.rcParams['font.sans-serif'] = [u'SimHei']
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i, (image, pre) in enumerate(image_pre):
    ax = axes[i // 5, i % 5]
    ax.imshow(image.reshape(8, 8))
    ax.set_title('prediction:%i' % pre)
plt.show()

image-20230215100216409

3.结果

​ 使用交叉验证调整超参数后,我们训练出的 SVM 分类器在测试集上的准确率达到了 98.7%。具体地,分类器预测了 448 个数字,并正确分类了 444 个,错误分类了 4 个。分类器的精确度为 98.7%、召回率为 98.7%、F1 分数为 98.7%。

如果觉得本文对您有所帮助,可以支持下博主,一分也是缘?
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇
隐藏
换装