3.2.2

# -*- coding: utf-8 -*-
# @Time    : 2025/7/23 8:18 下午
# @Author  : Hoey
import onnxruntime
import numpy as np
from PIL import Image


# 加载ONNX模型  2分
ort_session = onnxruntime.InferenceSession('mnist.onnx')


# 加载图像 2分
image = Image.open('img_test.png').convert('L')  # 转为灰度图


#图像预处理
image = image.resize((28, 28))  # 调整大小为MNIST模型的输入尺寸2分
image_array = np.array(image, dtype=np.float32)  # 转为numpy数组2分
image_array = np.expand_dims(image_array, axis=0)  # 添加batch维度2分
image_array = np.expand_dims(image_array, axis=1)  # 添加通道维度2分


#返回模型输入列表 2分
ort_inputs = {ort_session.get_inputs()[0].name: image_array}
# 执行预测 2分
ort_outs = ort_session.run(None, ort_inputs)

# 获取预测结果 2分
predicted_class = np.argmax(ort_outs)


# 输出预测结果
print(f"Predicted class: {predicted_class}")