From 6628f663b636675bcaea316f2deaddf337de480e Mon Sep 17 00:00:00 2001
From: baoshiwei <baoshiwei@shlanbao.cn>
Date: 星期五, 13 三月 2026 10:23:31 +0800
Subject: [PATCH] feat(米重分析): 新增稳态识别和预测功能页面并优化现有模型

---
 app/pages/metered_weight_prediction.py |  208 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 208 insertions(+), 0 deletions(-)

diff --git a/app/pages/metered_weight_prediction.py b/app/pages/metered_weight_prediction.py
new file mode 100644
index 0000000..a413ea8
--- /dev/null
+++ b/app/pages/metered_weight_prediction.py
@@ -0,0 +1,208 @@
+import streamlit as st
+import plotly.express as px
+import plotly.graph_objects as go
+import pandas as pd
+import numpy as np
+import joblib
+import os
+from datetime import datetime
+
+# 灏濊瘯瀵煎叆torch锛屽鏋滃け璐ュ垯绂佺敤娣卞害瀛︿範妯″瀷鏀寔
+try:
+    import torch
+    TORCH_AVAILABLE = True
+except ImportError:
+    TORCH_AVAILABLE = False
+
+# 椤甸潰鍑芥暟瀹氫箟
+def show_metered_weight_prediction():
+    # 椤甸潰鏍囬
+    st.title("绫抽噸缁熶竴棰勬祴")
+    
+    # 鍒濆鍖栦細璇濈姸鎬�
+    if 'selected_model' not in st.session_state:
+        st.session_state['selected_model'] = None
+    
+    # 鍒涘缓妯″瀷鐩綍锛堝鏋滀笉瀛樺湪锛�
+    model_dir = "saved_models"
+    os.makedirs(model_dir, exist_ok=True)
+    
+    # 鑾峰彇鎵�鏈夊凡淇濆瓨鐨勬ā鍨嬫枃浠�
+    model_files = [f for f in os.listdir(model_dir) if f.endswith('.joblib')]
+    model_files.sort(reverse=True)  # 鏈�鏂扮殑妯″瀷鎺掑湪鍓嶉潰
+    
+    # 妯″瀷閫夋嫨鍖哄煙
+    with st.expander("馃搧 閫夋嫨妯″瀷", expanded=True):
+        if not model_files:
+            st.warning("灏氭湭淇濆瓨浠讳綍妯″瀷锛岃鍏堣缁冩ā鍨嬪苟淇濆瓨銆�")
+        else:
+            # 妯″瀷閫夋嫨涓嬫媺妗�
+            selected_model_file = st.selectbox(
+                "閫夋嫨宸蹭繚瀛樼殑妯″瀷",
+                options=model_files,
+                help="閫夋嫨瑕佺敤浜庨娴嬬殑妯″瀷鏂囦欢"
+            )
+            
+            # 鍔犺浇骞舵樉绀烘ā鍨嬩俊鎭�
+            if selected_model_file:
+                model_path = os.path.join(model_dir, selected_model_file)
+                model_info = joblib.load(model_path)
+                
+                # 鏄剧ず妯″瀷鍩烘湰淇℃伅
+                st.subheader("馃搳 妯″瀷淇℃伅")
+                info_cols = st.columns(2)
+                
+                with info_cols[0]:
+                    st.metric("妯″瀷绫诲瀷", model_info['model_type'])
+                    st.metric("鍒涘缓鏃堕棿", model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'))
+                    st.metric("浣跨敤绋虫�佹暟鎹�", "鏄�" if model_info.get('use_steady_data', False) else "鍚�")
+                
+                with info_cols[1]:
+                    st.metric("R虏 寰楀垎", f"{model_info['r2_score']:.4f}")
+                    st.metric("鍧囨柟璇樊 (MSE)", f"{model_info['mse']:.6f}")
+                    st.metric("鍧囨柟鏍硅宸� (RMSE)", f"{model_info['rmse']:.6f}")
+                
+                # 鏄剧ず妯″瀷鐗瑰緛
+                st.write("馃攽 妯″瀷浣跨敤鐨勭壒寰�:")
+                st.code(", ".join(model_info['features']))
+                
+                # 濡傛灉鏄繁搴﹀涔犳ā鍨嬶紝鏄剧ず搴忓垪闀垮害
+                if 'sequence_length' in model_info:
+                    st.metric("搴忓垪闀垮害", model_info['sequence_length'])
+                
+                # 淇濆瓨妯″瀷淇℃伅鍒颁細璇濈姸鎬�
+                st.session_state['selected_model'] = model_info
+                st.session_state['selected_model_file'] = selected_model_file
+    
+    # 棰勬祴鍔熻兘鍖哄煙
+    st.subheader("馃敭 绫抽噸棰勬祴")
+    
+    if st.session_state['selected_model']:
+        model_info = st.session_state['selected_model']
+        
+        # 鑾峰彇妯″瀷闇�瑕佺殑鐗瑰緛
+        required_features = model_info['features']
+        
+        # 鍒涘缓棰勬祴琛ㄥ崟
+        st.write("杈撳叆鐗瑰緛鍊艰繘琛岀背閲嶉娴�:")
+        predict_cols = st.columns(2)
+        input_features = {}
+        
+        # 鏄剧ず杈撳叆琛ㄥ崟
+        for i, feature in enumerate(required_features):
+            with predict_cols[i % 2]:
+                input_features[feature] = st.number_input(
+                    f"{feature}",
+                    key=f"pred_{feature}",
+                    value=0.0,
+                    step=0.0001,
+                    format="%.4f"
+                )
+        
+        # 棰勬祴鎸夐挳
+        if st.button("馃殌 寮�濮嬮娴�"):
+            try:
+                # 鍑嗗棰勬祴鏁版嵁
+                input_df = pd.DataFrame([input_features])
+                
+                # 鏍规嵁妯″瀷绫诲瀷鎵ц涓嶅悓鐨勯娴嬮�昏緫
+                predicted_weight = None
+                
+                # 鑾峰彇妯″瀷
+                model = model_info['model']
+                
+                # 妫�鏌ユā鍨嬬被鍨嬪苟鎵ц棰勬祴
+                if model_info['model_type'] in ['LSTM', 'GRU', 'BiLSTM']:
+                    # 娣卞害瀛︿範妯″瀷棰勬祴
+                    if not TORCH_AVAILABLE:
+                        st.error("PyTorch 鏈畨瑁咃紝鏃犳硶浣跨敤娣卞害瀛︿範妯″瀷杩涜棰勬祴銆�")
+                        return
+                    
+                    # 鏁版嵁鏍囧噯鍖�
+                    scaler_X = model_info['scaler_X']
+                    scaler_y = model_info['scaler_y']
+                    input_scaled = scaler_X.transform(input_df)
+                    
+                    # 鑾峰彇搴忓垪闀垮害
+                    sequence_length = model_info['sequence_length']
+                    
+                    # 涓烘繁搴﹀涔犳ā鍨嬪垱寤哄簭鍒�
+                    input_seq = np.tile(input_scaled, (sequence_length, 1)).reshape(1, sequence_length, -1)
+                    
+                    # 杞崲涓篜yTorch寮犻噺
+                    import torch
+                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+                    input_tensor = torch.tensor(input_seq, dtype=torch.float32).to(device)
+                    
+                    # 棰勬祴
+                    model.eval()
+                    with torch.no_grad():
+                        y_pred_scaled_tensor = model(input_tensor)
+                        y_pred_scaled = y_pred_scaled_tensor.cpu().numpy().ravel()[0]
+                        
+                        # 鍙嶅綊涓�鍖�
+                        predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0]
+                    
+                elif model_info['model_type'] in ['SVR', 'MLP']:
+                    # 鏀寔鍚戦噺鏈烘垨澶氬眰鎰熺煡鍣ㄩ娴�
+                    
+                    # 鏁版嵁鏍囧噯鍖�
+                    scaler_X = model_info['scaler_X']
+                    scaler_y = model_info['scaler_y']
+                    input_scaled = scaler_X.transform(input_df)
+                    
+                    # 棰勬祴
+                    y_pred_scaled = model.predict(input_scaled)[0]
+                    
+                    # 鍙嶅綊涓�鍖�
+                    predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0]
+                    
+                else:
+                    # 鍏朵粬妯″瀷锛堝闅忔満妫灄銆佹搴︽彁鍗囥�佺嚎鎬у洖褰掔瓑锛�
+                    predicted_weight = model.predict(input_df)[0]
+                
+                # 鏄剧ず棰勬祴缁撴灉
+                st.success(f"棰勬祴绫抽噸: {predicted_weight:.4f} Kg/m")
+                
+                
+            except Exception as e:
+                st.error(f"棰勬祴澶辫触: {str(e)}")
+    else:
+        st.warning("璇峰厛閫夋嫨涓�涓ā鍨嬨��")
+    
+    # 妯″瀷绠$悊鍖哄煙
+    if model_files:
+        with st.expander("馃棏锔� 妯″瀷绠$悊", expanded=False):
+            st.write("绠$悊宸蹭繚瀛樼殑妯″瀷鏂囦欢:")
+            
+            # 鏄剧ず鎵�鏈夋ā鍨嬫枃浠�
+            for model_file in model_files:
+                cols = st.columns([3, 1, 1])
+                cols[0].write(model_file)
+                
+                # 鏌ョ湅妯″瀷淇℃伅鎸夐挳
+                if cols[1].button("鏌ョ湅", key=f"view_{model_file}", help="鏌ョ湅妯″瀷淇℃伅"):
+                    model_path = os.path.join(model_dir, model_file)
+                    model_info = joblib.load(model_path)
+                    st.write("妯″瀷璇︾粏淇℃伅:")
+                    st.json({
+                        'model_type': model_info['model_type'],
+                        'created_at': model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'),
+                        'r2_score': f"{model_info['r2_score']:.4f}",
+                        'mse': f"{model_info['mse']:.6f}",
+                        'mae': f"{model_info['mae']:.6f}",
+                        'rmse': f"{model_info['rmse']:.6f}",
+                        'features': model_info['features'],
+                        'use_steady_data': model_info.get('use_steady_data', False)
+                    })
+                
+                # 鍒犻櫎妯″瀷鎸夐挳
+                if cols[2].button("鍒犻櫎", key=f"delete_{model_file}", help="鍒犻櫎妯″瀷鏂囦欢", type="primary"):
+                    model_path = os.path.join(model_dir, model_file)
+                    os.remove(model_path)
+                    st.success(f"宸插垹闄ゆā鍨�: {model_file}")
+                    st.rerun()
+
+# 椤甸潰鍏ュ彛
+if __name__ == "__main__":
+    show_metered_weight_prediction()

--
Gitblit v1.9.3