| | |
| | | package org.jeecg.modules.dry.util; |
| | | |
| | | import ai.djl.Device; |
| | | import ai.djl.Model; |
| | | import ai.djl.inference.Predictor; |
| | | import ai.djl.modality.Classifications; |
| | | import ai.djl.modality.cv.Image; |
| | | import ai.djl.modality.cv.ImageFactory; |
| | | import ai.djl.modality.cv.transform.*; |
| | | import ai.djl.modality.cv.translator.ImageClassificationTranslator; |
| | | import ai.djl.translate.Translator; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.core.io.Resource; |
| | | import org.springframework.core.io.support.PathMatchingResourcePatternResolver; |
| | | import org.springframework.core.io.support.ResourcePatternResolver; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import javax.imageio.ImageIO; |
| | | import javax.imageio.stream.ImageOutputStream; |
| | | import java.awt.*; |
| | | import java.awt.image.BufferedImage; |
| | | import java.io.*; |
| | | import java.util.ArrayList; |
| | | import java.util.List; |
| | | |
| | | @Slf4j |
| | | @Component |
| | | public class HerbUtil { |
| | | |
| | | //规定输入尺寸 |
| | | private static final int INPUT_SIZE = 224; |
| | | |
| | | private static final int TARGET_SIZE = 256; |
| | | |
| | | //标签文件 一种类别名字占一行 |
| | | private List<String> herbNames; |
| | | |
| | | //用于识别 |
| | | Predictor<Image, Classifications> predictor; |
| | | |
| | | //模型 |
| | | private Model model; |
| | | |
| | | public HerbUtil() { |
| | | //加载标签到herbNames中 |
| | | this.loadHerbNames(); |
| | | //初始化模型工作 |
| | | this.init(); |
| | | |
| | | |
| | | |
| | | } |
| | | |
| | | public List<Classifications.Classification> predict(InputStream inputStream) { |
| | | List<Classifications.Classification> result = new ArrayList<>(); |
| | | Image input = this.resizeImage(inputStream); |
| | | try { |
| | | Classifications output = predictor.predict(input); |
| | | System.out.println("推测为:" + output.best().getClassName() |
| | | + ", 概率:" + output.best().getProbability()); |
| | | System.out.println(output); |
| | | result = output.topK(); |
| | | } catch (Exception e) { |
| | | log.error("药材识别异常!!"); |
| | | log.error(input.toString()); |
| | | log.error(predictor.toString()); |
| | | e.printStackTrace(); |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | private void loadHerbNames() { |
| | | BufferedReader reader = null; |
| | | herbNames = new ArrayList<>(); |
| | | try { |
| | | InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt"); |
| | | reader = new BufferedReader(new InputStreamReader(in)); |
| | | String name = null; |
| | | while ((name = reader.readLine()) != null) { |
| | | herbNames.add(name); |
| | | } |
| | | System.out.println(herbNames); |
| | | } catch (Exception e) { |
| | | e.printStackTrace(); |
| | | } finally { |
| | | if (reader != null) { |
| | | try { |
| | | reader.close(); |
| | | } catch (IOException e) { |
| | | e.printStackTrace(); |
| | | } |
| | | } |
| | | } |
| | | } |
| | | |
| | | private void init() { |
| | | Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() |
| | | //下面的transform根据自己的改 |
| | | .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE)) |
| | | |
| | | .addTransform(new ToTensor()) |
| | | .addTransform(new Normalize( |
| | | new float[] {0.485f, 0.456f, 0.406f}, |
| | | new float[] {0.229f, 0.224f, 0.225f})) |
| | | |
| | | //载入所有标签进去 |
| | | .optSynset(herbNames) |
| | | //最终显示概率最高的5个 |
| | | .optTopK(5) |
| | | .build(); |
| | | //随便起名 |
| | | Model model = Model.newInstance("model", Device.cpu()); |
| | | try { |
| | | // ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(); |
| | | // Resource[] resources = resolver.getResources("../pytorch/model34.pt"); |
| | | // Resource resource = resources[0]; |
| | | File f = new File("../pytorch/model34.pt"); |
| | | |
| | | InputStream inputStream = new FileInputStream(f); |
| | | // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt"); |
| | | if (inputStream == null) { |
| | | throw new RuntimeException("找不到模型文件"); |
| | | } |
| | | model.load(inputStream); |
| | | |
| | | predictor = model.newPredictor(translator); |
| | | } catch (Exception e) { |
| | | e.printStackTrace(); |
| | | } |
| | | } |
| | | |
| | | private Image resizeImage(InputStream inputStream) { |
| | | BufferedImage input = null; |
| | | try { |
| | | input = ImageIO.read(inputStream); |
| | | } catch (IOException e) { |
| | | e.printStackTrace(); |
| | | } |
| | | int iw = input.getWidth(), ih = input.getHeight(); |
| | | int w = 256, h = 256; |
| | | double scale = Math.max(1. * w / iw, 1. * h / ih); |
| | | int nw = (int) (iw * scale), nh = (int) (ih * scale); |
| | | java.awt.Image img; |
| | | //只有太长或太宽才会保留横纵比,填充颜色 |
| | | // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4; |
| | | // if (needResize) { |
| | | img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH); |
| | | // } else { |
| | | // img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH); |
| | | // } |
| | | BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB); |
| | | Graphics g = out.getGraphics(); |
| | | //先将整个224*224区域填充128 128 128颜色 |
| | | g.setColor(new Color(255, 255, 255)); |
| | | g.fillRect(0, 0, nw, nh); |
| | | out.getGraphics().drawImage(img, 0, 0, null); |
| | | ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); |
| | | try { |
| | | ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream); |
| | | ImageIO.write(out, "jpg", imageOutputStream); |
| | | //去D盘看效果 |
| | | ImageIO.write(out, "jpg", new File("E:\\out.jpg")); |
| | | InputStream is = new ByteArrayInputStream(outputStream.toByteArray()); |
| | | return ImageFactory.getInstance().fromInputStream(is); |
| | | } catch (IOException e) { |
| | | e.printStackTrace(); |
| | | throw new RuntimeException("图片转换失败"); |
| | | } |
| | | } |
| | | } |
| | | //package org.jeecg.modules.dry.util; |
| | | // |
| | | //import ai.djl.Device; |
| | | //import ai.djl.Model; |
| | | //import ai.djl.inference.Predictor; |
| | | //import ai.djl.modality.Classifications; |
| | | //import ai.djl.modality.cv.Image; |
| | | //import ai.djl.modality.cv.ImageFactory; |
| | | //import ai.djl.modality.cv.transform.*; |
| | | //import ai.djl.modality.cv.translator.ImageClassificationTranslator; |
| | | //import ai.djl.translate.Translator; |
| | | //import lombok.extern.slf4j.Slf4j; |
| | | //import org.springframework.core.io.Resource; |
| | | //import org.springframework.core.io.support.PathMatchingResourcePatternResolver; |
| | | //import org.springframework.core.io.support.ResourcePatternResolver; |
| | | //import org.springframework.stereotype.Component; |
| | | // |
| | | //import javax.imageio.ImageIO; |
| | | //import javax.imageio.stream.ImageOutputStream; |
| | | //import java.awt.*; |
| | | //import java.awt.image.BufferedImage; |
| | | //import java.io.*; |
| | | //import java.util.ArrayList; |
| | | //import java.util.List; |
| | | // |
| | | //@Slf4j |
| | | //@Component |
| | | //public class HerbUtil { |
| | | // |
| | | // //规定输入尺寸 |
| | | // private static final int INPUT_SIZE = 224; |
| | | // |
| | | // private static final int TARGET_SIZE = 256; |
| | | // |
| | | // //标签文件 一种类别名字占一行 |
| | | // private List<String> herbNames; |
| | | // |
| | | // //用于识别 |
| | | // Predictor<Image, Classifications> predictor; |
| | | // |
| | | // //模型 |
| | | // private Model model; |
| | | // |
| | | // public HerbUtil() { |
| | | // //加载标签到herbNames中 |
| | | // this.loadHerbNames(); |
| | | // //初始化模型工作 |
| | | // this.init(); |
| | | // |
| | | // |
| | | // |
| | | // } |
| | | // |
| | | // public List<Classifications.Classification> predict(InputStream inputStream) { |
| | | // List<Classifications.Classification> result = new ArrayList<>(); |
| | | // Image input = this.resizeImage(inputStream); |
| | | // try { |
| | | // Classifications output = predictor.predict(input); |
| | | // System.out.println("推测为:" + output.best().getClassName() |
| | | // + ", 概率:" + output.best().getProbability()); |
| | | // System.out.println(output); |
| | | // result = output.topK(); |
| | | // } catch (Exception e) { |
| | | // log.error("药材识别异常!!"); |
| | | // log.error(input.toString()); |
| | | // log.error(predictor.toString()); |
| | | // e.printStackTrace(); |
| | | // } |
| | | // return result; |
| | | // } |
| | | // |
| | | // private void loadHerbNames() { |
| | | // BufferedReader reader = null; |
| | | // herbNames = new ArrayList<>(); |
| | | // try { |
| | | // InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt"); |
| | | // reader = new BufferedReader(new InputStreamReader(in)); |
| | | // String name = null; |
| | | // while ((name = reader.readLine()) != null) { |
| | | // herbNames.add(name); |
| | | // } |
| | | // System.out.println(herbNames); |
| | | // } catch (Exception e) { |
| | | // e.printStackTrace(); |
| | | // } finally { |
| | | // if (reader != null) { |
| | | // try { |
| | | // reader.close(); |
| | | // } catch (IOException e) { |
| | | // e.printStackTrace(); |
| | | // } |
| | | // } |
| | | // } |
| | | // } |
| | | // |
| | | // private void init() { |
| | | // Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() |
| | | // //下面的transform根据自己的改 |
| | | // .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE)) |
| | | // |
| | | // .addTransform(new ToTensor()) |
| | | // .addTransform(new Normalize( |
| | | // new float[] {0.485f, 0.456f, 0.406f}, |
| | | // new float[] {0.229f, 0.224f, 0.225f})) |
| | | // |
| | | // //载入所有标签进去 |
| | | // .optSynset(herbNames) |
| | | // //最终显示概率最高的5个 |
| | | // .optTopK(5) |
| | | // .build(); |
| | | // //随便起名 |
| | | // Model model = Model.newInstance("model", Device.cpu()); |
| | | // try { |
| | | //// ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(); |
| | | //// Resource[] resources = resolver.getResources("../pytorch/model34.pt"); |
| | | // // Resource resource = resources[0]; |
| | | // File f = new File("../pytorch/model34.pt"); |
| | | // |
| | | // InputStream inputStream = new FileInputStream(f); |
| | | // // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt"); |
| | | // if (inputStream == null) { |
| | | // throw new RuntimeException("找不到模型文件"); |
| | | // } |
| | | // model.load(inputStream); |
| | | // |
| | | // predictor = model.newPredictor(translator); |
| | | // } catch (Exception e) { |
| | | // e.printStackTrace(); |
| | | // } |
| | | // } |
| | | // |
| | | // private Image resizeImage(InputStream inputStream) { |
| | | // BufferedImage input = null; |
| | | // try { |
| | | // input = ImageIO.read(inputStream); |
| | | // } catch (IOException e) { |
| | | // e.printStackTrace(); |
| | | // } |
| | | // int iw = input.getWidth(), ih = input.getHeight(); |
| | | // int w = 256, h = 256; |
| | | // double scale = Math.max(1. * w / iw, 1. * h / ih); |
| | | // int nw = (int) (iw * scale), nh = (int) (ih * scale); |
| | | // java.awt.Image img; |
| | | // //只有太长或太宽才会保留横纵比,填充颜色 |
| | | // // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4; |
| | | // // if (needResize) { |
| | | // img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH); |
| | | // // } else { |
| | | // // img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH); |
| | | // // } |
| | | // BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB); |
| | | // Graphics g = out.getGraphics(); |
| | | // //先将整个224*224区域填充128 128 128颜色 |
| | | // g.setColor(new Color(255, 255, 255)); |
| | | // g.fillRect(0, 0, nw, nh); |
| | | // out.getGraphics().drawImage(img, 0, 0, null); |
| | | // ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); |
| | | // try { |
| | | // ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream); |
| | | // ImageIO.write(out, "jpg", imageOutputStream); |
| | | // //去D盘看效果 |
| | | // ImageIO.write(out, "jpg", new File("E:\\out.jpg")); |
| | | // InputStream is = new ByteArrayInputStream(outputStream.toByteArray()); |
| | | // return ImageFactory.getInstance().fromInputStream(is); |
| | | // } catch (IOException e) { |
| | | // e.printStackTrace(); |
| | | // throw new RuntimeException("图片转换失败"); |
| | | // } |
| | | // } |
| | | //} |