package org.jeecg.modules.dry.util; import ai.djl.Device; import ai.djl.Model; import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.transform.*; import ai.djl.modality.cv.translator.ImageClassificationTranslator; import ai.djl.translate.Translator; import org.springframework.stereotype.Component; import javax.imageio.ImageIO; import javax.imageio.stream.ImageOutputStream; import java.awt.*; import java.awt.image.BufferedImage; import java.io.*; import java.util.ArrayList; import java.util.List; @Component public class HerbUtil { //规定输入尺寸 private static final int INPUT_SIZE = 224; private static final int TARGET_SIZE = 256; //标签文件 一种类别名字占一行 private List herbNames; //用于识别 Predictor predictor; //模型 private Model model; public HerbUtil() { //加载标签到herbNames中 this.loadHerbNames(); //初始化模型工作 this.init(); } public List predict(InputStream inputStream) { List result = new ArrayList<>(); Image input = this.resizeImage(inputStream); try { Classifications output = predictor.predict(input); System.out.println("推测为:" + output.best().getClassName() + ", 概率:" + output.best().getProbability()); System.out.println(output); result = output.topK(); } catch (Exception e) { e.printStackTrace(); } return result; } private void loadHerbNames() { BufferedReader reader = null; herbNames = new ArrayList<>(); try { InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt"); reader = new BufferedReader(new InputStreamReader(in)); String name = null; while ((name = reader.readLine()) != null) { herbNames.add(name); } System.out.println(herbNames); } catch (Exception e) { e.printStackTrace(); } finally { if (reader != null) { try { reader.close(); } catch (IOException e) { e.printStackTrace(); } } } } private void init() { Translator translator = ImageClassificationTranslator.builder() //下面的transform根据自己的改 .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE)) .addTransform(new ToTensor()) .addTransform(new Normalize( new float[] {0.485f, 0.456f, 0.406f}, new float[] {0.229f, 0.224f, 0.225f})) //载入所有标签进去 .optSynset(herbNames) //最终显示概率最高的5个 .optTopK(5) .build(); //随便起名 Model model = Model.newInstance("model", Device.cpu()); try { InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model1.pt"); if (inputStream == null) { throw new RuntimeException("找不到模型文件"); } model.load(inputStream); predictor = model.newPredictor(translator); } catch (Exception e) { e.printStackTrace(); } } private Image resizeImage(InputStream inputStream) { BufferedImage input = null; try { input = ImageIO.read(inputStream); } catch (IOException e) { e.printStackTrace(); } int iw = input.getWidth(), ih = input.getHeight(); int w = 256, h = 256; double scale = Math.max(1. * w / iw, 1. * h / ih); int nw = (int) (iw * scale), nh = (int) (ih * scale); java.awt.Image img; //只有太长或太宽才会保留横纵比,填充颜色 // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4; // if (needResize) { img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH); // } else { // img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH); // } BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB); Graphics g = out.getGraphics(); //先将整个224*224区域填充128 128 128颜色 g.setColor(new Color(255, 255, 255)); g.fillRect(0, 0, nw, nh); out.getGraphics().drawImage(img, 0, 0, null); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); try { ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream); ImageIO.write(out, "jpg", imageOutputStream); //去D盘看效果 ImageIO.write(out, "jpg", new File("E:\\out.jpg")); InputStream is = new ByteArrayInputStream(outputStream.toByteArray()); return ImageFactory.getInstance().fromInputStream(is); } catch (IOException e) { e.printStackTrace(); throw new RuntimeException("图片转换失败"); } } }