import uuid
from typing import Dict, List

from paddlex_hps_server import utils
from pydantic import BaseModel

from common.base_model import BaseFaceRecognitionModel
from common.index_data_utils import serialize_index_data
from common.schema import ImageLabelPair


class BuildIndexInput(BaseModel):
    imageLabelPairs: List[ImageLabelPair]


class BuildIndexResult(BaseModel):
    indexKey: str
    idMap: Dict[int, str]


def _generate_index_key():
    return str(uuid.uuid4())


class TritonPythonModel(BaseFaceRecognitionModel):
    def _get_input_model_type(self):
        return BuildIndexInput

    def _get_result_model_type(self):
        return BuildIndexResult

    def _execute(self, input, log_id):
        images = [pair.image for pair in input.imageLabelPairs]
        file_bytes_list = [utils.get_raw_bytes(img) for img in images]
        images = [utils.image_bytes_to_array(item) for item in file_bytes_list]
        labels = [pair.label for pair in input.imageLabelPairs]

        index_data = self._pipeline.build_index(
            images,
            labels,
            index_type="Flat",
            metric_type="IP",
        )

        index_storage = self._context["index_storage"]
        index_key = _generate_index_key()
        index_data_bytes = serialize_index_data(index_data)
        index_storage.set(index_key, index_data_bytes)

        return BuildIndexResult(indexKey=index_key, idMap=index_data.id_map)
