| | import os |
| | import json |
| | import albumentations |
| | import numpy as np |
| | from PIL import Image |
| | from tqdm import tqdm |
| | from torch.utils.data import Dataset |
| | from abc import abstractmethod |
| |
|
| |
|
| | class CocoBase(Dataset): |
| | """needed for (image, caption, segmentation) pairs""" |
| | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, |
| | crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None): |
| | self.split = self.get_split() |
| | self.size = size |
| | if crop_size is None: |
| | self.crop_size = size |
| | else: |
| | self.crop_size = crop_size |
| |
|
| | assert crop_type in [None, 'random', 'center'] |
| | self.crop_type = crop_type |
| | self.use_segmenation = use_segmentation |
| | self.onehot = onehot_segmentation |
| | self.stuffthing = use_stuffthing |
| | if self.onehot and not self.stuffthing: |
| | raise NotImplemented("One hot mode is only supported for the " |
| | "stuffthings version because labels are stored " |
| | "a bit different.") |
| |
|
| | data_json = datajson |
| | with open(data_json) as json_file: |
| | self.json_data = json.load(json_file) |
| | self.img_id_to_captions = dict() |
| | self.img_id_to_filepath = dict() |
| | self.img_id_to_segmentation_filepath = dict() |
| |
|
| | assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json", |
| | f"captions_val{self.year()}.json"] |
| | |
| | |
| | if self.use_segmenation: |
| | if self.stuffthing: |
| | self.segmentation_prefix = ( |
| | f"data/cocostuffthings/val{self.year()}" if |
| | data_json.endswith(f"captions_val{self.year()}.json") else |
| | f"data/cocostuffthings/train{self.year()}") |
| | else: |
| | self.segmentation_prefix = ( |
| | f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if |
| | data_json.endswith(f"captions_val{self.year()}.json") else |
| | f"data/coco/annotations/stuff_train{self.year()}_pixelmaps") |
| |
|
| | imagedirs = self.json_data["images"] |
| | self.labels = {"image_ids": list()} |
| | for imgdir in tqdm(imagedirs, desc="ImgToPath"): |
| | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) |
| | self.img_id_to_captions[imgdir["id"]] = list() |
| | pngfilename = imgdir["file_name"].replace("jpg", "png") |
| | if self.use_segmenation: |
| | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( |
| | self.segmentation_prefix, pngfilename) |
| | if given_files is not None: |
| | if pngfilename in given_files: |
| | self.labels["image_ids"].append(imgdir["id"]) |
| | else: |
| | self.labels["image_ids"].append(imgdir["id"]) |
| |
|
| | capdirs = self.json_data["annotations"] |
| | for capdir in tqdm(capdirs, desc="ImgToCaptions"): |
| | |
| | |
| | self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"]) |
| |
|
| | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) |
| | if self.split=="validation": |
| | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) |
| | else: |
| | |
| | if self.crop_type in [None, 'random']: |
| | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) |
| | else: |
| | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) |
| | self.preprocessor = albumentations.Compose( |
| | [self.rescaler, self.cropper], |
| | additional_targets={"segmentation": "image"}) |
| | if force_no_crop: |
| | self.rescaler = albumentations.Resize(height=self.size, width=self.size) |
| | self.preprocessor = albumentations.Compose( |
| | [self.rescaler], |
| | additional_targets={"segmentation": "image"}) |
| |
|
| | @abstractmethod |
| | def year(self): |
| | raise NotImplementedError() |
| |
|
| | def __len__(self): |
| | return len(self.labels["image_ids"]) |
| |
|
| | def preprocess_image(self, image_path, segmentation_path=None): |
| | image = Image.open(image_path) |
| | if not image.mode == "RGB": |
| | image = image.convert("RGB") |
| | image = np.array(image).astype(np.uint8) |
| | if segmentation_path: |
| | segmentation = Image.open(segmentation_path) |
| | if not self.onehot and not segmentation.mode == "RGB": |
| | segmentation = segmentation.convert("RGB") |
| | segmentation = np.array(segmentation).astype(np.uint8) |
| | if self.onehot: |
| | assert self.stuffthing |
| | |
| | |
| | |
| | |
| | |
| | |
| | assert segmentation.dtype == np.uint8 |
| | segmentation = segmentation + 1 |
| |
|
| | processed = self.preprocessor(image=image, segmentation=segmentation) |
| |
|
| | image, segmentation = processed["image"], processed["segmentation"] |
| | else: |
| | image = self.preprocessor(image=image,)['image'] |
| |
|
| | image = (image / 127.5 - 1.0).astype(np.float32) |
| | if segmentation_path: |
| | if self.onehot: |
| | assert segmentation.dtype == np.uint8 |
| | |
| | n_labels = 183 |
| | flatseg = np.ravel(segmentation) |
| | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) |
| | onehot[np.arange(flatseg.size), flatseg] = True |
| | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) |
| | segmentation = onehot |
| | else: |
| | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) |
| | return image, segmentation |
| | else: |
| | return image |
| |
|
| | def __getitem__(self, i): |
| | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] |
| | if self.use_segmenation: |
| | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] |
| | image, segmentation = self.preprocess_image(img_path, seg_path) |
| | else: |
| | image = self.preprocess_image(img_path) |
| | captions = self.img_id_to_captions[self.labels["image_ids"][i]] |
| | |
| | caption = captions[np.random.randint(0, len(captions))] |
| | example = {"image": image, |
| | |
| | "caption": caption, |
| | "img_path": img_path, |
| | "filename_": img_path.split(os.sep)[-1] |
| | } |
| | if self.use_segmenation: |
| | example.update({"seg_path": seg_path, 'segmentation': segmentation}) |
| | return example |
| |
|
| |
|
| | class CocoImagesAndCaptionsTrain2017(CocoBase): |
| | """returns a pair of (image, caption)""" |
| | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,): |
| | super().__init__(size=size, |
| | dataroot="data/coco/train2017", |
| | datajson="data/coco/annotations/captions_train2017.json", |
| | onehot_segmentation=onehot_segmentation, |
| | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) |
| |
|
| | def get_split(self): |
| | return "train" |
| |
|
| | def year(self): |
| | return '2017' |
| |
|
| |
|
| | class CocoImagesAndCaptionsValidation2017(CocoBase): |
| | """returns a pair of (image, caption)""" |
| | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, |
| | given_files=None): |
| | super().__init__(size=size, |
| | dataroot="data/coco/val2017", |
| | datajson="data/coco/annotations/captions_val2017.json", |
| | onehot_segmentation=onehot_segmentation, |
| | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, |
| | given_files=given_files) |
| |
|
| | def get_split(self): |
| | return "validation" |
| |
|
| | def year(self): |
| | return '2017' |
| |
|
| |
|
| |
|
| | class CocoImagesAndCaptionsTrain2014(CocoBase): |
| | """returns a pair of (image, caption)""" |
| | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'): |
| | super().__init__(size=size, |
| | dataroot="data/coco/train2014", |
| | datajson="data/coco/annotations2014/annotations/captions_train2014.json", |
| | onehot_segmentation=onehot_segmentation, |
| | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, |
| | use_segmentation=False, |
| | crop_type=crop_type) |
| |
|
| | def get_split(self): |
| | return "train" |
| |
|
| | def year(self): |
| | return '2014' |
| |
|
| | class CocoImagesAndCaptionsValidation2014(CocoBase): |
| | """returns a pair of (image, caption)""" |
| | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, |
| | given_files=None,crop_type='center',**kwargs): |
| | super().__init__(size=size, |
| | dataroot="data/coco/val2014", |
| | datajson="data/coco/annotations2014/annotations/captions_val2014.json", |
| | onehot_segmentation=onehot_segmentation, |
| | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, |
| | given_files=given_files, |
| | use_segmentation=False, |
| | crop_type=crop_type) |
| |
|
| | def get_split(self): |
| | return "validation" |
| |
|
| | def year(self): |
| | return '2014' |
| |
|
| | if __name__ == '__main__': |
| | with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file: |
| | json_data = json.load(json_file) |
| | capdirs = json_data["annotations"] |
| | import pudb; pudb.set_trace() |
| | |
| | d2 = CocoImagesAndCaptionsValidation2014(size=256) |
| | print("constructed dataset.") |
| | print(f"length of {d2.__class__.__name__}: {len(d2)}") |
| |
|
| | ex2 = d2[0] |
| | |
| | |
| | print(ex2["image"].shape) |
| | |
| | |
| | print(ex2["caption"].__class__.__name__) |
| |
|