| """ |
| |
| """ |
| from torch.utils.data import IterableDataset |
| import numpy as np |
| from typing import List |
| from ..readers import InputExample |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class SentenceLabelDataset(IterableDataset): |
| """ |
| This dataset can be used for some specific Triplet Losses like BATCH_HARD_TRIPLET_LOSS which requires |
| multiple examples with the same label in a batch. |
| |
| It draws n consecutive, random and unique samples from one label at a time. This is repeated for each label. |
| |
| Labels with fewer than n unique samples are ignored. |
| This also applied to drawing without replacement, once less than n samples remain for a label, it is skipped. |
| |
| This *DOES NOT* check if there are more labels than the batch is large or if the batch size is divisible |
| by the samples drawn per label. |
| """ |
| def __init__(self, examples: List[InputExample], samples_per_label: int = 2, with_replacement: bool = False): |
| """ |
| Creates a LabelSampler for a SentenceLabelDataset. |
| |
| :param examples: |
| a list with InputExamples |
| :param samples_per_label: |
| the number of consecutive, random and unique samples drawn per label. Batch size should be a multiple of samples_per_label |
| :param with_replacement: |
| if this is True, then each sample is drawn at most once (depending on the total number of samples per label). |
| if this is False, then one sample can be drawn in multiple draws, but still not multiple times in the same |
| drawing. |
| """ |
| super().__init__() |
|
|
| self.samples_per_label = samples_per_label |
|
|
| |
| label2ex = {} |
| for example in examples: |
| if example.label not in label2ex: |
| label2ex[example.label] = [] |
| label2ex[example.label].append(example) |
|
|
| |
| self.grouped_inputs = [] |
| self.groups_right_border = [] |
| num_labels = 0 |
|
|
| for label, label_examples in label2ex.items(): |
| if len(label_examples) >= self.samples_per_label: |
| self.grouped_inputs.extend(label_examples) |
| self.groups_right_border.append(len(self.grouped_inputs)) |
| num_labels += 1 |
|
|
| self.label_range = np.arange(num_labels) |
| self.with_replacement = with_replacement |
| np.random.shuffle(self.label_range) |
|
|
| logger.info("SentenceLabelDataset: {} examples, from which {} examples could be used (those labels appeared at least {} times). {} different labels found.".format(len(examples), len(self.grouped_inputs), self.samples_per_label, num_labels )) |
|
|
| def __iter__(self): |
| label_idx = 0 |
| count = 0 |
| already_seen = {} |
| while count < len(self.grouped_inputs): |
| label = self.label_range[label_idx] |
| if label not in already_seen: |
| already_seen[label] = set() |
|
|
| left_border = 0 if label == 0 else self.groups_right_border[label-1] |
| right_border = self.groups_right_border[label] |
|
|
| if self.with_replacement: |
| selection = np.arange(left_border, right_border) |
| else: |
| selection = [i for i in np.arange(left_border, right_border) if i not in already_seen[label]] |
|
|
| if len(selection) >= self.samples_per_label: |
| for element_idx in np.random.choice(selection, self.samples_per_label, replace=False): |
| count += 1 |
| already_seen[label].add(element_idx) |
| yield self.grouped_inputs[element_idx] |
|
|
| label_idx += 1 |
| if label_idx >= len(self.label_range): |
| label_idx = 0 |
| already_seen = {} |
| np.random.shuffle(self.label_range) |
|
|
| def __len__(self): |
| return len(self.grouped_inputs) |