Saving sklearn example
parent
a5fc7b5646
commit
94c218ff1d
@ -0,0 +1,3 @@
|
|||||||
|
black
|
||||||
|
flake8
|
||||||
|
pip-tools
|
@ -0,0 +1,40 @@
|
|||||||
|
#
|
||||||
|
# This file is autogenerated by pip-compile with python 3.8
|
||||||
|
# To update, run:
|
||||||
|
#
|
||||||
|
# pip-compile dev-requirements.in
|
||||||
|
#
|
||||||
|
black==22.1.0
|
||||||
|
# via -r dev-requirements.in
|
||||||
|
click==8.0.4
|
||||||
|
# via
|
||||||
|
# black
|
||||||
|
# pip-tools
|
||||||
|
flake8==4.0.1
|
||||||
|
# via -r dev-requirements.in
|
||||||
|
mccabe==0.6.1
|
||||||
|
# via flake8
|
||||||
|
mypy-extensions==0.4.3
|
||||||
|
# via black
|
||||||
|
pathspec==0.9.0
|
||||||
|
# via black
|
||||||
|
pep517==0.12.0
|
||||||
|
# via pip-tools
|
||||||
|
pip-tools==6.5.1
|
||||||
|
# via -r dev-requirements.in
|
||||||
|
platformdirs==2.5.1
|
||||||
|
# via black
|
||||||
|
pycodestyle==2.8.0
|
||||||
|
# via flake8
|
||||||
|
pyflakes==2.4.0
|
||||||
|
# via flake8
|
||||||
|
tomli==2.0.1
|
||||||
|
# via black
|
||||||
|
typing-extensions==4.1.1
|
||||||
|
# via black
|
||||||
|
wheel==0.37.1
|
||||||
|
# via pip-tools
|
||||||
|
|
||||||
|
# The following packages are considered to be unsafe in a requirements file:
|
||||||
|
# pip
|
||||||
|
# setuptools
|
@ -0,0 +1,7 @@
|
|||||||
|
joblib
|
||||||
|
numpy
|
||||||
|
onnx
|
||||||
|
onnxruntime
|
||||||
|
opencv-python
|
||||||
|
skl2onnx
|
||||||
|
sklearn
|
@ -0,0 +1,55 @@
|
|||||||
|
#
|
||||||
|
# This file is autogenerated by pip-compile with python 3.8
|
||||||
|
# To update, run:
|
||||||
|
#
|
||||||
|
# pip-compile requirements.in
|
||||||
|
#
|
||||||
|
flatbuffers==2.0
|
||||||
|
# via onnxruntime
|
||||||
|
joblib==1.1.0
|
||||||
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# scikit-learn
|
||||||
|
numpy==1.22.3
|
||||||
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# onnx
|
||||||
|
# onnxconverter-common
|
||||||
|
# onnxruntime
|
||||||
|
# opencv-python
|
||||||
|
# scikit-learn
|
||||||
|
# scipy
|
||||||
|
# skl2onnx
|
||||||
|
onnx==1.11.0
|
||||||
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# onnxconverter-common
|
||||||
|
# skl2onnx
|
||||||
|
onnxconverter-common==1.9.0
|
||||||
|
# via skl2onnx
|
||||||
|
onnxruntime==1.10.0
|
||||||
|
# via -r requirements.in
|
||||||
|
opencv-python==4.5.5.64
|
||||||
|
# via -r requirements.in
|
||||||
|
protobuf==3.19.4
|
||||||
|
# via
|
||||||
|
# onnx
|
||||||
|
# onnxconverter-common
|
||||||
|
# onnxruntime
|
||||||
|
# skl2onnx
|
||||||
|
scikit-learn==1.0.2
|
||||||
|
# via
|
||||||
|
# skl2onnx
|
||||||
|
# sklearn
|
||||||
|
scipy==1.8.0
|
||||||
|
# via
|
||||||
|
# scikit-learn
|
||||||
|
# skl2onnx
|
||||||
|
skl2onnx==1.11
|
||||||
|
# via -r requirements.in
|
||||||
|
sklearn==0.0
|
||||||
|
# via -r requirements.in
|
||||||
|
threadpoolctl==3.1.0
|
||||||
|
# via scikit-learn
|
||||||
|
typing-extensions==4.1.1
|
||||||
|
# via onnx
|
@ -0,0 +1,2 @@
|
|||||||
|
*.pkl
|
||||||
|
*.onnx
|
@ -0,0 +1,22 @@
|
|||||||
|
from turtle import pd
|
||||||
|
import joblib as jl
|
||||||
|
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Use iris dataset
|
||||||
|
iris = load_iris()
|
||||||
|
X, y = iris.data, iris.target
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||||
|
clr = RandomForestClassifier()
|
||||||
|
# Fit
|
||||||
|
clr.fit(X_train, y_train)
|
||||||
|
# Serialize the classifier to pickle file
|
||||||
|
jl.dump(clr, "./output/model.pkl", compress=9)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Building iris model...")
|
||||||
|
main()
|
||||||
|
print("Model trained and dumped as pickle file.")
|
@ -0,0 +1,27 @@
|
|||||||
|
# sklearn example
|
||||||
|
|
||||||
|
This is a simple demo showing the training of Randmon Forest Classifier, serializing it as a pickle file, converting it to an Onnx model, then making inferences using the Onnx model.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Train the model
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Convert to Onnx
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Make inferences
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model
|
||||||
|
|
||||||
|
![Model Graph](./img/model_graph.png)
|
@ -0,0 +1,48 @@
|
|||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import joblib as jl
|
||||||
|
|
||||||
|
from skl2onnx import convert_sklearn
|
||||||
|
from skl2onnx.common.data_types import FloatTensorType
|
||||||
|
|
||||||
|
|
||||||
|
def convert_model_to_onnx(path):
|
||||||
|
"""Converts the example sklearn classifier to Onnx"""
|
||||||
|
clr_model = jl.load(path)
|
||||||
|
# Define Initial Types
|
||||||
|
initial_type = [("float_input", FloatTensorType([None, 4]))]
|
||||||
|
onnx_model = convert_sklearn(clr_model, initial_types=initial_type)
|
||||||
|
return onnx_model
|
||||||
|
|
||||||
|
|
||||||
|
def write_onnx_file(model, output_path):
|
||||||
|
"""Writes an Onnx model to a desired file location"""
|
||||||
|
with open(output_path, mode="wb") as f:
|
||||||
|
f.write(model.SerializeToString())
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(f"Converting {args.model_path} to Onnx")
|
||||||
|
model = convert_model_to_onnx(args.model_path)
|
||||||
|
print(f"Writing saving Onnx file to {args.output_path}")
|
||||||
|
write_onnx_file(model, args.output_path)
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Converts a pickled classifier to onnx format"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"model_path", type=pathlib.Path, help="Path to classifier pickle file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o" "--out",
|
||||||
|
type=pathlib.Path,
|
||||||
|
dest="output_path",
|
||||||
|
default=pathlib.Path("./output/clr_model.onnx"),
|
||||||
|
help="Path to output Onnx file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
Binary file not shown.
After Width: | Height: | Size: 27 KiB |
@ -0,0 +1,2 @@
|
|||||||
|
*.pkl
|
||||||
|
*.onnx
|
@ -0,0 +1,24 @@
|
|||||||
|
from turtle import pd
|
||||||
|
import joblib as jl
|
||||||
|
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Use iris dataset
|
||||||
|
iris = load_iris()
|
||||||
|
X, y = iris.data, iris.target
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||||
|
clr = RandomForestClassifier()
|
||||||
|
# Fit
|
||||||
|
clr.fit(X_train, y_train)
|
||||||
|
# Serialize the classifier to pickle file
|
||||||
|
jl.dump(clr, "./output/model.pkl", compress=9)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Building iris model...")
|
||||||
|
main()
|
||||||
|
print("Model trained and dumped as pickle file.")
|
Loading…
Reference in New Issue