¶Ô±ÈÐÂÎļþ |
| | |
| | | package org.jeecg.modules.dry.util; |
| | | |
| | | import ai.djl.Device; |
| | | import ai.djl.Model; |
| | | import ai.djl.inference.Predictor; |
| | | import ai.djl.modality.Classifications; |
| | | import ai.djl.modality.cv.Image; |
| | | import ai.djl.modality.cv.ImageFactory; |
| | | import ai.djl.modality.cv.transform.*; |
| | | import ai.djl.modality.cv.translator.ImageClassificationTranslator; |
| | | import ai.djl.translate.Translator; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import javax.imageio.ImageIO; |
| | | import javax.imageio.stream.ImageOutputStream; |
| | | import java.awt.*; |
| | | import java.awt.image.BufferedImage; |
| | | import java.io.*; |
| | | import java.util.ArrayList; |
| | | import java.util.List; |
| | | |
| | | @Component |
| | | public class HerbUtil { |
| | | |
| | | //è§å®è¾å
¥å°ºå¯¸ |
| | | private static final int INPUT_SIZE = 224; |
| | | |
| | | private static final int TARGET_SIZE = 256; |
| | | |
| | | //æ ç¾æä»¶ ä¸ç§ç±»å«ååå ä¸è¡ |
| | | private List<String> herbNames; |
| | | |
| | | //ç¨äºè¯å« |
| | | Predictor<Image, Classifications> predictor; |
| | | |
| | | //模å |
| | | private Model model; |
| | | |
| | | public HerbUtil() { |
| | | //å è½½æ ç¾å°herbNamesä¸ |
| | | this.loadHerbNames(); |
| | | //åå§å模åå·¥ä½ |
| | | this.init(); |
| | | |
| | | |
| | | |
| | | } |
| | | |
| | | public List<Classifications.Classification> predict(InputStream inputStream) { |
| | | List<Classifications.Classification> result = new ArrayList<>(); |
| | | Image input = this.resizeImage(inputStream); |
| | | try { |
| | | Classifications output = predictor.predict(input); |
| | | System.out.println("æ¨æµä¸ºï¼" + output.best().getClassName() |
| | | + ", æ¦çï¼" + output.best().getProbability()); |
| | | System.out.println(output); |
| | | result = output.topK(); |
| | | } catch (Exception e) { |
| | | e.printStackTrace(); |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | private void loadHerbNames() { |
| | | BufferedReader reader = null; |
| | | herbNames = new ArrayList<>(); |
| | | try { |
| | | InputStream in = HerbUtil.class.getClassLoader().getResourceAsStream("class.txt"); |
| | | reader = new BufferedReader(new InputStreamReader(in)); |
| | | String name = null; |
| | | while ((name = reader.readLine()) != null) { |
| | | herbNames.add(name); |
| | | } |
| | | System.out.println(herbNames); |
| | | } catch (Exception e) { |
| | | e.printStackTrace(); |
| | | } finally { |
| | | if (reader != null) { |
| | | try { |
| | | reader.close(); |
| | | } catch (IOException e) { |
| | | e.printStackTrace(); |
| | | } |
| | | } |
| | | } |
| | | } |
| | | |
| | | private void init() { |
| | | Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() |
| | | //ä¸é¢çtransformæ ¹æ®èªå·±çæ¹ |
| | | .addTransform(new CenterCrop(INPUT_SIZE,INPUT_SIZE)) |
| | | |
| | | .addTransform(new ToTensor()) |
| | | .addTransform(new Normalize( |
| | | new float[] {0.485f, 0.456f, 0.406f}, |
| | | new float[] {0.229f, 0.224f, 0.225f})) |
| | | |
| | | //è½½å
¥æææ ç¾è¿å» |
| | | .optSynset(herbNames) |
| | | //æç»æ¾ç¤ºæ¦çæé«ç5个 |
| | | .optTopK(5) |
| | | .build(); |
| | | //é便起å |
| | | Model model = Model.newInstance("model", Device.cpu()); |
| | | try { |
| | | InputStream inputStream = HerbUtil.class.getClassLoader().getResourceAsStream("model1.pt"); |
| | | if (inputStream == null) { |
| | | throw new RuntimeException("æ¾ä¸å°æ¨¡åæä»¶"); |
| | | } |
| | | model.load(inputStream); |
| | | |
| | | predictor = model.newPredictor(translator); |
| | | } catch (Exception e) { |
| | | e.printStackTrace(); |
| | | } |
| | | } |
| | | |
| | | private Image resizeImage(InputStream inputStream) { |
| | | BufferedImage input = null; |
| | | try { |
| | | input = ImageIO.read(inputStream); |
| | | } catch (IOException e) { |
| | | e.printStackTrace(); |
| | | } |
| | | int iw = input.getWidth(), ih = input.getHeight(); |
| | | int w = 256, h = 256; |
| | | double scale = Math.max(1. * w / iw, 1. * h / ih); |
| | | int nw = (int) (iw * scale), nh = (int) (ih * scale); |
| | | java.awt.Image img; |
| | | //åªæå¤ªé¿æå¤ªå®½æä¼ä¿ç横纵æ¯ï¼å¡«å
é¢è² |
| | | // boolean needResize = 1. * iw / ih > 1.4 || 1. * ih / iw > 1.4; |
| | | // if (needResize) { |
| | | img = input.getScaledInstance(nw, nh, BufferedImage.SCALE_SMOOTH); |
| | | // } else { |
| | | // img = input.getScaledInstance(INPUT_SIZE, INPUT_SIZE, BufferedImage.SCALE_SMOOTH); |
| | | // } |
| | | BufferedImage out = new BufferedImage(nw, nh, BufferedImage.TYPE_INT_RGB); |
| | | Graphics g = out.getGraphics(); |
| | | //å
å°æ´ä¸ª224*224åºåå¡«å
128 128 128é¢è² |
| | | g.setColor(new Color(255, 255, 255)); |
| | | g.fillRect(0, 0, nw, nh); |
| | | out.getGraphics().drawImage(img, 0, 0, null); |
| | | ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); |
| | | try { |
| | | ImageOutputStream imageOutputStream = ImageIO.createImageOutputStream(outputStream); |
| | | ImageIO.write(out, "jpg", imageOutputStream); |
| | | //å»Dççææ |
| | | ImageIO.write(out, "jpg", new File("E:\\out.jpg")); |
| | | InputStream is = new ByteArrayInputStream(outputStream.toByteArray()); |
| | | return ImageFactory.getInstance().fromInputStream(is); |
| | | } catch (IOException e) { |
| | | e.printStackTrace(); |
| | | throw new RuntimeException("å¾ç转æ¢å¤±è´¥"); |
| | | } |
| | | } |
| | | } |