Adding basic onnx inference on sklearn example
							parent
							
								
									94c218ff1d
								
							
						
					
					
						commit
						904eebae8c
					
				@ -0,0 +1,52 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import pathlib
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import onnxruntime as ort
 | 
			
		||||
 | 
			
		||||
# See https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset
 | 
			
		||||
 | 
			
		||||
# ['sepal length in cm', 'sepal width in cm', 'petal length in cm', 'petal width in cm']
 | 
			
		||||
 | 
			
		||||
DESCRIPTION = """
 | 
			
		||||
Makes Iris predictions using Onnx classier.
 | 
			
		||||
 | 
			
		||||
Inputs:
 | 
			
		||||
 | 
			
		||||
0: 'sepal length in cm'
 | 
			
		||||
1: 'sepal width in cm'
 | 
			
		||||
2: 'petal length in cm'
 | 
			
		||||
3: 'petal width in cm'
 | 
			
		||||
 | 
			
		||||
Pedicts:
 | 
			
		||||
 | 
			
		||||
Class
 | 
			
		||||
    - Iris-Setosa
 | 
			
		||||
    - Iris-Versicolour
 | 
			
		||||
    - Iris-Virginica
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
EXAMPLE_DATA = np.array(
 | 
			
		||||
    [[5.4, 3.9, 1.7, 0.4], [6.1, 2.6, 5.6, 1.4], [5.2, 2.7, 3.9, 1.4]]
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args):
 | 
			
		||||
    inference_session = ort.InferenceSession(str(args.model_path))
 | 
			
		||||
    input_name = inference_session.get_inputs()[0].name
 | 
			
		||||
    label_name = inference_session.get_outputs()[0].name
 | 
			
		||||
    prediction = inference_session.run(
 | 
			
		||||
        [label_name], {input_name: EXAMPLE_DATA.astype(np.float32)}
 | 
			
		||||
    )[0]
 | 
			
		||||
    print(prediction)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description=DESCRIPTION
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "model_path", type=pathlib.Path, help="Path to onnx classifier model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    main(args)
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue