Keras对多维Tensor的argmax()解析

基础理论

argmax中的axis参数表示在该维度上比较各元素。并且,张量各维度对换,不影响在该维度取argmax()的结果。

a = tf.constant([[[1, 2, 3], [3, 2, 2]]
using System; // 基础系统功能(Console、异常等) using System.Linq; // LINQ查询(Select、ToArray等) using System.Reflection; // 反射(当前代码未直接使用,可移除) using Tensorflow.Keras.Engine; // Keras引擎接口(IModel) using Tensorflow.Keras.Models; // Keras模型操作(load_model) using static Tensorflow.Binding; // TensorFlow静态绑定(tf对象) using static Tensorflow.KerasApi; // Keras静态API(如keras对象) class ModelPredictor { static IModel? _loadedModel; // 声明可空模型变量 static void Main() { _loadedModel = keras.models.load_model("trained_model.h5");// 从当前目录加载训练好的模型(确保模型文件与程序在同一目录) Console.WriteLine("===== Keras.NET 模型预测程序 ===="); Console.WriteLine("请输入待预测数据(格式:特征1,特征2,特征3,特征4,示例:0.2,0.5,0.3,0.7)"); // 输入循环(持续接收用户输入) while (true) { Console.Write("输入数据(输入exit退出): "); // 提醒用户输入操作 string input = Console.ReadLine(); // 读取用户输入 if (input?.ToLower() == "exit") break; // 用户输入退出条件 try { float[] features = [.. input.Split(',').Select(float.Parse)]; // 解析输入数据 if (features.Length != 4) throw new ArgumentException("需要4个特征值,用逗号分隔"); // 输入验证 var inputTensor = tf.convert_to_tensor(new[] { features }, dtype: tf.float32); // 转换为模型输入张量(关键数据处理),二维张量[1,4] if (_loadedModel == null) throw new InvalidOperationException("模型未成功加载"); // 执行预测 var predictions = _loadedModel.predict(inputTensor); // int classIndex = predictions.numpy().argmax(axis: 1)[0]; // 获取第一个样本的最大概率类别 Console.WriteLine($"预测结果:类别 {classIndex + 1}\n"); // +1适配1-based标签 } catch (Exception ex) { Console.WriteLine($"输入错误:{ex.Message}\n");// 捕获并提示具体错误(如格式错误、模型空引用) } } } }CS1061 “NDArray”未包含“argmax”的定义,并且找不到可接受第一个“NDArray”类型参数的可访问扩展方法“argmax”(是否缺少 using 指令或程序集引用?)
最新发布
07-21
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值