from typing import List, Optional

from paddlex.inference import results
from paddlex_hps_server import BaseTritonPythonModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from common.llm_utils import LLMName, LLMParams, llm_params_to_dict


class ChatInput(BaseModel):
    keys: List[str]
    visionInfo: dict
    vectorStore: Optional[str] = None
    retrievalResult: Optional[str] = None
    taskDescription: Optional[str] = None
    rules: Optional[str] = None
    fewShot: Optional[str] = None
    llmName: Optional[LLMName] = None
    llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None
    returnPrompts: bool = False


class Prompts(BaseModel):
    ocr: str
    table: Optional[str] = None
    html: Optional[str] = None


class ChatResult(BaseModel):
    chatResult: dict
    prompts: Optional[Prompts] = None


class TritonPythonModel(BaseTritonPythonModel):
    def _get_input_model_type(self):
        return ChatInput

    def _get_result_model_type(self):
        return ChatResult

    def _execute(self, input, log_id):
        kwargs = {
            "key_list": input.keys,
            "visual_info": results.VisualInfoResult(input.visionInfo),
        }
        if input.vectorStore is not None:
            kwargs["vector"] = results.VectorResult({"vector": input.vectorStore})
        if input.retrievalResult is not None:
            kwargs["retrieval_result"] = results.RetrievalResult(
                {"retrieval": input.retrievalResult}
            )
        if input.taskDescription is not None:
            kwargs["user_task_description"] = input.taskDescription
        if input.rules is not None:
            kwargs["rules"] = input.rules
        if input.fewShot is not None:
            kwargs["few_shot"] = input.fewShot
        if input.llmName is not None:
            kwargs["llm_name"] = input.llmName
        if input.llmParams is not None:
            kwargs["llm_params"] = llm_params_to_dict(input.llmParams)
        kwargs["save_prompt"] = input.returnPrompts

        result = self._pipeline.chat(**kwargs)

        if result["prompt"]:
            prompts = Prompts(
                ocr=result["prompt"]["ocr_prompt"],
                table=result["prompt"]["table_prompt"] or None,
                html=result["prompt"]["html_prompt"] or None,
            )
        else:
            prompts = None
        chat_result = ChatResult(
            chatResult=result["chat_res"],
            prompts=prompts,
        )

        return chat_result

    def _no_predictors(self):
        return True
