You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

49 lines
1.4 KiB
Python

import argparse
import pathlib
import joblib as jl
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
def convert_model_to_onnx(path):
"""Converts the example sklearn classifier to Onnx"""
clr_model = jl.load(path)
# Define Initial Types
initial_type = [("float_input", FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(clr_model, initial_types=initial_type)
return onnx_model
def write_onnx_file(model, output_path):
"""Writes an Onnx model to a desired file location"""
with open(output_path, mode="wb") as f:
f.write(model.SerializeToString())
def main(args):
print(f"Converting {args.model_path} to Onnx")
model = convert_model_to_onnx(args.model_path)
print(f"Writing saving Onnx file to {args.output_path}")
write_onnx_file(model, args.output_path)
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Converts a pickled classifier to onnx format"
)
parser.add_argument(
"model_path", type=pathlib.Path, help="Path to classifier pickle file"
)
parser.add_argument(
"-o" "--out",
type=pathlib.Path,
dest="output_path",
default=pathlib.Path("./output/clr_model.onnx"),
help="Path to output Onnx file",
)
args = parser.parse_args()
main(args)