from typing import Dict, List

from paddlex_hps_server import utils
from pydantic import BaseModel

from common.base_model import BaseFaceRecognitionModel
from common.index_data_utils import deserialize_index_data, serialize_index_data
from common.schema import ImageLabelPair


class AddImagesToIndexInput(BaseModel):
    imageLabelPairs: List[ImageLabelPair]
    indexKey: str


class AddImagesToIndexResult(BaseModel):
    idMap: Dict[int, str]


class TritonPythonModel(BaseFaceRecognitionModel):
    def _get_input_model_type(self):
        return AddImagesToIndexInput

    def _get_result_model_type(self):
        return AddImagesToIndexResult

    def _execute(self, input, log_id):
        images = [pair.image for pair in input.imageLabelPairs]
        file_bytes_list = [utils.get_raw_bytes(img) for img in images]
        images = [utils.image_bytes_to_array(item) for item in file_bytes_list]
        labels = [pair.label for pair in input.imageLabelPairs]
        index_storage = self._context["index_storage"]
        index_data_bytes = index_storage.get(input.indexKey)
        index_data = deserialize_index_data(index_data_bytes)

        index_data = self._pipeline.append_index(images, labels, index_data)

        index_data_bytes = serialize_index_data(index_data)
        index_storage.set(input.indexKey, index_data_bytes)

        return AddImagesToIndexResult(idMap=index_data.id_map)
