Adding basic onnx inference on sklearn example

master
Drew Bednar 3 years ago
parent 94c218ff1d
commit 904eebae8c

@ -1,6 +1,6 @@
# sklearn example # sklearn example
This is a simple demo showing the training of Randmon Forest Classifier, serializing it as a pickle file, converting it to an Onnx model, then making inferences using the Onnx model. This is a simple demo showing the training of Random Forest Classifier, serializing it as a pickle file, converting it to an Onnx model, then making inferences using the Onnx model.
## Usage ## Usage
@ -13,7 +13,7 @@ python train.py
Convert to Onnx Convert to Onnx
``` ```
python convert_to_onnx.py ./output/model.pkl
``` ```
Make inferences Make inferences

@ -0,0 +1,52 @@
import argparse
import pathlib
import numpy as np
import onnxruntime as ort
# See https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset
# ['sepal length in cm', 'sepal width in cm', 'petal length in cm', 'petal width in cm']
DESCRIPTION = """
Makes Iris predictions using Onnx classier.
Inputs:
0: 'sepal length in cm'
1: 'sepal width in cm'
2: 'petal length in cm'
3: 'petal width in cm'
Pedicts:
Class
- Iris-Setosa
- Iris-Versicolour
- Iris-Virginica
"""
EXAMPLE_DATA = np.array(
[[5.4, 3.9, 1.7, 0.4], [6.1, 2.6, 5.6, 1.4], [5.2, 2.7, 3.9, 1.4]]
)
def main(args):
inference_session = ort.InferenceSession(str(args.model_path))
input_name = inference_session.get_inputs()[0].name
label_name = inference_session.get_outputs()[0].name
prediction = inference_session.run(
[label_name], {input_name: EXAMPLE_DATA.astype(np.float32)}
)[0]
print(prediction)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=DESCRIPTION
)
parser.add_argument(
"model_path", type=pathlib.Path, help="Path to onnx classifier model."
)
args = parser.parse_args()
main(args)
Loading…
Cancel
Save