demo.onnxΒΆ

#!/usr/bin/env python3
import tensorflow as tf
import numpy as np
from acuitylib.interface.importer import OnnxLoader
from acuitylib.interface.exporter import TimVxExporter
from acuitylib.interface.exporter import TFLiteExporter
from acuitylib.interface.quantization import Quantization, QuantizerType
from acuitylib.interface.inference import Inference


# prepare
onnx_model = "model/squeezenet.onnx"
image0 = "data/813.jpg"


# data generator
def get_data():
    for image in [image0]:
        arr = tf.io.decode_image(tf.io.read_file(image)).numpy()
        arr = np.transpose(arr, [2, 0, 1])
        arr = arr.reshape(1, 3, 224, 224)
        inputs = {'data_0': np.array(arr, dtype=np.float32)}
        yield inputs


def test_onnx_squeezenet():
    # import model
    model = OnnxLoader(model=onnx_model).load(inputs=['data_0'], input_size_list=[[1, 3, 224, 224]],
                                              outputs=['softmaxout_1'])
    # inference
    infer = Inference(model)
    infer.build_session()  # build inference session
    for i, data in enumerate(get_data()):
        ins, outs = infer.run_session(data)  # run inference session
        assert np.argmax(outs[0].flatten()) == 812

    TimVxExporter(model).export('export_timvx/float16/lenet')

    # quantize
    quantizer = QuantizerType.ASYMMETRIC_AFFINE
    qtype = 'uint8'
    Quantization(model).quantize(input_generator_func=get_data, quantizer=quantizer, qtype=qtype, iterations=1)

    # inference with quantized model
    infer = Inference(model)
    infer.build_session()  # build inference session
    for i, data in enumerate(get_data()):
        ins, outs = infer.run_session(data)  # run inference session
        assert np.argmax(outs[0].flatten()) == 812

    # export tim-vx case
    TimVxExporter(model).export('export_timvx/asymu8/squeezenet')
    TFLiteExporter(model).export('export_tflite/asymu8/squeezenet.tflite')


if __name__ == '__main__':
    test_onnx_squeezenet()