forked from AUTOMATIC1111/stable-diffusion-webui-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_onnx.py
45 lines (36 loc) · 1.55 KB
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
from modules import sd_hijack, sd_unet
from modules import shared, devices
import torch
def export_current_unet_to_onnx(filename, opset_version=17):
x = torch.randn(1, 4, 16, 16).to(devices.device, devices.dtype)
timesteps = torch.zeros((1,)).to(devices.device, devices.dtype) + 500
context = torch.randn(1, 77, 768).to(devices.device, devices.dtype)
def disable_checkpoint(self):
if getattr(self, 'use_checkpoint', False) == True:
self.use_checkpoint = False
if getattr(self, 'checkpoint', False) == True:
self.checkpoint = False
shared.sd_model.model.diffusion_model.apply(disable_checkpoint)
sd_unet.apply_unet("None")
sd_hijack.model_hijack.apply_optimizations('None')
os.makedirs(os.path.dirname(filename), exist_ok=True)
with devices.autocast():
torch.onnx.export(
shared.sd_model.model.diffusion_model,
(x, timesteps, context),
filename,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['x', 'timesteps', 'context'],
output_names=['output'],
dynamic_axes={
'x': {0: 'batch_size', 2: 'height', 3: 'width'},
'timesteps': {0: 'batch_size'},
'context': {0: 'batch_size', 1: 'sequence_length'},
'output': {0: 'batch_size'},
},
)
sd_hijack.model_hijack.apply_optimizations()
sd_unet.apply_unet()