From 904eebae8c13824838d1c35125b541e828986f63 Mon Sep 17 00:00:00 2001 From: androiddrew Date: Thu, 24 Mar 2022 14:16:35 -0400 Subject: [PATCH] Adding basic onnx inference on sklearn example --- sklearn_ex/README.md | 4 ++-- sklearn_ex/onnx_infer.py | 52 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/sklearn_ex/README.md b/sklearn_ex/README.md index 6b8d0e5..24b737a 100644 --- a/sklearn_ex/README.md +++ b/sklearn_ex/README.md @@ -1,6 +1,6 @@ # sklearn example -This is a simple demo showing the training of Randmon Forest Classifier, serializing it as a pickle file, converting it to an Onnx model, then making inferences using the Onnx model. +This is a simple demo showing the training of Random Forest Classifier, serializing it as a pickle file, converting it to an Onnx model, then making inferences using the Onnx model. ## Usage @@ -13,7 +13,7 @@ python train.py Convert to Onnx ``` - +python convert_to_onnx.py ./output/model.pkl ``` Make inferences diff --git a/sklearn_ex/onnx_infer.py b/sklearn_ex/onnx_infer.py index e69de29..1dadb6a 100644 --- a/sklearn_ex/onnx_infer.py +++ b/sklearn_ex/onnx_infer.py @@ -0,0 +1,52 @@ +import argparse +import pathlib + +import numpy as np +import onnxruntime as ort + +# See https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset + +# ['sepal length in cm', 'sepal width in cm', 'petal length in cm', 'petal width in cm'] + +DESCRIPTION = """ +Makes Iris predictions using Onnx classier. + +Inputs: + +0: 'sepal length in cm' +1: 'sepal width in cm' +2: 'petal length in cm' +3: 'petal width in cm' + +Pedicts: + +Class + - Iris-Setosa + - Iris-Versicolour + - Iris-Virginica +""" + +EXAMPLE_DATA = np.array( + [[5.4, 3.9, 1.7, 0.4], [6.1, 2.6, 5.6, 1.4], [5.2, 2.7, 3.9, 1.4]] +) + + +def main(args): + inference_session = ort.InferenceSession(str(args.model_path)) + input_name = inference_session.get_inputs()[0].name + label_name = inference_session.get_outputs()[0].name + prediction = inference_session.run( + [label_name], {input_name: EXAMPLE_DATA.astype(np.float32)} + )[0] + print(prediction) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=DESCRIPTION + ) + parser.add_argument( + "model_path", type=pathlib.Path, help="Path to onnx classifier model." + ) + args = parser.parse_args() + main(args)