From 88fc0f9f9b7fd3eb81c958ca41ed822cf3657c47 Mon Sep 17 00:00:00 2001
From: baoshiwei <baoshiwei@shlanbao.cn>
Date: 星期二, 22 四月 2025 15:50:22 +0800
Subject: [PATCH] refactor: 重构中药识别项目 分为onnx版和openvino版

---
 openvino/herb_ai.py |   46 ++++++++++++++++++----------------------------
 1 files changed, 18 insertions(+), 28 deletions(-)

diff --git a/herb_ai.py b/openvino/herb_ai.py
similarity index 93%
copy from herb_ai.py
copy to openvino/herb_ai.py
index 66be1e8..c916d9e 100644
--- a/herb_ai.py
+++ b/openvino/herb_ai.py
@@ -11,6 +11,7 @@
 import multiprocessing
 from safety_detect import SAFETY_DETECT
 from identifier import IDENTIFIER
+
 import os
 from logger_config import logger
 import threading
@@ -46,7 +47,7 @@
 
 # 璋冪敤鍙︿竴涓暱鐒﹂暅澶达紝鎷嶆憚娓呮櫚鐨勫眬閮ㄨ嵂鏉愬浘鐗�
 def get_image():
-    herb_identifier = IDENTIFIER("model/herb_identify.onnx")
+    herb_identifier = IDENTIFIER("./model/herb_id")
     logger.info("璇嗗埆绾跨▼鍚姩")
     global is_loaded, class_count, class_count_max, class_sum
     camera2_index = config['cam']['cam2']
@@ -179,9 +180,7 @@
     print("鎽勫儚澶村垎杈ㄧ巼:", width, "x", height)
     logger.info(f"鎽勫儚澶村垎杈ㄧ巼:, {width}, x, {height}")
     # 鐩爣鍥惧儚灏哄
-    # 璁℃椂鍣�
-    frame_count = 0
-    start_time = time.time()
+
     stime = time.time()
     if not os.path.exists(save_path):
         os.makedirs(save_path)
@@ -197,7 +196,7 @@
 
     # 寰幆璇诲彇鎽勫儚澶寸敾闈�
     while True:
-        logger.info("寰幆璇诲彇鎽勫儚澶寸敾闈�")
+        start_time = time.time()
         # 鐫$湢100姣
         time.sleep(config['cam']['sleep'])
         ret, frame = cap.read()
@@ -211,7 +210,7 @@
 
         # 瀹夊叏妫�娴�
         boxes, scores, class_ids = safety_detect(frame)
-        draw_img = safety_detect.draw_detections(frame, boxes, scores, class_ids)
+        draw_img = safety_detect.draw_detections(frame, class_ids, scores,boxes )
 
         det_res = {}
         if class_ids is not None:
@@ -277,7 +276,7 @@
         # print(status)
 
         # 涓婃枡鏈轰綅缃瘑鍒�
-        probabilities2 = hoister_position(frame);
+        probabilities2 = hoister_position(frame)
         predicted_class2 = np.argmax(probabilities2, axis=1)[0]
         max_probability2 = np.max(probabilities2, axis=1)[0]
         class_2 = hoister_position.class_names[predicted_class2]
@@ -291,10 +290,8 @@
             logger.info("鍙戦�佷笂鏂欐満浣嶇疆璇嗗埆缁撴灉锛�"+str(class_feeder))
             l.send_msg(class_feeder)
         # 璁$畻甯ч�熺巼
-        frame_count += 1
         end_time = time.time()
-        elapsed_time = end_time - start_time
-        fps = frame_count / elapsed_time
+        fps = (1 / (end_time - start_time))
         # print(f"FPS: {fps:.2f}")
         # 灏咶PS缁樺埗鍦ㄥ浘鍍忎笂
         cv2.putText(draw_img, f"FPS: {fps:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2,
@@ -303,6 +300,11 @@
         # 鑾峰彇褰撳墠绐楀彛澶у皬
         width = cv2.getWindowImageRect("AICamera")[2]
         height = cv2.getWindowImageRect("AICamera")[3]
+        # print("width", width, "height", height)
+
+        # 濡傛灉height灏忎簬1鍒欒祴鍊�100
+        if height < 1:
+            height = 100
 
         # 璋冩暣鍥惧儚澶у皬浠ラ�傚簲绐楀彛
         resized_frame = cv2.resize(draw_img, (width, height))
@@ -335,21 +337,6 @@
         ('cbData', ctypes.wintypes.DWORD),
         ('lpData', ctypes.c_char_p)
     ]
-
-
-
-# logging.info("鍑嗗鍔犺浇瀹夊叏妫�娴嬫ā鍨�..")
-# print("鍑嗗鍔犺浇瀹夊叏妫�娴嬫ā鍨�..")
-# model_safe = SAFETY_DETECT(config['model']['safe'])
-#
-# logging.info("瀹夊叏妫�娴嬫ā鍨嬪姞杞芥垚鍔熴��")
-# print("瀹夊叏妫�娴嬫ā鍨嬪姞杞芥垚鍔熴��")
-# logging.info("鍑嗗鍔犺浇鑽潗璇嗗埆妯″瀷..")
-# print("鍑嗗鍔犺浇鑽潗璇嗗埆妯″瀷..")
-# model_cls = HERB_IDENTIFY(config['model']['cls'])
-# logging.info("鑽潗璇嗗埆妯″瀷鍔犺浇鎴愬姛銆�")
-# print("鑽潗璇嗗埆妯″瀷鍔犺浇鎴愬姛銆�")
-
 
 class Listener:
     def __init__(self):
@@ -462,6 +449,8 @@
         return 0
 
 if __name__ == '__main__':
+
+
     # 绱姣忕鑽潗涓嶈鍚嶆鍑虹幇鐨勬鏁�
     class_count = {}
     # 绱姣忕鑽潗缃俊搴︽渶楂樼殑娆℃暟
@@ -478,9 +467,10 @@
     is_loaded = False
     # 鍔犺浇ONNX妯″瀷
 
-    load_identifier = IDENTIFIER("model/loading.onnx")
-    hoister_position = IDENTIFIER("model/hl.onnx")
-    safety_detect = SAFETY_DETECT("model/safety_det.onnx")
+    print("鍔犺浇妯″瀷===============")
+    load_identifier = IDENTIFIER("./model/load_id")
+    hoister_position = IDENTIFIER("./model/feeder_id")
+    safety_detect = SAFETY_DETECT("./model/safe_det")
     config = read_config()
     PCOPYDATASTRUCT = ctypes.POINTER(COPYDATASTRUCT)
 

--
Gitblit v1.9.3