c++ OpenCV knn 手写数字识别 模型保存及加载

c++ OpenCV knn 手写数字识别 模型保存及加载

一、原图
opencv4.7\sources\samples\data\digits.png

//读取OpenCV自带的一张手写体数字图,尺寸为Size(2000,1000),其中每个数字为(20,20)的区域,总共有 [(1000/20)x(2000/20)] 共5000个数字。

二、分割图片

// 使用math库里的宏常量
#define _USE_MATH_DEFINES

#include <iostream>
#include <filesystem>
#include <string>
#include <windows.h>
#include <io.h>
#include <direct.h>
#include <opencv2/opencv.hpp>

namespace fs = std::filesystem;
using namespace cv;
using namespace std;

// 分割数字图片
void split_digital_img()
{
    int  filename = 0, filenum = 0;
    Mat img = imread("D:/opencv/opencv4.7/sources/samples/data/digits.png");
    Mat gray;
    cvtColor(img, gray, COLOR_BGR2GRAY);
    int b = 20;
    int m = gray.rows / b; // 原图为1000*2000
    int n = gray.cols / b; // 裁剪为5000个20*20的小图块

    // 创建文件夹
    for (int i = 0; i <= 9; i++)
    {
        // 文件夹路径
        string dir = "D:/opencv/mnist/pic/" + to_string(i);
        // 判断该文件夹是否存在
        if (_access(dir.c_str(), 0) == -1) {
            // Windows 创建文件夹
            int flag = _mkdir(dir.c_str());
        }
    }

    for (int i = 0; i < m; i++)
    {
        // 行上的偏移量
        int offsetRow = i * b;
        // 原图中每5行存储相同数字,因此过了5行要递增文件名
        if (i % 5 == 0 && i != 0)
        {
            filename++; // 递增文件名
            filenum = 0; // 清零文件计数器
        }

        for (int j = 0; j < n; j++)
        {
            int offsetCol = j * b; // 列上的偏移量
            string file_savepath = "D:/opencv/mnist/pic/" + to_string(filename) + "/" + to_string(filenum++) + ".png";
            // 截取20*20的小块
            Mat tmp;
            gray(Range(offsetRow, offsetRow + b), Range(offsetCol, offsetCol + b)).copyTo(tmp);
            imwrite(file_savepath, tmp); // 将对应的数字图像块保存到对应名字的文件夹中
        }
    }
}

三、训练、测试

// knn手写数字识别
void test_knn_train()
{
    string file_savepath;
    int testnum = 0, truenum = 0;
    const int K = 3;   //设置K值为3
    cv::Ptr<cv::ml::KNearest> knn = cv::ml::KNearest::create(); // 创建KNN类
    knn->setDefaultK(K);
    knn->setIsClassifier(true); // 设置KNN用于分类 
    knn->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE); // 设置寻找距离最近的K个样本的方式为遍历所有训练样本,即暴力破解的方式。  

    Mat traindata, trainlabel;

    for (int i = 0; i < 10; i++)
    {
        for (int j = 0; j < 400; j++)
        {
            file_savepath = "D:/opencv/mnist/pic/" + to_string(i) + "/" + to_string(j) + ".png";
            Mat srcimage = imread(file_savepath);
            // 把二维数据转换为一维数据
            srcimage = srcimage.reshape(1, 1); // Mat Mat::reshape(int cn, int rows = 0),cn表示通道数,为0则通道不变,rows表示矩阵行数。
            traindata.push_back(srcimage); // srcimage为Mat类型的1行n列的一维矩阵,将该矩阵保存到训练矩阵中。
            trainlabel.push_back(i); // i为srcimage的标签,同时将i保存到标签矩阵中。
        }
    }

    traindata.convertTo(traindata, CV_32F); //重要:训练矩阵必须是浮点型数据
    knn->train(traindata, cv::ml::ROW_SAMPLE, trainlabel); // 导入训练数据和标签

    // 标签
    for (int i = 0; i < 10; i++)
    {
        for (int j = 400; j < 500; j++)
        {
            testnum++; //统计总的分类次数
            file_savepath = "D:/opencv/mnist/pic/" + to_string(i) + "/" + to_string(j) + ".png";
            Mat testdata = imread(file_savepath);
            testdata = testdata.reshape(1, 1); // 将二维数据转换成一维数据
            testdata.convertTo(testdata, CV_32F); // 将数据转换成浮点数据

            Mat result;
            // 寻找K个最邻近样本,并统计K个样本的分类数量,返回数量最多的分类的标签。
            int response = knn->findNearest(testdata, K, result);
            // 如果得到的分类标签与真实标签一致,则分类正确
            if (response == i)
            {
                truenum++;
            }
        }
    }
    cout << "测试总数" << testnum << endl;
    cout << "正确分类数" << truenum << endl;
    cout << "准确率:" << (float)truenum / testnum * 100 << "%" << endl;
}

(用时7秒左右)

四、保存模型

// knn手写数字识别
void test_knn_train()
{
    string file_savepath;
    int testnum = 0, truenum = 0;
    const int K = 3;   //设置K值为3
    cv::Ptr<cv::ml::KNearest> knn = cv::ml::KNearest::create(); // 创建KNN类
    knn->setDefaultK(K);
    knn->setIsClassifier(true); // 设置KNN用于分类 
    knn->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE); // 设置寻找距离最近的K个样本的方式为遍历所有训练样本,即暴力破解的方式。

    Mat traindata, trainlabel;

    for (int i = 0; i < 10; i++)
    {
        for (int j = 0; j < 400; j++)
        {
            file_savepath = "D:/opencv/mnist/pic/" + to_string(i) + "/" + to_string(j) + ".png";
            Mat srcimage = imread(file_savepath);
            // 把二维数据转换为一维数据
            srcimage = srcimage.reshape(1, 1); // Mat Mat::reshape(int cn, int rows = 0),cn表示通道数,为0则通道不变,rows表示矩阵行数。
            traindata.push_back(srcimage); // srcimage为Mat类型的1行n列的一维矩阵,将该矩阵保存到训练矩阵中。
            trainlabel.push_back(i); // i为srcimage的标签,同时将i保存到标签矩阵中。
        }
    }

    traindata.convertTo(traindata, CV_32F); //重要:训练矩阵必须是浮点型数据
    knn->train(traindata, cv::ml::ROW_SAMPLE, trainlabel); // 导入训练数据和标签
    knn->save("D:/opencv/mnist/knn_digits_model.yml"); // 保存模型
}

五、加载模型

// knn手写数字识别
void test_knn_model()
{
    string file_savepath;
    int testnum = 0, truenum = 0;
    const int k = 3; // 设置k值为3

    // 模型加载
    cv::Ptr<cv::ml::KNearest> knn = cv::ml::KNearest::load("D:/opencv/mnist/knn_digits_model.yml"); // 创建KNN类
    knn->setDefaultK(k);
    knn->setIsClassifier(true); // 设置KNN用于分类 
    knn->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE); // 设置寻找距离最近的K个样本的方式为遍历所有训练样本,即暴力破解的方式。

    Mat traindata, trainlabel;

    // 标签
    for (int i = 0; i < 10; i++)
    {
        for (int j = 400; j < 500; j++)
        {
            testnum++; //统计总的分类次数
            file_savepath = "D:/opencv/mnist/pic/" + to_string(i) + "/" + to_string(j) + ".png";
            Mat testdata = imread(file_savepath);
            testdata = testdata.reshape(1, 1); // 将二维数据转换成一维数据
            testdata.convertTo(testdata, CV_32F); // 将数据转换成浮点数据

            Mat result;
            // 寻找k个最邻近样本,并统计k个样本的分类数量,返回数量最多的分类的标签。
            int response = knn->findNearest(testdata, k, result);
            // 如果得到的分类标签与真实标签一致,则分类正确
            if (response == i)
            {
                truenum++;
            }
        }
    }
    cout << "测试总数" << testnum << endl;
    cout << "正确分类数" << truenum << endl;
    cout << "准确率:" << (float)truenum / testnum * 100 << "%" << endl;
}
int main()
{
    // knn手写数字识别
    //split_digital_img();
    //test_knn_train();
    test_knn_model();

    waitKey(0);
    destroyAllWindows();
    return 0;
}

 

发表回复

您的电子邮箱地址不会被公开。