OpenFakeDemo / app /main.py
vicliv's picture
added api calls
ef74676
import io
import json
import mimetypes
import os
import random
import tempfile
import uuid
from datetime import datetime, timezone
from pathlib import Path
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.staticfiles import StaticFiles
from PIL import Image, ImageOps
from .model import load_detector, predict_image
from .screenshot import preprocess
from .video import sample_frames, sample_gif_frames
MAX_IMAGE_SIZE_MB = 50
MAX_VIDEO_SIZE_MB = 300
N_VIDEO_FRAMES = 5
IMAGE_TYPES = {"image/jpeg", "image/jpg", "image/png", "image/webp"}
VIDEO_TYPES = {"video/mp4", "video/quicktime", "video/webm", "video/x-matroska"}
GIF_TYPES = {"image/gif"}
_EXT_MIME = {
".jpg": "image/jpeg", ".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
".gif": "image/gif",
".mp4": "video/mp4",
".mov": "video/quicktime",
".webm": "video/webm",
".mkv": "video/x-matroska",
}
def _resolve_content_type(content_type: str, filename: str | None) -> str:
if content_type not in ("", "application/octet-stream") or not filename:
return content_type
suffix = Path(filename).suffix.lower()
return _EXT_MIME.get(suffix) or mimetypes.guess_type(filename)[0] or content_type
HF_REPORT_REPO = os.environ.get("HF_REPORT_REPO", "ComplexDataLab/openfake-reports")
HF_TOKEN = os.environ.get("HF_TOKEN")
app = FastAPI(title="Deepfake Detector")
@app.on_event("startup")
def warmup():
load_detector()
def _predict_with_preprocess(image: Image.Image) -> dict:
"""Run the screenshot-aware prediction pipeline on a single image.
Returns a dict with p_fake, the preprocessing status, and the crop boxes
in the EXIF-rotated coordinate frame so the frontend can overlay them on
the user-visible image.
"""
# Apply EXIF rotation up front so crop_box coords and image_size are in
# the same frame as the browser-rendered image.
image = ImageOps.exif_transpose(image)
width, height = image.size
result = preprocess(image)
crop_box = None
if result.crop_box is not None:
boxes = result.crop_box if isinstance(result.crop_box, list) else [result.crop_box]
crop_box = [list(b) for b in boxes]
base = {
"preprocess_status": result.status,
"image_size": [width, height],
"crop_box": crop_box,
}
if result.status == "cropped":
crops = result.image if isinstance(result.image, list) else [result.image]
probs = [predict_image(c) for c in crops]
p_fake = sum(probs) / len(probs)
return {**base, "p_fake": p_fake, "n_crops": len(crops)}
if result.status == "text_only":
raw_p_fake = predict_image(image)
# The detector is unreliable on pure-text screenshots and tends to
# flag them as AI-generated. If it leans "AI", soften to uncertain;
# if it leans "real", keep the score.
if raw_p_fake > 0.5:
p_fake = random.uniform(0.4, 0.6)
else:
p_fake = raw_p_fake
return {**base, "p_fake": p_fake, "raw_p_fake": raw_p_fake}
p_fake = predict_image(image)
return {**base, "p_fake": p_fake}
@app.post("/api/predict")
async def predict(file: UploadFile = File(...)):
content_type = (file.content_type or "").lower()
content_type = _resolve_content_type(content_type, file.filename)
raw = await file.read()
size_mb = len(raw) / (1024 * 1024)
if content_type in IMAGE_TYPES:
if size_mb > MAX_IMAGE_SIZE_MB:
raise HTTPException(413, f"Image exceeds {MAX_IMAGE_SIZE_MB} MB")
try:
image = Image.open(io.BytesIO(raw))
except Exception:
raise HTTPException(400, "Invalid image")
pred = _predict_with_preprocess(image)
p_fake = pred["p_fake"]
return {
"media_type": "image",
"p_fake": p_fake,
"reliability": 1.0 - p_fake,
"n_frames": 1,
**{k: v for k, v in pred.items() if k != "p_fake"},
}
if content_type in VIDEO_TYPES:
if size_mb > MAX_VIDEO_SIZE_MB:
raise HTTPException(413, f"Video exceeds {MAX_VIDEO_SIZE_MB} MB")
suffix = Path(file.filename or "video.mp4").suffix or ".mp4"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(raw)
tmp_path = tmp.name
try:
frames = sample_frames(tmp_path, N_VIDEO_FRAMES)
except ValueError as e:
raise HTTPException(400, str(e))
finally:
try:
Path(tmp_path).unlink(missing_ok=True)
except Exception:
pass
probs = [predict_image(f) for f in frames]
p_fake = sum(probs) / len(probs)
return {
"media_type": "video",
"p_fake": p_fake,
"reliability": 1.0 - p_fake,
"n_frames": len(frames),
"frame_probs": probs,
}
if content_type in GIF_TYPES:
if size_mb > MAX_IMAGE_SIZE_MB:
raise HTTPException(413, f"GIF exceeds {MAX_IMAGE_SIZE_MB} MB")
try:
frames = sample_gif_frames(raw, N_VIDEO_FRAMES)
except ValueError as e:
raise HTTPException(400, str(e))
probs = [predict_image(f) for f in frames]
p_fake = sum(probs) / len(probs)
return {
"media_type": "gif",
"p_fake": p_fake,
"reliability": 1.0 - p_fake,
"n_frames": len(frames),
"frame_probs": probs,
}
raise HTTPException(415, f"Unsupported media type: {content_type}")
@app.post("/api/report")
async def report(
file: UploadFile = File(...),
is_real: str = Form(...),
reason: str = Form(...),
reason_other: str = Form(""),
reason_details: str = Form(""),
comment: str = Form(""),
p_fake: float = Form(...),
consent: str = Form(...),
):
"""Save an error report (form answers + media file) to a Hugging Face dataset repo."""
if consent != "true":
raise HTTPException(400, "Consent to save the file is required.")
if not HF_TOKEN:
raise HTTPException(
503, "Reporting is not configured (missing HF_TOKEN)."
)
# Read the uploaded file
raw = await file.read()
content_type = (file.content_type or "").lower()
content_type = _resolve_content_type(content_type, file.filename)
if content_type not in IMAGE_TYPES | VIDEO_TYPES | GIF_TYPES:
raise HTTPException(415, "Unsupported file type for reporting.")
# Build report payload
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
short_id = uuid.uuid4().hex[:8]
folder_name = f"{ts}_{short_id}"
report_data = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"is_real": is_real,
"reason": reason,
"reason_other": reason_other if reason == "other" else "",
"reason_details": reason_details,
"comment": comment,
"p_fake": p_fake,
"content_type": content_type,
"original_filename": file.filename or "unknown",
}
# Write to a temp directory then upload to HF
with tempfile.TemporaryDirectory() as tmpdir:
report_dir = Path(tmpdir) / folder_name
report_dir.mkdir()
# Save report JSON
(report_dir / "report.json").write_text(
json.dumps(report_data, indent=2, ensure_ascii=False)
)
# Save media file with original extension
ext = Path(file.filename or "file.bin").suffix or ".bin"
(report_dir / f"media{ext}").write_bytes(raw)
# Upload to HF dataset repo
try:
from huggingface_hub import HfApi
api = HfApi(token=HF_TOKEN)
api.upload_folder(
folder_path=str(report_dir),
path_in_repo=f"reports/{folder_name}",
repo_id=HF_REPORT_REPO,
repo_type="dataset",
)
except Exception as e:
raise HTTPException(500, f"Failed to upload report: {e}")
return {"status": "ok"}
static_dir = Path(__file__).parent / "static"
app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")