//package org.jeecg.modules.dry.util;
|
//
|
//import ai.djl.Device;
|
//import ai.djl.Model;
|
//import ai.djl.inference.Predictor;
|
//import ai.djl.modality.Classifications;
|
//import ai.djl.modality.cv.Image;
|
//import ai.djl.modality.cv.ImageFactory;
|
//import ai.djl.modality.cv.transform.*;
|
//import ai.djl.modality.cv.translator.ImageClassificationTranslator;
|
//import ai.djl.translate.Translator;
|
//import lombok.extern.slf4j.Slf4j;
|
//import org.springframework.core.io.Resource;
|
//import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
|
//import org.springframework.core.io.support.ResourcePatternResolver;
|
//import org.springframework.stereotype.Component;
|
//
|
//import javax.imageio.ImageIO;
|
//import javax.imageio.stream.ImageOutputStream;
|
//import java.awt.*;
|
//import java.awt.image.BufferedImage;
|
//import java.io.*;
|
//import java.util.ArrayList;
|
//import java.util.List;
|
//
|
//@Slf4j
|
//@Component
|
//public class HerbUtil {
|
//
|
// //规定输入尺寸
|
// private static final int INPUT_SIZE = 224;
|
//
|
// private static final int TARGET_SIZE = 256;
|
//
|
// //标签文件 一种类别名字占一行
|
// private List<String> herbNames;
|
//
|
// //用于识别
|
// Predictor<Image, Classifications> predictor;
|
//
|
// //模型
|
// private Model model;
|
//
|
// public HerbUtil() {
|
// //加载标签到herbNames中
|
// this.loadHerbNames();
|
// //初始化模型工作
|
// this.init();
|
//
|
//
|
//
|
// }
|
//
|
// public List<Classifications.Classification> predict(InputStream inputStream) {
|
// List<Classifications.Classification> result = new ArrayList<>();
|
// Image input = this.resizeImage(inputStream);
|
// try {
|
// Classifications output = predictor.predict(input);
|
// System.out.println("推测为:" + output.best().getClassName()
|
// + ", 概率:" + output.best().getProbability());
|
// System.out.println(output);
|
// result = output.topK();
|
// } catch (Exception e) {
|
// log.error("药材识别异常!!");
|
// log.error(input.toString());
|
// log.error(predictor.toString());
|
// e.printStackTrace();
|
// }
|
// return result;
|
// }
|
//
|
// private void loadHerbNames() {
|
// BufferedReader reader = null;
|
// herbNames = new ArrayList<>();
|
// try {
|
// InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt");
|
// reader = new BufferedReader(new InputStreamReader(in));
|
// String name = null;
|
// while ((name = reader.readLine()) != null) {
|
// herbNames.add(name);
|
// }
|
// System.out.println(herbNames);
|
// } catch (Exception e) {
|
// e.printStackTrace();
|
// } finally {
|
// if (reader != null) {
|
// try {
|
// reader.close();
|
// } catch (IOException e) {
|
// e.printStackTrace();
|
// }
|
// }
|
// }
|
// }
|
//
|
// private void init() {
|
// Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
|
// //下面的transform根据自己的改
|
// .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE))
|
//
|
// .addTransform(new ToTensor())
|
// .addTransform(new Normalize(
|
// new float[] {0.485f, 0.456f, 0.406f},
|
// new float[] {0.229f, 0.224f, 0.225f}))
|
//
|
// //载入所有标签进去
|
// .optSynset(herbNames)
|
// //最终显示概率最高的5个
|
// .optTopK(5)
|
// .build();
|
// //随便起名
|
// Model model = Model.newInstance("model", Device.cpu());
|
// try {
|
//// ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
|
//// Resource[] resources = resolver.getResources("../pytorch/model34.pt");
|
// // Resource resource = resources[0];
|
// File f = new File("../pytorch/model34.pt");
|
//
|
// InputStream inputStream = new FileInputStream(f);
|
// // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt");
|
// if (inputStream == null) {
|
// throw new RuntimeException("找不到模型文件");
|
// }
|
// model.load(inputStream);
|
//
|
// predictor = model.newPredictor(translator);
|
// } catch (Exception e) {
|
// e.printStackTrace();
|
// }
|
// }
|
//
|
// private Image resizeImage(InputStream inputStream) {
|
// BufferedImage input = null;
|
// try {
|
// input = ImageIO.read(inputStream);
|
// } catch (IOException e) {
|
// e.printStackTrace();
|
// }
|
// int iw = input.getWidth(), ih = input.getHeight();
|
// int w = 256, h = 256;
|
// double scale = Math.max(1. * w / iw, 1. * h / ih);
|
// int nw = (int) (iw * scale), nh = (int) (ih * scale);
|
// java.awt.Image img;
|
// //只有太长或太宽才会保留横纵比,填充颜色
|
// // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4;
|
// // if (needResize) {
|
// img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH);
|
// // } else {
|
// // img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH);
|
// // }
|
// BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB);
|
// Graphics g = out.getGraphics();
|
// //先将整个224*224区域填充128 128 128颜色
|
// g.setColor(new Color(255, 255, 255));
|
// g.fillRect(0, 0, nw, nh);
|
// out.getGraphics().drawImage(img, 0, 0, null);
|
// ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
|
// try {
|
// ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream);
|
// ImageIO.write(out, "jpg", imageOutputStream);
|
// //去D盘看效果
|
// ImageIO.write(out, "jpg", new File("E:\\out.jpg"));
|
// InputStream is = new ByteArrayInputStream(outputStream.toByteArray());
|
// return ImageFactory.getInstance().fromInputStream(is);
|
// } catch (IOException e) {
|
// e.printStackTrace();
|
// throw new RuntimeException("图片转换失败");
|
// }
|
// }
|
//}
|