from typing import List

from paddlex_hps_server import BaseTritonPythonModel, utils
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypeAlias


class InferInput(BaseModel):
    image: str


BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]


class DetectedObject(BaseModel):
    bbox: BoundingBox
    categoryId: int
    score: float


class InferResult(BaseModel):
    detectedObjects: List[DetectedObject]
    image: str


class TritonPythonModel(BaseTritonPythonModel):
    def _get_input_model_type(self):
        return InferInput

    def _get_result_model_type(self):
        return InferResult

    def _execute(self, input, log_id):
        file_bytes = utils.get_raw_bytes(input.image)
        image = utils.image_bytes_to_array(file_bytes)

        result = list(self._pipeline.predict(image))[0]
        objects: List[DetectedObject] = []
        for obj in result["boxes"]:
            objects.append(
                DetectedObject(
                    bbox=obj["coordinate"],
                    categoryId=obj["cls_id"],
                    score=obj["score"],
                )
            )
        output_image_base64 = utils.base64_encode(utils.image_to_bytes(result.img))

        return InferResult(detectedObjects=objects, image=output_image_base64)
