c# OpenCvSharp knn 手写数字识别

c# OpenCvSharp knn 手写数字识别

using OpenCvSharp;
using OpenCvSharp.Extensions;
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows.Forms;
using System.Threading;
using System.IO;

namespace app
{
    public partial class FrmMain : Form
    {
        public FrmMain()
        {
            InitializeComponent();
        }

        /// <summary>
        /// knn手写数字识别
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void btnKNN_Click(object sender, EventArgs e)
        {
            // 训练数据数量
            int train_sample_count = 60000;
            // 测试数据数量
            int test_sample_count = 10000;
            // 声明训练数据集合 mat,60000行,784列
            Mat trainData = new Mat(train_sample_count, 28 * 28, MatType.CV_32FC1);
            // 声明测试数据集合 mat,10000行,784列
            Mat testData = new Mat(test_sample_count, 28 * 28, MatType.CV_32FC1);
            // 声明训练数据标签 mat,60000行,1列
            Mat trainLabel = new Mat(train_sample_count, 1, MatType.CV_32FC1);
            // 声明测试数据标签 mat,10000行,1列
            Mat testLabel = new Mat(test_sample_count, 1, MatType.CV_32FC1);

            string trainPath = @"img\mnist\train_images";
            string testPath = @"img\mnist\test_images";

            // 组织训练数据,循环训练文件夹内所有图片。
            int trainNum = 0;
            for (int i = 0; i < 10; i++)
            {
                string path = trainPath + "\\" + i;
                DirectoryInfo TheFolder = new DirectoryInfo(path);
                foreach (FileInfo NextFile in TheFolder.GetFiles())
                {
                    // 读入单通道灰度图
                    Mat temp = new Mat(NextFile.FullName, ImreadModes.Grayscale);
                    // 转换CV_32FC1,因为下面训练函数需要这个格式
                    temp.ConvertTo(temp, MatType.CV_32FC1);
                    // 写入到训练数据集合的mat内,注意reshape的用法。
                    /*
                    reshape有两个参数:
                    其中,参数:cn为新的通道数,如果cn = 0,表示通道数不会改变。
                    参数rows为新的行数,如果rows = 0,表示行数不会改变。
                    注意:新的行* 列必须与原来的行*列相等。就是说,如果原来是5行3列,新的行和列可以是1行15列,3行5列,5行3列,15行1列。
                    */
                    temp.Reshape(0, 1).CopyTo(trainData.Row(trainNum));
                    // 写入到训练标签集合的mat内
                    trainLabel.Set<float>(trainNum, i);
                    trainNum++;
                }
            }

            // 组织测试数据
            int testNum = 0;
            for (int i = 0; i < 10; i++)
            {
                string path = testPath + "\\" + i;
                DirectoryInfo TheFolder = new DirectoryInfo(path);
                foreach (FileInfo NextFile in TheFolder.GetFiles())
                {
                    Mat temp = new Mat(NextFile.FullName, ImreadModes.Grayscale);
                    temp.ConvertTo(temp, MatType.CV_32FC1);
                    temp.Reshape(0, 1).CopyTo(testData.Row(testNum));
                    testLabel.Set<float>(testNum, i);
                    testNum++;
                }
            }

            // 创建knn模型
            OpenCvSharp.ML.KNearest knn = OpenCvSharp.ML.KNearest.Create();
            // k 可以根据需要自行调整
            int k = 3;
            // 设置K值
            knn.DefaultK = k;
            // 设置KNN是进行分类还是回归
            knn.IsClassifier = true;
            // 设置算法类型 BruteForce 或 KdTree
            knn.AlgorithmType = OpenCvSharp.ML.KNearest.Types.BruteForce;

            // 训练
            knn.Train(trainData, OpenCvSharp.ML.SampleTypes.RowSample, trainLabel);

            // 测试
            Mat result = new Mat(test_sample_count, 1, MatType.CV_32FC1);
            knn.FindNearest(testData, k, result);
            int t = 0;
            int f = 0;
            for (int i = 0; i < test_sample_count; i++)
            {
                int predict = (int)result.At<float>(i);
                int actual = (int)testLabel.At<float>(i);

                if (predict == actual)
                {
                    System.Console.WriteLine("正确:" + predict + "-" + actual);
                    t++;
                }
                else
                {
                    System.Console.WriteLine("错误------:" + predict + "-" + actual);
                    f++;
                }
            }

            double accuracy = (t * 1.0) / (t + f);
            System.Console.WriteLine("准确率:" + accuracy);
        }
    }
}

 

发表回复

您的电子邮箱地址不会被公开。