From aa562d3b26d8b6de0f0fc0b842ba3894ebcf0945 Mon Sep 17 00:00:00 2001
From: bsw215583320 <baoshiwei121@163.com>
Date: 星期一, 08 一月 2024 08:45:30 +0800
Subject: [PATCH] 优化模型调用

---
 jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java                        |   16 +++++++++++++++-
 jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java |   32 +++++++++++++++++++-------------
 2 files changed, 34 insertions(+), 14 deletions(-)

diff --git a/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java
index 8635429..90c6db2 100644
--- a/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java
+++ b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java
@@ -9,6 +9,10 @@
 import ai.djl.modality.cv.transform.*;
 import ai.djl.modality.cv.translator.ImageClassificationTranslator;
 import ai.djl.translate.Translator;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.core.io.Resource;
+import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
+import org.springframework.core.io.support.ResourcePatternResolver;
 import org.springframework.stereotype.Component;
 
 import javax.imageio.ImageIO;
@@ -19,6 +23,7 @@
 import java.util.ArrayList;
 import java.util.List;
 
+@Slf4j
 @Component
 public class HerbUtil {
  
@@ -56,6 +61,9 @@
             System.out.println(output);
             result = output.topK();
         } catch (Exception e) {
+            log.error("鑽潗璇嗗埆寮傚父锛侊紒");
+            log.error(input.toString());
+            log.error(predictor.toString());
             e.printStackTrace();
         }
         return result;
@@ -103,7 +111,13 @@
         //闅忎究璧峰悕
         Model model = Model.newInstance("model", Device.cpu());
         try {
-            InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt");
+//            ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
+//            Resource[] resources = resolver.getResources("../pytorch/model34.pt");
+            //            Resource resource = resources[0];
+            File f = new File("../pytorch/model34.pt");
+
+            InputStream inputStream = new FileInputStream(f);
+           // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt");
             if (inputStream == null) {
                 throw new RuntimeException("鎵句笉鍒版ā鍨嬫枃浠�");
             }
diff --git a/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java b/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java
index 0ca23e1..2878189 100644
--- a/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java
+++ b/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java
@@ -95,23 +95,29 @@
             }
             InputStream inputStream = file.getInputStream();
             List<Classifications.Classification> predict = herbUtil.predict(inputStream);
-            Map<String, Double> collect = predict.stream().collect(Collectors.toMap(Classifications.Classification::getClassName, Classifications.Classification::getProbability));
+
+            if(predict.size()>0) {
+                Map<String, Double> collect = predict.stream().collect(Collectors.toMap(Classifications.Classification::getClassName, Classifications.Classification::getProbability));
 
 
-            List<DryHerbInfoVo> voList = new ArrayList<>();
-            Set<String> strings = collect.keySet();
-            List<DryHerbInfo> list = herbInfoService.list(new LambdaQueryWrapper<DryHerbInfo>().in(DryHerbInfo::getPinyin, strings));
-            list.forEach(item -> {
-                DryHerbInfoVo dryHerbInfoVo = new DryHerbInfoVo();
-                BeanUtil.copyProperties(item, dryHerbInfoVo);
-                dryHerbInfoVo.setProbabily(collect.get(item.getPinyin()));
-                voList.add(dryHerbInfoVo);
-            });
+                List<DryHerbInfoVo> voList = new ArrayList<>();
+                Set<String> strings = collect.keySet();
+                List<DryHerbInfo> list = herbInfoService.list(new LambdaQueryWrapper<DryHerbInfo>().in(DryHerbInfo::getPinyin, strings));
+                list.forEach(item -> {
+                    DryHerbInfoVo dryHerbInfoVo = new DryHerbInfoVo();
+                    BeanUtil.copyProperties(item, dryHerbInfoVo);
+                    dryHerbInfoVo.setProbabily(collect.get(item.getPinyin()));
+                    voList.add(dryHerbInfoVo);
+                });
 
-            List<DryHerbInfoVo> collect1 = voList.stream().sorted(Comparator.comparing(DryHerbInfoVo::getProbabily, Comparator.reverseOrder())).
-                    collect(Collectors.toList());
+                List<DryHerbInfoVo> collect1 = voList.stream().sorted(Comparator.comparing(DryHerbInfoVo::getProbabily, Comparator.reverseOrder())).
+                        collect(Collectors.toList());
 
-            return Result.ok(collect1);
+                return Result.ok(collect1);
+            } else {
+                return Result.error("AI璇嗗埆鏈嶅姟寮傚父");
+            }
+
         } catch (Exception e) {
             e.printStackTrace();
             return Result.error("AI璇嗗埆鏈嶅姟寮傚父");

--
Gitblit v1.9.3