You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

53 lines
1.2 KiB
Python

import argparse
import pathlib
import numpy as np
import onnxruntime as ort
# See https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset
# ['sepal length in cm', 'sepal width in cm', 'petal length in cm', 'petal width in cm']
DESCRIPTION = """
Makes Iris predictions using Onnx classier.
Inputs:
0: 'sepal length in cm'
1: 'sepal width in cm'
2: 'petal length in cm'
3: 'petal width in cm'
Pedicts:
Class
- Iris-Setosa
- Iris-Versicolour
- Iris-Virginica
"""
EXAMPLE_DATA = np.array(
[[5.4, 3.9, 1.7, 0.4], [6.1, 2.6, 5.6, 1.4], [5.2, 2.7, 3.9, 1.4]]
)
def main(args):
inference_session = ort.InferenceSession(str(args.model_path))
input_name = inference_session.get_inputs()[0].name
label_name = inference_session.get_outputs()[0].name
prediction = inference_session.run(
[label_name], {input_name: EXAMPLE_DATA.astype(np.float32)}
)[0]
print(prediction)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=DESCRIPTION
)
parser.add_argument(
"model_path", type=pathlib.Path, help="Path to onnx classifier model."
)
args = parser.parse_args()
main(args)