from typing import List

import numpy as np
import pycocotools.mask as mask_util
from paddlex_hps_server import BaseTritonPythonModel, utils
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypeAlias


class InferInput(BaseModel):
    image: str


BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]


class Mask(BaseModel):
    rleResult: str
    size: Annotated[List[int], Field(min_length=2, max_length=2)]


class Instance(BaseModel):
    bbox: BoundingBox
    categoryId: int
    score: float
    mask: Mask


class InferResult(BaseModel):
    instances: List[Instance]
    image: str


def _rle(mask: np.ndarray) -> str:
    rle_res = mask_util.encode(np.asarray(mask[..., None], order="F", dtype="uint8"))[0]
    return rle_res["counts"].decode("utf-8")


class TritonPythonModel(BaseTritonPythonModel):
    def _get_input_model_type(self):
        return InferInput

    def _get_result_model_type(self):
        return InferResult

    def _execute(self, input, log_id):
        file_bytes = utils.get_raw_bytes(input.image)
        image = utils.image_bytes_to_array(file_bytes)

        result = list(self._pipeline.predict(image))[0]

        instances: List[Instance] = []
        for obj, mask in zip(result["boxes"], result["masks"]):
            rle_res = _rle(mask)
            mask = Mask(rleResult=rle_res, size=mask.shape)
            instances.append(
                Instance(
                    bbox=obj["coordinate"],
                    categoryId=obj["cls_id"],
                    score=obj["score"],
                    mask=mask,
                )
            )
        output_image_base64 = utils.base64_encode(utils.image_to_bytes(result.img))

        return InferResult(instances=instances, image=output_image_base64)
