from typing import Callable, List, Optional
from collections import defaultdict
import time
import json
import pickle
from PIL import Image
from ffrecord import FileReader
from pycocotools.coco import COCO
from .base import BaseDataset, get_data_dir, register_dataset
"""
Expected file organization:
[coco_data_dir]
train2017.ffr
PART_00000.ffr
PART_00001.ffr
...
val2017.ffr
PART_00000.ffr
PART_00001.ffr
...
panoptic_train2017.ffr
PART_00000.ffr
PART_00001.ffr
...
panoptic_val2017.ffr
PART_00000.ffr
PART_00001.ffr
...
annotations/
captions_train2017.json
captions_val2017.json
instances_train2017.json
instances_val2017.json
person_keypoints_train2017.json
person_keypoints_val2017.json
panoptic_train2017.json
panoptic_val2017.json
"""
class CocoPanopticBase(COCO):
def __init__(self, annotation_file=None):
super(CocoPanopticBase, self).__init__()
# load dataset
self.dataset, self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict()
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
if not annotation_file == None:
print("loading annotations into memory...")
tic = time.time()
with open(annotation_file, "r") as fp:
dataset = json.load(fp)
assert type(dataset) == dict, "annotation file format {} not supported".format(type(dataset))
print("Done (t={:0.2f}s)".format(time.time() - tic))
self.dataset = dataset
self.createIndex()
def createIndex(self):
# create index
print("creating index...")
anns, cats, imgs = {}, {}, {}
imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
if "annotations" in self.dataset:
for ann_id, ann in enumerate(self.dataset["annotations"]):
ann["id"] = ann_id
imgToAnns[ann["image_id"]].append(ann)
anns[ann["id"]] = ann
if "images" in self.dataset:
for img in self.dataset["images"]:
imgs[img["id"]] = img
if "categories" in self.dataset:
for cat in self.dataset["categories"]:
cats[cat["id"]] = cat
if "annotations" in self.dataset and "categories" in self.dataset:
for ann in self.dataset["annotations"]:
for seg in ann["segments_info"]:
catToImgs[seg["category_id"]].append(ann["image_id"])
print("index created!")
# create class members
self.anns = anns
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
self.imgs = imgs
self.cats = cats
class CocoReader():
"""
这是COCO数据集的读取接口
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
"""
def __init__(self, split: str, check_data: bool = True, miniset: bool = False):
data_dir = get_data_dir()
if miniset:
data_dir = data_dir / "mini"
self.data_dir = data_dir / "COCO"
assert split in ["train", "val"]
self.split = split
self.fname = self.data_dir / f"{split}2017.ffr"
self.reader = FileReader(self.fname, check_data)
self.panoptic_fname = self.data_dir / f"panoptic_{split}2017.ffr"
self.panoptic_reader = FileReader(self.panoptic_fname, check_data)
self.ids = None
self.coco = None
def __len__(self):
return self.reader.n
def get_cocoapi(self) -> COCO:
"""
返回一个pycocotools.coco.COCO, 里面包含了对应的标注数据
"""
return self.coco
def load_captions(self) -> None:
"""
加载 Coco Caption 标注数据
"""
self.coco = COCO(self.data_dir / f"annotations/captions_{self.split}2017.json")
self.ids = list(sorted(self.coco.imgs.keys()))
def load_instances(self) -> None:
"""
加载 Coco Object Detection 标注数据
"""
self.coco = COCO(self.data_dir / f"annotations/instances_{self.split}2017.json")
self.ids = list(sorted(self.coco.imgs.keys()))
def load_keypoints(self) -> None:
"""
加载 Coco Keypoint Detection 标注数据
"""
self.coco = COCO(self.data_dir / f"annotations/person_keypoints_{self.split}2017.json")
self.ids = list(sorted(self.coco.imgs.keys()))
def load_panoptics(self) -> None:
"""
加载 Coco Panoptic 标注数据
"""
self.coco = CocoPanopticBase(self.data_dir / f"annotations/panoptic_{self.split}2017.json")
self.ids = list(sorted(self.coco.imgs.keys()))
def read_imgs(self, indices: List[int]) -> List[Image.Image]:
"""
读取图片数据
Args:
indices (list): 图片数据的索引
Returns:
RGB格式的PIL.Image.Image图片
"""
assert self.coco is not None, "annotations are not loaded yet."
bytes_ = self.reader.read(indices)
imgs = []
for x in bytes_:
img = pickle.loads(x).convert("RGB")
imgs.append(img)
return imgs
def read_anno_imgs(self, indices: List[int]) -> List[Image.Image]:
"""
读取带注解的图片数据(后缀png)
Args:
indices (list): 图片数据的索引
Returns:
RGB格式的PIL.Image.Image图片
"""
assert self.coco is not None, "annotations are not loaded yet."
bytes_ = self.panoptic_reader.read(indices)
imgs = []
for x in bytes_:
img = pickle.loads(x).convert("RGB")
imgs.append(img)
return imgs
def read_anno(self, index: int) -> List[dict]:
"""
读取指定索引下的注解信息。更多信息参考:https://cocodataset.org/#format-data
Args:
index (int): 指定的索引
Returns:
注解信息,返回一个包含若干字典组成的列表,每一个列表里包括 ``instance`` 和 ``contains``,例如:
.. code-block:: python
captions:
{'image_id': 444010,
'id': 104057,
'caption': 'A group of friends sitting down at a table sharing a meal.'}
instances:
{'segmentation': ...,
'area': 3514.564,
'iscrowd': 0,
'image_id': 444010,
'bbox': [x_left, y_top, w, h],
'category_id': 44,
'id': 91863}
keypoints:
{'segmentation': ...,
'num_keypoints': 11,
'area': 34800.5498,
'iscrowd': 0,
'keypoints': ...,
'image_id': 444010,
'bbox': [x_left, y_top, w, h],
'category_id': 1,
'id': 1200757}
panoptic:
{"image_id": int,
"file_name": str,
"segments_info":
{
"id": int,
"category_id": int,
"area": int,
"bbox": [x,y,width,height],
"iscrowd": 0 or 1,
},
}
"""
img_id = self.ids[index]
ann_id = self.coco.getAnnIds(img_id)
ann = self.coco.loadAnns(ann_id)
return ann
class CocoDataset(BaseDataset):
def __init__(
self,
split: str,
transform: Optional[Callable] = None,
check_data: bool = True,
miniset: bool = False,
) -> None:
super(CocoDataset, self).__init__()
self.split = split
self.reader = CocoReader(split, check_data, miniset)
self._load_annotations() # load annotations into memory
self.transform = transform
self.coco = self.reader.coco
def _load_annotations(self):
raise NotImplementedError
def __len__(self):
return len(self.reader)
def __getitem__(self, indices):
imgs = self.reader.read_imgs(indices)
annos = [self.reader.read_anno(idx) for idx in indices]
img_ids = [self.reader.ids[idx] for idx in indices]
if self.transform is not None:
samples = [self.transform(img, img_id, anno) for img, img_id, anno in zip(imgs, img_ids, annos)]
else:
samples = list(zip(imgs, img_ids, annos))
return samples
[docs]@register_dataset
class CocoPanoptic(CocoDataset):
"""
这是一个用于全景分割的 COCO 数据集
更多信息参考:https://cocodataset.org
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
pic, id, anno_pic, anno (PIL.Image.Image, int, PIL.Image.Image, dict): 返回的每条样本是一个四元组,包括一张RGB格式图片,对应的图片ID,一张标注的RGB格式图片,物体的标注信息。如下例所示:
.. code-block:: python
{
"image_id": int,
"file_name": str,
"segments_info": {
"id": int,
"category_id": int,
"area": int,
"bbox": [x_left, y_top,width,height],
"iscrowd": 0 or 1,
},
}
Examples:
.. code-block:: python
from hfai.datasets import CocoPanoptic
def transform(pic, id, anno_pic, anno):
...
dataset = CocoPanoptic(split, transform)
loader = dataset.loader(batch_size=64, num_workers=4)
coco = dataset.coco # same as pycocotools.coco.COCO
for pic, id, anno_pic, anno in loader:
# training model
"""
def __getitem__(self, indices):
imgs = self.reader.read_imgs(indices)
anno_imgs = self.reader.read_anno_imgs(indices)
annos = [self.reader.read_anno(idx) for idx in indices]
img_ids = [self.reader.ids[idx] for idx in indices]
if self.transform is not None:
samples = [
self.transform(img, img_id, anno_img, anno)
for img, img_id, anno_img, anno in zip(imgs, img_ids, anno_imgs, annos)
]
else:
samples = list(zip(imgs, img_ids, anno_imgs, annos))
return samples
def _load_annotations(self):
self.reader.load_panoptics()
[docs]@register_dataset
class CocoDetection(CocoDataset):
"""
这是一个用于目标检测的 COCO 数据集
更多信息参考:https://cocodataset.org
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,物体的标注信息。如下例所示:
.. code-block:: python
{
'segmentation': ...,
'area': 3514.564,
'iscrowd': 0,
'image_id': 444010,
'bbox': [x_left, y_top, w, h],
'category_id': 44,
'id': 91863
}
Examples:
.. code-block:: python
from hfai.datasets import CocoDetection
def transform(pic, id, anno):
...
dataset = CocoDetection(split, transform)
loader = dataset.loader(batch_size=64, num_workers=4)
coco = dataset.coco # same as pycocotools.coco.COCO
for pic, id, anno in loader:
# training model
"""
def _load_annotations(self):
self.reader.load_instances()
[docs]@register_dataset
class CocoCaption(CocoDataset):
"""
这是一个用于图像说明的 COCO 数据集
更多信息参考:https://cocodataset.org
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,标注信息。标注如下例所示:
.. code-block:: python
{
'image_id': 444010,
'id': 104057,
'caption': 'A group of friends sitting down at a table sharing a meal.'
}
Examples:
.. code-block:: python
from hfai.datasets import CocoCaption
def transform(pic, id, anno):
...
dataset = CocoCaption(split, transform)
loader = dataset.loader(batch_size=64, num_workers=4)
coco = dataset.coco # same as pycocotools.coco.COCO
for pic, id, anno in loader:
# training model
"""
def _load_annotations(self):
self.reader.load_captions()
[docs]@register_dataset
class CocoKeypoint(CocoDataset):
"""
这是一个用于关键点检测的COCO数据集
更多信息参考:https://cocodataset.org
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,物体的标注信息。标注如下例所示:
.. code-block:: python
{
'segmentation': ...,
'num_keypoints': 11,
'area': 34800.5498,
'iscrowd': 0,
'keypoints': ...,
'image_id': 444010,
'bbox': [x_left, y_top, w, h],
'category_id': 1,
'id': 1200757
}
Examples:
.. code-block:: python
from hfai.datasets import CocoKeypoint
def transform(pic, id, anno):
...
dataset = CocoKeypoint(split, transform)
loader = dataset.loader(batch_size=64, num_workers=4)
coco = dataset.coco # same as pycocotools.coco.COCO
for pic, id, anno in loader:
# training model
"""
def _load_annotations(self):
self.reader.load_keypoints()