import os.path
from typing import Final, List, Literal, Optional, Tuple

import numpy as np
from numpy.typing import ArrayLike
from paddlex_hps_server import BaseTritonPythonModel, logging, protocol, utils
from paddlex_hps_server.storage import Storage, SupportsGetURL, create_storage
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypeAlias

_DEFAULT_MAX_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
_DEFAULT_MAX_NUM_IMGS: Final[int] = 10


FileType: TypeAlias = Literal[0, 1]


class InferenceParams(BaseModel):
    maxLongSide: Optional[Annotated[int, Field(gt=0)]] = None


class AnalyzeImagesInput(BaseModel):
    file: str
    fileType: Optional[FileType] = None
    useImgOrientationCls: bool = True
    useImgUnwarping: bool = True
    useSealTextDet: bool = True
    inferenceParams: Optional[InferenceParams] = None


Point: TypeAlias = Annotated[List[int], Field(min_length=2, max_length=2)]
Polygon: TypeAlias = Annotated[List[Point], Field(min_length=3)]
BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]


class Text(BaseModel):
    poly: Polygon
    text: str
    score: float


class Table(BaseModel):
    bbox: BoundingBox
    html: str


class VisionResult(BaseModel):
    texts: List[Text]
    tables: List[Table]
    inputImage: str
    ocrImage: str
    layoutImage: str


class AnalyzeImagesResult(BaseModel):
    visionResults: List[VisionResult]
    visionInfo: dict


def _postprocess_image(
    img: ArrayLike,
    log_id: str,
    filename: str,
    file_storage: Optional[Storage],
) -> str:
    key = f"{log_id}/{filename}"
    ext = os.path.splitext(filename)[1]
    img = np.asarray(img)
    img_bytes = utils.image_array_to_bytes(img, ext=ext)
    if file_storage is not None:
        file_storage.set(key, img_bytes)
        if isinstance(file_storage, SupportsGetURL):
            return file_storage.get_url(key)
    return utils.base64_encode(img_bytes)


class TritonPythonModel(BaseTritonPythonModel):
    def initialize(self, args):
        super().initialize(args)
        self._context = {}
        self._context["file_storage"] = None
        self._context["max_img_size"] = _DEFAULT_MAX_IMG_SIZE
        self._context["max_num_imgs"] = _DEFAULT_MAX_NUM_IMGS
        if self._app_config.extra:
            if "file_storage" in self._app_config.extra:
                self._context["file_storage"] = create_storage(
                    self._app_config.extra["file_storage"]
                )
            if "max_img_size" in self._app_config.extra:
                self._context["max_img_size"] = self._app_config.extra["max_img_size"]
            if "max_num_imgs" in self._app_config.extra:
                self._context["max_num_imgs"] = self._app_config.extra["max_num_imgs"]

    def _get_input_model_type(self):
        return AnalyzeImagesInput

    def _get_result_model_type(self):
        return AnalyzeImagesResult

    def _execute(self, input, log_id):
        if input.fileType is None:
            if utils.is_url(input.file):
                try:
                    file_type = utils.infer_file_type(input.file)
                except Exception:
                    logging.exception("Failed to infer the file type")
                    return protocol.create_output_without_result(
                        422,
                        "The file type cannot be inferred from the URL. Please specify the file type explicitly.",
                        log_id=log_id,
                    )
            else:
                return protocol.create_output_without_result(
                    422,
                    "Unknown file type",
                    log_id=log_id,
                )
        else:
            file_type = "PDF" if input.fileType == 0 else "IMAGE"

        if input.inferenceParams:
            max_long_side = input.inferenceParams.maxLongSide
            if max_long_side:
                return protocol.create_output_without_result(
                    422,
                    "`max_long_side` is currently not supported.",
                    log_id=log_id,
                )

        file_bytes = utils.get_raw_bytes(input.file)
        images = utils.file_to_images(
            file_bytes,
            file_type,
            max_img_size=self._context["max_img_size"],
            max_num_imgs=self._context["max_num_imgs"],
        )

        result = self._pipeline.visual_predict(
            images,
            use_doc_image_ori_cls_model=input.useImgOrientationCls,
            use_doc_image_unwarp_model=input.useImgUnwarping,
            use_seal_text_det_model=input.useSealTextDet,
        )

        vision_results: List[VisionResult] = []
        for i, (img, item) in enumerate(zip(images, result[0])):
            input_img = _postprocess_image(
                img,
                log_id=log_id,
                filename=f"input_image_{i}.jpg",
                file_storage=self._context["file_storage"],
            )
            ocr_img = _postprocess_image(
                item["ocr_result"].img,
                log_id=log_id,
                filename=f"ocr_image_{i}.jpg",
                file_storage=self._context["file_storage"],
            )
            layout_img = _postprocess_image(
                item["layout_result"].img,
                log_id=log_id,
                filename=f"layout_image_{i}.jpg",
                file_storage=self._context["file_storage"],
            )
            texts: List[Text] = []
            for poly, text, score in zip(
                item["ocr_result"]["dt_polys"],
                item["ocr_result"]["rec_text"],
                item["ocr_result"]["rec_score"],
            ):
                texts.append(Text(poly=poly, text=text, score=score))
            tables = [
                Table(bbox=r["layout_bbox"], html=r["html"])
                for r in item["table_result"]
            ]
            vision_result = VisionResult(
                texts=texts,
                tables=tables,
                inputImage=input_img,
                ocrImage=ocr_img,
                layoutImage=layout_img,
            )
            vision_results.append(vision_result)

        return AnalyzeImagesResult(
            visionResults=vision_results,
            visionInfo=result[1],
        )
