#!/usr/bin/env python

import argparse
import pprint
import sys

from paddlex_hps_client import triton_request, utils
from tritonclient import grpc as triton_grpc

LLM_NAME = "ernie-3.5"


def ensure_no_error(output, additional_msg):
    if output["errorCode"] != 0:
        print(additional_msg, file=sys.stderr)
        print(f"Error code: {output['errorCode']}", file=sys.stderr)
        print(f"Error message: {output['errorMsg']}", file=sys.stderr)
        sys.exit(1)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file", type=str, required=True)
    parser.add_argument("--keys", type=str, nargs="+", required=True)
    parser.add_argument("--file-type", type=int, choices=[0, 1])
    parser.add_argument("--use-img-orientation-cls", action="store_true")
    parser.add_argument("--use-img-unwarping", action="store_true")
    parser.add_argument("--use-seal-text-det", action="store_true")
    parser.add_argument("--llm-api-type", type=str, choices=["qianfan", "aistudio"])
    parser.add_argument("--qianfan-api-key", type=str)
    parser.add_argument("--qianfan-secret-key", type=str)
    parser.add_argument("--aistudio-access-token", type=str)
    parser.add_argument("--url", type=str, default="localhost:8001")
    args = parser.parse_args()

    llm_params = None
    if args.llm_api_type:
        if args.llm_api_type == "qianfan":
            if args.qianfan_api_key is None or args.qianfan_secret_key is None:
                print(
                    "The Qianfan API key and secret key must be provided.",
                    file=sys.stderr,
                )
                sys.exit(1)
            llm_params = {
                "apiType": "qianfan",
                "apiKey": args.qianfan_api_key,
                "secretKey": args.qianfan_secret_key,
            }
        else:
            if args.aistudio_access_token is None:
                print("The AI Studio access token must be provided.", file=sys.stderr)
                sys.exit(1)
            llm_params = {
                "apiType": "qianfan",
                "accessToken": args.aistudio_access_token,
            }

    client = triton_grpc.InferenceServerClient(args.url)

    input_ = {"file": utils.prepare_input_file(args.file)}
    if args.file_type is not None:
        input_["fileType"] = args.file_type
    if args.use_img_orientation_cls:
        input_["useImgOrientationCls"] = True
    if args.use_img_unwarping:
        input_["useImgUnwarping"] = True
    if args.use_seal_text_det:
        input_["useSealTextDet"] = True
    output = triton_request(client, "chatocr-vision", input_)
    ensure_no_error(output, "Failed to analyze the images")
    result_vision = output["result"]

    for i, res in enumerate(result_vision["visionResults"]):
        print("Texts:")
        pprint.pp(res["texts"])
        print("Tables:")
        pprint.pp(res["tables"])
        ocr_img_path = f"ocr_{i}.jpg"
        utils.save_output_file(res["ocrImage"], ocr_img_path)
        layout_img_path = f"layout_{i}.jpg"
        utils.save_output_file(res["layoutImage"], layout_img_path)
        print(f"Output images saved at {ocr_img_path} and {layout_img_path}")

    input_ = {
        "visionInfo": result_vision["visionInfo"],
        "llmName": LLM_NAME,
    }
    if llm_params is not None:
        input_["llmParams"] = llm_params
    output = triton_request(client, "chatocr-vector", input_)
    ensure_no_error(output, "Failed to build a vector store")
    result_vector = output["result"]

    input_ = {
        "keys": args.keys,
        "vectorStore": result_vector["vectorStore"],
        "llmName": LLM_NAME,
    }
    if llm_params is not None:
        input_["llmParams"] = llm_params
    output = triton_request(client, "chatocr-retrieval", input_)
    ensure_no_error(output, "Failed to retrieve knowledge")
    result_retrieval = output["result"]

    input_ = {
        "keys": args.keys,
        "visionInfo": result_vision["visionInfo"],
        "vectorStore": result_vector["vectorStore"],
        "retrievalResult": result_retrieval["retrievalResult"],
        "llmName": LLM_NAME,
    }
    if llm_params is not None:
        input_["llmParams"] = llm_params
    output = triton_request(client, "chatocr-chat", input_)
    ensure_no_error(output, "Failed to chat with the LLM")
    result_chat = output["result"]
    print("Final result:")
    print(result_chat["chatResult"])


if __name__ == "__main__":
    main()
