import cv2 import numpy as np import matplotlib.pyplot as plt point_all = np.random.randint(0, 100, (20, 2)) # 随机选择20个点 label_all = np.random.randint(0, 2, (20, 1)) # 为随机点随机分配标志 label_0 = point_all[label_all.ravel() == 0] # 分出标志为0的点 plt.scatter(label_0[:, 0], label_0[:, 1], 80, 'b', 's') # 将标志为0的点绘制为蓝色矩形 label_1 = point_all[label_all.ravel() == 1] # 分出标志为1的点 plt.scatter(label_1[:, 0], label_1[:, 1], 80, 'r', '^') # 将标志为1的点绘制为红色三角形 point_new = np.random.randint(0, 100, (1, 2)) # 随机选择一个点,下面确定其分类 plt.scatter(point_new[:, 0], point_new[:, 1], 80, 'g', 'o') # 将待分类新点绘制为绿色圆点 plt.show() # 进一步使用knn算法确认待分类新点的类别、3个最近邻居和距离 knn = cv2.ml.KNearest_create() # 创建kNN分类器 knn.train(point_all.astype(np.float32), cv2.ml.ROW_SAMPLE, label_all.astype(np.float32)) # 训练模型 ret, results, neighbours, dist = knn.findNearest(point_new.astype(np.float32), 3) # 找出3个最近邻居 print("新点标志: %s" % results) print("邻居: %s" % neighbours) print("距离: %s" % dist)
np.random.randint(low, high=None, size=None, dtype=’l’)
low 最小值。
high 最大值。
size 数组维度大小。
dtyp 数据类型,默认的数据类型是np.int。
np.random.randint(5, size=(2, 4)) #2行4列,higt为none,取值是0到5之间的整数。
array([[4, 3, 0, 4],
[3, 1, 1, 3]])