From 92e356d1ee1b877bc17d3aee3a137c234f01c477 Mon Sep 17 00:00:00 2001
From: bsw215583320 <baoshiwei121@163.com>
Date: 星期三, 20 十二月 2023 16:04:59 +0800
Subject: [PATCH] 增加药材识别神经网络模型及接口

---
 jeecg-module-dry/jeecg-module-dry-api/pom.xml                                                                       |   15 +++
 jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model.pt                                                 |    0 
 jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java                        |  156 +++++++++++++++++++++++++++++++++++++++
 jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model1.pt                                                |    0 
 jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java |   27 ++++++
 jeecg-module-dry/jeecg-module-dry-start/src/main/resources/class.txt                                                |   20 +++++
 6 files changed, 218 insertions(+), 0 deletions(-)

diff --git a/jeecg-module-dry/jeecg-module-dry-api/pom.xml b/jeecg-module-dry/jeecg-module-dry-api/pom.xml
index 3ad456f..8bce88b 100644
--- a/jeecg-module-dry/jeecg-module-dry-api/pom.xml
+++ b/jeecg-module-dry/jeecg-module-dry-api/pom.xml
@@ -47,6 +47,21 @@
             <artifactId>milo-spring-boot-starter</artifactId>
             <version>3.0.4</version>
         </dependency>
+        <dependency>
+            <groupId>ai.djl.pytorch</groupId>
+            <artifactId>pytorch-engine</artifactId>
+            <version>0.16.0</version>
+        </dependency>
+        <dependency>
+            <groupId>ai.djl.pytorch</groupId>
+            <artifactId>pytorch-native-auto</artifactId>
+            <version>1.9.1</version>
+        </dependency>
+        <dependency>
+            <groupId>ai.djl.pytorch</groupId>
+            <artifactId>pytorch-jni</artifactId>
+            <version>1.9.1-0.16.0</version>
+        </dependency>
     </dependencies>
     <build>
         <plugins>
diff --git a/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java
new file mode 100644
index 0000000..eba00b2
--- /dev/null
+++ b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java
@@ -0,0 +1,156 @@
+package org.jeecg.modules.dry.util;
+
+import ai.djl.Device;
+import ai.djl.Model;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
+import ai.djl.modality.cv.transform.*;
+import ai.djl.modality.cv.translator.ImageClassificationTranslator;
+import ai.djl.translate.Translator;
+import org.springframework.stereotype.Component;
+
+import javax.imageio.ImageIO;
+import javax.imageio.stream.ImageOutputStream;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.io.*;
+import java.util.ArrayList;
+import java.util.List;
+
+@Component
+public class HerbUtil {
+ 
+    //瑙勫畾杈撳叆灏哄
+    private static final int INPUT_SIZE = 224;
+
+    private static final int TARGET_SIZE = 256;
+ 
+    //鏍囩鏂囦欢 涓�绉嶇被鍒悕瀛楀崰涓�琛�
+    private List<String> herbNames;
+ 
+    //鐢ㄤ簬璇嗗埆
+    Predictor<Image, Classifications> predictor;
+ 
+    //妯″瀷
+    private Model model;
+ 
+    public HerbUtil() {
+        //鍔犺浇鏍囩鍒癶erbNames涓�
+        this.loadHerbNames();
+        //鍒濆鍖栨ā鍨嬪伐浣�
+        this.init();
+
+
+
+    }
+
+    public List<Classifications.Classification> predict(InputStream inputStream) {
+        List<Classifications.Classification> result = new ArrayList<>();
+        Image input = this.resizeImage(inputStream);
+        try {
+            Classifications output = predictor.predict(input);
+            System.out.println("鎺ㄦ祴涓猴細" + output.best().getClassName()
+                    + ", 姒傜巼锛�" + output.best().getProbability());
+            System.out.println(output);
+            result = output.topK();
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+        return result;
+    }
+
+    private void loadHerbNames() {
+        BufferedReader reader = null;
+        herbNames = new ArrayList<>();
+        try {
+            InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt");
+            reader = new BufferedReader(new InputStreamReader(in));
+            String name = null;
+            while ((name = reader.readLine()) != null) {
+                herbNames.add(name);
+            }
+            System.out.println(herbNames);
+        } catch (Exception e) {
+            e.printStackTrace();
+        } finally {
+            if (reader != null) {
+                try {
+                    reader.close();
+                } catch (IOException e) {
+                    e.printStackTrace();
+                }
+            }
+        }
+    }
+
+    private void init() {
+        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
+                //涓嬮潰鐨則ransform鏍规嵁鑷繁鐨勬敼
+                .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE))
+
+                .addTransform(new ToTensor())
+                .addTransform(new Normalize(
+                        new float[] {0.485f, 0.456f, 0.406f},
+                        new float[] {0.229f, 0.224f, 0.225f}))
+
+                //杞藉叆鎵�鏈夋爣绛捐繘鍘�
+                .optSynset(herbNames)
+                //鏈�缁堟樉绀烘鐜囨渶楂樼殑5涓�
+                .optTopK(5)
+                .build();
+        //闅忎究璧峰悕
+        Model model = Model.newInstance("model", Device.cpu());
+        try {
+            InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model1.pt");
+            if (inputStream == null) {
+                throw new RuntimeException("鎵句笉鍒版ā鍨嬫枃浠�");
+            }
+            model.load(inputStream);
+
+            predictor = model.newPredictor(translator);
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+    }
+
+    private Image resizeImage(InputStream inputStream) {
+        BufferedImage input = null;
+        try {
+            input = ImageIO.read(inputStream);
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+        int iw = input.getWidth(), ih = input.getHeight();
+        int w = 256, h = 256;
+        double scale = Math.max(1. *  w / iw, 1. * h / ih);
+        int nw = (int) (iw * scale), nh = (int) (ih * scale);
+        java.awt.Image img;
+        //鍙湁澶暱鎴栧お瀹芥墠浼氫繚鐣欐í绾垫瘮锛屽~鍏呴鑹�
+       // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4;
+      //  if (needResize) {
+            img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH);
+      //  } else {
+       //     img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH);
+      //  }
+        BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB);
+        Graphics g = out.getGraphics();
+        //鍏堝皢鏁翠釜224*224鍖哄煙濉厖128 128 128棰滆壊
+        g.setColor(new Color(255, 255, 255));
+        g.fillRect(0, 0, nw, nh);
+        out.getGraphics().drawImage(img, 0, 0, null);
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        try {
+            ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream);
+            ImageIO.write(out, "jpg", imageOutputStream);
+            //鍘籇鐩樼湅鏁堟灉
+            ImageIO.write(out, "jpg", new File("E:\\out.jpg"));
+            InputStream is = new ByteArrayInputStream(outputStream.toByteArray());
+            return ImageFactory.getInstance().fromInputStream(is);
+        } catch (IOException e) {
+            e.printStackTrace();
+            throw new RuntimeException("鍥剧墖杞崲澶辫触");
+        }
+    }
+}
\ No newline at end of file
diff --git a/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java b/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java
index 46ee064..a6dbf31 100644
--- a/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java
+++ b/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java
@@ -1,6 +1,7 @@
 package org.jeecg.modules.dry.controller;
 
 
+import ai.djl.modality.Classifications;
 import io.swagger.annotations.Api;
 import io.swagger.annotations.ApiOperation;
 import lombok.extern.slf4j.Slf4j;
@@ -8,11 +9,16 @@
 
 import org.jeecg.modules.dry.service.*;
 
+import org.jeecg.modules.dry.util.HerbUtil;
 import org.jeecg.modules.dry.vo.CommandMessageVo;
 import org.jeecg.modules.dry.vo.RealTimeDataVo;
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.web.bind.annotation.*;
+import org.springframework.web.multipart.MultipartFile;
+
+import java.io.InputStream;
+import java.util.List;
 
 
 @Api(tags = "瀹炴椂鏁版嵁澶勭悊鎺у埗鍣�")
@@ -23,6 +29,9 @@
 
     @Autowired
     private IDryRealTimeDataService dryRealTimeDataService;
+
+    @Autowired
+    private HerbUtil herbUtil;
 
 
     @ApiOperation(value="娴嬭瘯", notes="杩斿洖Hello")
@@ -66,4 +75,22 @@
     public Result<?> sendCommand(@RequestBody CommandMessageVo msgVo) {
         return dryRealTimeDataService.sendSocketMsg(msgVo);
     }
+
+
+    @ApiOperation(value = "鑽潗璇嗗埆")
+    @PostMapping("/identify")
+    public Result<?> identify(@RequestParam("file") MultipartFile file) throws Exception {
+        try {
+            if (file.isEmpty()) {
+                throw new RuntimeException("涓婁紶鏂囦欢涓嶈兘涓虹┖");
+            }
+            InputStream inputStream = file.getInputStream();
+            List<Classifications.Classification> predict = herbUtil.predict(inputStream);
+            return Result.ok(predict);
+        } catch (Exception e) {
+            e.printStackTrace();
+            return Result.error("AI璇嗗埆鏈嶅姟寮傚父");
+        }
+    }
+
 }
diff --git a/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/class.txt b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/class.txt
new file mode 100644
index 0000000..6a15e51
--- /dev/null
+++ b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/class.txt
@@ -0,0 +1,20 @@
+baihe
+baihuasheshecao
+danggui
+dangshen
+fangfeng
+gancao
+gouqi
+huaihua
+huangqi
+jinyinhua
+juhua
+machixian
+mohanlian
+nuodaogen
+sangbaipi
+sangzhipian
+shudihuang
+yinyanghuo
+zaojiaoci
+zisugeng
diff --git a/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model.pt b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model.pt
new file mode 100644
index 0000000..cd39814
--- /dev/null
+++ b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model.pt
Binary files differ
diff --git a/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model1.pt b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model1.pt
new file mode 100644
index 0000000..dfca629
--- /dev/null
+++ b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model1.pt
Binary files differ

--
Gitblit v1.9.3