From c2fccb01b972176dc3da5a497b5e904025e9e98d Mon Sep 17 00:00:00 2001
From: bsw215583320 <baoshiwei121@163.com>
Date: 星期二, 16 四月 2024 15:06:51 +0800
Subject: [PATCH] Merge branch 'master' of http://210.22.126.130:1111/r/dry/herb

---
 jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java |  170 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 170 insertions(+), 0 deletions(-)

diff --git a/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java
new file mode 100644
index 0000000..90c6db2
--- /dev/null
+++ b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java
@@ -0,0 +1,170 @@
+package org.jeecg.modules.dry.util;
+
+import ai.djl.Device;
+import ai.djl.Model;
+import ai.djl.inference.Predictor;
+import ai.djl.modality.Classifications;
+import ai.djl.modality.cv.Image;
+import ai.djl.modality.cv.ImageFactory;
+import ai.djl.modality.cv.transform.*;
+import ai.djl.modality.cv.translator.ImageClassificationTranslator;
+import ai.djl.translate.Translator;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.core.io.Resource;
+import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
+import org.springframework.core.io.support.ResourcePatternResolver;
+import org.springframework.stereotype.Component;
+
+import javax.imageio.ImageIO;
+import javax.imageio.stream.ImageOutputStream;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.io.*;
+import java.util.ArrayList;
+import java.util.List;
+
+@Slf4j
+@Component
+public class HerbUtil {
+ 
+    //瑙勫畾杈撳叆灏哄
+    private static final int INPUT_SIZE = 224;
+
+    private static final int TARGET_SIZE = 256;
+ 
+    //鏍囩鏂囦欢 涓�绉嶇被鍒悕瀛楀崰涓�琛�
+    private List<String> herbNames;
+ 
+    //鐢ㄤ簬璇嗗埆
+    Predictor<Image, Classifications> predictor;
+ 
+    //妯″瀷
+    private Model model;
+ 
+    public HerbUtil() {
+        //鍔犺浇鏍囩鍒癶erbNames涓�
+        this.loadHerbNames();
+        //鍒濆鍖栨ā鍨嬪伐浣�
+        this.init();
+
+
+
+    }
+
+    public List<Classifications.Classification> predict(InputStream inputStream) {
+        List<Classifications.Classification> result = new ArrayList<>();
+        Image input = this.resizeImage(inputStream);
+        try {
+            Classifications output = predictor.predict(input);
+            System.out.println("鎺ㄦ祴涓猴細" + output.best().getClassName()
+                    + ", 姒傜巼锛�" + output.best().getProbability());
+            System.out.println(output);
+            result = output.topK();
+        } catch (Exception e) {
+            log.error("鑽潗璇嗗埆寮傚父锛侊紒");
+            log.error(input.toString());
+            log.error(predictor.toString());
+            e.printStackTrace();
+        }
+        return result;
+    }
+
+    private void loadHerbNames() {
+        BufferedReader reader = null;
+        herbNames = new ArrayList<>();
+        try {
+            InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt");
+            reader = new BufferedReader(new InputStreamReader(in));
+            String name = null;
+            while ((name = reader.readLine()) != null) {
+                herbNames.add(name);
+            }
+            System.out.println(herbNames);
+        } catch (Exception e) {
+            e.printStackTrace();
+        } finally {
+            if (reader != null) {
+                try {
+                    reader.close();
+                } catch (IOException e) {
+                    e.printStackTrace();
+                }
+            }
+        }
+    }
+
+    private void init() {
+        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
+                //涓嬮潰鐨則ransform鏍规嵁鑷繁鐨勬敼
+                .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE))
+
+                .addTransform(new ToTensor())
+                .addTransform(new Normalize(
+                        new float[] {0.485f, 0.456f, 0.406f},
+                        new float[] {0.229f, 0.224f, 0.225f}))
+
+                //杞藉叆鎵�鏈夋爣绛捐繘鍘�
+                .optSynset(herbNames)
+                //鏈�缁堟樉绀烘鐜囨渶楂樼殑5涓�
+                .optTopK(5)
+                .build();
+        //闅忎究璧峰悕
+        Model model = Model.newInstance("model", Device.cpu());
+        try {
+//            ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
+//            Resource[] resources = resolver.getResources("../pytorch/model34.pt");
+            //            Resource resource = resources[0];
+            File f = new File("../pytorch/model34.pt");
+
+            InputStream inputStream = new FileInputStream(f);
+           // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt");
+            if (inputStream == null) {
+                throw new RuntimeException("鎵句笉鍒版ā鍨嬫枃浠�");
+            }
+            model.load(inputStream);
+
+            predictor = model.newPredictor(translator);
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+    }
+
+    private Image resizeImage(InputStream inputStream) {
+        BufferedImage input = null;
+        try {
+            input = ImageIO.read(inputStream);
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+        int iw = input.getWidth(), ih = input.getHeight();
+        int w = 256, h = 256;
+        double scale = Math.max(1. *  w / iw, 1. * h / ih);
+        int nw = (int) (iw * scale), nh = (int) (ih * scale);
+        java.awt.Image img;
+        //鍙湁澶暱鎴栧お瀹芥墠浼氫繚鐣欐í绾垫瘮锛屽~鍏呴鑹�
+       // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4;
+      //  if (needResize) {
+            img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH);
+      //  } else {
+       //     img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH);
+      //  }
+        BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB);
+        Graphics g = out.getGraphics();
+        //鍏堝皢鏁翠釜224*224鍖哄煙濉厖128 128 128棰滆壊
+        g.setColor(new Color(255, 255, 255));
+        g.fillRect(0, 0, nw, nh);
+        out.getGraphics().drawImage(img, 0, 0, null);
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        try {
+            ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream);
+            ImageIO.write(out, "jpg", imageOutputStream);
+            //鍘籇鐩樼湅鏁堟灉
+            ImageIO.write(out, "jpg", new File("E:\\out.jpg"));
+            InputStream is = new ByteArrayInputStream(outputStream.toByteArray());
+            return ImageFactory.getInstance().fromInputStream(is);
+        } catch (IOException e) {
+            e.printStackTrace();
+            throw new RuntimeException("鍥剧墖杞崲澶辫触");
+        }
+    }
+}
\ No newline at end of file

--
Gitblit v1.9.3