From 6628f663b636675bcaea316f2deaddf337de480e Mon Sep 17 00:00:00 2001
From: baoshiwei <baoshiwei@shlanbao.cn>
Date: 星期五, 13 三月 2026 10:23:31 +0800
Subject: [PATCH] feat(米重分析): 新增稳态识别和预测功能页面并优化现有模型

---
 app/pages/metered_weight_advanced pytorch.py |  873 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 873 insertions(+), 0 deletions(-)

diff --git a/app/pages/metered_weight_advanced pytorch.py b/app/pages/metered_weight_advanced pytorch.py
new file mode 100644
index 0000000..2f35692
--- /dev/null
+++ b/app/pages/metered_weight_advanced pytorch.py
@@ -0,0 +1,873 @@
+import streamlit as st
+import plotly.express as px
+import plotly.graph_objects as go
+import pandas as pd
+import numpy as np
+from datetime import datetime, timedelta
+from app.services.extruder_service import ExtruderService
+from app.services.main_process_service import MainProcessService
+from sklearn.preprocessing import StandardScaler, MinMaxScaler
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
+from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
+from sklearn.svm import SVR
+from sklearn.neural_network import MLPRegressor
+
+# 灏濊瘯瀵煎叆娣卞害瀛︿範搴�
+use_deep_learning = False
+try:
+    
+    import torch
+    import torch.nn as nn
+    import torch.optim as optim
+    use_deep_learning = True
+    # 妫�娴婫PU鏄惁鍙敤
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    st.success(f"浣跨敤璁惧: {device}")
+except ImportError:
+    st.warning("鏈娴嬪埌PyTorch锛屾繁搴﹀涔犳ā鍨嬪皢涓嶅彲鐢ㄣ�傝瀹夎pytorch浠ヤ娇鐢↙STM/GRU妯″瀷銆�")
+
+
+# PyTorch娣卞害瀛︿範妯″瀷瀹氫箟
+class LSTMModel(nn.Module):
+    def __init__(self, input_dim, hidden_dim=64, num_layers=2):
+        super(LSTMModel, self).__init__()
+        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
+        self.fc1 = nn.Linear(hidden_dim, 32)
+        self.dropout = nn.Dropout(0.2)
+        self.fc2 = nn.Linear(32, 1)
+    
+    def forward(self, x):
+        out, _ = self.lstm(x)
+        out = out[:, -1, :]
+        out = torch.relu(self.fc1(out))
+        out = self.dropout(out)
+        out = self.fc2(out)
+        return out
+
+class GRUModel(nn.Module):
+    def __init__(self, input_dim, hidden_dim=64, num_layers=2):
+        super(GRUModel, self).__init__()
+        self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
+        self.fc1 = nn.Linear(hidden_dim, 32)
+        self.dropout = nn.Dropout(0.2)
+        self.fc2 = nn.Linear(32, 1)
+    
+    def forward(self, x):
+        out, _ = self.gru(x)
+        out = out[:, -1, :]
+        out = torch.relu(self.fc1(out))
+        out = self.dropout(out)
+        out = self.fc2(out)
+        return out
+
+class BiLSTMModel(nn.Module):
+    def __init__(self, input_dim, hidden_dim=64, num_layers=2):
+        super(BiLSTMModel, self).__init__()
+        self.bilstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
+        self.fc1 = nn.Linear(hidden_dim * 2, 32)
+        self.dropout = nn.Dropout(0.2)
+        self.fc2 = nn.Linear(32, 1)
+    
+    def forward(self, x):
+        out, _ = self.bilstm(x)
+        out = out[:, -1, :]
+        out = torch.relu(self.fc1(out))
+        out = self.dropout(out)
+        out = self.fc2(out)
+        return out
+
+def show_metered_weight_advanced():
+    # 鍒濆鍖栨湇鍔�
+    extruder_service = ExtruderService()
+    main_process_service = MainProcessService()
+
+    # 椤甸潰鏍囬
+    st.title("绫抽噸楂樼骇棰勬祴鍒嗘瀽")
+
+    # 鍒濆鍖栦細璇濈姸鎬�
+    if 'ma_start_date' not in st.session_state:
+        st.session_state['ma_start_date'] = datetime.now().date() - timedelta(days=7)
+    if 'ma_end_date' not in st.session_state:
+        st.session_state['ma_end_date'] = datetime.now().date()
+    if 'ma_quick_select' not in st.session_state:
+        st.session_state['ma_quick_select'] = "鏈�杩�7澶�"
+    if 'ma_model_type' not in st.session_state:
+        st.session_state['ma_model_type'] = 'RandomForest'
+    if 'ma_sequence_length' not in st.session_state:
+        st.session_state['ma_sequence_length'] = 10
+    
+    # 榛樿鐗瑰緛鍒楄〃锛堜笉鍐嶅厑璁哥敤鎴烽�夋嫨锛�
+    default_features = ['铻烘潌杞��', '鏈哄ご鍘嬪姏', '娴佺▼涓婚��', '铻烘潌娓╁害', 
+                       '鍚庢満绛掓俯搴�', '鍓嶆満绛掓俯搴�', '鏈哄ご娓╁害']
+
+    # 瀹氫箟鍥炶皟鍑芥暟
+    def update_dates(qs):
+        st.session_state['ma_quick_select'] = qs
+        today = datetime.now().date()
+        if qs == "浠婂ぉ":
+            st.session_state['ma_start_date'] = today
+            st.session_state['ma_end_date'] = today
+        elif qs == "鏈�杩�3澶�":
+            st.session_state['ma_start_date'] = today - timedelta(days=3)
+            st.session_state['ma_end_date'] = today
+        elif qs == "鏈�杩�7澶�":
+            st.session_state['ma_start_date'] = today - timedelta(days=7)
+            st.session_state['ma_end_date'] = today
+        elif qs == "鏈�杩�30澶�":
+            st.session_state['ma_start_date'] = today - timedelta(days=30)
+            st.session_state['ma_end_date'] = today
+
+    def on_date_change():
+        st.session_state['ma_quick_select'] = "鑷畾涔�"
+
+    # 鏌ヨ鏉′欢鍖哄煙
+    with st.expander("馃攳 鏌ヨ閰嶇疆", expanded=True):
+        # 娣诲姞鑷畾涔� CSS 瀹炵幇鍝嶅簲寮忔崲琛�
+        st.markdown("""
+            <style>
+            /* 寮哄埗鍒楀鍣ㄦ崲琛� */
+            [data-testid="stExpander"] [data-testid="column"] {
+                flex: 1 1 120px !important;
+                min-width: 120px !important;
+            }
+            /* 閽堝鏃ユ湡杈撳叆妗嗗垪绋嶅井鍔犲涓�鐐� */
+            @media (min-width: 768px) {
+                [data-testid="stExpander"] [data-testid="column"]:nth-child(6),
+                [data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
+                    flex: 2 1 180px !important;
+                    min-width: 180px !important;
+                }
+            }
+            </style>
+            """, unsafe_allow_html=True)
+
+        # 鍒涘缓甯冨眬
+        cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
+
+        options = ["浠婂ぉ", "鏈�杩�3澶�", "鏈�杩�7澶�", "鏈�杩�30澶�", "鑷畾涔�"]
+        for i, option in enumerate(options):
+            with cols[i]:
+                # 鏍规嵁褰撳墠閫夋嫨鐘舵�佸喅瀹氭寜閽被鍨�
+                button_type = "primary" if st.session_state['ma_quick_select'] == option else "secondary"
+                if st.button(option, key=f"btn_ma_{option}", width='stretch', type=button_type):
+                    update_dates(option)
+                    st.rerun()
+
+        with cols[5]:
+            start_date = st.date_input(
+                "寮�濮嬫棩鏈�",
+                label_visibility="collapsed",
+                key="ma_start_date",
+                on_change=on_date_change
+            )
+
+        with cols[6]:
+            end_date = st.date_input(
+                "缁撴潫鏃ユ湡",
+                label_visibility="collapsed",
+                key="ma_end_date",
+                on_change=on_date_change
+            )
+
+        with cols[7]:
+            query_button = st.button("馃殌 寮�濮嬪垎鏋�", key="ma_query", width='stretch')
+
+        # 妯″瀷閰嶇疆
+        st.markdown("---")
+        st.write("馃 **妯″瀷閰嶇疆**")
+        model_cols = st.columns(2)
+        
+        with model_cols[0]:
+            # 妯″瀷绫诲瀷閫夋嫨
+            model_options = ['RandomForest', 'GradientBoosting', 'SVR', 'MLP']
+            if use_deep_learning:
+                model_options.extend(['LSTM', 'GRU', 'BiLSTM'])
+            
+            model_type = st.selectbox(
+                "妯″瀷绫诲瀷",
+                options=model_options,
+                key="ma_model_type",
+                help="閫夋嫨鐢ㄤ簬棰勬祴鐨勬ā鍨嬬被鍨�"
+            )
+        
+        with model_cols[1]:
+            # 搴忓垪闀垮害锛堜粎閫傜敤浜庢繁搴﹀涔犳ā鍨嬶級
+            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
+                sequence_length = st.slider(
+                    "搴忓垪闀垮害",
+                    min_value=5,
+                    max_value=30,
+                    value=st.session_state['ma_sequence_length'],
+                    step=1,
+                    help="鐢ㄤ簬娣卞害瀛︿範妯″瀷鐨勬椂闂村簭鍒楅暱搴�",
+                    key="ma_sequence_length"
+                )
+            else:
+                st.session_state['ma_sequence_length'] = 10
+                st.write("搴忓垪闀垮害: 10 (榛樿锛屼粎閫傜敤浜庢繁搴﹀涔犳ā鍨�)")
+
+    # 杞崲涓篸atetime瀵硅薄
+    start_dt = datetime.combine(start_date, datetime.min.time())
+    end_dt = datetime.combine(end_date, datetime.max.time())
+
+    # 鏌ヨ澶勭悊
+    if query_button:
+        with st.spinner("姝e湪鑾峰彇鏁版嵁..."):
+            # 1. 鑾峰彇瀹屾暣鐨勬尋鍑烘満鏁版嵁
+            df_extruder_full = extruder_service.get_extruder_data(start_dt, end_dt)
+
+            # 2. 鑾峰彇涓绘祦绋嬫帶鍒舵暟鎹�
+            df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
+
+            df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
+
+            # 妫�鏌ユ槸鍚︽湁鏁版嵁
+            has_data = any([
+                df_extruder_full is not None and not df_extruder_full.empty,
+                df_main_speed is not None and not df_main_speed.empty,
+                df_temp is not None and not df_temp.empty
+            ])
+
+            if not has_data:
+                st.warning("鎵�閫夋椂闂存鍐呮湭鎵惧埌浠讳綍鏁版嵁锛岃灏濊瘯璋冩暣鏌ヨ鏉′欢銆�")
+                # 娓呴櫎缂撳瓨鏁版嵁
+                for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end']:
+                    if key in st.session_state:
+                        del st.session_state[key]
+                return
+
+            # 缂撳瓨鏁版嵁鍒颁細璇濈姸鎬�
+            st.session_state['cached_extruder_full'] = df_extruder_full
+            st.session_state['cached_main_speed'] = df_main_speed
+            st.session_state['cached_temp'] = df_temp
+            st.session_state['last_query_start'] = start_dt
+            st.session_state['last_query_end'] = end_dt
+
+    # 鏁版嵁澶勭悊鍜屽垎鏋�
+    if all(key in st.session_state for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp']):
+        with st.spinner("姝e湪鍒嗘瀽鏁版嵁..."):
+            # 鑾峰彇缂撳瓨鏁版嵁
+            df_extruder_full = st.session_state['cached_extruder_full']
+            df_main_speed = st.session_state['cached_main_speed']
+            df_temp = st.session_state['cached_temp']
+
+           
+
+            # 妫�鏌ユ槸鍚︽湁鏁版嵁
+            has_data = any([
+                df_extruder_full is not None and not df_extruder_full.empty,
+                df_main_speed is not None and not df_main_speed.empty,
+                df_temp is not None and not df_temp.empty
+            ])
+
+            if not has_data:
+                st.warning("鎵�閫夋椂闂存鍐呮湭鎵惧埌浠讳綍鏁版嵁锛岃灏濊瘯璋冩暣鏌ヨ鏉′欢銆�")
+                return
+
+            # 鏁版嵁鏁村悎涓庨澶勭悊
+            def integrate_data(df_extruder_full, df_main_speed, df_temp):
+                # 纭繚鎸ゅ嚭鏈烘暟鎹瓨鍦�
+                if df_extruder_full is None or df_extruder_full.empty:
+                    return None
+
+                # 鍒涘缓鍙寘鍚背閲嶅拰鏃堕棿鐨勪富鏁版嵁闆�
+                df_merged = df_extruder_full[['time', 'metered_weight', 'screw_speed_actual', 'head_pressure']].copy()
+              
+
+                # 鏁村悎涓绘祦绋嬫暟鎹�
+                if df_main_speed is not None and not df_main_speed.empty:
+                    df_main_speed = df_main_speed[['time', 'process_main_speed']]
+                    df_merged = pd.merge_asof(
+                        df_merged.sort_values('time'),
+                        df_main_speed.sort_values('time'),
+                        on='time',
+                        direction='nearest',
+                        tolerance=pd.Timedelta('1min')
+                    )
+
+                # 鏁村悎娓╁害鏁版嵁
+                if df_temp is not None and not df_temp.empty:
+                    temp_cols = ['time', 'nakata_extruder_screw_display_temp',
+                               'nakata_extruder_rear_barrel_display_temp',
+                               'nakata_extruder_front_barrel_display_temp',
+                               'nakata_extruder_head_display_temp']
+                    df_temp_subset = df_temp[temp_cols].copy()
+                    df_merged = pd.merge_asof(
+                        df_merged.sort_values('time'),
+                        df_temp_subset.sort_values('time'),
+                        on='time',
+                        direction='nearest',
+                        tolerance=pd.Timedelta('1min')
+                    )
+
+                # 閲嶅懡鍚嶅垪浠ユ彁楂樺彲璇绘��
+                df_merged.rename(columns={
+                    'screw_speed_actual': '铻烘潌杞��',
+                    'head_pressure': '鏈哄ご鍘嬪姏',
+                    'process_main_speed': '娴佺▼涓婚��',
+                    'nakata_extruder_screw_display_temp': '铻烘潌娓╁害',
+                    'nakata_extruder_rear_barrel_display_temp': '鍚庢満绛掓俯搴�',
+                    'nakata_extruder_front_barrel_display_temp': '鍓嶆満绛掓俯搴�',
+                    'nakata_extruder_head_display_temp': '鏈哄ご娓╁害'
+                }, inplace=True)
+
+                # 娓呯悊鏁版嵁
+                df_merged.dropna(subset=['metered_weight'], inplace=True)
+
+                return df_merged
+
+            # 鎵ц鏁版嵁鏁村悎
+            df_analysis = integrate_data(df_extruder_full, df_main_speed, df_temp)
+
+            if df_analysis is None or df_analysis.empty:
+                st.warning("鏁版嵁鏁村悎澶辫触锛岃妫�鏌ユ暟鎹川閲忔垨璋冩暣鏃堕棿鑼冨洿銆�")
+                return
+
+            # 閲嶅懡鍚嶇背閲嶅垪
+            df_analysis.rename(columns={'metered_weight': '绫抽噸'}, inplace=True)
+
+            # --- 鍘熷鏁版嵁瓒嬪娍鍥� ---
+            st.subheader("馃搱 鍘熷鏁版嵁瓒嬪娍鍥�")
+
+            # 鍒涘缓瓒嬪娍鍥�
+            fig_trend = go.Figure()
+
+            # 娣诲姞绫抽噸鏁版嵁
+            if df_extruder_full is not None and not df_extruder_full.empty:
+                fig_trend.add_trace(go.Scatter(
+                    x=df_extruder_full['time'],
+                    y=df_extruder_full['metered_weight'],
+                    name='绫抽噸 (Kg/m)',
+                    mode='lines',
+                    line=dict(color='blue', width=2)
+                ))
+
+                # 娣诲姞铻烘潌杞��
+                fig_trend.add_trace(go.Scatter(
+                    x=df_extruder_full['time'],
+                    y=df_extruder_full['screw_speed_actual'],
+                    name='铻烘潌杞�� (RPM)',
+                    mode='lines',
+                    line=dict(color='green', width=1.5),
+                    yaxis='y2'
+                ))
+
+                # 娣诲姞鏈哄ご鍘嬪姏
+                fig_trend.add_trace(go.Scatter(
+                    x=df_extruder_full['time'],
+                    y=df_extruder_full['head_pressure'],
+                    name='鏈哄ご鍘嬪姏',
+                    mode='lines',
+                    line=dict(color='orange', width=1.5),
+                    yaxis='y3'
+                ))
+
+            # 娣诲姞娴佺▼涓婚��
+            if df_main_speed is not None and not df_main_speed.empty:
+                fig_trend.add_trace(go.Scatter(
+                    x=df_main_speed['time'],
+                    y=df_main_speed['process_main_speed'],
+                    name='娴佺▼涓婚�� (M/Min)',
+                    mode='lines',
+                    line=dict(color='red', width=1.5),
+                    yaxis='y4'
+                ))
+
+            # 娣诲姞娓╁害鏁版嵁
+            if df_temp is not None and not df_temp.empty:
+                # 铻烘潌娓╁害
+                fig_trend.add_trace(go.Scatter(
+                    x=df_temp['time'],
+                    y=df_temp['nakata_extruder_screw_display_temp'],
+                    name='铻烘潌娓╁害 (掳C)',
+                    mode='lines',
+                    line=dict(color='purple', width=1),
+                    yaxis='y5'
+                ))
+
+            # 閰嶇疆瓒嬪娍鍥惧竷灞�
+            fig_trend.update_layout(
+                title='鍘熷鏁版嵁瓒嬪娍',
+                xaxis=dict(
+                    title='鏃堕棿',
+                    rangeslider=dict(visible=True),
+                    type='date'
+                ),
+                yaxis=dict(
+                    title='绫抽噸 (Kg/m)',
+                    title_font=dict(color='blue'),
+                    tickfont=dict(color='blue')
+                ),
+                yaxis2=dict(
+                    title='铻烘潌杞�� (RPM)',
+                    title_font=dict(color='green'),
+                    tickfont=dict(color='green'),
+                    overlaying='y',
+                    side='right'
+                ),
+                yaxis3=dict(
+                    title='鏈哄ご鍘嬪姏',
+                    title_font=dict(color='orange'),
+                    tickfont=dict(color='orange'),
+                    overlaying='y',
+                    side='right',
+                    anchor='free',
+                    position=0.85
+                ),
+                yaxis4=dict(
+                    title='娴佺▼涓婚�� (M/Min)',
+                    title_font=dict(color='red'),
+                    tickfont=dict(color='red'),
+                    overlaying='y',
+                    side='right',
+                    anchor='free',
+                    position=0.75
+                ),
+                yaxis5=dict(
+                    title='娓╁害 (掳C)',
+                    title_font=dict(color='purple'),
+                    tickfont=dict(color='purple'),
+                    overlaying='y',
+                    side='left',
+                    anchor='free',
+                    position=0.15
+                ),
+                legend=dict(
+                    orientation="h",
+                    yanchor="bottom",
+                    y=1.02,
+                    xanchor="right",
+                    x=1
+                ),
+                height=600,
+                margin=dict(l=100, r=200, t=100, b=100),
+                hovermode='x unified'
+            )
+
+            # 鏄剧ず瓒嬪娍鍥�
+            st.plotly_chart(fig_trend, width='stretch', config={'scrollZoom': True})
+
+            # --- 楂樼骇棰勬祴鍒嗘瀽 ---
+            st.subheader("馃搳 楂樼骇棰勬祴鍒嗘瀽")
+
+            # 妫�鏌ユ墍鏈夐粯璁ょ壒寰佹槸鍚﹀湪鏁版嵁涓�
+            missing_features = [f for f in default_features if f not in df_analysis.columns]
+            if missing_features:
+                st.warning(f"鏁版嵁涓己灏戜互涓嬬壒寰�: {', '.join(missing_features)}")
+            else:
+                # 鍑嗗鏁版嵁
+                X = df_analysis[default_features]
+                y = df_analysis['绫抽噸']
+
+                # 娓呯悊鏁版嵁涓殑NaN鍊�
+                combined = pd.concat([X, y], axis=1)
+                combined_clean = combined.dropna()
+                
+                # 妫�鏌ユ竻鐞嗗悗鐨勬暟鎹噺
+                if len(combined_clean) < 30:
+                    st.warning("鏁版嵁閲忎笉瓒筹紝鏃犳硶杩涜鏈夋晥鐨勯娴嬪垎鏋�")
+                else:
+                    # 閲嶆柊鍒嗙X鍜寉
+                    X_clean = combined_clean[default_features]
+                    y_clean = combined_clean['绫抽噸']
+                    
+                    # 鐗瑰緛宸ョ▼锛氭坊鍔犳椂闂寸浉鍏崇壒寰�
+                    # 纭繚浣跨敤鏃堕棿鍒椾綔涓虹储寮�
+                    if 'time' in combined_clean.columns:
+                        # 灏唗ime鍒楄缃负绱㈠紩
+                        combined_clean = combined_clean.set_index('time')
+                        
+                        # 鍒涘缓鏃堕棿鐗瑰緛
+                        time_features = pd.DataFrame(index=combined_clean.index)
+                        time_features['hour'] = combined_clean.index.hour
+                        time_features['minute'] = combined_clean.index.minute
+                        time_features['second'] = combined_clean.index.second
+                        time_features['time_of_day'] = time_features['hour'] * 3600 + time_features['minute'] * 60 + time_features['second']
+                    else:
+                        # 濡傛灉娌℃湁time鍒楋紝鍒涘缓绌虹殑鏃堕棿鐗瑰緛
+                        time_features = pd.DataFrame(index=combined_clean.index)
+                        time_features['hour'] = 0
+                        time_features['minute'] = 0
+                        time_features['second'] = 0
+                        time_features['time_of_day'] = 0
+                    
+                    # 娣诲姞婊炲悗鐗瑰緛
+                    for feature in default_features:
+                        for lag in [1, 2, 3]:
+                            time_features[f'{feature}_lag{lag}'] = X_clean[feature].shift(lag)
+                            time_features[f'{feature}_diff{lag}'] = X_clean[feature].diff(lag)
+                    
+                    # 娣诲姞婊氬姩缁熻鐗瑰緛
+                    for feature in default_features:
+                        time_features[f'{feature}_rolling_mean'] = X_clean[feature].rolling(window=5).mean()
+                        time_features[f'{feature}_rolling_std'] = X_clean[feature].rolling(window=5).std()
+                        time_features[f'{feature}_rolling_min'] = X_clean[feature].rolling(window=5).min()
+                        time_features[f'{feature}_rolling_max'] = X_clean[feature].rolling(window=5).max()
+                    
+                    # 娓呯悊婊炲悗鐗瑰緛鍜屾粴鍔ㄧ粺璁$壒寰佷骇鐢熺殑NaN鍊�
+                    time_features.dropna(inplace=True)
+                    
+                    # 瀵归綈X鍜寉
+                    common_index = time_features.index.intersection(y_clean.index)
+                    X_final = pd.concat([X_clean.loc[common_index], time_features.loc[common_index]], axis=1)
+                    y_final = y_clean.loc[common_index]
+                    
+                    # 妫�鏌ユ渶缁堟暟鎹噺
+                    if len(X_final) < 20:
+                        st.warning("鐗瑰緛宸ョ▼鍚庢暟鎹噺涓嶈冻锛屾棤娉曡繘琛屾湁鏁堢殑棰勬祴鍒嗘瀽")
+                    else:
+                        # 鍒嗗壊璁粌闆嗗拰娴嬭瘯闆�
+                        X_train, X_test, y_train, y_test = train_test_split(X_final, y_final, test_size=0.2, random_state=42)
+                        
+                        # 鏁版嵁鏍囧噯鍖�
+                        scaler_X = StandardScaler()
+                        scaler_y = MinMaxScaler()
+                        
+                        X_train_scaled = scaler_X.fit_transform(X_train)
+                        X_test_scaled = scaler_X.transform(X_test)
+                        y_train_scaled = scaler_y.fit_transform(y_train.values.reshape(-1, 1)).ravel()
+                        y_test_scaled = scaler_y.transform(y_test.values.reshape(-1, 1)).ravel()
+                        
+                        # 妯″瀷璁粌
+                        model = None
+                        y_pred = None
+                        
+                        try:
+                            if model_type == 'RandomForest':
+                                # 闅忔満妫灄鍥炲綊
+                                model = RandomForestRegressor(n_estimators=100, random_state=42)
+                                model.fit(X_train, y_train)
+                                y_pred = model.predict(X_test)
+                            
+                            elif model_type == 'GradientBoosting':
+                                # 姊害鎻愬崌鍥炲綊
+                                model = GradientBoostingRegressor(n_estimators=100, random_state=42)
+                                model.fit(X_train, y_train)
+                                y_pred = model.predict(X_test)
+                            
+                            elif model_type == 'SVR':
+                                # 鏀寔鍚戦噺鍥炲綊
+                                model = SVR(kernel='rbf', C=1.0, gamma='scale')
+                                model.fit(X_train_scaled, y_train_scaled)
+                                y_pred_scaled = model.predict(X_test_scaled)
+                                y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
+                            
+                            elif model_type == 'MLP':
+                                # 澶氬眰鎰熺煡鍣ㄥ洖褰�
+                                model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
+                                model.fit(X_train_scaled, y_train_scaled)
+                                y_pred_scaled = model.predict(X_test_scaled)
+                                y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
+                            
+                            elif use_deep_learning and model_type in ['LSTM', 'GRU', 'BiLSTM']:
+                                # 鍑嗗鏃堕棿搴忓垪鏁版嵁
+                                sequence_length = st.session_state['ma_sequence_length']
+                                
+                                def create_sequences(X, y, seq_length):
+                                    X_seq = []
+                                    y_seq = []
+                                    # 纭繚X鍜寉鐨勯暱搴︿竴鑷�
+                                    min_len = min(len(X), len(y))
+                                    # 纭繚X鍜寉鐨勯暱搴﹁嚦灏戜负seq_length + 1
+                                    if min_len <= seq_length:
+                                        return np.array([]), np.array([])
+                                    # 鎴柇X鍜寉鍒扮浉鍚岄暱搴�
+                                    X_trimmed = X[:min_len]
+                                    y_trimmed = y[:min_len]
+                                    # 鍒涘缓搴忓垪
+                                    for i in range(len(X_trimmed) - seq_length):
+                                        X_seq.append(X_trimmed[i:i+seq_length])
+                                        y_seq.append(y_trimmed[i+seq_length])
+                                    return np.array(X_seq), np.array(y_seq)
+                                
+                                # 涓烘繁搴﹀涔犳ā鍨嬪垱寤哄簭鍒�
+                                X_train_seq, y_train_seq = create_sequences(X_train_scaled, y_train_scaled, sequence_length)
+                                X_test_seq, y_test_seq = create_sequences(X_test_scaled, y_test_scaled, sequence_length)
+                                
+                                # 妫�鏌ュ垱寤虹殑搴忓垪鏄惁涓虹┖
+                                if len(X_train_seq) == 0 or len(y_train_seq) == 0:
+                                    st.warning(f"鏁版嵁閲忎笉瓒筹紝鏃犳硶鍒涘缓鏈夋晥鐨凩STM搴忓垪銆傞渶瑕佽嚦灏� {sequence_length + 1} 涓牱鏈紝褰撳墠鍙湁 {min(len(X_train_scaled), len(y_train_scaled))} 涓牱鏈��")
+                                    # 浣跨敤闅忔満妫灄浣滀负澶囬�夋ā鍨�
+                                    model = RandomForestRegressor(n_estimators=100, random_state=42)
+                                    model.fit(X_train, y_train)
+                                    y_pred = model.predict(X_test)
+                                else:
+                                    # 杞崲涓篜yTorch寮犻噺骞剁Щ鍔ㄥ埌璁惧
+                                    X_train_tensor = torch.tensor(X_train_seq, dtype=torch.float32).to(device)
+                                    y_train_tensor = torch.tensor(y_train_seq, dtype=torch.float32).unsqueeze(1).to(device)
+                                    X_test_tensor = torch.tensor(X_test_seq, dtype=torch.float32).to(device)
+                                    y_test_tensor = torch.tensor(y_test_seq, dtype=torch.float32).unsqueeze(1).to(device)
+                                    
+                                    # 鏋勫缓PyTorch妯″瀷骞剁Щ鍔ㄥ埌璁惧
+                                    input_dim = X_train_scaled.shape[1]
+                                    
+                                    if model_type == 'LSTM':
+                                        deep_model = LSTMModel(input_dim).to(device)
+                                    elif model_type == 'GRU':
+                                        deep_model = GRUModel(input_dim).to(device)
+                                    elif model_type == 'BiLSTM':
+                                        deep_model = BiLSTMModel(input_dim).to(device)
+                                    
+                                    # 瀹氫箟鎹熷け鍑芥暟鍜屼紭鍖栧櫒
+                                    criterion = nn.MSELoss()
+                                    optimizer = optim.Adam(deep_model.parameters(), lr=0.001)
+                                    
+                                    # 鏄剧ず浣跨敤鐨勮澶�
+                                    st.info(f"浣跨敤璁惧: {device}")
+                                    
+                                    # 璁粌妯″瀷
+                                    num_epochs = 50
+                                    batch_size = 32
+                                    
+                                    for epoch in range(num_epochs):
+                                        deep_model.train()
+                                        optimizer.zero_grad()
+                                        
+                                        # 鍓嶅悜浼犳挱
+                                        outputs = deep_model(X_train_tensor)
+                                        loss = criterion(outputs, y_train_tensor)
+                                        
+                                        # 鍙嶅悜浼犳挱鍜屼紭鍖�
+                                        loss.backward()
+                                        optimizer.step()
+                                    
+                                    # 棰勬祴
+                                    deep_model.eval()
+                                    with torch.no_grad():
+                                        y_pred_scaled_tensor = deep_model(X_test_tensor)
+                                        y_pred_scaled = y_pred_scaled_tensor.numpy().ravel()
+                                        y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
+                                    
+                                    # 灏唝_test_seq杞崲鍥炲師濮嬪昂搴�
+                                    y_test_actual = scaler_y.inverse_transform(y_test_seq.reshape(-1, 1)).ravel()
+                                
+                                # 淇濆瓨妯″瀷
+                                model = deep_model
+                            
+                            # 璁$畻璇勪及鎸囨爣
+                            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
+                                # 浣跨敤杞崲鍚庣殑y_test_seq浣滀负鐪熷疄鍊�
+                                r2 = r2_score(y_test_actual, y_pred)
+                                mse = mean_squared_error(y_test_actual, y_pred)
+                                mae = mean_absolute_error(y_test_actual, y_pred)
+                                rmse = np.sqrt(mse)
+                            else:
+                                # 浣跨敤鍘熷鐨剏_test浣滀负鐪熷疄鍊�
+                                r2 = r2_score(y_test, y_pred)
+                                mse = mean_squared_error(y_test, y_pred)
+                                mae = mean_absolute_error(y_test, y_pred)
+                                rmse = np.sqrt(mse)
+
+                            # 鏄剧ず妯″瀷鎬ц兘
+                            metrics_cols = st.columns(2)
+                            with metrics_cols[0]:
+                                st.metric("R虏 寰楀垎", f"{r2:.4f}")
+                                st.metric("鍧囨柟璇樊 (MSE)", f"{mse:.6f}")
+                            with metrics_cols[1]:
+                                st.metric("骞冲潎缁濆璇樊 (MAE)", f"{mae:.6f}")
+                                st.metric("鍧囨柟鏍硅宸� (RMSE)", f"{rmse:.6f}")
+
+                            # --- 瀹為檯鍊间笌棰勬祴鍊煎姣� ---
+                            st.subheader("馃攧 瀹為檯鍊间笌棰勬祴鍊煎姣�")
+
+                            # 鍒涘缓瀵规瘮鏁版嵁
+                            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
+                                # 浣跨敤杞崲鍚庣殑y_test_actual
+                                compare_df = pd.DataFrame({
+                                    '瀹為檯鍊�': y_test_actual,
+                                    '棰勬祴鍊�': y_pred
+                                })
+                            else:
+                                # 浣跨敤鍘熷鐨剏_test
+                                compare_df = pd.DataFrame({
+                                    '瀹為檯鍊�': y_test,
+                                    '棰勬祴鍊�': y_pred
+                                })
+                            compare_df = compare_df.sort_index()
+
+                            # 鍒涘缓瀵规瘮鍥�
+                            fig_compare = go.Figure()
+                            fig_compare.add_trace(go.Scatter(
+                                x=compare_df.index,
+                                y=compare_df['瀹為檯鍊�'],
+                                name='瀹為檯鍊�',
+                                mode='lines+markers',
+                                line=dict(color='blue', width=2)
+                            ))
+                            fig_compare.add_trace(go.Scatter(
+                                x=compare_df.index,
+                                y=compare_df['棰勬祴鍊�'],
+                                name='棰勬祴鍊�',
+                                mode='lines+markers',
+                                line=dict(color='red', width=2, dash='dash')
+                            ))
+                            fig_compare.update_layout(
+                                title=f'娴嬭瘯闆�: 瀹為檯绫抽噸 vs 棰勬祴绫抽噸 ({model_type})',
+                                xaxis=dict(title='鏃堕棿'),
+                                yaxis=dict(title='绫抽噸 (Kg/m)'),
+                                legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
+                                height=400
+                            )
+                            st.plotly_chart(fig_compare, width='stretch')
+
+                            # --- 娈嬪樊鍒嗘瀽 ---
+                            st.subheader("馃搲 娈嬪樊鍒嗘瀽")
+
+                            # 璁$畻娈嬪樊
+                            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
+                                # 浣跨敤杞崲鍚庣殑y_test_actual
+                                residuals = y_test_actual - y_pred
+                            else:
+                                # 浣跨敤鍘熷鐨剏_test
+                                residuals = y_test - y_pred
+
+                            # 鍒涘缓娈嬪樊鍥�
+                            fig_residual = go.Figure()
+                            fig_residual.add_trace(go.Scatter(
+                                x=y_pred,
+                                y=residuals,
+                                mode='markers',
+                                marker=dict(color='green', size=8, opacity=0.6)
+                            ))
+                            fig_residual.add_shape(
+                                type="line",
+                                x0=y_pred.min(),
+                                y0=0,
+                                x1=y_pred.max(),
+                                y1=0,
+                                line=dict(color="red", width=2, dash="dash")
+                            )
+                            fig_residual.update_layout(
+                                title='娈嬪樊鍥�',
+                                xaxis=dict(title='棰勬祴鍊�'),
+                                yaxis=dict(title='娈嬪樊'),
+                                height=400
+                            )
+                            st.plotly_chart(fig_residual, width='stretch')
+
+                            # --- 鐗瑰緛閲嶈鎬э紙濡傛灉妯″瀷鏀寔锛� ---
+                            if model_type in ['RandomForest', 'GradientBoosting']:
+                                st.subheader("鈿栵笍 鐗瑰緛閲嶈鎬у垎鏋�")
+
+                                # 璁$畻鐗瑰緛閲嶈鎬�
+                                feature_importance = pd.DataFrame({
+                                    '鐗瑰緛': X_train.columns,
+                                    '閲嶈鎬�': model.feature_importances_
+                                })
+                                feature_importance = feature_importance.sort_values('閲嶈鎬�', ascending=False)
+
+                                # 鍒涘缓鐗瑰緛閲嶈鎬у浘
+                                fig_importance = px.bar(
+                                    feature_importance,
+                                    x='鐗瑰緛',
+                                    y='閲嶈鎬�',
+                                    title='鐗瑰緛閲嶈鎬�',
+                                    color='閲嶈鎬�',
+                                    color_continuous_scale='viridis'
+                                )
+                                fig_importance.update_layout(
+                                    xaxis=dict(tickangle=-45),
+                                    height=400
+                                )
+                                st.plotly_chart(fig_importance, width='stretch')
+
+                            # --- 棰勬祴鍔熻兘 ---
+                            st.subheader("馃敭 绫抽噸棰勬祴")
+
+                            # 鍒涘缓棰勬祴琛ㄥ崟
+                            st.write("杈撳叆鐗瑰緛鍊艰繘琛岀背閲嶉娴�:")
+                            predict_cols = st.columns(2)
+                            input_features = {}
+
+                            for i, feature in enumerate(default_features):
+                                with predict_cols[i % 2]:
+                                    # 鑾峰彇鐗瑰緛鐨勭粺璁′俊鎭�
+                                    min_val = X_clean[feature].min()
+                                    max_val = X_clean[feature].max()
+                                    mean_val = X_clean[feature].mean()
+
+                                    input_features[feature] = st.number_input(
+                                        f"{feature}",
+                                        key=f"ma_pred_{feature}",
+                                        value=float(mean_val),
+                                        min_value=float(min_val),
+                                        max_value=float(max_val),
+                                        step=0.1
+                                    )
+
+                            if st.button("棰勬祴绫抽噸"):
+                                # 鍑嗗棰勬祴鏁版嵁
+                                input_df = pd.DataFrame([input_features])
+                                
+                                # 娣诲姞鏃堕棿鐗瑰緛锛堜娇鐢ㄥ綋鍓嶆椂闂达級
+                                current_time = datetime.now()
+                                time_features_input = pd.DataFrame({
+                                    'hour': [current_time.hour],
+                                    'minute': [current_time.minute],
+                                    'second': [current_time.second],
+                                    'time_of_day': [current_time.hour * 3600 + current_time.minute * 60 + current_time.second]
+                                })
+                                
+                                # 娣诲姞婊炲悗鐗瑰緛锛堜娇鐢ㄨ緭鍏ュ�间綔涓烘浛浠o級
+                                for feature in default_features:
+                                    for lag in [1, 2, 3]:
+                                        time_features_input[f'{feature}_lag{lag}'] = input_features[feature]
+                                        time_features_input[f'{feature}_diff{lag}'] = 0.0
+                                
+                                # 娣诲姞婊氬姩缁熻鐗瑰緛锛堜娇鐢ㄨ緭鍏ュ�间綔涓烘浛浠o級
+                                for feature in default_features:
+                                    time_features_input[f'{feature}_rolling_mean'] = input_features[feature]
+                                    time_features_input[f'{feature}_rolling_std'] = 0.0
+                                    time_features_input[f'{feature}_rolling_min'] = input_features[feature]
+                                    time_features_input[f'{feature}_rolling_max'] = input_features[feature]
+                                
+                                # 鍚堝苟鐗瑰緛
+                                input_combined = pd.concat([input_df, time_features_input], axis=1)
+                                
+                                # 棰勬祴
+                                if model_type in ['SVR', 'MLP']:
+                                    input_scaled = scaler_X.transform(input_combined)
+                                    prediction_scaled = model.predict(input_scaled)
+                                    predicted_weight = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).ravel()[0]
+                                elif use_deep_learning and model_type in ['LSTM', 'GRU', 'BiLSTM']:
+                                    # 涓烘繁搴﹀涔犳ā鍨嬪垱寤哄簭鍒�
+                                    input_scaled = scaler_X.transform(input_combined)
+                                    # 閲嶅杈撳叆浠ュ垱寤哄簭鍒�
+                                    input_seq = np.tile(input_scaled, (sequence_length, 1)).reshape(1, sequence_length, -1)
+                                    # 杞崲涓篜yTorch寮犻噺骞剁Щ鍔ㄥ埌璁惧
+                                    input_tensor = torch.tensor(input_seq, dtype=torch.float32).to(device)
+                                    # 棰勬祴
+                                    model.eval()
+                                    with torch.no_grad():
+                                        prediction_scaled_tensor = model(input_tensor)
+                                        prediction_scaled = prediction_scaled_tensor.cpu().numpy().ravel()[0]
+                                    predicted_weight = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).ravel()[0]
+                                else:
+                                    predicted_weight = model.predict(input_combined)[0]
+                                
+                                # 鏄剧ず棰勬祴缁撴灉
+                                st.success(f"棰勬祴绫抽噸: {predicted_weight:.4f} Kg/m")
+
+                            # --- 鏁版嵁棰勮 ---
+                            st.subheader("馃攳 鏁版嵁棰勮")
+                            st.dataframe(df_analysis.head(20), width='stretch')
+                            
+                            # --- 瀵煎嚭鏁版嵁 ---
+                            st.subheader("馃捑 瀵煎嚭鏁版嵁")
+                            # 灏嗘暟鎹浆鎹负CSV鏍煎紡
+                            csv = df_analysis.to_csv(index=False)
+                            # 鍒涘缓涓嬭浇鎸夐挳
+                            st.download_button(
+                                label="瀵煎嚭鏁村悎鍚庣殑鏁版嵁 (CSV)",
+                                data=csv,
+                                file_name=f"metered_weight_advanced_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
+                                mime="text/csv",
+                                help="鐐瑰嚮鎸夐挳瀵煎嚭鏁村悎鍚庣殑绫抽噸鍒嗘瀽鏁版嵁"
+                            )
+                        except Exception as e:
+                            st.error(f"妯″瀷璁粌鎴栭娴嬪け璐�: {str(e)}")
+            
+    else:
+        # 鎻愮ず鐢ㄦ埛鐐瑰嚮寮�濮嬪垎鏋愭寜閽�
+        st.info("璇烽�夋嫨鏃堕棿鑼冨洿骞剁偣鍑�'寮�濮嬪垎鏋�'鎸夐挳鑾峰彇鏁版嵁銆�")

--
Gitblit v1.9.3