from typing import Callable, List, Optional
from collections import defaultdict
import time
import json
import pickle
from PIL import Image
from ffrecord import FileReader
from pycocotools.coco import COCO
from .base import BaseDataset, get_data_dir, register_dataset
"""
Expected file organization:
    [coco_data_dir]
        train2017.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        val2017.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        panoptic_train2017.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        panoptic_val2017.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        annotations/
            captions_train2017.json
            captions_val2017.json
            instances_train2017.json
            instances_val2017.json
            person_keypoints_train2017.json
            person_keypoints_val2017.json
            panoptic_train2017.json
            panoptic_val2017.json
"""
class CocoPanopticBase(COCO):
    def __init__(self, annotation_file=None):
        super(CocoPanopticBase, self).__init__()
        # load dataset
        self.dataset, self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict()
        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
        if not annotation_file == None:
            print("loading annotations into memory...")
            tic = time.time()
            with open(annotation_file, "r") as fp:
                dataset = json.load(fp)
            assert type(dataset) == dict, "annotation file format {} not supported".format(type(dataset))
            print("Done (t={:0.2f}s)".format(time.time() - tic))
            self.dataset = dataset
            self.createIndex()
    def createIndex(self):
        # create index
        print("creating index...")
        anns, cats, imgs = {}, {}, {}
        imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
        if "annotations" in self.dataset:
            for ann_id, ann in enumerate(self.dataset["annotations"]):
                ann["id"] = ann_id
                imgToAnns[ann["image_id"]].append(ann)
                anns[ann["id"]] = ann
        if "images" in self.dataset:
            for img in self.dataset["images"]:
                imgs[img["id"]] = img
        if "categories" in self.dataset:
            for cat in self.dataset["categories"]:
                cats[cat["id"]] = cat
        if "annotations" in self.dataset and "categories" in self.dataset:
            for ann in self.dataset["annotations"]:
                for seg in ann["segments_info"]:
                    catToImgs[seg["category_id"]].append(ann["image_id"])
        print("index created!")
        # create class members
        self.anns = anns
        self.imgToAnns = imgToAnns
        self.catToImgs = catToImgs
        self.imgs = imgs
        self.cats = cats
class CocoReader():
    """
    这是COCO数据集的读取接口
    Args:
        split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
        check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
        miniset (bool): 是否使用 mini 集合(默认为 ``False``)
    """
    def __init__(self, split: str, check_data: bool = True, miniset: bool = False):
        data_dir = get_data_dir()
        if miniset:
            data_dir = data_dir / "mini"
        self.data_dir = data_dir / "COCO"
        assert split in ["train", "val"]
        self.split = split
        self.fname = self.data_dir / f"{split}2017.ffr"
        self.reader = FileReader(self.fname, check_data)
        self.panoptic_fname = self.data_dir / f"panoptic_{split}2017.ffr"
        self.panoptic_reader = FileReader(self.panoptic_fname, check_data)
        self.ids = None
        self.coco = None
    def __len__(self):
        return self.reader.n
    def get_cocoapi(self) -> COCO:
        """
        返回一个pycocotools.coco.COCO, 里面包含了对应的标注数据
        """
        return self.coco
    def load_captions(self) -> None:
        """
        加载 Coco Caption 标注数据
        """
        self.coco = COCO(self.data_dir / f"annotations/captions_{self.split}2017.json")
        self.ids = list(sorted(self.coco.imgs.keys()))
    def load_instances(self) -> None:
        """
        加载 Coco Object Detection 标注数据
        """
        self.coco = COCO(self.data_dir / f"annotations/instances_{self.split}2017.json")
        self.ids = list(sorted(self.coco.imgs.keys()))
    def load_keypoints(self) -> None:
        """
        加载 Coco Keypoint Detection 标注数据
        """
        self.coco = COCO(self.data_dir / f"annotations/person_keypoints_{self.split}2017.json")
        self.ids = list(sorted(self.coco.imgs.keys()))
    def load_panoptics(self) -> None:
        """
        加载 Coco Panoptic 标注数据
        """
        self.coco = CocoPanopticBase(self.data_dir / f"annotations/panoptic_{self.split}2017.json")
        self.ids = list(sorted(self.coco.imgs.keys()))
    def read_imgs(self, indices: List[int]) -> List[Image.Image]:
        """
        读取图片数据
        Args:
            indices (list): 图片数据的索引
        Returns:
            RGB格式的PIL.Image.Image图片
        """
        assert self.coco is not None, "annotations are not loaded yet."
        bytes_ = self.reader.read(indices)
        imgs = []
        for x in bytes_:
            img = pickle.loads(x).convert("RGB")
            imgs.append(img)
        return imgs
    def read_anno_imgs(self, indices: List[int]) -> List[Image.Image]:
        """
        读取带注解的图片数据(后缀png)
        Args:
            indices (list): 图片数据的索引
        Returns:
            RGB格式的PIL.Image.Image图片
        """
        assert self.coco is not None, "annotations are not loaded yet."
        bytes_ = self.panoptic_reader.read(indices)
        imgs = []
        for x in bytes_:
            img = pickle.loads(x).convert("RGB")
            imgs.append(img)
        return imgs
    def read_anno(self, index: int) -> List[dict]:
        """
        读取指定索引下的注解信息。更多信息参考:https://cocodataset.org/#format-data
        Args:
            index (int): 指定的索引
        Returns:
            注解信息,返回一个包含若干字典组成的列表,每一个列表里包括 ``instance`` 和 ``contains``,例如:
            .. code-block:: python
                captions:
                    {'image_id': 444010,
                     'id': 104057,
                     'caption': 'A group of friends sitting down at a table sharing a meal.'}
                instances:
                    {'segmentation': ...,
                     'area': 3514.564,
                     'iscrowd': 0,
                     'image_id': 444010,
                     'bbox': [x_left, y_top, w, h],
                     'category_id': 44,
                     'id': 91863}
                keypoints:
                    {'segmentation': ...,
                     'num_keypoints': 11,
                     'area': 34800.5498,
                     'iscrowd': 0,
                     'keypoints': ...,
                     'image_id': 444010,
                     'bbox': [x_left, y_top, w, h],
                     'category_id': 1,
                     'id': 1200757}
                panoptic:
                    {"image_id": int,
                     "file_name": str,
                     "segments_info":
                        {
                        "id": int,
                        "category_id": int,
                        "area": int,
                        "bbox": [x,y,width,height],
                        "iscrowd": 0 or 1,
                        },
                     }
        """
        img_id = self.ids[index]
        ann_id = self.coco.getAnnIds(img_id)
        ann = self.coco.loadAnns(ann_id)
        return ann
class CocoDataset(BaseDataset):
    def __init__(
            self,
            split: str,
            transform: Optional[Callable] = None,
            check_data: bool = True,
            miniset: bool = False,
    ) -> None:
        super(CocoDataset, self).__init__()
        self.split = split
        self.reader = CocoReader(split, check_data, miniset)
        self._load_annotations()  # load annotations into memory
        self.transform = transform
        self.coco = self.reader.coco
    def _load_annotations(self):
        raise NotImplementedError
    def __len__(self):
        return len(self.reader)
    def __getitem__(self, indices):
        imgs = self.reader.read_imgs(indices)
        annos = [self.reader.read_anno(idx) for idx in indices]
        img_ids = [self.reader.ids[idx] for idx in indices]
        if self.transform is not None:
            samples = [self.transform(img, img_id, anno) for img, img_id, anno in zip(imgs, img_ids, annos)]
        else:
            samples = list(zip(imgs, img_ids, annos))
        return samples
[docs]@register_dataset
class CocoPanoptic(CocoDataset):
    """
    这是一个用于全景分割的 COCO 数据集
    更多信息参考:https://cocodataset.org
    Args:
        split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
        transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注
        check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
        miniset (bool): 是否使用 mini 集合(默认为 ``False``)
    Returns:
        pic, id, anno_pic, anno (PIL.Image.Image, int, PIL.Image.Image, dict): 返回的每条样本是一个四元组,包括一张RGB格式图片,对应的图片ID,一张标注的RGB格式图片,物体的标注信息。如下例所示:
        .. code-block:: python
            {
                "image_id": int,
                "file_name": str,
                "segments_info": {
                    "id": int,
                    "category_id": int,
                    "area": int,
                    "bbox": [x_left, y_top,width,height],
                    "iscrowd": 0 or 1,
                },
            }
    Examples:
    .. code-block:: python
        from hfai.datasets import CocoPanoptic
        def transform(pic, id, anno_pic, anno):
            ...
        dataset = CocoPanoptic(split, transform)
        loader = dataset.loader(batch_size=64, num_workers=4)
        coco = dataset.coco # same as pycocotools.coco.COCO
        for pic, id, anno_pic, anno in loader:
            # training model
    """
    def __getitem__(self, indices):
        imgs = self.reader.read_imgs(indices)
        anno_imgs = self.reader.read_anno_imgs(indices)
        annos = [self.reader.read_anno(idx) for idx in indices]
        img_ids = [self.reader.ids[idx] for idx in indices]
        if self.transform is not None:
            samples = [
                self.transform(img, img_id, anno_img, anno)
                for img, img_id, anno_img, anno in zip(imgs, img_ids, anno_imgs, annos)
            ]
        else:
            samples = list(zip(imgs, img_ids, anno_imgs, annos))
        return samples
    def _load_annotations(self):
        self.reader.load_panoptics() 
[docs]@register_dataset
class CocoDetection(CocoDataset):
    """
    这是一个用于目标检测的 COCO 数据集
    更多信息参考:https://cocodataset.org
    Args:
        split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
        transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注
        check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
        miniset (bool): 是否使用 mini 集合(默认为 ``False``)
    Returns:
        pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,物体的标注信息。如下例所示:
        .. code-block:: python
            {
                'segmentation': ...,
                'area': 3514.564,
                'iscrowd': 0,
                'image_id': 444010,
                'bbox': [x_left, y_top, w, h],
                'category_id': 44,
                'id': 91863
            }
    Examples:
    .. code-block:: python
        from hfai.datasets import CocoDetection
        def transform(pic, id, anno):
            ...
        dataset = CocoDetection(split, transform)
        loader = dataset.loader(batch_size=64, num_workers=4)
        coco = dataset.coco # same as pycocotools.coco.COCO
        for pic, id, anno in loader:
            # training model
    """
    def _load_annotations(self):
        self.reader.load_instances() 
[docs]@register_dataset
class CocoCaption(CocoDataset):
    """
    这是一个用于图像说明的 COCO 数据集
    更多信息参考:https://cocodataset.org
    Args:
        split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
        transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注
        check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
        miniset (bool): 是否使用 mini 集合(默认为 ``False``)
    Returns:
        pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,标注信息。标注如下例所示:
        .. code-block:: python
            {
                'image_id': 444010,
                'id': 104057,
                'caption': 'A group of friends sitting down at a table sharing a meal.'
            }
    Examples:
    .. code-block:: python
        from hfai.datasets import CocoCaption
        def transform(pic, id, anno):
            ...
        dataset = CocoCaption(split, transform)
        loader = dataset.loader(batch_size=64, num_workers=4)
        coco = dataset.coco # same as pycocotools.coco.COCO
        for pic, id, anno in loader:
            # training model
    """
    def _load_annotations(self):
        self.reader.load_captions() 
[docs]@register_dataset
class CocoKeypoint(CocoDataset):
    """
    这是一个用于关键点检测的COCO数据集
    更多信息参考:https://cocodataset.org
    Args:
        split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
        transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片、图片 id、标注图片和对应的标注作为输入,输出 transform 之后的图片、图片 id 和标注
        check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
        miniset (bool): 是否使用 mini 集合(默认为 ``False``)
    Returns:
        pic, id, anno (PIL.Image.Image, int, dict): 返回的每条样本是一个三元组,包括一张 RGB 格式图片,对应的图片 ID,物体的标注信息。标注如下例所示:
        .. code-block:: python
            {
                'segmentation': ...,
                'num_keypoints': 11,
                'area': 34800.5498,
                'iscrowd': 0,
                'keypoints': ...,
                'image_id': 444010,
                'bbox': [x_left, y_top, w, h],
                'category_id': 1,
                'id': 1200757
            }
    Examples:
    .. code-block:: python
        from hfai.datasets import CocoKeypoint
        def transform(pic, id, anno):
            ...
        dataset = CocoKeypoint(split, transform)
        loader = dataset.loader(batch_size=64, num_workers=4)
        coco = dataset.coco # same as pycocotools.coco.COCO
        for pic, id, anno in loader:
            # training model
    """
    def _load_annotations(self):
        self.reader.load_keypoints()