Saving sklearn example
parent
a5fc7b5646
commit
94c218ff1d
@ -0,0 +1,3 @@
|
||||
black
|
||||
flake8
|
||||
pip-tools
|
@ -0,0 +1,40 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with python 3.8
|
||||
# To update, run:
|
||||
#
|
||||
# pip-compile dev-requirements.in
|
||||
#
|
||||
black==22.1.0
|
||||
# via -r dev-requirements.in
|
||||
click==8.0.4
|
||||
# via
|
||||
# black
|
||||
# pip-tools
|
||||
flake8==4.0.1
|
||||
# via -r dev-requirements.in
|
||||
mccabe==0.6.1
|
||||
# via flake8
|
||||
mypy-extensions==0.4.3
|
||||
# via black
|
||||
pathspec==0.9.0
|
||||
# via black
|
||||
pep517==0.12.0
|
||||
# via pip-tools
|
||||
pip-tools==6.5.1
|
||||
# via -r dev-requirements.in
|
||||
platformdirs==2.5.1
|
||||
# via black
|
||||
pycodestyle==2.8.0
|
||||
# via flake8
|
||||
pyflakes==2.4.0
|
||||
# via flake8
|
||||
tomli==2.0.1
|
||||
# via black
|
||||
typing-extensions==4.1.1
|
||||
# via black
|
||||
wheel==0.37.1
|
||||
# via pip-tools
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# pip
|
||||
# setuptools
|
@ -0,0 +1,7 @@
|
||||
joblib
|
||||
numpy
|
||||
onnx
|
||||
onnxruntime
|
||||
opencv-python
|
||||
skl2onnx
|
||||
sklearn
|
@ -0,0 +1,55 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with python 3.8
|
||||
# To update, run:
|
||||
#
|
||||
# pip-compile requirements.in
|
||||
#
|
||||
flatbuffers==2.0
|
||||
# via onnxruntime
|
||||
joblib==1.1.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# scikit-learn
|
||||
numpy==1.22.3
|
||||
# via
|
||||
# -r requirements.in
|
||||
# onnx
|
||||
# onnxconverter-common
|
||||
# onnxruntime
|
||||
# opencv-python
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# skl2onnx
|
||||
onnx==1.11.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# onnxconverter-common
|
||||
# skl2onnx
|
||||
onnxconverter-common==1.9.0
|
||||
# via skl2onnx
|
||||
onnxruntime==1.10.0
|
||||
# via -r requirements.in
|
||||
opencv-python==4.5.5.64
|
||||
# via -r requirements.in
|
||||
protobuf==3.19.4
|
||||
# via
|
||||
# onnx
|
||||
# onnxconverter-common
|
||||
# onnxruntime
|
||||
# skl2onnx
|
||||
scikit-learn==1.0.2
|
||||
# via
|
||||
# skl2onnx
|
||||
# sklearn
|
||||
scipy==1.8.0
|
||||
# via
|
||||
# scikit-learn
|
||||
# skl2onnx
|
||||
skl2onnx==1.11
|
||||
# via -r requirements.in
|
||||
sklearn==0.0
|
||||
# via -r requirements.in
|
||||
threadpoolctl==3.1.0
|
||||
# via scikit-learn
|
||||
typing-extensions==4.1.1
|
||||
# via onnx
|
@ -0,0 +1,2 @@
|
||||
*.pkl
|
||||
*.onnx
|
@ -0,0 +1,22 @@
|
||||
from turtle import pd
|
||||
import joblib as jl
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
def main():
|
||||
# Use iris dataset
|
||||
iris = load_iris()
|
||||
X, y = iris.data, iris.target
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||
clr = RandomForestClassifier()
|
||||
# Fit
|
||||
clr.fit(X_train, y_train)
|
||||
# Serialize the classifier to pickle file
|
||||
jl.dump(clr, "./output/model.pkl", compress=9)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Building iris model...")
|
||||
main()
|
||||
print("Model trained and dumped as pickle file.")
|
@ -0,0 +1,27 @@
|
||||
# sklearn example
|
||||
|
||||
This is a simple demo showing the training of Randmon Forest Classifier, serializing it as a pickle file, converting it to an Onnx model, then making inferences using the Onnx model.
|
||||
|
||||
## Usage
|
||||
|
||||
Train the model
|
||||
|
||||
```
|
||||
python train.py
|
||||
```
|
||||
|
||||
Convert to Onnx
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
Make inferences
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
## Model
|
||||
|
||||
![Model Graph](./img/model_graph.png)
|
@ -0,0 +1,48 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import joblib as jl
|
||||
|
||||
from skl2onnx import convert_sklearn
|
||||
from skl2onnx.common.data_types import FloatTensorType
|
||||
|
||||
|
||||
def convert_model_to_onnx(path):
|
||||
"""Converts the example sklearn classifier to Onnx"""
|
||||
clr_model = jl.load(path)
|
||||
# Define Initial Types
|
||||
initial_type = [("float_input", FloatTensorType([None, 4]))]
|
||||
onnx_model = convert_sklearn(clr_model, initial_types=initial_type)
|
||||
return onnx_model
|
||||
|
||||
|
||||
def write_onnx_file(model, output_path):
|
||||
"""Writes an Onnx model to a desired file location"""
|
||||
with open(output_path, mode="wb") as f:
|
||||
f.write(model.SerializeToString())
|
||||
|
||||
|
||||
def main(args):
|
||||
print(f"Converting {args.model_path} to Onnx")
|
||||
model = convert_model_to_onnx(args.model_path)
|
||||
print(f"Writing saving Onnx file to {args.output_path}")
|
||||
write_onnx_file(model, args.output_path)
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Converts a pickled classifier to onnx format"
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_path", type=pathlib.Path, help="Path to classifier pickle file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o" "--out",
|
||||
type=pathlib.Path,
|
||||
dest="output_path",
|
||||
default=pathlib.Path("./output/clr_model.onnx"),
|
||||
help="Path to output Onnx file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Binary file not shown.
After Width: | Height: | Size: 27 KiB |
@ -0,0 +1,2 @@
|
||||
*.pkl
|
||||
*.onnx
|
@ -0,0 +1,24 @@
|
||||
from turtle import pd
|
||||
import joblib as jl
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
|
||||
def main():
|
||||
# Use iris dataset
|
||||
iris = load_iris()
|
||||
X, y = iris.data, iris.target
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||
clr = RandomForestClassifier()
|
||||
# Fit
|
||||
clr.fit(X_train, y_train)
|
||||
# Serialize the classifier to pickle file
|
||||
jl.dump(clr, "./output/model.pkl", compress=9)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Building iris model...")
|
||||
main()
|
||||
print("Model trained and dumped as pickle file.")
|
Loading…
Reference in New Issue