From 6628f663b636675bcaea316f2deaddf337de480e Mon Sep 17 00:00:00 2001
From: baoshiwei <baoshiwei@shlanbao.cn>
Date: 星期五, 13 三月 2026 10:23:31 +0800
Subject: [PATCH] feat(米重分析): 新增稳态识别和预测功能页面并优化现有模型
---
app/pages/metered_weight_prediction.py | 208 ++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 208 insertions(+), 0 deletions(-)
diff --git a/app/pages/metered_weight_prediction.py b/app/pages/metered_weight_prediction.py
new file mode 100644
index 0000000..a413ea8
--- /dev/null
+++ b/app/pages/metered_weight_prediction.py
@@ -0,0 +1,208 @@
+import streamlit as st
+import plotly.express as px
+import plotly.graph_objects as go
+import pandas as pd
+import numpy as np
+import joblib
+import os
+from datetime import datetime
+
+# 灏濊瘯瀵煎叆torch锛屽鏋滃け璐ュ垯绂佺敤娣卞害瀛︿範妯″瀷鏀寔
+try:
+ import torch
+ TORCH_AVAILABLE = True
+except ImportError:
+ TORCH_AVAILABLE = False
+
+# 椤甸潰鍑芥暟瀹氫箟
+def show_metered_weight_prediction():
+ # 椤甸潰鏍囬
+ st.title("绫抽噸缁熶竴棰勬祴")
+
+ # 鍒濆鍖栦細璇濈姸鎬�
+ if 'selected_model' not in st.session_state:
+ st.session_state['selected_model'] = None
+
+ # 鍒涘缓妯″瀷鐩綍锛堝鏋滀笉瀛樺湪锛�
+ model_dir = "saved_models"
+ os.makedirs(model_dir, exist_ok=True)
+
+ # 鑾峰彇鎵�鏈夊凡淇濆瓨鐨勬ā鍨嬫枃浠�
+ model_files = [f for f in os.listdir(model_dir) if f.endswith('.joblib')]
+ model_files.sort(reverse=True) # 鏈�鏂扮殑妯″瀷鎺掑湪鍓嶉潰
+
+ # 妯″瀷閫夋嫨鍖哄煙
+ with st.expander("馃搧 閫夋嫨妯″瀷", expanded=True):
+ if not model_files:
+ st.warning("灏氭湭淇濆瓨浠讳綍妯″瀷锛岃鍏堣缁冩ā鍨嬪苟淇濆瓨銆�")
+ else:
+ # 妯″瀷閫夋嫨涓嬫媺妗�
+ selected_model_file = st.selectbox(
+ "閫夋嫨宸蹭繚瀛樼殑妯″瀷",
+ options=model_files,
+ help="閫夋嫨瑕佺敤浜庨娴嬬殑妯″瀷鏂囦欢"
+ )
+
+ # 鍔犺浇骞舵樉绀烘ā鍨嬩俊鎭�
+ if selected_model_file:
+ model_path = os.path.join(model_dir, selected_model_file)
+ model_info = joblib.load(model_path)
+
+ # 鏄剧ず妯″瀷鍩烘湰淇℃伅
+ st.subheader("馃搳 妯″瀷淇℃伅")
+ info_cols = st.columns(2)
+
+ with info_cols[0]:
+ st.metric("妯″瀷绫诲瀷", model_info['model_type'])
+ st.metric("鍒涘缓鏃堕棿", model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'))
+ st.metric("浣跨敤绋虫�佹暟鎹�", "鏄�" if model_info.get('use_steady_data', False) else "鍚�")
+
+ with info_cols[1]:
+ st.metric("R虏 寰楀垎", f"{model_info['r2_score']:.4f}")
+ st.metric("鍧囨柟璇樊 (MSE)", f"{model_info['mse']:.6f}")
+ st.metric("鍧囨柟鏍硅宸� (RMSE)", f"{model_info['rmse']:.6f}")
+
+ # 鏄剧ず妯″瀷鐗瑰緛
+ st.write("馃攽 妯″瀷浣跨敤鐨勭壒寰�:")
+ st.code(", ".join(model_info['features']))
+
+ # 濡傛灉鏄繁搴﹀涔犳ā鍨嬶紝鏄剧ず搴忓垪闀垮害
+ if 'sequence_length' in model_info:
+ st.metric("搴忓垪闀垮害", model_info['sequence_length'])
+
+ # 淇濆瓨妯″瀷淇℃伅鍒颁細璇濈姸鎬�
+ st.session_state['selected_model'] = model_info
+ st.session_state['selected_model_file'] = selected_model_file
+
+ # 棰勬祴鍔熻兘鍖哄煙
+ st.subheader("馃敭 绫抽噸棰勬祴")
+
+ if st.session_state['selected_model']:
+ model_info = st.session_state['selected_model']
+
+ # 鑾峰彇妯″瀷闇�瑕佺殑鐗瑰緛
+ required_features = model_info['features']
+
+ # 鍒涘缓棰勬祴琛ㄥ崟
+ st.write("杈撳叆鐗瑰緛鍊艰繘琛岀背閲嶉娴�:")
+ predict_cols = st.columns(2)
+ input_features = {}
+
+ # 鏄剧ず杈撳叆琛ㄥ崟
+ for i, feature in enumerate(required_features):
+ with predict_cols[i % 2]:
+ input_features[feature] = st.number_input(
+ f"{feature}",
+ key=f"pred_{feature}",
+ value=0.0,
+ step=0.0001,
+ format="%.4f"
+ )
+
+ # 棰勬祴鎸夐挳
+ if st.button("馃殌 寮�濮嬮娴�"):
+ try:
+ # 鍑嗗棰勬祴鏁版嵁
+ input_df = pd.DataFrame([input_features])
+
+ # 鏍规嵁妯″瀷绫诲瀷鎵ц涓嶅悓鐨勯娴嬮�昏緫
+ predicted_weight = None
+
+ # 鑾峰彇妯″瀷
+ model = model_info['model']
+
+ # 妫�鏌ユā鍨嬬被鍨嬪苟鎵ц棰勬祴
+ if model_info['model_type'] in ['LSTM', 'GRU', 'BiLSTM']:
+ # 娣卞害瀛︿範妯″瀷棰勬祴
+ if not TORCH_AVAILABLE:
+ st.error("PyTorch 鏈畨瑁咃紝鏃犳硶浣跨敤娣卞害瀛︿範妯″瀷杩涜棰勬祴銆�")
+ return
+
+ # 鏁版嵁鏍囧噯鍖�
+ scaler_X = model_info['scaler_X']
+ scaler_y = model_info['scaler_y']
+ input_scaled = scaler_X.transform(input_df)
+
+ # 鑾峰彇搴忓垪闀垮害
+ sequence_length = model_info['sequence_length']
+
+ # 涓烘繁搴﹀涔犳ā鍨嬪垱寤哄簭鍒�
+ input_seq = np.tile(input_scaled, (sequence_length, 1)).reshape(1, sequence_length, -1)
+
+ # 杞崲涓篜yTorch寮犻噺
+ import torch
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ input_tensor = torch.tensor(input_seq, dtype=torch.float32).to(device)
+
+ # 棰勬祴
+ model.eval()
+ with torch.no_grad():
+ y_pred_scaled_tensor = model(input_tensor)
+ y_pred_scaled = y_pred_scaled_tensor.cpu().numpy().ravel()[0]
+
+ # 鍙嶅綊涓�鍖�
+ predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0]
+
+ elif model_info['model_type'] in ['SVR', 'MLP']:
+ # 鏀寔鍚戦噺鏈烘垨澶氬眰鎰熺煡鍣ㄩ娴�
+
+ # 鏁版嵁鏍囧噯鍖�
+ scaler_X = model_info['scaler_X']
+ scaler_y = model_info['scaler_y']
+ input_scaled = scaler_X.transform(input_df)
+
+ # 棰勬祴
+ y_pred_scaled = model.predict(input_scaled)[0]
+
+ # 鍙嶅綊涓�鍖�
+ predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0]
+
+ else:
+ # 鍏朵粬妯″瀷锛堝闅忔満妫灄銆佹搴︽彁鍗囥�佺嚎鎬у洖褰掔瓑锛�
+ predicted_weight = model.predict(input_df)[0]
+
+ # 鏄剧ず棰勬祴缁撴灉
+ st.success(f"棰勬祴绫抽噸: {predicted_weight:.4f} Kg/m")
+
+
+ except Exception as e:
+ st.error(f"棰勬祴澶辫触: {str(e)}")
+ else:
+ st.warning("璇峰厛閫夋嫨涓�涓ā鍨嬨��")
+
+ # 妯″瀷绠$悊鍖哄煙
+ if model_files:
+ with st.expander("馃棏锔� 妯″瀷绠$悊", expanded=False):
+ st.write("绠$悊宸蹭繚瀛樼殑妯″瀷鏂囦欢:")
+
+ # 鏄剧ず鎵�鏈夋ā鍨嬫枃浠�
+ for model_file in model_files:
+ cols = st.columns([3, 1, 1])
+ cols[0].write(model_file)
+
+ # 鏌ョ湅妯″瀷淇℃伅鎸夐挳
+ if cols[1].button("鏌ョ湅", key=f"view_{model_file}", help="鏌ョ湅妯″瀷淇℃伅"):
+ model_path = os.path.join(model_dir, model_file)
+ model_info = joblib.load(model_path)
+ st.write("妯″瀷璇︾粏淇℃伅:")
+ st.json({
+ 'model_type': model_info['model_type'],
+ 'created_at': model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'),
+ 'r2_score': f"{model_info['r2_score']:.4f}",
+ 'mse': f"{model_info['mse']:.6f}",
+ 'mae': f"{model_info['mae']:.6f}",
+ 'rmse': f"{model_info['rmse']:.6f}",
+ 'features': model_info['features'],
+ 'use_steady_data': model_info.get('use_steady_data', False)
+ })
+
+ # 鍒犻櫎妯″瀷鎸夐挳
+ if cols[2].button("鍒犻櫎", key=f"delete_{model_file}", help="鍒犻櫎妯″瀷鏂囦欢", type="primary"):
+ model_path = os.path.join(model_dir, model_file)
+ os.remove(model_path)
+ st.success(f"宸插垹闄ゆā鍨�: {model_file}")
+ st.rerun()
+
+# 椤甸潰鍏ュ彛
+if __name__ == "__main__":
+ show_metered_weight_prediction()
--
Gitblit v1.9.3