import os.path
from typing import Final, List, Literal, Optional, Tuple

import numpy as np
from numpy.typing import ArrayLike
from paddlex_hps_server import BaseTritonPythonModel, logging, protocol, utils
from paddlex_hps_server.storage import Storage, SupportsGetURL, create_storage
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypeAlias

_DEFAULT_MAX_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
_DEFAULT_MAX_NUM_IMGS: Final[int] = 10


FileType: TypeAlias = Literal[0, 1]


class InferenceParams(BaseModel):
    maxLongSide: Optional[Annotated[int, Field(gt=0)]] = None


class InferInput(BaseModel):
    file: str
    fileType: Optional[FileType] = None
    useImgOrientationCls: bool = True
    useImgUnwarping: bool = True
    useSealTextDet: bool = True
    inferenceParams: Optional[InferenceParams] = None


BoundingBox: TypeAlias = Annotated[List[float], Field(min_length=4, max_length=4)]


class LayoutElement(BaseModel):
    bbox: BoundingBox
    label: str
    text: str
    layoutType: Literal["single", "double"]
    image: Optional[str] = None


class LayoutParsingResult(BaseModel):
    layoutElements: List[LayoutElement]


class InferResult(BaseModel):
    layoutParsingResults: List[LayoutParsingResult]


def _postprocess_image(
    img: ArrayLike,
    log_id: str,
    filename: str,
    file_storage: Optional[Storage],
) -> str:
    key = f"{log_id}/{filename}"
    ext = os.path.splitext(filename)[1]
    img = np.asarray(img)
    img_bytes = utils.image_array_to_bytes(img, ext=ext)
    if file_storage is not None:
        file_storage.set(key, img_bytes)
        if isinstance(file_storage, SupportsGetURL):
            return file_storage.get_url(key)
    return utils.base64_encode(img_bytes)


class TritonPythonModel(BaseTritonPythonModel):
    def initialize(self, args):
        super().initialize(args)
        self._context = {}
        self._context["file_storage"] = None
        self._context["max_img_size"] = _DEFAULT_MAX_IMG_SIZE
        self._context["max_num_imgs"] = _DEFAULT_MAX_NUM_IMGS
        if self._app_config.extra:
            if "file_storage" in self._app_config.extra:
                self._context["file_storage"] = create_storage(
                    self._app_config.extra["file_storage"]
                )
            if "max_img_size" in self._app_config.extra:
                self._context["max_img_size"] = self._app_config.extra["max_img_size"]
            if "max_num_imgs" in self._app_config.extra:
                self._context["max_num_imgs"] = self._app_config.extra["max_num_imgs"]

    def _get_input_model_type(self):
        return InferInput

    def _get_result_model_type(self):
        return InferResult

    def _execute(self, input, log_id):
        if input.fileType is None:
            if utils.is_url(input.file):
                try:
                    file_type = utils.infer_file_type(input.file)
                except Exception:
                    logging.exception("Failed to infer the file type")
                    return protocol.create_output_without_result(
                        422,
                        "The file type cannot be inferred from the URL. Please specify the file type explicitly.",
                        log_id=log_id,
                    )
            else:
                return protocol.create_output_without_result(
                    422,
                    "Unknown file type",
                    log_id=log_id,
                )
        else:
            file_type = "PDF" if input.fileType == 0 else "IMAGE"

        if input.inferenceParams:
            max_long_side = input.inferenceParams.maxLongSide
            if max_long_side:
                return protocol.create_output_without_result(
                    422,
                    "`max_long_side` is currently not supported.",
                    log_id=log_id,
                )

        file_bytes = utils.get_raw_bytes(input.file)
        images = utils.file_to_images(
            file_bytes,
            file_type,
            max_img_size=self._context["max_img_size"],
            max_num_imgs=self._context["max_num_imgs"],
        )

        result = self._pipeline.predict(
            images,
            use_doc_image_ori_cls_model=input.useImgOrientationCls,
            use_doc_image_unwarp_model=input.useImgUnwarping,
            use_seal_text_det_model=input.useSealTextDet,
        )

        layout_parsing_results: List[LayoutParsingResult] = []
        for i, item in enumerate(result):
            layout_elements: List[LayoutElement] = []
            for j, subitem in enumerate(
                item["layout_parsing_result"]["parsing_result"]
            ):
                dyn_keys = subitem.keys() - {"input_path", "layout_bbox", "layout"}
                if len(dyn_keys) != 1:
                    raise RuntimeError(f"Unexpected result: {subitem}")
                label = next(iter(dyn_keys))
                if label in ("image", "figure", "img", "fig"):
                    image_ = _postprocess_image(
                        subitem[label]["img"],
                        log_id=log_id,
                        filename=f"image_{i}_{j}.jpg",
                        file_storage=self._context["file_storage"],
                    )
                    text = subitem[label]["image_text"]
                else:
                    image_ = None
                    text = subitem[label]
                layout_elements.append(
                    LayoutElement(
                        bbox=subitem["layout_bbox"],
                        label=label,
                        text=text,
                        layoutType=subitem["layout"],
                        image=image_,
                    )
                )
            layout_parsing_results.append(
                LayoutParsingResult(layoutElements=layout_elements)
            )

        return InferResult(
            layoutParsingResults=layout_parsing_results,
        )
