from typing import List, Optional

from paddlex.inference import results
from paddlex_hps_server import BaseTritonPythonModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from common.llm_utils import LLMName, LLMParams, llm_params_to_dict


class RetrieveKnowledgeInput(BaseModel):
    keys: List[str]
    vectorStore: str
    llmName: Optional[LLMName] = None
    llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None


class RetrieveKnowledgeResult(BaseModel):
    retrievalResult: str


class TritonPythonModel(BaseTritonPythonModel):
    def _get_input_model_type(self):
        return RetrieveKnowledgeInput

    def _get_result_model_type(self):
        return RetrieveKnowledgeResult

    def _execute(self, input, log_id):
        kwargs = {
            "key_list": input.keys,
            "vector": results.VectorResult({"vector": input.vectorStore}),
        }
        if input.llmName is not None:
            kwargs["llm_name"] = input.llmName
        if input.llmParams is not None:
            kwargs["llm_params"] = llm_params_to_dict(input.llmParams)

        result = self._pipeline.retrieval(**kwargs)

        return RetrieveKnowledgeResult(retrievalResult=result["retrieval"])

    def _no_predictors(self):
        return True
