package com.drone.poppy.detection.service.impl;
import com.drone.poppy.detection.service.ImageRecognitionService;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
@Service
public class TensorFlowImageRecognitionServiceImpl implements ImageRecognitionService {
private static final String MODEL_PATH = "model/poppy_detection_model.pb";
private static final String LABEL_PATH = "model/labels.txt";
private static final int INPUT_SIZE = 224;
private static final int CHANNELS = 3;
private byte[] graphDef;
private String[] labels;
public TensorFlowImageRecognitionServiceImpl() {
try {
loadModel();
loadLabels();
} catch (IOException e) {
e.printStackTrace();
}
}
private void loadModel() throws IOException {
Path modelPath = Paths.get(MODEL_PATH);
this.graphDef = Files.readAllBytes(modelPath);
}
private void loadLabels() throws IOException {
Path labelsPath = Paths.get(LABEL_PATH);
this.labels = new String(Files.readAllBytes(labelsPath)).split("\n");
}
@Override
public Map<String, Object> recognizePoppy(MultipartFile imageFile) throws IOException {
Path tempFile = Files.createTempFile("poppy-image-", ".jpg");
imageFile.transferTo(tempFile.toFile());
return recognizePoppyByPath(tempFile.toString());
}
@Override
public Map<String, Object> recognizePoppyByPath(String imagePath) throws IOException {
Map<String, Object> result = new HashMap<>();
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
Tensor<Float> inputTensor = constructAndExecuteGraphToNormalizeImage(imagePath)) {
Tensor<Float> output = s.runner()
.feed("input", inputTensor)
.fetch("output")
.run()
.get(0)
.expect(Float.class);
final long[] rshape = output.shape();
if (output.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
java.util.Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
float[] probabilities = output.copyTo(new float[1][nlabels])[0];
int maxIndex = 0;
float maxProbability = probabilities[0];
for (int i = 1; i < probabilities.length; i++) {
if (probabilities[i] > maxProbability) {
maxProbability = probabilities[i];
maxIndex = i;
}
}
result.put("hasPoppy", labels[maxIndex].equals("poppy"));
result.put("probability", maxProbability);
result.put("label", labels[maxIndex]);
}
}
return result;
}
private Tensor<Float> constructAndExecuteGraphToNormalizeImage(String imagePath) throws IOException {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
// 定义图操作:读取、解码和缩放图像
final String inputName = "file_reader";
b.readFile(inputName, imagePath);
String outputName = "normalized";
b.normalizeImage(inputName, outputName, INPUT_SIZE, INPUT_SIZE, CHANNELS);
try (Session s = new Session(g)) {
return s.runner().fetch(outputName).run().get(0).expect(Float.class);
}
}
}
// 用于构建TensorFlow图的辅助类
static class GraphBuilder {
private final Graph g;
GraphBuilder(Graph g) {
this.g = g;
}
Tensor<String> constant(String name, String value) {
try (Tensor<String> t = Tensor.create(value, String.class)) {
return g.opBuilder("Const", name)
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build()
.output(0);
}
}
Tensor<Float> normalizeImage(String inputName, String outputName, long width, long height, long channels) {
Tensor<String> fileReader = constant("file_reader", inputName);
Tensor<Float> decodedImage = decodeImage(fileReader);
Tensor<Float> resizedImage = resizeImage(decodedImage, width, height);
return g.opBuilder("Cast", outputName)
.addInput(resizedImage)
.setAttr("DstT", org.tensorflow.DataType.FLOAT)
.build()
.output(0);
}
Tensor<Float> decodeImage(Tensor<String> fileReader) {
Tensor<Float> decodedImage = g.opBuilder("DecodeJpeg", "decode_jpeg")
.addInput(fileReader)
.setAttr("channels", 3L)
.build()
.output(0);
return g.opBuilder("Cast", "float_casted")
.addInput(decodedImage)
.setAttr("DstT", org.tensorflow.DataType.FLOAT)
.build()
.output(0);
}
Tensor<Float> resizeImage(Tensor<Float> decodedImage, long width, long height) {
Tensor<Long> desiredSize = g.opBuilder("Const", "size")
.setAttr("dtype", org.tensorflow.DataType.INT32)
.setAttr("value", Tensor.create(new int[]{(int) width, (int) height}, new long[]{2}))
.build()
.output(0);
return g.opBuilder("ResizeBilinear", "resize")
.addInput(decodedImage)
.addInput(desiredSize)
.build()
.output(0);
}
Tensor<String> readFile(String name, String path) {
return g.opBuilder("ReadFile", name)
.addInput(constant(name + "_path", path))
.build()
.output(0);
}
}
}