From c2fccb01b972176dc3da5a497b5e904025e9e98d Mon Sep 17 00:00:00 2001 From: bsw215583320 <baoshiwei121@163.com> Date: 星期二, 16 四月 2024 15:06:51 +0800 Subject: [PATCH] Merge branch 'master' of http://210.22.126.130:1111/r/dry/herb --- jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java | 170 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 files changed, 170 insertions(+), 0 deletions(-) diff --git a/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java new file mode 100644 index 0000000..90c6db2 --- /dev/null +++ b/jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java @@ -0,0 +1,170 @@ +package org.jeecg.modules.dry.util; + +import ai.djl.Device; +import ai.djl.Model; +import ai.djl.inference.Predictor; +import ai.djl.modality.Classifications; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.transform.*; +import ai.djl.modality.cv.translator.ImageClassificationTranslator; +import ai.djl.translate.Translator; +import lombok.extern.slf4j.Slf4j; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.springframework.core.io.support.ResourcePatternResolver; +import org.springframework.stereotype.Component; + +import javax.imageio.ImageIO; +import javax.imageio.stream.ImageOutputStream; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.*; +import java.util.ArrayList; +import java.util.List; + +@Slf4j +@Component +public class HerbUtil { + + //瑙勫畾杈撳叆灏哄 + private static final int INPUT_SIZE = 224; + + private static final int TARGET_SIZE = 256; + + //鏍囩鏂囦欢 涓�绉嶇被鍒悕瀛楀崰涓�琛� + private List<String> herbNames; + + //鐢ㄤ簬璇嗗埆 + Predictor<Image, Classifications> predictor; + + //妯″瀷 + private Model model; + + public HerbUtil() { + //鍔犺浇鏍囩鍒癶erbNames涓� + this.loadHerbNames(); + //鍒濆鍖栨ā鍨嬪伐浣� + this.init(); + + + + } + + public List<Classifications.Classification> predict(InputStream inputStream) { + List<Classifications.Classification> result = new ArrayList<>(); + Image input = this.resizeImage(inputStream); + try { + Classifications output = predictor.predict(input); + System.out.println("鎺ㄦ祴涓猴細" + output.best().getClassName() + + ", 姒傜巼锛�" + output.best().getProbability()); + System.out.println(output); + result = output.topK(); + } catch (Exception e) { + log.error("鑽潗璇嗗埆寮傚父锛侊紒"); + log.error(input.toString()); + log.error(predictor.toString()); + e.printStackTrace(); + } + return result; + } + + private void loadHerbNames() { + BufferedReader reader = null; + herbNames = new ArrayList<>(); + try { + InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt"); + reader = new BufferedReader(new InputStreamReader(in)); + String name = null; + while ((name = reader.readLine()) != null) { + herbNames.add(name); + } + System.out.println(herbNames); + } catch (Exception e) { + e.printStackTrace(); + } finally { + if (reader != null) { + try { + reader.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + } + + private void init() { + Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() + //涓嬮潰鐨則ransform鏍规嵁鑷繁鐨勬敼 + .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE)) + + .addTransform(new ToTensor()) + .addTransform(new Normalize( + new float[] {0.485f, 0.456f, 0.406f}, + new float[] {0.229f, 0.224f, 0.225f})) + + //杞藉叆鎵�鏈夋爣绛捐繘鍘� + .optSynset(herbNames) + //鏈�缁堟樉绀烘鐜囨渶楂樼殑5涓� + .optTopK(5) + .build(); + //闅忎究璧峰悕 + Model model = Model.newInstance("model", Device.cpu()); + try { +// ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(); +// Resource[] resources = resolver.getResources("../pytorch/model34.pt"); + // Resource resource = resources[0]; + File f = new File("../pytorch/model34.pt"); + + InputStream inputStream = new FileInputStream(f); + // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt"); + if (inputStream == null) { + throw new RuntimeException("鎵句笉鍒版ā鍨嬫枃浠�"); + } + model.load(inputStream); + + predictor = model.newPredictor(translator); + } catch (Exception e) { + e.printStackTrace(); + } + } + + private Image resizeImage(InputStream inputStream) { + BufferedImage input = null; + try { + input = ImageIO.read(inputStream); + } catch (IOException e) { + e.printStackTrace(); + } + int iw = input.getWidth(), ih = input.getHeight(); + int w = 256, h = 256; + double scale = Math.max(1. * w / iw, 1. * h / ih); + int nw = (int) (iw * scale), nh = (int) (ih * scale); + java.awt.Image img; + //鍙湁澶暱鎴栧お瀹芥墠浼氫繚鐣欐í绾垫瘮锛屽~鍏呴鑹� + // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4; + // if (needResize) { + img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH); + // } else { + // img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH); + // } + BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB); + Graphics g = out.getGraphics(); + //鍏堝皢鏁翠釜224*224鍖哄煙濉厖128 128 128棰滆壊 + g.setColor(new Color(255, 255, 255)); + g.fillRect(0, 0, nw, nh); + out.getGraphics().drawImage(img, 0, 0, null); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream); + ImageIO.write(out, "jpg", imageOutputStream); + //鍘籇鐩樼湅鏁堟灉 + ImageIO.write(out, "jpg", new File("E:\\out.jpg")); + InputStream is = new ByteArrayInputStream(outputStream.toByteArray()); + return ImageFactory.getInstance().fromInputStream(is); + } catch (IOException e) { + e.printStackTrace(); + throw new RuntimeException("鍥剧墖杞崲澶辫触"); + } + } +} \ No newline at end of file -- Gitblit v1.9.3