| from torch.utils.data import Dataset |
| import logging |
| import gzip |
| from .. import SentenceTransformer |
| from ..readers import InputExample |
| from typing import List |
| import random |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| class ParallelSentencesDataset(Dataset): |
| """ |
| This dataset reader can be used to read-in parallel sentences, i.e., it reads in a file with tab-seperated sentences with the same |
| sentence in different languages. For example, the file can look like this (EN\tDE\tES): |
| hello world hallo welt hola mundo |
| second sentence zweiter satz segunda oración |
| |
| The sentence in the first column will be mapped to a sentence embedding using the given the embedder. For example, |
| embedder is a mono-lingual sentence embedding method for English. The sentences in the other languages will also be |
| mapped to this English sentence embedding. |
| |
| When getting a sample from the dataset, we get one sentence with the according sentence embedding for this sentence. |
| |
| teacher_model can be any class that implement an encode function. The encode function gets a list of sentences and |
| returns a list of sentence embeddings |
| """ |
|
|
| def __init__(self, student_model: SentenceTransformer, teacher_model: SentenceTransformer, batch_size: int = 8, use_embedding_cache: bool = True): |
| """ |
| Parallel sentences dataset reader to train student model given a teacher model |
| :param student_model: Student sentence embedding model that should be trained |
| :param teacher_model: Teacher model, that provides the sentence embeddings for the first column in the dataset file |
| """ |
| self.student_model = student_model |
| self.teacher_model = teacher_model |
| self.datasets = [] |
| self.datasets_iterator = [] |
| self.datasets_tokenized = [] |
| self.dataset_indices = [] |
| self.copy_dataset_indices = [] |
| self.cache = [] |
| self.batch_size = batch_size |
| self.use_embedding_cache = use_embedding_cache |
| self.embedding_cache = {} |
| self.num_sentences = 0 |
|
|
| def load_data(self, filepath: str, weight: int = 100, max_sentences: int = None, max_sentence_length: int = 128): |
| """ |
| Reads in a tab-seperated .txt/.csv/.tsv or .gz file. The different columns contain the different translations of the sentence in the first column |
| |
| :param filepath: Filepath to the file |
| :param weight: If more than one dataset is loaded with load_data: With which frequency should data be sampled from this dataset? |
| :param max_sentences: Max number of lines to be read from filepath |
| :param max_sentence_length: Skip the example if one of the sentences is has more characters than max_sentence_length |
| :param batch_size: Size for encoding parallel sentences |
| :return: |
| """ |
|
|
| logger.info("Load "+filepath) |
| parallel_sentences = [] |
|
|
| with gzip.open(filepath, 'rt', encoding='utf8') if filepath.endswith('.gz') else open(filepath, encoding='utf8') as fIn: |
| count = 0 |
| for line in fIn: |
| sentences = line.strip().split("\t") |
| if max_sentence_length is not None and max_sentence_length > 0 and max([len(sent) for sent in sentences]) > max_sentence_length: |
| continue |
|
|
| parallel_sentences.append(sentences) |
| count += 1 |
| if max_sentences is not None and max_sentences > 0 and count >= max_sentences: |
| break |
| self.add_dataset(parallel_sentences, weight=weight, max_sentences=max_sentences, max_sentence_length=max_sentence_length) |
|
|
|
|
| def add_dataset(self, parallel_sentences: List[List[str]], weight: int = 100, max_sentences: int = None, max_sentence_length: int = 128): |
| sentences_map = {} |
| for sentences in parallel_sentences: |
| if max_sentence_length is not None and max_sentence_length > 0 and max([len(sent) for sent in sentences]) > max_sentence_length: |
| continue |
|
|
| source_sentence = sentences[0] |
| if source_sentence not in sentences_map: |
| sentences_map[source_sentence] = set() |
|
|
| for sent in sentences: |
| sentences_map[source_sentence].add(sent) |
|
|
| if max_sentences is not None and max_sentences > 0 and len(sentences_map) >= max_sentences: |
| break |
|
|
| if len(sentences_map) == 0: |
| return |
|
|
| self.num_sentences += sum([len(sentences_map[sent]) for sent in sentences_map]) |
|
|
| dataset_id = len(self.datasets) |
| self.datasets.append(list(sentences_map.items())) |
| self.datasets_iterator.append(0) |
| self.dataset_indices.extend([dataset_id] * weight) |
|
|
|
|
| def generate_data(self): |
| source_sentences_list = [] |
| target_sentences_list = [] |
| for data_idx in self.dataset_indices: |
| src_sentence, trg_sentences = self.next_entry(data_idx) |
| source_sentences_list.append(src_sentence) |
| target_sentences_list.append(trg_sentences) |
|
|
| |
| src_embeddings = self.get_embeddings(source_sentences_list) |
|
|
| for src_embedding, trg_sentences in zip(src_embeddings, target_sentences_list): |
| for trg_sentence in trg_sentences: |
| self.cache.append(InputExample(texts=[trg_sentence], label=src_embedding)) |
|
|
| random.shuffle(self.cache) |
|
|
| def next_entry(self, data_idx): |
| source, target_sentences = self.datasets[data_idx][self.datasets_iterator[data_idx]] |
|
|
| self.datasets_iterator[data_idx] += 1 |
| if self.datasets_iterator[data_idx] >= len(self.datasets[data_idx]): |
| self.datasets_iterator[data_idx] = 0 |
| random.shuffle(self.datasets[data_idx]) |
|
|
| return source, target_sentences |
|
|
| def get_embeddings(self, sentences): |
| if not self.use_embedding_cache: |
| return self.teacher_model.encode(sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True) |
|
|
| |
| new_sentences = [] |
| for sent in sentences: |
| if sent not in self.embedding_cache: |
| new_sentences.append(sent) |
|
|
| if len(new_sentences) > 0: |
| new_embeddings = self.teacher_model.encode(new_sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True) |
| for sent, embedding in zip(new_sentences, new_embeddings): |
| self.embedding_cache[sent] = embedding |
|
|
| return [self.embedding_cache[sent] for sent in sentences] |
|
|
| def __len__(self): |
| return self.num_sentences |
|
|
| def __getitem__(self, idx): |
| if len(self.cache) == 0: |
| self.generate_data() |
|
|
| return self.cache.pop() |
|
|