import argparse import pathlib import joblib as jl from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType def convert_model_to_onnx(path): """Converts the example sklearn classifier to Onnx""" clr_model = jl.load(path) # Define Initial Types initial_type = [("float_input", FloatTensorType([None, 4]))] onnx_model = convert_sklearn(clr_model, initial_types=initial_type) return onnx_model def write_onnx_file(model, output_path): """Writes an Onnx model to a desired file location""" with open(output_path, mode="wb") as f: f.write(model.SerializeToString()) def main(args): print(f"Converting {args.model_path} to Onnx") model = convert_model_to_onnx(args.model_path) print(f"Writing saving Onnx file to {args.output_path}") write_onnx_file(model, args.output_path) print("Done") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Converts a pickled classifier to onnx format" ) parser.add_argument( "model_path", type=pathlib.Path, help="Path to classifier pickle file" ) parser.add_argument( "-o" "--out", type=pathlib.Path, dest="output_path", default=pathlib.Path("./output/clr_model.onnx"), help="Path to output Onnx file", ) args = parser.parse_args() main(args)