关于tensorflow.net推理mnist图片问题
using System;
using Tensorflow;
using Tensorflow.Lite;
using OpenCvSharp;
using System.Runtime.InteropServices;
namespace ConsoleApp1
{
class Program
{
public static unsafe void test()
{
var modelPath = "mnist.tflite";
var imagePath = "1.jpg";
var model = c_api_lite.TfLiteModelCreateFromFile(modelPath);
var options = c_api_lite.TfLiteInterpreterOptionsCreate();
var interpreter = c_api_lite.TfLiteInterpreterCreate(model, options);
TfLiteStatus status = c_api_lite.TfLiteInterpreterAllocateTensors(interpreter);
TfLiteTensor inputTensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
Mat img = Cv2.ImRead(imagePath, ImreadModes.Grayscale);
var bytes = new byte[img.Total() * 1];
Marshal.Copy(img.Data, bytes, 0, bytes.Length);
fixed (byte* addr = &bytes[0])
{
c_api_lite.TfLiteTensorCopyFromBuffer(inputTensor, new IntPtr(addr), img.Cols*img.Rows*img.Channels() * sizeof(byte));
}
c_api_lite.TfLiteInterpreterInvoke(interpreter);
var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
var output = new byte[10];
fixed (byte* addr = &output[0])
{
c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 10 * sizeof(byte));
}
for (int i = 0; i< output.Length;i++)
{
Console.WriteLine(output[i]);
}
c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
}
static void Main(string[] args)
{
test();
}
}
}
推理结果为什么为零啊
这份代码使用了 Tensorflow Lite 进行机器学习模型的预测,根据输入的灰度图像(例如这里的 `1.jpg` ),模型会输出一个长度为 10 的一维数组,表示图像所属的数字类别。这里的模型路径 `mnist.tflite` 是经过训练后的 MNIST 手写数字识别模型。
在代码中,先通过 `c_api_lite.TfLiteModelCreateFromFile()` 从文件中创建模型。然后通过 `c_api_lite.TfLiteInterpreterCreate()` 创建解释器,并对解释器进行配置和内存分配等操作后,调用 `c_api_lite.TfLiteInterpreterInvoke()` 运行模型进行推理,最后通过 `c_api_lite.TfLiteInterpreterGetOutputTensor()` 和 `c_api_lite.TfLiteTensorCopyToBuffer()` 获取输出结果。其中涉及到了 C# 操作指针的部分。
需要注意的是,这份代码使用了 OpenCvSharp 库对图像进行处理,并用 `ImreadModes.Grayscale` 模式读取灰度图像,需要先安装 OpenCvSharp 库,并引入相应的命名空间。
同时,上述代码中不包括相关的库文件和训练好的模型文件,请自行下载和配置。
# Tensorflow MNIST 数据集使用
import tensorflow as tf
tf.__version__
'2.1.0'
import numpy as np
import matplotlib.pyplot as plt