You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
3 years ago
|
import argparse
|
||
|
import pathlib
|
||
|
|
||
|
import joblib as jl
|
||
|
|
||
|
from skl2onnx import convert_sklearn
|
||
|
from skl2onnx.common.data_types import FloatTensorType
|
||
|
|
||
|
|
||
|
def convert_model_to_onnx(path):
|
||
|
"""Converts the example sklearn classifier to Onnx"""
|
||
|
clr_model = jl.load(path)
|
||
|
# Define Initial Types
|
||
|
initial_type = [("float_input", FloatTensorType([None, 4]))]
|
||
|
onnx_model = convert_sklearn(clr_model, initial_types=initial_type)
|
||
|
return onnx_model
|
||
|
|
||
|
|
||
|
def write_onnx_file(model, output_path):
|
||
|
"""Writes an Onnx model to a desired file location"""
|
||
|
with open(output_path, mode="wb") as f:
|
||
|
f.write(model.SerializeToString())
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
print(f"Converting {args.model_path} to Onnx")
|
||
|
model = convert_model_to_onnx(args.model_path)
|
||
|
print(f"Writing saving Onnx file to {args.output_path}")
|
||
|
write_onnx_file(model, args.output_path)
|
||
|
print("Done")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Converts a pickled classifier to onnx format"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"model_path", type=pathlib.Path, help="Path to classifier pickle file"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-o" "--out",
|
||
|
type=pathlib.Path,
|
||
|
dest="output_path",
|
||
|
default=pathlib.Path("./output/clr_model.onnx"),
|
||
|
help="Path to output Onnx file",
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
main(args)
|