from typing import Optional

from paddlex.inference import results
from paddlex_hps_server import BaseTritonPythonModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from common.llm_utils import LLMName, LLMParams, llm_params_to_dict


class BuildVectorStoreInput(BaseModel):
    visionInfo: dict
    minChars: Optional[int] = None
    llmRequestInterval: Optional[float] = None
    llmName: Optional[LLMName] = None
    llmParams: Optional[Annotated[LLMParams, Field(discriminator="apiType")]] = None


class BuildVectorStoreResult(BaseModel):
    vectorStore: str


class TritonPythonModel(BaseTritonPythonModel):
    def _get_input_model_type(self):
        return BuildVectorStoreInput

    def _get_result_model_type(self):
        return BuildVectorStoreResult

    def _execute(self, input, log_id):
        kwargs = {"visual_info": results.VisualInfoResult(input.visionInfo)}
        if input.minChars is not None:
            kwargs["min_characters"] = input.minChars
        if input.llmRequestInterval is not None:
            kwargs["llm_request_interval"] = input.llmRequestInterval
        if input.llmName is not None:
            kwargs["llm_name"] = input.llmName
        if input.llmParams is not None:
            kwargs["llm_params"] = llm_params_to_dict(input.llmParams)

        result = self._pipeline.build_vector(**kwargs)

        return BuildVectorStoreResult(vectorStore=result["vector"])

    def _no_predictors(self):
        return True
