from typing import List

from paddlex_hps_server import BaseTritonPythonModel, utils
from pydantic import BaseModel, Field
from typing_extensions import Annotated


class InferInput(BaseModel):
    image: str


class InferResult(BaseModel):
    labelMap: List[int]
    size: Annotated[List[int], Field(min_length=2, max_length=2)]
    image: str


class TritonPythonModel(BaseTritonPythonModel):
    def _get_input_model_type(self):
        return InferInput

    def _get_result_model_type(self):
        return InferResult

    def _execute(self, input, log_id):
        file_bytes = utils.get_raw_bytes(input.image)
        image = utils.image_bytes_to_array(file_bytes)

        result = list(self._pipeline.predict(image))[0]
        pred = result["pred"][0].tolist()
        size = [len(pred), len(pred[0])]
        label_map = [item for sublist in pred for item in sublist]
        output_image_base64 = utils.base64_encode(
            utils.image_to_bytes(result.img.convert("RGB"))
        )

        return InferResult(labelMap=label_map, size=size, image=output_image_base64)
