mamba413's picture
Update src/app.py
7bc1c43 verified
import os
from pathlib import Path
# -----------------
# Get the directory where app.py is located
# -----------------
APP_DIR = Path(__file__).parent.resolve()
account_name = 'mamba413'
# -----------------
# Fix Streamlit Permission Issues
# -----------------
# 在 HF Space 中,将 Streamlit 配置目录设置到可写位置
if os.environ.get('SPACE_ID'):
os.environ['STREAMLIT_SERVER_FILE_WATCHER_TYPE'] = 'none'
os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false'
os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'false'
# 设置 HuggingFace 缓存到可写目录
CACHE_DIR = '/tmp/huggingface_cache'
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = CACHE_DIR
os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
# 设置可写的配置目录
streamlit_dir = Path('/tmp/.streamlit')
streamlit_dir.mkdir(exist_ok=True, parents=True)
# os.environ['STREAMLIT_HOME'] = '/tmp/.streamlit'
import streamlit as st
from FineTune.model import ComputeStat
import time
st.markdown(
"""
<style>
/* Text area & text input */
textarea, input[type="text"] {
background-color: #f8fafc !important;
border: 1px solid #e5e7eb !important;
color: #111827 !important;
}
textarea::placeholder {
color: #9ca3af !important;
}
/* Selectbox */
div[data-testid="stSelectbox"] > div {
background-color: #f8fafc !important;
border: 1px solid #e5e7eb !important;
}
</style>
""",
unsafe_allow_html=True
)
st.markdown(
"""
<style>
/* Detect button */
div.stButton > button[kind="primary"] {
background-color: #fdae6b;
border: white;
color: black;
font-weight: 600;
height: 4.3rem;
font-size: 1.1rem;
display: flex;
align-items: center;
justify-content: center;
gap: 0.55rem;
}
/* Icon inside Detect button */
div.stButton > button[kind="primary"] span {
font-size: 1.25rem;
line-height: 1;
}
div.stButton > button[kind="primary"]:hover {
background-color: #fd8d3c;
border-color: white;
}
div.stButton > button[kind="primary"]:active {
background-color: #fd8d3c;
border-color: white;
}
</style>
""",
unsafe_allow_html=True
)
# -----------------
# Page Configuration
# -----------------
st.set_page_config(
page_title="DetectGPTPro",
page_icon="🕵️",
)
# -----------------
# Model Loading (Cached)
# -----------------
@st.cache_resource
def load_model(from_pretrained, base_model, cache_dir, device):
"""
Load and cache the model to avoid reloading on every user interaction.
This function runs only once when the app starts or when parameters change.
"""
# is_hf_space = os.environ.get('SPACE_ID') is not None
is_hf_space = False
if is_hf_space:
cache_dir = '/tmp/huggingface_cache'
os.makedirs(cache_dir, exist_ok=True)
device = 'cpu'
print("Using **CPU** now!")
# 获取 HF Token(用于访问 gated 模型)
hf_token = os.environ.get('HF_TOKEN', None)
if hf_token:
# 也可以用 login 方式
try:
from huggingface_hub import login
login(token=hf_token)
print("✅ Successfully authenticated with HF token")
except Exception as e:
print(f"⚠️ HF login warning: {e}")
# 🔥 新增:从 HF Hub 下载模型
# 检查是否是 HF Hub 路径(格式:username/repo-name)
is_hf_hub = '/' in from_pretrained and not from_pretrained.startswith('.')
if is_hf_hub:
from huggingface_hub import snapshot_download
print(f"📥 Downloading model from HuggingFace Hub: {from_pretrained}")
try:
# 下载整个仓库到本地
local_model_path = snapshot_download(
repo_id=from_pretrained,
cache_dir=cache_dir,
token=hf_token,
repo_type="model"
)
print(f"✅ Model downloaded to: {local_model_path}")
# 使用下载后的本地路径
from_pretrained = local_model_path
except Exception as e:
print(f"❌ Failed to download model: {e}")
raise
else:
cache_dir = cache_dir
with st.spinner("🔄 Loading model... This may take a moment on first launch."):
model = ComputeStat.from_pretrained(
from_pretrained,
base_model,
device=device,
cache_dir=cache_dir
)
model.set_criterion_fn('mean')
return model
# -----------------
# Result Feedback Module Import
# -----------------
from feedback import FeedbackManager
from stats import StatsManager
# Initialize Feedback Manager with HF dataset
# make sure HF_TOKEN is set to visit private repository
FEEDBACK_DATASET_ID = os.environ.get('FEEDBACK_DATASET_ID', f'{account_name}/user-feedback')
feedback_manager = FeedbackManager(
dataset_repo_id=FEEDBACK_DATASET_ID,
hf_token=os.environ.get('HF_TOKEN'),
local_backup=False if os.environ.get('SPACE_ID') else True # 保留本地备份
)
@st.cache_resource
def get_stats_manager():
return StatsManager(
dataset_repo_id=FEEDBACK_DATASET_ID,
hf_token=os.environ.get('HF_TOKEN'),
local_backup=False if os.environ.get('SPACE_ID') else True,
)
stats_manager = get_stats_manager()
# -----------------
# Configuration
# -----------------
MODEL_CONFIG = {
'from_pretrained': './src/FineTune/ckpt/',
'base_model': 'gemma-1b',
'cache_dir': '../cache',
'device': 'cpu' if os.environ.get('SPACE_ID') else 'mps',
# 'device': 'cuda',
}
DOMAINS = [
"General",
"Academia",
"Finance",
"Government",
"Knowledge",
"Legislation",
"Medicine",
"News",
"UserReview"
]
# Load model once at startup
try:
model = load_model(
MODEL_CONFIG['from_pretrained'],
MODEL_CONFIG['base_model'],
MODEL_CONFIG['cache_dir'],
MODEL_CONFIG['device']
)
model_loaded = True
except Exception as e:
model_loaded = False
error_message = str(e)
# =========== 🆕 session_state ===========
if 'last_detection' not in st.session_state:
st.session_state.last_detection = None
if 'feedback_given' not in st.session_state:
st.session_state.feedback_given = False
if 'pending_toast' not in st.session_state:
st.session_state.pending_toast = None
# ========================================
# Show any pending toast (set by feedback buttons before st.rerun())
if st.session_state.pending_toast:
_msg, _icon = st.session_state.pending_toast
st.toast(_msg, icon=_icon)
st.session_state.pending_toast = None
# ----- Visit Counter -----
# session_state resets on F5 / new tab, so this runs exactly once per browser session
if 'visit_counted' not in st.session_state:
st.session_state.visit_counted = True
stats_manager.increment_visit()
# -------------------------
# -----------------
# Streamlit Layout
# -----------------
st.markdown(
"<h1 style='text-align: center;'> Detect AI-Generated Texts 🕵️ </h1>",
unsafe_allow_html=True,
)
# st.markdown(
# """Pasted the text to be detected below and click the 'Detect' button to get the p-value. Use a better option may improve detection."""
# )
# Display model loading status
if not model_loaded:
st.error(f"❌ Failed to load model: {error_message}")
st.stop()
# -----------------
# Main Interface
# -----------------
# --- Two columns: Input text & button | Result displays ---
text_input = st.text_area(
label="📝 Input Text to be Detected",
placeholder="Paste your text here",
height=240,
label_visibility="hidden",
)
subcol11, subcol12, subcol13 = st.columns((1, 1, 1))
selected_domain = subcol11.selectbox(
label="💡 Domain that matches your text",
options=DOMAINS,
index=0, # Default to General
# label_visibility="collapsed",
# label_visibility="hidden",
)
detect_clicked = subcol12.button("🔍 Detect", type="primary", use_container_width=True)
selected_level = subcol13.slider(
label="Significance level (α)",
min_value=0.01,
max_value=0.2,
value=0.05,
step=0.005,
# label_visibility="collapsed",
)
# -----------------
# Detection Logic
# -----------------
if detect_clicked:
if not text_input.strip():
st.warning("⚠️ Please enter some text before detecting.")
else:
# ========== Reset feedback state ==========
st.session_state.feedback_given = False
# ==========================================
# Start timing to decide whether to show progress bar
start_time = time.time()
# Use a placeholder for dynamic updates
status_placeholder = st.empty()
result_placeholder = st.empty()
try:
# Show spinner for quick operations (< 2 seconds expected)
with status_placeholder:
with st.spinner(f"🔍 Analyzing text in {selected_domain} domain..."):
# Perform inference
crit, p_value = model.compute_p_value(text_input, selected_domain)
elapsed_time = time.time() - start_time
# Convert tensors to Python scalars if needed
if hasattr(crit, 'item'):
crit = crit.item()
if hasattr(p_value, 'item'):
p_value = p_value.item()
# Clear status and show results
status_placeholder.empty()
# ========== 🆕 保存检测结果到 session_state ==========
st.session_state.last_detection = {
'text': text_input,
'domain': selected_domain,
'statistics': crit,
'p_value': p_value,
'elapsed_time': elapsed_time
}
# Count detection once per unique detect action
_det_key = f'det_counted_{hash(text_input[:80])}'
if _det_key not in st.session_state:
st.session_state[_det_key] = True
stats_manager.increment_detection()
st.info(
f"""
**Conclusion**:
{'Text is likely LLM-generated.' if p_value < selected_level else 'Fail to reject hypothesis that text is human-written.'}
based on the observation that $p$-value {p_value:.3f} is {'less' if p_value < selected_level else 'greater'} than significance level {selected_level:.2f} 📊
""",
icon="💡"
)
st.markdown(
"""
<style>
/* Tighten spacing inside Clarification / Citation expanders */
div[data-testid="stExpander"] {
margin-top: -1.3rem;
}
div[data-testid="stExpander"] p,
div[data-testid="stExpander"] li {
line-height: 1.35;
margin-bottom: 0.1rem;
}
div[data-testid="stExpander"] ul {
margin-top: 0.1rem;
}
</style>
""",
unsafe_allow_html=True
)
with st.expander("📋 Interpretation and Suggestions"):
st.markdown(
"""
+ Interpretation:
- $p$-value: Lower $p$-value (closer to 0) indicates text is **more likely AI-generated**; Higher $p$-value (closer to 1) indicates text is **more likely human-written**.
- Significance Level (α): a threshold set by the user to determine the sensitivity of the detection. Lower α means stricter criteria for claiming the text is AI-generated.
+ Suggestions for better detection:
- Provide longer text inputs for more reliable detection results.
- Select the domain that best matches the content of your text to improve detection accuracy.
"""
)
# Show detailed results
with result_placeholder:
st.caption(f"⏱️ Processing time: {elapsed_time:.2f} seconds")
except Exception as e:
status_placeholder.empty()
st.error(f"❌ Error during detection: {str(e)}")
st.exception(e)
# -----------------
# Feedback UI (outside if detect_clicked — persists across all reruns via session_state)
# -----------------
if st.session_state.last_detection is not None and not st.session_state.feedback_given:
_ld = st.session_state.last_detection
st.markdown(
"""
<style>
.fb-header { display: flex; align-items: center; gap: 0.4rem; margin-bottom: 0.3rem; }
.privacy-tip {
position: relative; display: inline-block;
cursor: help; color: #9ca3af; font-size: 0.9rem;
}
.privacy-tip .tip-text {
visibility: hidden; opacity: 0;
width: 240px; background-color: #374151; color: #f9fafb;
text-align: left; border-radius: 6px;
padding: 0.5rem 0.7rem; font-size: 0.78rem; line-height: 1.4;
position: absolute; z-index: 100;
bottom: 130%; left: 50%; transform: translateX(-50%);
transition: opacity 0.25s ease; pointer-events: none;
}
.privacy-tip:hover .tip-text { visibility: visible; opacity: 1; }
</style>
<div class="fb-header">
<strong>📝 Result Feedback</strong>: Does this detection result meet your expectations?
<span class="privacy-tip">🔒
<span class="tip-text">🔒 Your feedback is stored privately and will never be shared with third parties. It is used solely to improve detection accuracy.</span>
</span>
</div>
""",
unsafe_allow_html=True
)
feedback_col1, feedback_col2 = st.columns(2)
with feedback_col1:
if st.button("✅ Expected", use_container_width=True, type="secondary",
key="expected_btn"):
try:
fb_success, fb_message = feedback_manager.save_feedback(
_ld['text'], _ld['domain'], _ld['statistics'], _ld['p_value'], 'expected'
)
if fb_success:
st.session_state.feedback_given = True
st.session_state.pending_toast = ("Thank you for your feedback!", "✅")
st.rerun()
else:
st.error(f"Failed to save feedback: {fb_message}")
except Exception as e:
st.error(f"Failed to save feedback: {str(e)}")
with feedback_col2:
if st.button("❌ Unexpected", use_container_width=True, type="secondary",
key="unexpected_btn"):
try:
fb_success, fb_message = feedback_manager.save_feedback(
_ld['text'], _ld['domain'], _ld['statistics'], _ld['p_value'], 'unexpected'
)
if fb_success:
st.session_state.feedback_given = True
st.session_state.pending_toast = ("Feedback recorded! This will help us improve.", "📝")
st.rerun()
else:
st.error(f"Failed to save feedback: {fb_message}")
except Exception as e:
st.error(f"Failed to save feedback: {str(e)}")
# with st.expander("📋 Citation"):
# st.markdown(
# """
# If you find this tool useful for you, please cite our paper: **[AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees](https://arxiv.org/abs/2510.01268)**
# """
# )
# st.code(
# """
# @inproceedings{zhou2024adadetectgpt,
# title={AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees},
# author={Hongyi Zhou and Jin Zhu and Pingfan Su and Kai Ye and Ying Yang and Shakeel A O B Gavioli-Akilagun and Chengchun Shi},
# booktitle={The Thirty-Ninth Annual Conference on Neural Information Processing Systems},
# year={2025},
# }
# """,
# language="bibtex"
# )
# -----------------
# Statistics Chip (fixed top-right)
# -----------------
st.markdown(
f"""
<style>
.stats-chip {{
position: fixed;
top: 3.6rem;
right: 1rem;
display: flex;
align-items: center;
gap: 0.35rem;
font-size: 0.78rem;
color: #9ca3af;
z-index: 999;
pointer-events: none;
}}
</style>
<div class="stats-chip">
<span>{stats_manager.visit_count:,} visits</span>
</div>
""",
unsafe_allow_html=True
)
# -----------------
# Footer
# -----------------
st.markdown(
"""
<style>
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: white;
color: gray;
text-align: center;
padding: 1px;
border-top: 1px solid #e0e0e0;
z-index: 999;
}
/* Add padding to main content to prevent overlap with fixed footer */
.main .block-container {
padding-bottom: 1px;
}
</style>
<div class='footer'>
<small> This tool is developed for research purposes only. The detection results are not 100% accurate and should not be used as the sole basis for any critical decisions. Users are advised to use this tool responsibly and ethically. </small>
</div>
""",
unsafe_allow_html=True
)