Adding basic onnx inference on sklearn example
parent
94c218ff1d
commit
904eebae8c
@ -0,0 +1,52 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
# See https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset
|
||||
|
||||
# ['sepal length in cm', 'sepal width in cm', 'petal length in cm', 'petal width in cm']
|
||||
|
||||
DESCRIPTION = """
|
||||
Makes Iris predictions using Onnx classier.
|
||||
|
||||
Inputs:
|
||||
|
||||
0: 'sepal length in cm'
|
||||
1: 'sepal width in cm'
|
||||
2: 'petal length in cm'
|
||||
3: 'petal width in cm'
|
||||
|
||||
Pedicts:
|
||||
|
||||
Class
|
||||
- Iris-Setosa
|
||||
- Iris-Versicolour
|
||||
- Iris-Virginica
|
||||
"""
|
||||
|
||||
EXAMPLE_DATA = np.array(
|
||||
[[5.4, 3.9, 1.7, 0.4], [6.1, 2.6, 5.6, 1.4], [5.2, 2.7, 3.9, 1.4]]
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
inference_session = ort.InferenceSession(str(args.model_path))
|
||||
input_name = inference_session.get_inputs()[0].name
|
||||
label_name = inference_session.get_outputs()[0].name
|
||||
prediction = inference_session.run(
|
||||
[label_name], {input_name: EXAMPLE_DATA.astype(np.float32)}
|
||||
)[0]
|
||||
print(prediction)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=DESCRIPTION
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_path", type=pathlib.Path, help="Path to onnx classifier model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Loading…
Reference in New Issue