干燥机配套车间生产管理系统/云平台服务端
baoshiwei
2024-12-11 7c585586e9bea943161676bd9d127e81123891c3
jeecg-module-dry/jeecg-module-dry-api/src/main/java/org/jeecg/modules/dry/util/HerbUtil.java
old mode 100644 new mode 100755
@@ -1,156 +1,170 @@
package org.jeecg.modules.dry.util;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.translate.Translator;
import org.springframework.stereotype.Component;
import javax.imageio.ImageIO;
import javax.imageio.stream.ImageOutputStream;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
@Component
public class HerbUtil {
    //规定输入尺寸
    private static final int INPUT_SIZE = 224;
    private static final int TARGET_SIZE = 256;
    //标签文件 一种类别名字占一行
    private List<String> herbNames;
    //用于识别
    Predictor<Image, Classifications> predictor;
    //模型
    private Model model;
    public HerbUtil() {
        //加载标签到herbNames中
        this.loadHerbNames();
        //初始化模型工作
        this.init();
    }
    public List<Classifications.Classification> predict(InputStream inputStream) {
        List<Classifications.Classification> result = new ArrayList<>();
        Image input = this.resizeImage(inputStream);
        try {
            Classifications output = predictor.predict(input);
            System.out.println("推测为:" + output.best().getClassName()
                    + ", 概率:" + output.best().getProbability());
            System.out.println(output);
            result = output.topK();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return result;
    }
    private void loadHerbNames() {
        BufferedReader reader = null;
        herbNames = new ArrayList<>();
        try {
            InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt");
            reader = new BufferedReader(new InputStreamReader(in));
            String name = null;
            while ((name = reader.readLine()) != null) {
                herbNames.add(name);
            }
            System.out.println(herbNames);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
    private void init() {
        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                //下面的transform根据自己的改
                .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE))
                .addTransform(new ToTensor())
                .addTransform(new Normalize(
                        new float[] {0.485f, 0.456f, 0.406f},
                        new float[] {0.229f, 0.224f, 0.225f}))
                //载入所有标签进去
                .optSynset(herbNames)
                //最终显示概率最高的5个
                .optTopK(5)
                .build();
        //随便起名
        Model model = Model.newInstance("model", Device.cpu());
        try {
            InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt");
            if (inputStream == null) {
                throw new RuntimeException("找不到模型文件");
            }
            model.load(inputStream);
            predictor = model.newPredictor(translator);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    private Image resizeImage(InputStream inputStream) {
        BufferedImage input = null;
        try {
            input = ImageIO.read(inputStream);
        } catch (IOException e) {
            e.printStackTrace();
        }
        int iw = input.getWidth(), ih = input.getHeight();
        int w = 256, h = 256;
        double scale = Math.max(1. *  w / iw, 1. * h / ih);
        int nw = (int) (iw * scale), nh = (int) (ih * scale);
        java.awt.Image img;
        //只有太长或太宽才会保留横纵比,填充颜色
       // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4;
      //  if (needResize) {
            img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH);
      //  } else {
       //     img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH);
      //  }
        BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB);
        Graphics g = out.getGraphics();
        //先将整个224*224区域填充128 128 128颜色
        g.setColor(new Color(255, 255, 255));
        g.fillRect(0, 0, nw, nh);
        out.getGraphics().drawImage(img, 0, 0, null);
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        try {
            ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream);
            ImageIO.write(out, "jpg", imageOutputStream);
            //去D盘看效果
            ImageIO.write(out, "jpg", new File("E:\\out.jpg"));
            InputStream is = new ByteArrayInputStream(outputStream.toByteArray());
            return ImageFactory.getInstance().fromInputStream(is);
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException("图片转换失败");
        }
    }
}
//package org.jeecg.modules.dry.util;
//
//import ai.djl.Device;
//import ai.djl.Model;
//import ai.djl.inference.Predictor;
//import ai.djl.modality.Classifications;
//import ai.djl.modality.cv.Image;
//import ai.djl.modality.cv.ImageFactory;
//import ai.djl.modality.cv.transform.*;
//import ai.djl.modality.cv.translator.ImageClassificationTranslator;
//import ai.djl.translate.Translator;
//import lombok.extern.slf4j.Slf4j;
//import org.springframework.core.io.Resource;
//import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
//import org.springframework.core.io.support.ResourcePatternResolver;
//import org.springframework.stereotype.Component;
//
//import javax.imageio.ImageIO;
//import javax.imageio.stream.ImageOutputStream;
//import java.awt.*;
//import java.awt.image.BufferedImage;
//import java.io.*;
//import java.util.ArrayList;
//import java.util.List;
//
//@Slf4j
//@Component
//public class HerbUtil {
//
//    //规定输入尺寸
//    private static final int INPUT_SIZE = 224;
//
//    private static final int TARGET_SIZE = 256;
//
//    //标签文件 一种类别名字占一行
//    private List<String> herbNames;
//
//    //用于识别
//    Predictor<Image, Classifications> predictor;
//
//    //模型
//    private Model model;
//
//    public HerbUtil() {
//        //加载标签到herbNames中
//        this.loadHerbNames();
//        //初始化模型工作
//        this.init();
//
//
//
//    }
//
//    public List<Classifications.Classification> predict(InputStream inputStream) {
//        List<Classifications.Classification> result = new ArrayList<>();
//        Image input = this.resizeImage(inputStream);
//        try {
//            Classifications output = predictor.predict(input);
//            System.out.println("推测为:" + output.best().getClassName()
//                    + ", 概率:" + output.best().getProbability());
//            System.out.println(output);
//            result = output.topK();
//        } catch (Exception e) {
//            log.error("药材识别异常!!");
//            log.error(input.toString());
//            log.error(predictor.toString());
//            e.printStackTrace();
//        }
//        return result;
//    }
//
//    private void loadHerbNames() {
//        BufferedReader reader = null;
//        herbNames = new ArrayList<>();
//        try {
//            InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt");
//            reader = new BufferedReader(new InputStreamReader(in));
//            String name = null;
//            while ((name = reader.readLine()) != null) {
//                herbNames.add(name);
//            }
//            System.out.println(herbNames);
//        } catch (Exception e) {
//            e.printStackTrace();
//        } finally {
//            if (reader != null) {
//                try {
//                    reader.close();
//                } catch (IOException e) {
//                    e.printStackTrace();
//                }
//            }
//        }
//    }
//
//    private void init() {
//        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
//                //下面的transform根据自己的改
//                .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE))
//
//                .addTransform(new ToTensor())
//                .addTransform(new Normalize(
//                        new float[] {0.485f, 0.456f, 0.406f},
//                        new float[] {0.229f, 0.224f, 0.225f}))
//
//                //载入所有标签进去
//                .optSynset(herbNames)
//                //最终显示概率最高的5个
//                .optTopK(5)
//                .build();
//        //随便起名
//        Model model = Model.newInstance("model", Device.cpu());
//        try {
////            ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
////            Resource[] resources = resolver.getResources("../pytorch/model34.pt");
//            //            Resource resource = resources[0];
//            File f = new File("../pytorch/model34.pt");
//
//            InputStream inputStream = new FileInputStream(f);
//           // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt");
//            if (inputStream == null) {
//                throw new RuntimeException("找不到模型文件");
//            }
//            model.load(inputStream);
//
//            predictor = model.newPredictor(translator);
//        } catch (Exception e) {
//            e.printStackTrace();
//        }
//    }
//
//    private Image resizeImage(InputStream inputStream) {
//        BufferedImage input = null;
//        try {
//            input = ImageIO.read(inputStream);
//        } catch (IOException e) {
//            e.printStackTrace();
//        }
//        int iw = input.getWidth(), ih = input.getHeight();
//        int w = 256, h = 256;
//        double scale = Math.max(1. *  w / iw, 1. * h / ih);
//        int nw = (int) (iw * scale), nh = (int) (ih * scale);
//        java.awt.Image img;
//        //只有太长或太宽才会保留横纵比,填充颜色
//       // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4;
//      //  if (needResize) {
//            img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH);
//      //  } else {
//       //     img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH);
//      //  }
//        BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB);
//        Graphics g = out.getGraphics();
//        //先将整个224*224区域填充128 128 128颜色
//        g.setColor(new Color(255, 255, 255));
//        g.fillRect(0, 0, nw, nh);
//        out.getGraphics().drawImage(img, 0, 0, null);
//        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
//        try {
//            ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream);
//            ImageIO.write(out, "jpg", imageOutputStream);
//            //去D盘看效果
//            ImageIO.write(out, "jpg", new File("E:\\out.jpg"));
//            InputStream is = new ByteArrayInputStream(outputStream.toByteArray());
//            return ImageFactory.getInstance().fromInputStream(is);
//        } catch (IOException e) {
//            e.printStackTrace();
//            throw new RuntimeException("图片转换失败");
//        }
//    }
//}