Shortcuts

Source code for hfai.datasets.coco

from typing import Callable, List, Optional
from collections import defaultdict
import time
import json
import pickle
from PIL import Image
from ffrecord import FileReader
from pycocotools.coco import COCO

from .base import BaseDataset, get_data_dir, register_dataset

"""
Expected file organization:

    [coco_data_dir]
        train2017.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        val2017.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        panoptic_train2017.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        panoptic_val2017.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        annotations/
            captions_train2017.json
            captions_val2017.json
            instances_train2017.json
            instances_val2017.json
            person_keypoints_train2017.json
            person_keypoints_val2017.json
            panoptic_train2017.json
            panoptic_val2017.json
"""


class CocoPanopticBase(COCO):
    def __init__(self, annotation_file=None):
        super(CocoPanopticBase, self).__init__()

        # load dataset
        self.dataset, self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict()
        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
        if not annotation_file == None:
            print("loading annotations into memory...")
            tic = time.time()
            with open(annotation_file, "r") as fp:
                dataset = json.load(fp)
            assert type(dataset) == dict, "annotation file format {} not supported".format(type(dataset))
            print("Done (t={:0.2f}s)".format(time.time() - tic))
            self.dataset = dataset
            self.createIndex()

    def createIndex(self):
        # create index
        print("creating index...")
        anns, cats, imgs = {}, {}, {}
        imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
        if "annotations" in self.dataset:
            for ann_id, ann in enumerate(self.dataset["annotations"]):
                ann["id"] = ann_id
                imgToAnns[ann["image_id"]].append(ann)
                anns[ann["id"]] = ann

        if "images" in self.dataset:
            for img in self.dataset["images"]:
                imgs[img["id"]] = img

        if "categories" in self.dataset:
            for cat in self.dataset["categories"]:
                cats[cat["id"]] = cat

        if "annotations" in self.dataset and "categories" in self.dataset:
            for ann in self.dataset["annotations"]:
                for seg in ann["segments_info"]:
                    catToImgs[seg["category_id"]].append(ann["image_id"])

        print("index created!")

        # create class members
        self.anns = anns
        self.imgToAnns = imgToAnns
        self.catToImgs = catToImgs
        self.imgs = imgs
        self.cats = cats


class CocoReader():
    """
    这是COCO数据集的读取接口

    Args:
        split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
        check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
        miniset (bool): 是否使用 mini 集合(默认为 ``False``)

    """

    def __init__(self, split: str, check_data: bool = True, miniset: bool = False):
        data_dir = get_data_dir()
        if miniset:
            data_dir = data_dir / "mini"
        self.data_dir = data_dir / "COCO"

        assert split in ["train", "val"]
        self.split = split
        self.fname = self.data_dir / f"{split}2017.ffr"
        self.reader = FileReader(self.fname, check_data)

        self.panoptic_fname = self.data_dir / f"panoptic_{split}2017.ffr"
        self.panoptic_reader = FileReader(self.panoptic_fname, check_data)

        self.ids = None
        self.coco = None

    def __len__(self):
        return self.reader.n

    def get_cocoapi(self) -> COCO:
        """
        返回一个pycocotools.coco.COCO, 里面包含了对应的标注数据
        """
        return self.coco

    def load_captions(self) -> None:
        """
        加载 Coco Caption 标注数据
        """
        self.coco = COCO(self.data_dir / f"annotations/captions_{self.split}2017.json")
        self.ids = list(sorted(self.coco.imgs.keys()))

    def load_instances(self) -> None:
        """
        加载 Coco Object Detection 标注数据
        """
        self.coco = COCO(self.data_dir / f"annotations/instances_{self.split}2017.json")
        self.ids = list(sorted(self.coco.imgs.keys()))

    def load_keypoints(self) -> None:
        """
        加载 Coco Keypoint Detection 标注数据
        """
        self.coco = COCO(self.data_dir / f"annotations/person_keypoints_{self.split}2017.json")
        self.ids = list(sorted(self.coco.imgs.keys()))

    def load_panoptics(self) -> None:
        """
        加载 Coco Panoptic 标注数据
        """
        self.coco = CocoPanopticBase(self.data_dir / f"annotations/panoptic_{self.split}2017.json")
        self.ids = list(sorted(self.coco.imgs.keys()))

    def read_imgs(self, indices: List[int]) -> List[Image.Image]:
        """
        读取图片数据

        Args:
            indices (list): 图片数据的索引

        Returns:
            RGB格式的PIL.Image.Image图片
        """
        assert self.coco is not None, "annotations are not loaded yet."

        bytes_ = self.reader.read(indices)
        imgs = []
        for x in bytes_:
            img = pickle.loads(x).convert("RGB")
            imgs.append(img)
        return imgs

    def read_anno_imgs(self, indices: List[int]) -> List[Image.Image]:
        """
        读取带注解的图片数据(后缀png)

        Args:
            indices (list): 图片数据的索引

        Returns:
            RGB格式的PIL.Image.Image图片
        """
        assert self.coco is not None, "annotations are not loaded yet."

        bytes_ = self.panoptic_reader.read(indices)
        imgs = []
        for x in bytes_:
            img = pickle.loads(x).convert("RGB")
            imgs.append(img)
        return imgs

    def read_anno(self, index: int) -> List[dict]:
        """
        读取指定索引下的注解信息。更多信息参考:https://cocodataset.org/#format-data

        Args:
            index (int): 指定的索引

        Returns:
            注解信息,返回一个包含若干字典组成的列表,每一个列表里包括 ``instance`` 和 ``contains``,例如:

            .. code-block:: python

                captions:
                    {'image_id': 444010,
                     'id': 104057,
                     'caption': 'A group of friends sitting down at a table sharing a meal.'}

                instances:
                    {'segmentation': ...,
                     'area': 3514.564,
                     'iscrowd': 0,
                     'image_id': 444010,
                     'bbox': [x_left, y_top, w, h],
                     'category_id': 44,
                     'id': 91863}

                keypoints:
                    {'segmentation': ...,
                     'num_keypoints': 11,
                     'area': 34800.5498,
                     'iscrowd': 0,
                     'keypoints': ...,
                     'image_id': 444010,
                     'bbox': [x_left, y_top, w, h],
                     'category_id': 1,
                     'id': 1200757}

                panoptic:
                    {"image_id": int,
                     "file_name": str,
                     "segments_info":
                        {
                        "id": int,
                        "category_id": int,
                        "area": int,
                        "bbox": [x,y,width,height],
                        "iscrowd": 0 or 1,
                        },
                     }
        """
        img_id = self.ids[index]
        ann_id = self.coco.getAnnIds(img_id)
        ann = self.coco.loadAnns(ann_id)
        return ann


class CocoDataset(BaseDataset):
    def __init__(
            self,
            split: str,
            transform: Optional[Callable] = None,
            check_data: bool = True,
            miniset: bool = False,
    ) -> None:
        super(CocoDataset, self).__init__()

        self.split = split
        self.reader = CocoReader(split, check_data, miniset)
        self._load_annotations()  # load annotations into memory
        self.transform = transform
        self.coco = self.reader.coco

    def _load_annotations(self):
        raise NotImplementedError

    def __len__(self):
        return len(self.reader)

    def __getitem__(self, indices):
        imgs = self.reader.read_imgs(indices)
        annos = [self.reader.read_anno(idx) for idx in indices]
        img_ids = [self.reader.ids[idx] for idx in indices]
        if self.transform is not None:
            samples = [self.transform(img, img_id, anno) for img, img_id, anno in zip(imgs, img_ids, annos)]
        else:
            samples = list(zip(imgs, img_ids, annos))
        return samples


[docs]@register_dataset class CocoPanoptic(CocoDataset): """ 这是一个用于全景分割的 COCO 数据集 更多信息参考:https://cocodataset.org Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: pic, id, anno_pic, anno (PIL.Image.Image, int, PIL.Image.Image, dict): 返回的每条样本是一个四元组,包括一张RGB格式图片,对应的图片ID,一张标注的RGB格式图片,物体的标注信息。如下例所示: .. code-block:: python { "image_id": int, "file_name": str, "segments_info": { "id": int, "category_id": int, "area": int, "bbox": [x_left, y_top,width,height], "iscrowd": 0 or 1, }, } Examples: .. code-block:: python from hfai.datasets import CocoPanoptic def transform(pic, id, anno_pic, anno): ... dataset = CocoPanoptic(split, transform) loader = dataset.loader(batch_size=64, num_workers=4) coco = dataset.coco # same as pycocotools.coco.COCO for pic, id, anno_pic, anno in loader: # training model """ def __getitem__(self, indices): imgs = self.reader.read_imgs(indices) anno_imgs = self.reader.read_anno_imgs(indices) annos = [self.reader.read_anno(idx) for idx in indices] img_ids = [self.reader.ids[idx] for idx in indices] if self.transform is not None: samples = [ self.transform(img, img_id, anno_img, anno) for img, img_id, anno_img, anno in zip(imgs, img_ids, anno_imgs, annos) ] else: samples = list(zip(imgs, img_ids, anno_imgs, annos)) return samples def _load_annotations(self): self.reader.load_panoptics()
[docs]@register_dataset class CocoDetection(CocoDataset): """ 这是一个用于目标检测的 COCO 数据集 更多信息参考:https://cocodataset.org Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,物体的标注信息。如下例所示: .. code-block:: python { 'segmentation': ..., 'area': 3514.564, 'iscrowd': 0, 'image_id': 444010, 'bbox': [x_left, y_top, w, h], 'category_id': 44, 'id': 91863 } Examples: .. code-block:: python from hfai.datasets import CocoDetection def transform(pic, id, anno): ... dataset = CocoDetection(split, transform) loader = dataset.loader(batch_size=64, num_workers=4) coco = dataset.coco # same as pycocotools.coco.COCO for pic, id, anno in loader: # training model """ def _load_annotations(self): self.reader.load_instances()
[docs]@register_dataset class CocoCaption(CocoDataset): """ 这是一个用于图像说明的 COCO 数据集 更多信息参考:https://cocodataset.org Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,标注信息。标注如下例所示: .. code-block:: python { 'image_id': 444010, 'id': 104057, 'caption': 'A group of friends sitting down at a table sharing a meal.' } Examples: .. code-block:: python from hfai.datasets import CocoCaption def transform(pic, id, anno): ... dataset = CocoCaption(split, transform) loader = dataset.loader(batch_size=64, num_workers=4) coco = dataset.coco # same as pycocotools.coco.COCO for pic, id, anno in loader: # training model """ def _load_annotations(self): self.reader.load_captions()
[docs]@register_dataset class CocoKeypoint(CocoDataset): """ 这是一个用于关键点检测的COCO数据集 更多信息参考:https://cocodataset.org Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,物体的标注信息。标注如下例所示: .. code-block:: python { 'segmentation': ..., 'num_keypoints': 11, 'area': 34800.5498, 'iscrowd': 0, 'keypoints': ..., 'image_id': 444010, 'bbox': [x_left, y_top, w, h], 'category_id': 1, 'id': 1200757 } Examples: .. code-block:: python from hfai.datasets import CocoKeypoint def transform(pic, id, anno): ... dataset = CocoKeypoint(split, transform) loader = dataset.loader(batch_size=64, num_workers=4) coco = dataset.coco # same as pycocotools.coco.COCO for pic, id, anno in loader: # training model """ def _load_annotations(self): self.reader.load_keypoints()