From 92e356d1ee1b877bc17d3aee3a137c234f01c477 Mon Sep 17 00:00:00 2001 From: bsw215583320 <baoshiwei121@163.com> Date: 星期三, 20 十二月 2023 16:04:59 +0800 Subject: [PATCH] 增加药材识别神经网络模型及接口 --- jeecg-module-dry/jeecg-module-dry-api/pom.xml | 15 +++ jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model.pt | 0 jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java | 156 +++++++++++++++++++++++++++++++++++++++ jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model1.pt | 0 jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java | 27 ++++++ jeecg-module-dry/jeecg-module-dry-start/src/main/resources/class.txt | 20 +++++ 6 files changed, 218 insertions(+), 0 deletions(-) diff --git a/jeecg-module-dry/jeecg-module-dry-api/pom.xml b/jeecg-module-dry/jeecg-module-dry-api/pom.xml index 3ad456f..8bce88b 100644 --- a/jeecg-module-dry/jeecg-module-dry-api/pom.xml +++ b/jeecg-module-dry/jeecg-module-dry-api/pom.xml @@ -47,6 +47,21 @@ <artifactId>milo-spring-boot-starter</artifactId> <version>3.0.4</version> </dependency> + <dependency> + <groupId>ai.djl.pytorch</groupId> + <artifactId>pytorch-engine</artifactId> + <version>0.16.0</version> + </dependency> + <dependency> + <groupId>ai.djl.pytorch</groupId> + <artifactId>pytorch-native-auto</artifactId> + <version>1.9.1</version> + </dependency> + <dependency> + <groupId>ai.djl.pytorch</groupId> + <artifactId>pytorch-jni</artifactId> + <version>1.9.1-0.16.0</version> + </dependency> </dependencies> <build> <plugins> diff --git a/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java new file mode 100644 index 0000000..eba00b2 --- /dev/null +++ b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java @@ -0,0 +1,156 @@ +package org.jeecg.modules.dry.util; + +import ai.djl.Device; +import ai.djl.Model; +import ai.djl.inference.Predictor; +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.transform.*; +import ai.djl.modality.cv.translator.ImageClassificationTranslator; +import ai.djl.translate.Translator; +import org.springframework.stereotype.Component; + +import javax.imageio.ImageIO; +import javax.imageio.stream.ImageOutputStream; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.*; +import java.util.ArrayList; +import java.util.List; + +@Component +public class HerbUtil { + + //瑙勫畾杈撳叆灏哄 + private static final int INPUT_SIZE = 224; + + private static final int TARGET_SIZE = 256; + + //鏍囩鏂囦欢 涓�绉嶇被鍒悕瀛楀崰涓�琛� + private List<String> herbNames; + + //鐢ㄤ簬璇嗗埆 + Predictor<Image, Classifications> predictor; + + //妯″瀷 + private Model model; + + public HerbUtil() { + //鍔犺浇鏍囩鍒癶erbNames涓� + this.loadHerbNames(); + //鍒濆鍖栨ā鍨嬪伐浣� + this.init(); + + + + } + + public List<Classifications.Classification> predict(InputStream inputStream) { + List<Classifications.Classification> result = new ArrayList<>(); + Image input = this.resizeImage(inputStream); + try { + Classifications output = predictor.predict(input); + System.out.println("鎺ㄦ祴涓猴細" + output.best().getClassName() + + ", 姒傜巼锛�" + output.best().getProbability()); + System.out.println(output); + result = output.topK(); + } catch (Exception e) { + e.printStackTrace(); + } + return result; + } + + private void loadHerbNames() { + BufferedReader reader = null; + herbNames = new ArrayList<>(); + try { + InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt"); + reader = new BufferedReader(new InputStreamReader(in)); + String name = null; + while ((name = reader.readLine()) != null) { + herbNames.add(name); + } + System.out.println(herbNames); + } catch (Exception e) { + e.printStackTrace(); + } finally { + if (reader != null) { + try { + reader.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + } + + private void init() { + Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() + //涓嬮潰鐨則ransform鏍规嵁鑷繁鐨勬敼 + .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE)) + + .addTransform(new ToTensor()) + .addTransform(new Normalize( + new float[] {0.485f, 0.456f, 0.406f}, + new float[] {0.229f, 0.224f, 0.225f})) + + //杞藉叆鎵�鏈夋爣绛捐繘鍘� + .optSynset(herbNames) + //鏈�缁堟樉绀烘鐜囨渶楂樼殑5涓� + .optTopK(5) + .build(); + //闅忎究璧峰悕 + Model model = Model.newInstance("model", Device.cpu()); + try { + InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model1.pt"); + if (inputStream == null) { + throw new RuntimeException("鎵句笉鍒版ā鍨嬫枃浠�"); + } + model.load(inputStream); + + predictor = model.newPredictor(translator); + } catch (Exception e) { + e.printStackTrace(); + } + } + + private Image resizeImage(InputStream inputStream) { + BufferedImage input = null; + try { + input = ImageIO.read(inputStream); + } catch (IOException e) { + e.printStackTrace(); + } + int iw = input.getWidth(), ih = input.getHeight(); + int w = 256, h = 256; + double scale = Math.max(1. * w / iw, 1. * h / ih); + int nw = (int) (iw * scale), nh = (int) (ih * scale); + java.awt.Image img; + //鍙湁澶暱鎴栧お瀹芥墠浼氫繚鐣欐í绾垫瘮锛屽~鍏呴鑹� + // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4; + // if (needResize) { + img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH); + // } else { + // img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH); + // } + BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB); + Graphics g = out.getGraphics(); + //鍏堝皢鏁翠釜224*224鍖哄煙濉厖128 128 128棰滆壊 + g.setColor(new Color(255, 255, 255)); + g.fillRect(0, 0, nw, nh); + out.getGraphics().drawImage(img, 0, 0, null); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream); + ImageIO.write(out, "jpg", imageOutputStream); + //鍘籇鐩樼湅鏁堟灉 + ImageIO.write(out, "jpg", new File("E:\\out.jpg")); + InputStream is = new ByteArrayInputStream(outputStream.toByteArray()); + return ImageFactory.getInstance().fromInputStream(is); + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("鍥剧墖杞崲澶辫触"); + } + } +} \ No newline at end of file diff --git a/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java b/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java index 46ee064..a6dbf31 100644 --- a/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java +++ b/jeecg-module-dry/jeecg-module-dry-biz/src/main/java/org/jeecg/modules/dry/controller/DryRealTimeDataController.java @@ -1,6 +1,7 @@ package org.jeecg.modules.dry.controller; +import ai.djl.modality.Classifications; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import lombok.extern.slf4j.Slf4j; @@ -8,11 +9,16 @@ import org.jeecg.modules.dry.service.*; +import org.jeecg.modules.dry.util.HerbUtil; import org.jeecg.modules.dry.vo.CommandMessageVo; import org.jeecg.modules.dry.vo.RealTimeDataVo; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; + +import java.io.InputStream; +import java.util.List; @Api(tags = "瀹炴椂鏁版嵁澶勭悊鎺у埗鍣�") @@ -23,6 +29,9 @@ @Autowired private IDryRealTimeDataService dryRealTimeDataService; + + @Autowired + private HerbUtil herbUtil; @ApiOperation(value="娴嬭瘯", notes="杩斿洖Hello") @@ -66,4 +75,22 @@ public Result<?> sendCommand(@RequestBody CommandMessageVo msgVo) { return dryRealTimeDataService.sendSocketMsg(msgVo); } + + + @ApiOperation(value = "鑽潗璇嗗埆") + @PostMapping("/identify") + public Result<?> identify(@RequestParam("file") MultipartFile file) throws Exception { + try { + if (file.isEmpty()) { + throw new RuntimeException("涓婁紶鏂囦欢涓嶈兘涓虹┖"); + } + InputStream inputStream = file.getInputStream(); + List<Classifications.Classification> predict = herbUtil.predict(inputStream); + return Result.ok(predict); + } catch (Exception e) { + e.printStackTrace(); + return Result.error("AI璇嗗埆鏈嶅姟寮傚父"); + } + } + } diff --git a/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/class.txt b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/class.txt new file mode 100644 index 0000000..6a15e51 --- /dev/null +++ b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/class.txt @@ -0,0 +1,20 @@ +baihe +baihuasheshecao +danggui +dangshen +fangfeng +gancao +gouqi +huaihua +huangqi +jinyinhua +juhua +machixian +mohanlian +nuodaogen +sangbaipi +sangzhipian +shudihuang +yinyanghuo +zaojiaoci +zisugeng diff --git a/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model.pt b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model.pt new file mode 100644 index 0000000..cd39814 --- /dev/null +++ b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model.pt Binary files differ diff --git a/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model1.pt b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model1.pt new file mode 100644 index 0000000..dfca629 --- /dev/null +++ b/jeecg-module-dry/jeecg-module-dry-start/src/main/resources/model1.pt Binary files differ -- Gitblit v1.9.3