Saving sklearn example
							parent
							
								
									a5fc7b5646
								
							
						
					
					
						commit
						94c218ff1d
					
				@ -0,0 +1,3 @@
 | 
			
		||||
black
 | 
			
		||||
flake8
 | 
			
		||||
pip-tools
 | 
			
		||||
@ -0,0 +1,40 @@
 | 
			
		||||
#
 | 
			
		||||
# This file is autogenerated by pip-compile with python 3.8
 | 
			
		||||
# To update, run:
 | 
			
		||||
#
 | 
			
		||||
#    pip-compile dev-requirements.in
 | 
			
		||||
#
 | 
			
		||||
black==22.1.0
 | 
			
		||||
    # via -r dev-requirements.in
 | 
			
		||||
click==8.0.4
 | 
			
		||||
    # via
 | 
			
		||||
    #   black
 | 
			
		||||
    #   pip-tools
 | 
			
		||||
flake8==4.0.1
 | 
			
		||||
    # via -r dev-requirements.in
 | 
			
		||||
mccabe==0.6.1
 | 
			
		||||
    # via flake8
 | 
			
		||||
mypy-extensions==0.4.3
 | 
			
		||||
    # via black
 | 
			
		||||
pathspec==0.9.0
 | 
			
		||||
    # via black
 | 
			
		||||
pep517==0.12.0
 | 
			
		||||
    # via pip-tools
 | 
			
		||||
pip-tools==6.5.1
 | 
			
		||||
    # via -r dev-requirements.in
 | 
			
		||||
platformdirs==2.5.1
 | 
			
		||||
    # via black
 | 
			
		||||
pycodestyle==2.8.0
 | 
			
		||||
    # via flake8
 | 
			
		||||
pyflakes==2.4.0
 | 
			
		||||
    # via flake8
 | 
			
		||||
tomli==2.0.1
 | 
			
		||||
    # via black
 | 
			
		||||
typing-extensions==4.1.1
 | 
			
		||||
    # via black
 | 
			
		||||
wheel==0.37.1
 | 
			
		||||
    # via pip-tools
 | 
			
		||||
 | 
			
		||||
# The following packages are considered to be unsafe in a requirements file:
 | 
			
		||||
# pip
 | 
			
		||||
# setuptools
 | 
			
		||||
@ -0,0 +1,7 @@
 | 
			
		||||
joblib
 | 
			
		||||
numpy
 | 
			
		||||
onnx
 | 
			
		||||
onnxruntime
 | 
			
		||||
opencv-python
 | 
			
		||||
skl2onnx
 | 
			
		||||
sklearn
 | 
			
		||||
@ -0,0 +1,55 @@
 | 
			
		||||
#
 | 
			
		||||
# This file is autogenerated by pip-compile with python 3.8
 | 
			
		||||
# To update, run:
 | 
			
		||||
#
 | 
			
		||||
#    pip-compile requirements.in
 | 
			
		||||
#
 | 
			
		||||
flatbuffers==2.0
 | 
			
		||||
    # via onnxruntime
 | 
			
		||||
joblib==1.1.0
 | 
			
		||||
    # via
 | 
			
		||||
    #   -r requirements.in
 | 
			
		||||
    #   scikit-learn
 | 
			
		||||
numpy==1.22.3
 | 
			
		||||
    # via
 | 
			
		||||
    #   -r requirements.in
 | 
			
		||||
    #   onnx
 | 
			
		||||
    #   onnxconverter-common
 | 
			
		||||
    #   onnxruntime
 | 
			
		||||
    #   opencv-python
 | 
			
		||||
    #   scikit-learn
 | 
			
		||||
    #   scipy
 | 
			
		||||
    #   skl2onnx
 | 
			
		||||
onnx==1.11.0
 | 
			
		||||
    # via
 | 
			
		||||
    #   -r requirements.in
 | 
			
		||||
    #   onnxconverter-common
 | 
			
		||||
    #   skl2onnx
 | 
			
		||||
onnxconverter-common==1.9.0
 | 
			
		||||
    # via skl2onnx
 | 
			
		||||
onnxruntime==1.10.0
 | 
			
		||||
    # via -r requirements.in
 | 
			
		||||
opencv-python==4.5.5.64
 | 
			
		||||
    # via -r requirements.in
 | 
			
		||||
protobuf==3.19.4
 | 
			
		||||
    # via
 | 
			
		||||
    #   onnx
 | 
			
		||||
    #   onnxconverter-common
 | 
			
		||||
    #   onnxruntime
 | 
			
		||||
    #   skl2onnx
 | 
			
		||||
scikit-learn==1.0.2
 | 
			
		||||
    # via
 | 
			
		||||
    #   skl2onnx
 | 
			
		||||
    #   sklearn
 | 
			
		||||
scipy==1.8.0
 | 
			
		||||
    # via
 | 
			
		||||
    #   scikit-learn
 | 
			
		||||
    #   skl2onnx
 | 
			
		||||
skl2onnx==1.11
 | 
			
		||||
    # via -r requirements.in
 | 
			
		||||
sklearn==0.0
 | 
			
		||||
    # via -r requirements.in
 | 
			
		||||
threadpoolctl==3.1.0
 | 
			
		||||
    # via scikit-learn
 | 
			
		||||
typing-extensions==4.1.1
 | 
			
		||||
    # via onnx
 | 
			
		||||
@ -0,0 +1,2 @@
 | 
			
		||||
*.pkl
 | 
			
		||||
*.onnx
 | 
			
		||||
@ -0,0 +1,22 @@
 | 
			
		||||
from turtle import pd
 | 
			
		||||
import joblib as jl
 | 
			
		||||
 | 
			
		||||
from sklearn.datasets import load_iris
 | 
			
		||||
from sklearn.model_selection import train_test_split
 | 
			
		||||
from sklearn.ensemble import RandomForestClassifier
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    # Use iris dataset
 | 
			
		||||
    iris = load_iris()
 | 
			
		||||
    X, y = iris.data, iris.target
 | 
			
		||||
    X_train, X_test, y_train, y_test = train_test_split(X, y)
 | 
			
		||||
    clr = RandomForestClassifier()
 | 
			
		||||
    # Fit
 | 
			
		||||
    clr.fit(X_train, y_train)
 | 
			
		||||
    # Serialize the classifier to pickle file
 | 
			
		||||
    jl.dump(clr, "./output/model.pkl", compress=9)
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    print("Building iris model...")
 | 
			
		||||
    main()
 | 
			
		||||
    print("Model trained and dumped as pickle file.")
 | 
			
		||||
@ -0,0 +1,27 @@
 | 
			
		||||
# sklearn example
 | 
			
		||||
 | 
			
		||||
This is a simple demo showing the training of Randmon Forest Classifier, serializing it as a pickle file, converting it to an Onnx model, then making inferences using the Onnx model.
 | 
			
		||||
 | 
			
		||||
## Usage
 | 
			
		||||
 | 
			
		||||
Train the model
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python train.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Convert to Onnx
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Make inferences
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Model
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
@ -0,0 +1,48 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import pathlib
 | 
			
		||||
 | 
			
		||||
import joblib as jl
 | 
			
		||||
 | 
			
		||||
from skl2onnx import convert_sklearn
 | 
			
		||||
from skl2onnx.common.data_types import FloatTensorType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_model_to_onnx(path):
 | 
			
		||||
    """Converts the example sklearn classifier to Onnx"""
 | 
			
		||||
    clr_model = jl.load(path)
 | 
			
		||||
    # Define Initial Types
 | 
			
		||||
    initial_type = [("float_input", FloatTensorType([None, 4]))]
 | 
			
		||||
    onnx_model = convert_sklearn(clr_model, initial_types=initial_type)
 | 
			
		||||
    return onnx_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_onnx_file(model, output_path):
 | 
			
		||||
    """Writes an Onnx model to a desired file location"""
 | 
			
		||||
    with open(output_path, mode="wb") as f:
 | 
			
		||||
        f.write(model.SerializeToString())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args):
 | 
			
		||||
    print(f"Converting {args.model_path} to Onnx")
 | 
			
		||||
    model = convert_model_to_onnx(args.model_path)
 | 
			
		||||
    print(f"Writing saving Onnx file to {args.output_path}")
 | 
			
		||||
    write_onnx_file(model, args.output_path)
 | 
			
		||||
    print("Done")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="Converts a pickled classifier to onnx format"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "model_path", type=pathlib.Path, help="Path to classifier pickle file"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-o" "--out",
 | 
			
		||||
        type=pathlib.Path,
 | 
			
		||||
        dest="output_path",
 | 
			
		||||
        default=pathlib.Path("./output/clr_model.onnx"),
 | 
			
		||||
        help="Path to output Onnx file",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    main(args)
 | 
			
		||||
											
												Binary file not shown.
											
										
									
								| 
		 After Width: | Height: | Size: 27 KiB  | 
@ -0,0 +1,2 @@
 | 
			
		||||
*.pkl
 | 
			
		||||
*.onnx
 | 
			
		||||
@ -0,0 +1,24 @@
 | 
			
		||||
from turtle import pd
 | 
			
		||||
import joblib as jl
 | 
			
		||||
 | 
			
		||||
from sklearn.datasets import load_iris
 | 
			
		||||
from sklearn.model_selection import train_test_split
 | 
			
		||||
from sklearn.ensemble import RandomForestClassifier
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    # Use iris dataset
 | 
			
		||||
    iris = load_iris()
 | 
			
		||||
    X, y = iris.data, iris.target
 | 
			
		||||
    X_train, X_test, y_train, y_test = train_test_split(X, y)
 | 
			
		||||
    clr = RandomForestClassifier()
 | 
			
		||||
    # Fit
 | 
			
		||||
    clr.fit(X_train, y_train)
 | 
			
		||||
    # Serialize the classifier to pickle file
 | 
			
		||||
    jl.dump(clr, "./output/model.pkl", compress=9)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    print("Building iris model...")
 | 
			
		||||
    main()
 | 
			
		||||
    print("Model trained and dumped as pickle file.")
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue