Spaces:
Paused
Paused
| import os | |
| import json | |
| import time | |
| import socket | |
| import threading | |
| import gc | |
| import ctypes | |
| import multiprocessing as mp | |
| from pathlib import Path | |
| import numpy as np | |
| from tokenizers import Tokenizer | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| STATE_FILE = "/data/state.json" | |
| RAW_DIR = "/data/raw" | |
| OUT_DIR = "/data/tokenized" | |
| TOK_PATH = "/data/tokenizer.json" | |
| WORKER_ID = socket.gethostname() | |
| POLL_INTERVAL = 15 | |
| BATCH_SIZE = 2 # 2 lines at a time across 2 cores | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| # ββ Keep-alive ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def serve(): | |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| s.bind(("0.0.0.0", 7860)) | |
| s.listen(5) | |
| print(f"β [{WORKER_ID}] Listening on port 7860") | |
| while True: | |
| conn, _ = s.accept() | |
| conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK") | |
| conn.close() | |
| # ββ State βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_state(): | |
| with open(STATE_FILE) as f: | |
| return json.load(f) | |
| def save_state(state): | |
| tmp = STATE_FILE + f".tmp.{WORKER_ID}" | |
| with open(tmp, "w") as f: | |
| json.dump(state, f, indent=2) | |
| os.replace(tmp, STATE_FILE) | |
| # ββ Claim βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def claim_shard(state): | |
| for name, info in state["shards"].items(): | |
| if info["status"] == "pending": | |
| raw_path = Path(RAW_DIR) / name | |
| if raw_path.exists(): | |
| info["status"] = "claimed" | |
| info["worker"] = WORKER_ID | |
| info["claimed_at"] = time.time() | |
| save_state(state) | |
| return name, raw_path | |
| return None, None | |
| # ββ Tokenizer subprocess ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _worker_tok = None | |
| _worker_sep = None | |
| def init_worker(tok_path): | |
| global _worker_tok, _worker_sep | |
| _worker_tok = Tokenizer.from_file(tok_path) | |
| _worker_sep = _worker_tok.token_to_id("<sep>") | |
| def tokenize_texts(texts): | |
| """Tokenize a list of texts, append <sep> to each.""" | |
| encs = _worker_tok.encode_batch(texts) | |
| result = [] | |
| for enc in encs: | |
| ids = enc.ids | |
| if len(ids) >= 2: | |
| ids.append(_worker_sep) | |
| result.append(ids) | |
| return result | |
| # ββ Process shard βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_shard(name, raw_path, pool): | |
| print(f" [{WORKER_ID}] Processing: {name}") | |
| out_name = name.replace(".jsonl", ".bin") | |
| out_path = Path(OUT_DIR) / out_name | |
| tmp_path = Path(OUT_DIR) / f"{out_name}.tmp" | |
| # Crash recovery β delete any partial output from previous attempt | |
| tmp_path.unlink(missing_ok=True) | |
| out_path.unlink(missing_ok=True) | |
| total_tokens = 0 | |
| total_docs = 0 | |
| try: | |
| with open(raw_path, "r", encoding="utf-8") as fin, \ | |
| open(tmp_path, "wb") as fout: | |
| batch_texts = [] | |
| for line in fin: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| obj = json.loads(line) | |
| text = obj.get("text", "").strip() | |
| except Exception: | |
| continue | |
| if not text: | |
| continue | |
| batch_texts.append(text) | |
| if len(batch_texts) >= BATCH_SIZE: | |
| try: | |
| results = pool.apply(tokenize_texts, (batch_texts,)) | |
| except Exception as e: | |
| tmp_path.unlink(missing_ok=True) | |
| return False, f"tokenize_failed: {e}" | |
| for ids in results: | |
| arr = np.array(ids, dtype=np.uint16) | |
| arr.tofile(fout) | |
| total_tokens += len(ids) | |
| total_docs += 1 | |
| batch_texts = [] | |
| # Flush remaining | |
| if batch_texts: | |
| try: | |
| results = pool.apply(tokenize_texts, (batch_texts,)) | |
| except Exception as e: | |
| tmp_path.unlink(missing_ok=True) | |
| return False, f"tokenize_failed_flush: {e}" | |
| for ids in results: | |
| arr = np.array(ids, dtype=np.uint16) | |
| arr.tofile(fout) | |
| total_tokens += len(ids) | |
| total_docs += 1 | |
| except Exception as e: | |
| tmp_path.unlink(missing_ok=True) | |
| return False, f"process_failed: {e}" | |
| # Atomic rename β only visible when complete | |
| tmp_path.rename(out_path) | |
| print(f" β [{WORKER_ID}] {out_name} | {total_docs:,} docs | {total_tokens:,} tokens") | |
| return True, None | |
| # ββ Memory flush ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def flush_memory(): | |
| gc.collect() | |
| try: | |
| ctypes.CDLL("libc.so.6").malloc_trim(0) | |
| except Exception: | |
| pass | |
| # ββ Worker loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def worker_loop(): | |
| print(f"β [{WORKER_ID}] Starting worker...") | |
| pool = mp.Pool(processes=2, initializer=init_worker, initargs=(TOK_PATH,)) | |
| print(f"β [{WORKER_ID}] 2-core tokenizer pool ready") | |
| try: | |
| while True: | |
| if not os.path.exists(STATE_FILE): | |
| print(f" [{WORKER_ID}] Waiting for state.json...") | |
| time.sleep(POLL_INTERVAL) | |
| continue | |
| try: | |
| state = load_state() | |
| except Exception as e: | |
| print(f" [{WORKER_ID}] State read error: {e}") | |
| time.sleep(POLL_INTERVAL) | |
| continue | |
| total = len(state["shards"]) + len(state.get("queue", [])) | |
| done = sum(1 for v in state["shards"].values() if v["status"] == "done") | |
| if total > 0 and done == total: | |
| print(f" [{WORKER_ID}] All done. Sleeping.") | |
| time.sleep(300) | |
| continue | |
| name, raw_path = claim_shard(state) | |
| if not name: | |
| print(f" [{WORKER_ID}] Nothing ready β polling in {POLL_INTERVAL}s") | |
| time.sleep(POLL_INTERVAL) | |
| continue | |
| print(f" [{WORKER_ID}] Claimed: {name}") | |
| success, error = process_shard(name, raw_path, pool) | |
| try: | |
| state = load_state() | |
| except Exception: | |
| pass | |
| if success: | |
| state["shards"][name]["status"] = "done" | |
| state["shards"][name]["error"] = None | |
| save_state(state) | |
| try: | |
| raw_path.unlink() | |
| print(f" [{WORKER_ID}] Deleted raw: {raw_path.name}") | |
| except Exception as e: | |
| print(f" [{WORKER_ID}] Delete failed: {e}") | |
| else: | |
| state["shards"][name]["status"] = "pending" | |
| state["shards"][name]["worker"] = None | |
| state["shards"][name]["claimed_at"] = None | |
| state["shards"][name]["error"] = error | |
| save_state(state) | |
| print(f" [{WORKER_ID}] Failed ({error}) β reset to pending: {name}") | |
| flush_memory() | |
| time.sleep(5) | |
| finally: | |
| pool.terminate() | |
| pool.join() | |
| # ββ Entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| threading.Thread(target=serve, daemon=True).start() | |
| threading.Thread(target=worker_loop, daemon=True).start() | |
| while True: | |
| time.sleep(60) |