import os, sys
import numpy as np
import onnx
import ktc
import torch
import torch.onnx
import torch.nn as nn
import glob

onnx_path = '/docker_mount/rknn/output/trained_model/epoch19_4d3.onnx'
output_root = '/docker_mount/kneron/output'
edit_optimized_onnx_output_path = os.path.join(output_root, 'epoch19_4d3_optimized_edited.onnx')
# inference_output_path = os.path.join(output_root, 'npy')
bie_output_path = os.path.join(output_root, 'output.bie')
nef_output_path = os.path.join(output_root, 'output.nef')

inference_data_root = '/docker_mount/rknn/test_data/train_npy_ori'
inference_data_path = glob.glob(os.path.join(inference_data_root, '*npy'))


stage = 2

if stage <= 2 :
    print('----- stage 2 cut off unsupport op')
    model = onnx.load(onnx_path)
    model = ktc.onnx_optimizer.torch_exported_onnx_flow(model)
    node_names = 'Transpose_30'
    # node_names = 'LogSoftmax_33'
    edit_m = ktc.onnx_optimizer.cut_graph_from_nodes(model, node_names)

    optimized_m = ktc.onnx_optimizer.onnx2onnx_flow(edit_m, eliminate_tail=True, opt_matmul=False)
    onnx.save(optimized_m, edit_optimized_onnx_output_path)
    print('==> generate cut off unsupport op optimized onnx at ', edit_optimized_onnx_output_path)

if stage <= 3 :
    print('----- stage 3 IP evaluate')
    km = ktc.ModelConfig(32769, "0001", "720", onnx_path=edit_optimized_onnx_output_path)
    eval_result = km.evaluate()
    print('==> evaluate result at /data1/compiler/ProfileResult.txt(530,720) or ip_eval_prof.txt(520)')

if stage <= 4 :
    print('----- stage 4 E2E simulater(Floating Point)')
    inference_data = np.load(inference_data_path[0])
    # print(inference_data.shape)
    inference_data = np.expand_dims(inference_data, axis=2)
    # inference_data = np.expand_dims(inference_data, axis=0)
    inference_data = np.transpose(inference_data, (1, 2, 0))
    inference_data = inference_data.astype('float16')
    print(inference_data.shape)
    inf_results = ktc.kneron_inference([inference_data], onnx_file=edit_optimized_onnx_output_path, input_names=['input'], platform=720)
    output_dir = os.path.join(output_root, 'E2E_floating.npy')
    np.save(output_dir, inf_results[0])
    print('==> save E2E simulater(fp) result at ', output_dir)

if stage <= 5 :
    print('----- stage 5 Quantization(BIE Workflow)')
    
    input_data = []
    for path in inference_data_path:
        inference_data_temp = np.load(path)
        inference_data_temp = np.expand_dims(inference_data_temp, axis=2)
        inference_data_temp = np.transpose(inference_data_temp, (1, 2, 0))
        input_data.append(inference_data_temp)

    # create the input mapping
    input_mapping = {"input": input_data}

    # Quantization
    bie_path = km.analysis(input_mapping, output_bie=bie_output_path, threads = 4)
    print('==> generate bie data at ', bie_output_path)

    fixed_results = ktc.kneron_inference([inference_data], bie_file=bie_path, input_names=["input"], platform=720)
    output_dir = os.path.join(output_root, 'E2E_fixed.npy')
    np.save(output_dir, fixed_results[0])
    print('==> save E2E simulater(fp) result at ', output_dir)


if stage <= 6 :
    print('----- stage 6 Batch Compile (NEF Workflow)')
    # km = ktc.ModelConfig(32769, "0001", "720", bie_path=bie_output_path)
    compile_result = ktc.compile([km], output_dir=nef_output_path)
    print('==> generate nef data at ', nef_output_path)

    hw_results = ktc.kneron_inference([inference_data], nef_file=compile_result, platform=720)
    output_dir = os.path.join(output_root, 'E2E_hardware.npy')
    np.save(output_dir, hw_results[0])
    print('==> save E2E simulater(fp) result at ', output_dir)

    # print('----- NEF Combine')
    # ktc.combine_nef(nef_path_list, output_path = )
    # print('==> generate combined nef data at /data1/combined')

print('=====> kneron model convert done ! ')

