Skip to content

Commit

Permalink
Merge pull request #161 from SthPhoenix/add_onnx_export
Browse files Browse the repository at this point in the history
Add export to ONNX
  • Loading branch information
chandrikadeb7 committed Jan 1, 2022
2 parents 397ef35 + c90b9c1 commit 4406861
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
32 changes: 32 additions & 0 deletions model2onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

# import the necessary packages
from tensorflow.keras.models import load_model, save_model
import argparse
import tf2onnx
import onnx

def model2onnx():
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", type=str,
default="mask_detector.model",
help="path to trained face mask detector model")
ap.add_argument("-o", "--output", type=str,
default='mask_detector.onnx',
help="path to trained face mask detector model")
args = vars(ap.parse_args())


# load the face mask detector model from disk
print("[INFO] loading face mask detector model...")
model = load_model(args["model"])
onnx_model, _ = tf2onnx.convert.from_keras(model, opset=13)

onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = '?'
onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param = '?'

onnx.save(onnx_model, args['output'])


if __name__ == "__main__":
model2onnx()
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ scipy==1.6.2
scikit-learn==0.24.1
pillow>=8.3.2
streamlit==0.79.0
onnx==1.10.1
tf2onnx==1.9.3

0 comments on commit 4406861

Please sign in to comment.