from typing import List, Optional

from paddlex_hps_server import BaseTritonPythonModel, protocol, utils
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypeAlias


class InferenceParams(BaseModel):
    maxLongSide: Optional[Annotated[int, Field(gt=0)]] = None


class InferInput(BaseModel):
    image: str
    inferenceParams: Optional[InferenceParams] = None


Point: TypeAlias = Annotated[List[int], Field(min_length=2, max_length=2)]
Polygon: TypeAlias = Annotated[List[Point], Field(min_length=3)]


class Text(BaseModel):
    poly: Polygon
    text: str
    score: float


class InferResult(BaseModel):
    texts: List[Text]
    layoutImage: str
    ocrImage: str


class TritonPythonModel(BaseTritonPythonModel):
    def _get_input_model_type(self):
        return InferInput

    def _get_result_model_type(self):
        return InferResult

    def _execute(self, input, log_id):
        if input.inferenceParams:
            max_long_side = input.inferenceParams.maxLongSide
            if max_long_side:
                return protocol.create_output_without_result(
                    422,
                    "`max_long_side` is currently not supported.",
                    log_id=log_id,
                )

        file_bytes = utils.get_raw_bytes(input.image)
        image = utils.image_bytes_to_array(file_bytes)

        result = list(self._pipeline.predict(image))[0]

        texts: List[Text] = []
        for poly, text, score in zip(
            result["ocr_result"]["dt_polys"],
            result["ocr_result"]["rec_text"],
            result["ocr_result"]["rec_score"],
        ):
            texts.append(Text(poly=poly, text=text, score=score))
        layout_image_base64 = utils.base64_encode(
            utils.image_to_bytes(result["layout_result"].img)
        )
        ocr_image_base64 = utils.base64_encode(
            utils.image_to_bytes(result["ocr_result"].img)
        )

        return InferResult(
            texts=texts, layoutImage=layout_image_base64, ocrImage=ocr_image_base64
        )
