import ktc
import os
from PIL import Image
import numpy as np
import kneronnxopt


## Test Various converters.

## Clone https://github.com/kneron/ConvertorExamples first:
## git clone https://github.com/kneron/ConvertorExamples.git
## cd ConvertorExamples && git lfs pull

## Adjust logging setting to avoid too many messages before start.
import logging
logging.basicConfig(level=logging.INFO)

## Keras to ONNX
# result_m = ktc.onnx_optimizer.keras2onnx_flow('/data1/ConvertorExamples/keras_example/onet-0.417197.hdf5')

## Pytorch to ONNX
# import torch
# import torch.onnx
import onnx
# from ultralytics import YOLO
# Load the pth saved model
# add_safe_globals([DetectionModel])
# pth_model = torch.load("/workspace/mymodel/anpr_v8.pt", map_location='cpu')
# # Export the model
# dummy_input = torch.randn(1, 3, 640, 640)
# torch.onnx.export(pth_model, dummy_input, '/data1/anpr_v8.onnx', opset_version=11)

# # load the yolo model via ultralytics
# model = YOLO("/workspace/mymodel/anpr_v8.pt")

# # export to onnx
# onnx_path = model.export(format="onnx", imgsz=640, dynamic=False, opset=11, simplify=False, half=False, device="cpu")

# # Load the exported onnx model as an onnx object
# print(f"onnx path: {onnx_path}")


modelList = ["yolo12n", "yolo11n", "yolov10n", "yolov9t", "yolov6n", "yolov5nu", "yolov5su"]
# modelList = ["replace_conv_80"]
modelList = ["best"]

for file_name in modelList:
    ONNX_PATH = f"/workspace/mymodel/{file_name}.onnx"
    ONNX_OPTIMIZED_PATH = f"/workspace/mymodel/{file_name}_optimized.onnx"
    # ONNX_OPTIMIZED_PATH = f"/workspace/mymodel/{file_name}.onnx"
    print("========================================")
    print(f"Processing model: {file_name}")

    try:
        exported_m = onnx.load(ONNX_PATH)
        # Optimize the exported onnx object
        result_m = ktc.onnx_optimizer.torch_exported_onnx_flow(exported_m)

        ## Caffe to ONNX
        # result_m = ktc.onnx_optimizer.caffe2onnx_flow('/data1/ConvertorExamples/caffe_example/mobilenetv2.prototxt', '/data1/ConvertorExamples/caffe_example/mobilenetv2.caffemodel')

        ## TF Lite to ONNX
        # result_m = ktc.onnx_optimizer.tflite2onnx_flow('/data1/ConvertorExamples/tflite_example/model_unquant.tflite')

        ## ONNX Optimization
        # optimized_m = ktc.onnx_optimizer.onnx2onnx_flow(result_m, eliminate_tail=True)
        optimized_m = kneronnxopt.optimize(result_m)
        onnx.save(optimized_m, ONNX_OPTIMIZED_PATH)

        ## Section 3
        # km = ktc.ModelConfig(32769, "8b28", "720", onnx_path="/workspace/examples/mobilenetv2/mobilenetv2_zeroq.origin.onnx")
        km = ktc.ModelConfig(32769, "8b28", "720", onnx_path=ONNX_OPTIMIZED_PATH)
        eval_result = km.evaluate()

        def preprocess(input_file):
            image = Image.open(input_file)
            image = image.convert("RGB")
            img_data = np.array(image.resize((640, 640), Image.BILINEAR)) / 255 # need to resize to 640x640 for anpr_v8
            # The input data should be [C, H, W]
            img_data = np.transpose(img_data, (2, 0, 1))
            img_data = np.expand_dims(img_data, 0)
            return img_data

        input_data = [preprocess("/workspace/mymodel/img/0.jpg")]
        inf_results = ktc.kneron_inference(input_data, onnx_file=ONNX_OPTIMIZED_PATH, input_names=["images"])
        if inf_results is not None:
            print('Section 3 E2E simulator finished.')
        else:
            print('Section 3 E2E simulator failed.')
            exit(1)

        ## Section 4
        # Preprocess images and create the input mapping
        raw_images = os.listdir("/workspace/mymodel/img")
        input_images = [preprocess("/workspace/mymodel/img/" + image_name) for image_name in raw_images]
        input_mapping = {"images": input_images}

        # Quantization
        # Use single thread to avoid multiprocessing manager socket permission issues in some sandboxes.
        bie_path = km.analysis(input_mapping, threads = 1)

        # E2E simulator (fixed point)
        fixed_results = ktc.kneron_inference(input_data, bie_file=bie_path, input_names=["images"], platform=720)
        if fixed_results is not None:
            print('Section 4 E2E simulator finished.')
        else:
            print('Section 4 E2E simulator failed.')
            exit(1)

        ## Section 5
        # Batch compile
        compile_result = ktc.compile([km], output_dir="/data1/anpr")

        # E2E simulator (hardware)
        hw_results = ktc.kneron_inference(input_data, nef_file=compile_result, platform=720, input_names=["images"])
        if hw_results is not None:
            print('Section 5 E2E simulator finished.')
        else:
            print('Section 5 E2E simulator failed.')
            exit(1)


        try:
            np.testing.assert_almost_equal(fixed_results, hw_results, 4)
            print('Section 4 and Section 5 results are the same')
        except Exception as mismatch:
            print("Section 4 and Section 5 results mismatch!")
            print(mismatch)
            exit(1)

    except Exception as e:
        print("Workflow failed!")
        print(e)
