//package org.jeecg.modules.dry.util; // //import ai.djl.Device; //import ai.djl.Model; //import ai.djl.inference.Predictor; //import ai.djl.modality.Classifications; //import ai.djl.modality.cv.Image; //import ai.djl.modality.cv.ImageFactory; //import ai.djl.modality.cv.transform.*; //import ai.djl.modality.cv.translator.ImageClassificationTranslator; //import ai.djl.translate.Translator; //import lombok.extern.slf4j.Slf4j; //import org.springframework.core.io.Resource; //import org.springframework.core.io.support.PathMatchingResourcePatternResolver; //import org.springframework.core.io.support.ResourcePatternResolver; //import org.springframework.stereotype.Component; // //import javax.imageio.ImageIO; //import javax.imageio.stream.ImageOutputStream; //import java.awt.*; //import java.awt.image.BufferedImage; //import java.io.*; //import java.util.ArrayList; //import java.util.List; // //@Slf4j //@Component //public class HerbUtil { // // //规定输入尺寸 // private static final int INPUT_SIZE = 224; // // private static final int TARGET_SIZE = 256; // // //标签文件 一种类别名字占一行 // private List herbNames; // // //用于识别 // Predictor predictor; // // //模型 // private Model model; // // public HerbUtil() { // //加载标签到herbNames中 // this.loadHerbNames(); // //初始化模型工作 // this.init(); // // // // } // // public List predict(InputStream inputStream) { // List result = new ArrayList<>(); // Image input = this.resizeImage(inputStream); // try { // Classifications output = predictor.predict(input); // System.out.println("推测为:" + output.best().getClassName() // + ", 概率:" + output.best().getProbability()); // System.out.println(output); // result = output.topK(); // } catch (Exception e) { // log.error("药材识别异常!!"); // log.error(input.toString()); // log.error(predictor.toString()); // e.printStackTrace(); // } // return result; // } // // private void loadHerbNames() { // BufferedReader reader = null; // herbNames = new ArrayList<>(); // try { // InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt"); // reader = new BufferedReader(new InputStreamReader(in)); // String name = null; // while ((name = reader.readLine()) != null) { // herbNames.add(name); // } // System.out.println(herbNames); // } catch (Exception e) { // e.printStackTrace(); // } finally { // if (reader != null) { // try { // reader.close(); // } catch (IOException e) { // e.printStackTrace(); // } // } // } // } // // private void init() { // Translator translator = ImageClassificationTranslator.builder() // //下面的transform根据自己的改 // .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE)) // // .addTransform(new ToTensor()) // .addTransform(new Normalize( // new float[] {0.485f, 0.456f, 0.406f}, // new float[] {0.229f, 0.224f, 0.225f})) // // //载入所有标签进去 // .optSynset(herbNames) // //最终显示概率最高的5个 // .optTopK(5) // .build(); // //随便起名 // Model model = Model.newInstance("model", Device.cpu()); // try { //// ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(); //// Resource[] resources = resolver.getResources("../pytorch/model34.pt"); // // Resource resource = resources[0]; // File f = new File("../pytorch/model34.pt"); // // InputStream inputStream = new FileInputStream(f); // // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt"); // if (inputStream == null) { // throw new RuntimeException("找不到模型文件"); // } // model.load(inputStream); // // predictor = model.newPredictor(translator); // } catch (Exception e) { // e.printStackTrace(); // } // } // // private Image resizeImage(InputStream inputStream) { // BufferedImage input = null; // try { // input = ImageIO.read(inputStream); // } catch (IOException e) { // e.printStackTrace(); // } // int iw = input.getWidth(), ih = input.getHeight(); // int w = 256, h = 256; // double scale = Math.max(1. * w / iw, 1. * h / ih); // int nw = (int) (iw * scale), nh = (int) (ih * scale); // java.awt.Image img; // //只有太长或太宽才会保留横纵比,填充颜色 // // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4; // // if (needResize) { // img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH); // // } else { // // img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH); // // } // BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB); // Graphics g = out.getGraphics(); // //先将整个224*224区域填充128 128 128颜色 // g.setColor(new Color(255, 255, 255)); // g.fillRect(0, 0, nw, nh); // out.getGraphics().drawImage(img, 0, 0, null); // ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); // try { // ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream); // ImageIO.write(out, "jpg", imageOutputStream); // //去D盘看效果 // ImageIO.write(out, "jpg", new File("E:\\out.jpg")); // InputStream is = new ByteArrayInputStream(outputStream.toByteArray()); // return ImageFactory.getInstance().fromInputStream(is); // } catch (IOException e) { // e.printStackTrace(); // throw new RuntimeException("图片转换失败"); // } // } //}