OpenCV ml 模块 knn算法 Python版

OpenCV ml 模块 knn算法 Python版

import cv2
import numpy as np
import matplotlib.pyplot as plt

point_all = np.random.randint(0, 100, (20, 2))                # 随机选择20个点
label_all = np.random.randint(0, 2, (20, 1))                  # 为随机点随机分配标志
label_0 = point_all[label_all.ravel() == 0]                   # 分出标志为0的点
plt.scatter(label_0[:, 0], label_0[:, 1], 80, 'b', 's')      # 将标志为0的点绘制为蓝色矩形
label_1 = point_all[label_all.ravel() == 1]                   # 分出标志为1的点
plt.scatter(label_1[:, 0], label_1[:, 1], 80, 'r', '^')      # 将标志为1的点绘制为红色三角形
point_new = np.random.randint(0, 100, (1, 2))                 # 随机选择一个点,下面确定其分类
plt.scatter(point_new[:, 0], point_new[:, 1], 80, 'g', 'o')   # 将待分类新点绘制为绿色圆点
plt.show()                                                    # 进一步使用knn算法确认待分类新点的类别、3个最近邻居和距离
knn = cv2.ml.KNearest_create()                                # 创建kNN分类器
knn.train(point_all.astype(np.float32), cv2.ml.ROW_SAMPLE, label_all.astype(np.float32))  # 训练模型
ret, results, neighbours, dist = knn.findNearest(point_new.astype(np.float32), 3)         # 找出3个最近邻居
print("新点标志: %s" % results)
print("邻居: %s" % neighbours)
print("距离: %s" % dist)

 

np.random.randint(low, high=None, size=None, dtype=’l’)
low 最小值。
high 最大值。
size 数组维度大小。
dtyp 数据类型,默认的数据类型是np.int。

np.random.randint(5, size=(2, 4)) #2行4列,higt为none,取值是0到5之间的整数。
array([[4, 3, 0, 4],
[3, 1, 1, 3]])

发表回复

您的电子邮箱地址不会被公开。