You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
53 lines
1.2 KiB
Python
53 lines
1.2 KiB
Python
import argparse
|
|
import pathlib
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
# See https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset
|
|
|
|
# ['sepal length in cm', 'sepal width in cm', 'petal length in cm', 'petal width in cm']
|
|
|
|
DESCRIPTION = """
|
|
Makes Iris predictions using Onnx classier.
|
|
|
|
Inputs:
|
|
|
|
0: 'sepal length in cm'
|
|
1: 'sepal width in cm'
|
|
2: 'petal length in cm'
|
|
3: 'petal width in cm'
|
|
|
|
Pedicts:
|
|
|
|
Class
|
|
- Iris-Setosa
|
|
- Iris-Versicolour
|
|
- Iris-Virginica
|
|
"""
|
|
|
|
EXAMPLE_DATA = np.array(
|
|
[[5.4, 3.9, 1.7, 0.4], [6.1, 2.6, 5.6, 1.4], [5.2, 2.7, 3.9, 1.4]]
|
|
)
|
|
|
|
|
|
def main(args):
|
|
inference_session = ort.InferenceSession(str(args.model_path))
|
|
input_name = inference_session.get_inputs()[0].name
|
|
label_name = inference_session.get_outputs()[0].name
|
|
prediction = inference_session.run(
|
|
[label_name], {input_name: EXAMPLE_DATA.astype(np.float32)}
|
|
)[0]
|
|
print(prediction)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description=DESCRIPTION
|
|
)
|
|
parser.add_argument(
|
|
"model_path", type=pathlib.Path, help="Path to onnx classifier model."
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|