| | --- |
| | datasets: |
| | - stanfordnlp/imdb |
| | language: |
| | - en |
| | - hi |
| | base_model: |
| | - google-bert/bert-base-multilingual-cased |
| | --- |
| | # Language-Agnostic Text Classifier |
| |
|
| | Trained only on **English** data <br> |
| | Works on both **English** and **Hindi** at inference time without retraining *(Other langauges not tested)* |
| |
|
| | **Task:** Sentence-level sentiment classification |
| | **Base model:** bert-base-multilingual-cased <br> |
| | **For more details:** *[Github Repo](https://github.com/wizardoftrap/language_agnostic_classifier)* |
| | ## Usage |
| |
|
| | ```python |
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoTokenizer, AutoModel |
| | |
| | class LanguageAgnosticClassifier(nn.Module): |
| | def __init__(self, base_model, num_labels): |
| | super().__init__() |
| | self.encoder = AutoModel.from_pretrained(base_model) |
| | hidden = self.encoder.config.hidden_size |
| | self.classifier = nn.Linear(hidden, num_labels) |
| | |
| | def mean_pool(self, hidden, mask): |
| | mask = mask.unsqueeze(-1).float() |
| | return (hidden * mask).sum(1) / mask.sum(1) |
| | |
| | def forward(self, input_ids, attention_mask): |
| | out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| | pooled = self.mean_pool(out.last_hidden_state, attention_mask) |
| | return self.classifier(pooled) |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | "wizardoftrap/language_agnostic_classifier" |
| | ) |
| | |
| | model = LanguageAgnosticClassifier( |
| | base_model="bert-base-multilingual-cased", |
| | num_labels=2 |
| | ) |
| | |
| | state_dict = torch.hub.load_state_dict_from_url( |
| | "https://huggingface.co/wizardoftrap/language_agnostic_classifier/resolve/main/bert-language_agnostic-classifier.bin", |
| | map_location="cpu" |
| | ) |
| | |
| | model.load_state_dict(state_dict) |
| | model.eval() |
| | |
| | def predict(text): |
| | enc = tokenizer( |
| | text, |
| | return_tensors="pt", |
| | truncation=True, |
| | padding="max_length", |
| | max_length=128 |
| | ) |
| | with torch.no_grad(): |
| | logits = model(enc["input_ids"], enc["attention_mask"]) |
| | return logits.argmax(1).item() |
| | |
| | predict("This movie was amazing") |
| | predict("This movie was terrible") |
| | predict("The film was not bad, but not great either") |
| | predict("Despite good acting, the story failed to impress me") |
| | |
| | predict("यह फिल्म बहुत शानदार थी") |
| | predict("यह फिल्म बहुत खराब थी") |
| | predict("फिल्म बुरी नहीं थी, लेकिन खास भी नहीं लगी") |
| | predict("अभिनय अच्छा था, पर कहानी कमजोर रह गई") |
| | |
| | predict("Story अच्छी थी but execution weak था") |
| | predict("Acting was good लेकिन movie boring लगी") |
| | predict("Concept अच्छा था but screenplay खराब था") |
| | |
| | predict("Yeah, this movie was a masterpiece… said no one ever") |
| | predict("फिल्म इतनी अच्छी थी कि नींद आ गई") |
| | |
| | predict("The movie was okay") |
| | predict("फिल्म ठीक-ठाक थी") |
| | |
| | ``` |
| |
|
| | *- Shiv Prakash Verma* |