from typing import List, Optional

from paddlex_hps_server import utils
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypeAlias

from common.base_model import BaseFaceRecognitionModel
from common.index_data_utils import deserialize_index_data


class InferInput(BaseModel):
    image: str
    indexKey: Optional[str] = None


BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]


class RecResult(BaseModel):
    label: str
    score: float


class Face(BaseModel):
    bbox: BoundingBox
    recResults: List[RecResult]
    score: float


class InferResult(BaseModel):
    faces: List[Face]
    image: str


class TritonPythonModel(BaseFaceRecognitionModel):
    def _get_input_model_type(self):
        return InferInput

    def _get_result_model_type(self):
        return InferResult

    def _execute(self, input, log_id):
        image_bytes = utils.get_raw_bytes(input.image)
        image = utils.image_bytes_to_array(image_bytes)

        if input.indexKey is not None:
            index_storage = self._context["index_storage"]
            index_data_bytes = index_storage.get(input.indexKey)
            index_data = deserialize_index_data(index_data_bytes)
        else:
            index_data = None

        result = list(self._pipeline.predict(image, index=index_data))[0]

        faces: List[Face] = []
        for face in result["boxes"]:
            rec_results: List[RecResult] = []
            if face["rec_scores"] is not None:
                for label, score in zip(face["labels"], face["rec_scores"]):
                    rec_results.append(
                        RecResult(
                            label=label,
                            score=score,
                        )
                    )
            faces.append(
                Face(
                    bbox=face["coordinate"],
                    recResults=rec_results,
                    score=face["det_score"],
                )
            )
        output_image_base64 = utils.base64_encode(utils.image_to_bytes(result.img))

        return InferResult(faces=faces, image=output_image_base64)
