好的,以下是使用 Matplotlib 中的 'meshgrid'、'scatter3D' 和 'plot_surface' 函数来实现三维散点和超平面可视化的示例代码:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
from sklearn.svm import SVC

# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, [0, 1, 2]]  # 选择特征(萼片长度、萼片宽度、花瓣长度)
y = iris.target

# 创建SVM模型并拟合数据
model = SVC(kernel='linear')
model.fit(X, y)

# 绘制散点图
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter3D(X[:, 0], X[:, 1], X[:, 2], c=y)

# 生成超平面的网格点
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
z_min, z_max = X[:, 2].min() - 1, X[:, 2].max() + 1
xx, yy, zz = np.meshgrid(np.arange(x_min, x_max, 0.2),
                         np.arange(y_min, y_max, 0.2),
                         np.arange(z_min, z_max, 0.2))

# 预测网格点的标签
Z = model.predict(np.c_[xx.ravel(), yy.ravel(), zz.ravel()])
Z = Z.reshape(xx.shape)

# 绘制超平面
ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, cmap='cool')
ax.contour3D(xx, yy, zz, Z, cmap='cool', alpha=0.5)

# 设置坐标轴标签
ax.set_xlabel('Sepal Length')
ax.set_ylabel('Sepal Width')
ax.set_zlabel('Petal Length')

plt.show()

这段代码使用了鸢尾花数据集的前三个特征进行分类和可视化。首先,使用 'scatter3D' 函数绘制了三维散点图,不同类别的数据点使用不同的颜色进行区分。

然后,使用 'meshgrid' 函数生成超平面的网格点,通过预测网格点的标签并使用 'plot_surface' 函数绘制超平面。同时,使用 'contour3D' 函数绘制超平面的等高线,以增加可视化的效果。

最后,通过设置坐标轴的标签,使可视化结果更加清晰。

请注意,这只是一个示例代码,您可以根据实际需求调整特征选择、网格精度和其他参数,以获得更好的分类和可视化效果。

标签: 常规


原文地址: https://gggwd.com/t/topic/cmN5 著作权归作者所有。请勿转载和采集!