| | """ |
| | tcn_app.py |
| | Gradio app to serve the TCN fault classification model. |
| | |
| | Usage: |
| | - Place a local model file named by LOCAL_MODEL_FILE in the same repo, or |
| | - Set HUB_REPO and HUB_FILENAME to a public Hugging Face model repo + filename, |
| | and the app will download it at startup using hf_hub_download. |
| | |
| | This file is ready to push to a Hugging Face Space (Gradio). |
| | """ |
| | import os |
| | import numpy as np |
| | import pandas as pd |
| | import gradio as gr |
| | from tensorflow.keras.models import load_model |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | LOCAL_MODEL_FILE = "tcn_model.h5" |
| | HUB_REPO = "" |
| | HUB_FILENAME = "" |
| |
|
| | def get_model_path(): |
| | |
| | if os.path.exists(LOCAL_MODEL_FILE): |
| | return LOCAL_MODEL_FILE |
| | |
| | if HUB_REPO and HUB_FILENAME: |
| | try: |
| | print(f"Downloading {HUB_FILENAME} from {HUB_REPO} ...") |
| | return hf_hub_download(repo_id=HUB_REPO, filename=HUB_FILENAME) |
| | except Exception as e: |
| | print("Failed to download from hub:", e) |
| | return None |
| |
|
| | MODEL_PATH = get_model_path() |
| | MODEL = None |
| | if MODEL_PATH: |
| | try: |
| | MODEL = load_model(MODEL_PATH) |
| | print("Loaded model:", MODEL_PATH) |
| | except Exception as e: |
| | print("Failed to load model:", e) |
| | MODEL = None |
| | else: |
| | print("No model found. Please upload a model named", LOCAL_MODEL_FILE, "or set HUB_REPO/HUB_FILENAME.") |
| |
|
| | def prepare_input_array(arr, n_timesteps=1, n_features=None): |
| | arr = np.array(arr) |
| | |
| | if arr.ndim == 1: |
| | if n_features is None: |
| | |
| | return arr.reshape(1, n_timesteps, -1) |
| | return arr.reshape(1, n_timesteps, n_features) |
| | elif arr.ndim == 2: |
| | |
| | if arr.shape[0] == 1: |
| | return arr.reshape(1, arr.shape[1], -1) |
| | return arr |
| | else: |
| | return arr |
| |
|
| | def predict_text(text, n_timesteps=1, n_features=None): |
| | if MODEL is None: |
| | return "模型未加载,请上传或配置模型。" |
| | arr = np.fromstring(text, sep=',') |
| | x = prepare_input_array(arr, n_timesteps=int(n_timesteps), n_features=(int(n_features) if n_features else None)) |
| | probs = MODEL.predict(x) |
| | label = int(np.argmax(probs, axis=1)[0]) |
| | return f"预测类别: {label} (概率: {float(np.max(probs)):.4f})" |
| |
|
| | def predict_csv(file, n_timesteps=1, n_features=None): |
| | if MODEL is None: |
| | return {"error": "模型未加载,请上传或配置模型。"} |
| | df = pd.read_csv(file.name) |
| | X = df.values |
| | if n_features: |
| | X = X.reshape(X.shape[0], int(n_timesteps), int(n_features)) |
| | preds = MODEL.predict(X) |
| | labels = preds.argmax(axis=1).tolist() |
| | return {"labels": labels, "probs": preds.tolist()} |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# TCN Fault Classification") |
| | gr.Markdown("上传 CSV(每行一个样本)或粘贴逗号分隔的一行特征进行预测。") |
| | with gr.Row(): |
| | file_in = gr.File(label="上传 CSV(每行 = 一个样本)") |
| | text_in = gr.Textbox(lines=2, placeholder="粘贴逗号分隔的一行特征,例如: 0.1,0.2,0.3,...") |
| | n_ts = gr.Number(value=1, label="timesteps (整型)") |
| | n_feat = gr.Number(value=None, label="features (可选,留空尝试自动推断)") |
| | btn = gr.Button("预测") |
| | out_text = gr.Textbox(label="单样本预测输出") |
| | out_json = gr.JSON(label="批量预测结果 (labels & probs)") |
| |
|
| | def run_predict(file, text, n_timesteps, n_features): |
| | if file is not None: |
| | return "CSV 预测完成", predict_csv(file, n_timesteps, n_features) |
| | if text: |
| | return predict_text(text, n_timesteps, n_features), {} |
| | return "请提供 CSV 或特征文本", {} |
| |
|
| | btn.click(run_predict, inputs=[file_in, text_in, n_ts, n_feat], outputs=[out_text, out_json]) |
| |
|
| | if __name__ == '__main__': |
| | demo.launch(server_name='0.0.0.0', server_port=7860) |
| |
|