import argparse import pathlib import numpy as np import onnxruntime as ort # See https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset # ['sepal length in cm', 'sepal width in cm', 'petal length in cm', 'petal width in cm'] DESCRIPTION = """ Makes Iris predictions using Onnx classier. Inputs: 0: 'sepal length in cm' 1: 'sepal width in cm' 2: 'petal length in cm' 3: 'petal width in cm' Pedicts: Class - Iris-Setosa - Iris-Versicolour - Iris-Virginica """ EXAMPLE_DATA = np.array( [[5.4, 3.9, 1.7, 0.4], [6.1, 2.6, 5.6, 1.4], [5.2, 2.7, 3.9, 1.4]] ) def main(args): inference_session = ort.InferenceSession(str(args.model_path)) input_name = inference_session.get_inputs()[0].name label_name = inference_session.get_outputs()[0].name prediction = inference_session.run( [label_name], {input_name: EXAMPLE_DATA.astype(np.float32)} )[0] print(prediction) if __name__ == "__main__": parser = argparse.ArgumentParser( description=DESCRIPTION ) parser.add_argument( "model_path", type=pathlib.Path, help="Path to onnx classifier model." ) args = parser.parse_args() main(args)