干燥机配套车间生产管理系统/云平台服务端
bsw215583320
2024-01-08 aa562d3b26d8b6de0f0fc0b842ba3894ebcf0945
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package org.jeecg.modules.dry.util;
 
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.translate.Translator;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.stereotype.Component;
 
import javax.imageio.ImageIO;
import javax.imageio.stream.ImageOutputStream;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
 
@Slf4j
@Component
public class HerbUtil {
 
    //规定输入尺寸
    private static final int INPUT_SIZE = 224;
 
    private static final int TARGET_SIZE = 256;
 
    //标签文件 一种类别名字占一行
    private List<String> herbNames;
 
    //用于识别
    Predictor<Image, Classifications> predictor;
 
    //模型
    private Model model;
 
    public HerbUtil() {
        //加载标签到herbNames中
        this.loadHerbNames();
        //初始化模型工作
        this.init();
 
 
 
    }
 
    public List<Classifications.Classification> predict(InputStream inputStream) {
        List<Classifications.Classification> result = new ArrayList<>();
        Image input = this.resizeImage(inputStream);
        try {
            Classifications output = predictor.predict(input);
            System.out.println("推测为:" + output.best().getClassName()
                    + ", 概率:" + output.best().getProbability());
            System.out.println(output);
            result = output.topK();
        } catch (Exception e) {
            log.error("药材识别异常!!");
            log.error(input.toString());
            log.error(predictor.toString());
            e.printStackTrace();
        }
        return result;
    }
 
    private void loadHerbNames() {
        BufferedReader reader = null;
        herbNames = new ArrayList<>();
        try {
            InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt");
            reader = new BufferedReader(new InputStreamReader(in));
            String name = null;
            while ((name = reader.readLine()) != null) {
                herbNames.add(name);
            }
            System.out.println(herbNames);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
 
    private void init() {
        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                //下面的transform根据自己的改
                .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE))
 
                .addTransform(new ToTensor())
                .addTransform(new Normalize(
                        new float[] {0.485f, 0.456f, 0.406f},
                        new float[] {0.229f, 0.224f, 0.225f}))
 
                //载入所有标签进去
                .optSynset(herbNames)
                //最终显示概率最高的5个
                .optTopK(5)
                .build();
        //随便起名
        Model model = Model.newInstance("model", Device.cpu());
        try {
//            ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
//            Resource[] resources = resolver.getResources("../pytorch/model34.pt");
            //            Resource resource = resources[0];
            File f = new File("../pytorch/model34.pt");
 
            InputStream inputStream = new FileInputStream(f);
           // InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model34.pt");
            if (inputStream == null) {
                throw new RuntimeException("找不到模型文件");
            }
            model.load(inputStream);
 
            predictor = model.newPredictor(translator);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
 
    private Image resizeImage(InputStream inputStream) {
        BufferedImage input = null;
        try {
            input = ImageIO.read(inputStream);
        } catch (IOException e) {
            e.printStackTrace();
        }
        int iw = input.getWidth(), ih = input.getHeight();
        int w = 256, h = 256;
        double scale = Math.max(1. *  w / iw, 1. * h / ih);
        int nw = (int) (iw * scale), nh = (int) (ih * scale);
        java.awt.Image img;
        //只有太长或太宽才会保留横纵比,填充颜色
       // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4;
      //  if (needResize) {
            img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH);
      //  } else {
       //     img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH);
      //  }
        BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB);
        Graphics g = out.getGraphics();
        //先将整个224*224区域填充128 128 128颜色
        g.setColor(new Color(255, 255, 255));
        g.fillRect(0, 0, nw, nh);
        out.getGraphics().drawImage(img, 0, 0, null);
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        try {
            ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream);
            ImageIO.write(out, "jpg", imageOutputStream);
            //去D盘看效果
            ImageIO.write(out, "jpg", new File("E:\\out.jpg"));
            InputStream is = new ByteArrayInputStream(outputStream.toByteArray());
            return ImageFactory.getInstance().fromInputStream(is);
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException("图片转换失败");
        }
    }
}