| | import logging |
| | import os |
| | import sys |
| |
|
| | from modules.ffmpeg_env import setup_ffmpeg_path |
| |
|
| | try: |
| | setup_ffmpeg_path() |
| | |
| | logging.basicConfig( |
| | level=os.getenv("LOG_LEVEL", "INFO"), |
| | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| | ) |
| | except BaseException: |
| | pass |
| |
|
| | import argparse |
| |
|
| | from modules import config |
| | from modules.api.api_setup import process_api_args, setup_api_args |
| | from modules.api.app_config import app_description, app_title, app_version |
| | from modules.gradio_dcls_fix import dcls_patch |
| | from modules.models_setup import process_model_args, setup_model_args |
| | from modules.utils.env import get_and_update_env |
| | from modules.utils.ignore_warn import ignore_useless_warnings |
| | from modules.utils.torch_opt import configure_torch_optimizations |
| | from modules.webui import webui_config |
| | from modules.webui.app import create_interface, webui_init |
| |
|
| | import subprocess |
| |
|
| | subprocess.run( |
| | "pip install flash-attn --no-build-isolation", |
| | env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, |
| | shell=True, |
| | ) |
| |
|
| | dcls_patch() |
| | ignore_useless_warnings() |
| |
|
| |
|
| | def setup_webui_args(parser: argparse.ArgumentParser): |
| | parser.add_argument("--server_name", type=str, help="server name") |
| | parser.add_argument("--server_port", type=int, help="server port") |
| | parser.add_argument( |
| | "--share", action="store_true", help="share the gradio interface" |
| | ) |
| | parser.add_argument("--debug", action="store_true", help="enable debug mode") |
| | parser.add_argument("--auth", type=str, help="username:password for authentication") |
| | parser.add_argument( |
| | "--tts_max_len", |
| | type=int, |
| | help="Max length of text for TTS", |
| | ) |
| | parser.add_argument( |
| | "--ssml_max_len", |
| | type=int, |
| | help="Max length of text for SSML", |
| | ) |
| | parser.add_argument( |
| | "--max_batch_size", |
| | type=int, |
| | help="Max batch size for TTS", |
| | ) |
| | |
| | parser.add_argument( |
| | "--webui_experimental", |
| | action="store_true", |
| | help="Enable webui_experimental features", |
| | ) |
| | parser.add_argument( |
| | "--language", |
| | type=str, |
| | help="Set the default language for the webui", |
| | ) |
| | parser.add_argument( |
| | "--api", |
| | action="store_true", |
| | help="use api=True to launch the API together with the webui (run launch.py for only API server)", |
| | ) |
| |
|
| |
|
| | def process_webui_args(args): |
| | server_name = get_and_update_env(args, "server_name", "0.0.0.0", str) |
| | server_port = get_and_update_env(args, "server_port", 7860, int) |
| | share = get_and_update_env(args, "share", False, bool) |
| | debug = get_and_update_env(args, "debug", False, bool) |
| | auth = get_and_update_env(args, "auth", None, str) |
| | language = get_and_update_env(args, "language", "en", str) |
| | api = get_and_update_env(args, "api", False, bool) |
| |
|
| | webui_config.experimental = get_and_update_env( |
| | args, "webui_experimental", False, bool |
| | ) |
| | webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int) |
| | webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int) |
| | webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int) |
| |
|
| | webui_config.experimental = get_and_update_env( |
| | args, "webui_experimental", False, bool |
| | ) |
| | webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int) |
| | webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int) |
| | webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int) |
| |
|
| | configure_torch_optimizations() |
| | webui_init() |
| | demo = create_interface() |
| |
|
| | if auth: |
| | auth = tuple(auth.split(":")) |
| |
|
| | app, local_url, share_url = demo.queue().launch( |
| | server_name=server_name, |
| | server_port=server_port, |
| | share=share, |
| | debug=debug, |
| | auth=auth, |
| | show_api=False, |
| | prevent_thread_lock=True, |
| | inbrowser=sys.platform == "win32", |
| | app_kwargs={ |
| | "title": app_title, |
| | "description": app_description, |
| | "version": app_version, |
| | "redoc_url": ( |
| | None |
| | if api is False |
| | else None if config.runtime_env_vars.no_docs else "/redoc" |
| | ), |
| | "docs_url": ( |
| | None |
| | if api is False |
| | else None if config.runtime_env_vars.no_docs else "/docs" |
| | ), |
| | }, |
| | ) |
| | |
| | |
| | |
| | |
| | app.user_middleware = [ |
| | x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware" |
| | ] |
| |
|
| | if api: |
| | process_api_args(args, app) |
| |
|
| | demo.block_thread() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import dotenv |
| |
|
| | dotenv.load_dotenv( |
| | dotenv_path=os.getenv("ENV_FILE", ".env.webui"), |
| | ) |
| |
|
| | parser = argparse.ArgumentParser(description="Gradio App") |
| |
|
| | setup_webui_args(parser) |
| | setup_model_args(parser) |
| | setup_api_args(parser) |
| |
|
| | args = parser.parse_args() |
| |
|
| | process_model_args(args) |
| | process_webui_args(args) |
| |
|