Adding basic onnx inference on sklearn example
parent
94c218ff1d
commit
904eebae8c
@ -0,0 +1,52 @@
|
|||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
# See https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset
|
||||||
|
|
||||||
|
# ['sepal length in cm', 'sepal width in cm', 'petal length in cm', 'petal width in cm']
|
||||||
|
|
||||||
|
DESCRIPTION = """
|
||||||
|
Makes Iris predictions using Onnx classier.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
|
||||||
|
0: 'sepal length in cm'
|
||||||
|
1: 'sepal width in cm'
|
||||||
|
2: 'petal length in cm'
|
||||||
|
3: 'petal width in cm'
|
||||||
|
|
||||||
|
Pedicts:
|
||||||
|
|
||||||
|
Class
|
||||||
|
- Iris-Setosa
|
||||||
|
- Iris-Versicolour
|
||||||
|
- Iris-Virginica
|
||||||
|
"""
|
||||||
|
|
||||||
|
EXAMPLE_DATA = np.array(
|
||||||
|
[[5.4, 3.9, 1.7, 0.4], [6.1, 2.6, 5.6, 1.4], [5.2, 2.7, 3.9, 1.4]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
inference_session = ort.InferenceSession(str(args.model_path))
|
||||||
|
input_name = inference_session.get_inputs()[0].name
|
||||||
|
label_name = inference_session.get_outputs()[0].name
|
||||||
|
prediction = inference_session.run(
|
||||||
|
[label_name], {input_name: EXAMPLE_DATA.astype(np.float32)}
|
||||||
|
)[0]
|
||||||
|
print(prediction)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=DESCRIPTION
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"model_path", type=pathlib.Path, help="Path to onnx classifier model."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
Loading…
Reference in New Issue