| | import argparse |
| | import json |
| | import re |
| | import os |
| | import unicodedata |
| | from typing import Tuple, List |
| | from multiprocessing import Pool |
| |
|
| | import fasttext |
| | import pandas as pd |
| | from tqdm import tqdm |
| | from transformers import LlamaTokenizerFast |
| |
|
| |
|
| | language_model_map = { |
| | "en": "classifiers/ultra_fineweb_en.bin", |
| | "zh": "classifiers/ultra_fineweb_zh.bin" |
| | } |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--language", type=str, required=True, help="Inference language, support: en, zh.") |
| | parser.add_argument("--tokenizer-path", type=str, default="local_tokenizer", help="Tokenizer path.") |
| | parser.add_argument("--content-file", type=str, default="scripts/local_scripts/single_content.txt", help="Content file to infer.") |
| | return parser.parse_args() |
| |
|
| |
|
| | def fasttext_preprocess_func(content: str, tokenizer: LlamaTokenizerFast) -> str: |
| | """Fasttext preprocess function. |
| | |
| | Args: |
| | content (str): Content to process. |
| | |
| | Returns: |
| | str: Processed normalized content. |
| | """ |
| |
|
| | |
| | content = re.sub(r'\n{3,}', '\n\n', content) |
| |
|
| | |
| | content = content.lower() |
| |
|
| | |
| | content = ''.join( |
| | c for c in unicodedata.normalize('NFKD', content) |
| | if unicodedata.category(c) != 'Mn') |
| |
|
| | |
| | token_ids = tokenizer.encode(content, add_special_tokens=False) |
| | single_text_list = [] |
| | for token_id in token_ids: |
| | curr_text = tokenizer.decode([token_id]) |
| | single_text_list.append(curr_text) |
| |
|
| | content = ' '.join(single_text_list) |
| |
|
| | |
| | |
| | content = re.sub(r'\n', '\\\\n', content) |
| | content = re.sub(r'\r', '\\\\r', content) |
| | content = re.sub(r'\t', '\\\\t', content) |
| | content = re.sub(r' +', ' ', content) |
| | content = content.strip() |
| |
|
| | return content |
| |
|
| |
|
| | def fasttext_infer(norm_content: str, fasttext_model: fasttext.FastText) -> Tuple[str, float]: |
| | """Fasttext inference function |
| | |
| | Args: |
| | content (str): input text |
| | |
| | Returns: |
| | str: json string with pred_label and pred_score |
| | """ |
| |
|
| | pred_label, pred_prob = fasttext_model.predict(norm_content) |
| | pred_label = pred_label[0] |
| | _score = min(pred_prob.tolist()[0], 1) |
| | if pred_label == "__label__neg": |
| | _score = 1 - _score |
| |
|
| | return pred_label, _score |
| |
|
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | language = args.language |
| | tokenizer_path = args.tokenizer_path |
| | content_file = args.content_file |
| |
|
| | assert language in ["en", "zh"], f"Language {language} is not supported, please check the language." |
| | assert os.path.exists(content_file), f"Content file {content_file} does not exist, please check the content file." |
| |
|
| | fasttext_model_path = language_model_map[language] |
| |
|
| | |
| | tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path) |
| |
|
| | |
| | fasttext_model = fasttext.load_model(fasttext_model_path) |
| |
|
| | content = open(content_file, "r").read() |
| | |
| | norm_content = fasttext_preprocess_func(content, tokenizer) |
| | |
| | pred_label, pred_score = fasttext_infer(norm_content, fasttext_model) |
| | |
| | print("-" * 100) |
| | print(f"Content: {content}") |
| | print() |
| | print(f"Normalized content: {norm_content}") |
| | print() |
| | print(f" - Pred label: {pred_label}") |
| | print(f" - Pred score: {pred_score}") |
| | print("-" * 100) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|