Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor performance of quantized models after torch.compile() #113019

Closed
AlexKoff88 opened this issue Nov 6, 2023 · 13 comments 路 May be fixed by pytorch/tutorials#2682
Closed

Poor performance of quantized models after torch.compile() #113019

AlexKoff88 opened this issue Nov 6, 2023 · 13 comments 路 May be fixed by pytorch/tutorials#2682
Assignees
Labels
docathon-h2-2023 Issues for the docathon in H2 2023 oncall: pt2 oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@AlexKoff88
Copy link

AlexKoff88 commented Nov 6, 2023

馃悰 Describe the bug

I tried to benchmark the code from the post-training quantization tutorial and noticed significant performance degradation for quantized models vs. non-quantized ones.

Here is the script I used:

import time

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import torchvision.models as models
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer


# Create the Eager Model
model_name = "resnet18"
model = models.__dict__[model_name](pretrained=True)

# Set the model to eval mode
model = model.eval()

# Create the data, using the dummy data here as an example
traced_bs = 1
x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last)
example_inputs = (x,)

# Capture the FX Graph to be quantized
with torch.no_grad():
    # if you are using the PyTorch nightlies or building from source with the pytorch master,
    # use the API of `capture_pre_autograd_graph`
    # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be updated to use the official `torch.export` API when that is ready.
    exported_model = capture_pre_autograd_graph(model, example_inputs)
    # Note 2: if you are using the PyTorch 2.1 release binary or building from source with the PyTorch 2.1 release branch,
    # please use the API of `torch._dynamo.export` to capture the FX Graph.
    # exported_model, guards = torch._dynamo.export(
    #     model,
    #     *copy.deepcopy(example_inputs),
    #     aten_graph=True,
    # )


optimized_model = torch.compile(model)
# Running some benchmark
iters = 100
count = 0
start_t = time.time()

res = optimized_model(*example_inputs)  # warmap

for i in range(iters):
    res = optimized_model(*example_inputs)
    count += 1

print("fp32 elapsed: ", time.time() - start_t, "count: ", count)

quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepared_model = prepare_pt2e(exported_model, quantizer)


prepared_model(*example_inputs)

converted_model = convert_pt2e(prepared_model)

optimized_model = torch.compile(converted_model)

# Running some benchmark
iters = 100
count = 0
start_t = time.time()

res = optimized_model(*example_inputs)  # warmap

for i in range(iters):
    res = optimized_model(*example_inputs)
    count += 1

print("int8 elapsed: ", time.time() - start_t, "count: ", count)

Here is the output I got:

fp32 elapsed:  9.725126504898071 count:  100
int8 elapsed:  29.863123655319214 count:  100

The story is similar for batch size 50.

I used 4th Gen Intel Xeon CPU and torch-nightly.

Here is the command that I used to run this script: TORCHINDUCTOR_FREEZING=1 python pytorch_compile.py.

I wonder what the problem is?

Error logs

Unexpected performance results.

Minified repro

from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
torch._dynamo.config.specialize_int = True
torch._dynamo.config.automatic_dynamic_shapes = False
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.allow_rnn = True

from torch.nn import *
class Repro(torch.nn.Module):
def init(self):
super().init()
self.L__self___conv1 = Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.L__self___bn1 = BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.L__self___relu = ReLU(inplace=True)
self.L__self___maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.getattr_L__self___layer1___0___conv1 = Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer1___0___bn1 = BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer1___0___relu = ReLU(inplace=True)
self.getattr_L__self___layer1___0___conv2 = Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer1___0___bn2 = BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer1___1___conv1 = Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer1___1___bn1 = BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer1___1___relu = ReLU(inplace=True)
self.getattr_L__self___layer1___1___conv2 = Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer1___1___bn2 = BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer2___0___conv1 = Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
self.getattr_L__self___layer2___0___bn1 = BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer2___0___relu = ReLU(inplace=True)
self.getattr_L__self___layer2___0___conv2 = Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer2___0___bn2 = BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer2___0___downsample_0 = Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
self.getattr_L__self___layer2___0___downsample_1 = BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer2___1___conv1 = Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer2___1___bn1 = BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer2___1___relu = ReLU(inplace=True)
self.getattr_L__self___layer2___1___conv2 = Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer2___1___bn2 = BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer3___0___conv1 = Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
self.getattr_L__self___layer3___0___bn1 = BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer3___0___relu = ReLU(inplace=True)
self.getattr_L__self___layer3___0___conv2 = Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer3___0___bn2 = BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer3___0___downsample_0 = Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
self.getattr_L__self___layer3___0___downsample_1 = BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer3___1___conv1 = Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer3___1___bn1 = BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer3___1___relu = ReLU(inplace=True)
self.getattr_L__self___layer3___1___conv2 = Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer3___1___bn2 = BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer4___0___conv1 = Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
self.getattr_L__self___layer4___0___bn1 = BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer4___0___relu = ReLU(inplace=True)
self.getattr_L__self___layer4___0___conv2 = Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer4___0___bn2 = BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer4___0___downsample_0 = Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
self.getattr_L__self___layer4___0___downsample_1 = BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer4___1___conv1 = Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer4___1___bn1 = BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.getattr_L__self___layer4___1___relu = ReLU(inplace=True)
self.getattr_L__self___layer4___1___conv2 = Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.getattr_L__self___layer4___1___bn2 = BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.L__self___avgpool = AdaptiveAvgPool2d(output_size=(1, 1))
self.L__self___fc = Linear(in_features=512, out_features=1000, bias=True)

def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    x = self.L__self___conv1(l_x_);  l_x_ = None
    x_1 = self.L__self___bn1(x);  x = None
    x_2 = self.L__self___relu(x_1);  x_1 = None
    identity = self.L__self___maxpool(x_2);  x_2 = None
    out = self.getattr_L__self___layer1___0___conv1(identity)
    out_1 = self.getattr_L__self___layer1___0___bn1(out);  out = None
    out_2 = self.getattr_L__self___layer1___0___relu(out_1);  out_1 = None
    out_3 = self.getattr_L__self___layer1___0___conv2(out_2);  out_2 = None
    out_4 = self.getattr_L__self___layer1___0___bn2(out_3);  out_3 = None
    out_4 += identity;  out_5 = out_4;  out_4 = identity = None
    identity_1 = self.getattr_L__self___layer1___0___relu(out_5);  out_5 = None
    out_7 = self.getattr_L__self___layer1___1___conv1(identity_1)
    out_8 = self.getattr_L__self___layer1___1___bn1(out_7);  out_7 = None
    out_9 = self.getattr_L__self___layer1___1___relu(out_8);  out_8 = None
    out_10 = self.getattr_L__self___layer1___1___conv2(out_9);  out_9 = None
    out_11 = self.getattr_L__self___layer1___1___bn2(out_10);  out_10 = None
    out_11 += identity_1;  out_12 = out_11;  out_11 = identity_1 = None
    identity_2 = self.getattr_L__self___layer1___1___relu(out_12);  out_12 = None
    out_14 = self.getattr_L__self___layer2___0___conv1(identity_2)
    out_15 = self.getattr_L__self___layer2___0___bn1(out_14);  out_14 = None
    out_16 = self.getattr_L__self___layer2___0___relu(out_15);  out_15 = None
    out_17 = self.getattr_L__self___layer2___0___conv2(out_16);  out_16 = None
    out_18 = self.getattr_L__self___layer2___0___bn2(out_17);  out_17 = None
    getattr_l__self___layer2___0___downsample_0 = self.getattr_L__self___layer2___0___downsample_0(identity_2);  identity_2 = None
    identity_3 = self.getattr_L__self___layer2___0___downsample_1(getattr_l__self___layer2___0___downsample_0);  getattr_l__self___layer2___0___downsample_0 = None
    out_18 += identity_3;  out_19 = out_18;  out_18 = identity_3 = None
    identity_4 = self.getattr_L__self___layer2___0___relu(out_19);  out_19 = None
    out_21 = self.getattr_L__self___layer2___1___conv1(identity_4)
    out_22 = self.getattr_L__self___layer2___1___bn1(out_21);  out_21 = None
    out_23 = self.getattr_L__self___layer2___1___relu(out_22);  out_22 = None
    out_24 = self.getattr_L__self___layer2___1___conv2(out_23);  out_23 = None
    out_25 = self.getattr_L__self___layer2___1___bn2(out_24);  out_24 = None
    out_25 += identity_4;  out_26 = out_25;  out_25 = identity_4 = None
    identity_5 = self.getattr_L__self___layer2___1___relu(out_26);  out_26 = None
    out_28 = self.getattr_L__self___layer3___0___conv1(identity_5)
    out_29 = self.getattr_L__self___layer3___0___bn1(out_28);  out_28 = None
    out_30 = self.getattr_L__self___layer3___0___relu(out_29);  out_29 = None
    out_31 = self.getattr_L__self___layer3___0___conv2(out_30);  out_30 = None
    out_32 = self.getattr_L__self___layer3___0___bn2(out_31);  out_31 = None
    getattr_l__self___layer3___0___downsample_0 = self.getattr_L__self___layer3___0___downsample_0(identity_5);  identity_5 = None
    identity_6 = self.getattr_L__self___layer3___0___downsample_1(getattr_l__self___layer3___0___downsample_0);  getattr_l__self___layer3___0___downsample_0 = None
    out_32 += identity_6;  out_33 = out_32;  out_32 = identity_6 = None
    identity_7 = self.getattr_L__self___layer3___0___relu(out_33);  out_33 = None
    out_35 = self.getattr_L__self___layer3___1___conv1(identity_7)
    out_36 = self.getattr_L__self___layer3___1___bn1(out_35);  out_35 = None
    out_37 = self.getattr_L__self___layer3___1___relu(out_36);  out_36 = None
    out_38 = self.getattr_L__self___layer3___1___conv2(out_37);  out_37 = None
    out_39 = self.getattr_L__self___layer3___1___bn2(out_38);  out_38 = None
    out_39 += identity_7;  out_40 = out_39;  out_39 = identity_7 = None
    identity_8 = self.getattr_L__self___layer3___1___relu(out_40);  out_40 = None
    out_42 = self.getattr_L__self___layer4___0___conv1(identity_8)
    out_43 = self.getattr_L__self___layer4___0___bn1(out_42);  out_42 = None
    out_44 = self.getattr_L__self___layer4___0___relu(out_43);  out_43 = None
    out_45 = self.getattr_L__self___layer4___0___conv2(out_44);  out_44 = None
    out_46 = self.getattr_L__self___layer4___0___bn2(out_45);  out_45 = None
    getattr_l__self___layer4___0___downsample_0 = self.getattr_L__self___layer4___0___downsample_0(identity_8);  identity_8 = None
    identity_9 = self.getattr_L__self___layer4___0___downsample_1(getattr_l__self___layer4___0___downsample_0);  getattr_l__self___layer4___0___downsample_0 = None
    out_46 += identity_9;  out_47 = out_46;  out_46 = identity_9 = None
    identity_10 = self.getattr_L__self___layer4___0___relu(out_47);  out_47 = None
    out_49 = self.getattr_L__self___layer4___1___conv1(identity_10)
    out_50 = self.getattr_L__self___layer4___1___bn1(out_49);  out_49 = None
    out_51 = self.getattr_L__self___layer4___1___relu(out_50);  out_50 = None
    out_52 = self.getattr_L__self___layer4___1___conv2(out_51);  out_51 = None
    out_53 = self.getattr_L__self___layer4___1___bn2(out_52);  out_52 = None
    out_53 += identity_10;  out_54 = out_53;  out_53 = identity_10 = None
    x_7 = self.getattr_L__self___layer4___1___relu(out_54);  out_54 = None
    x_8 = self.L__self___avgpool(x_7);  x_7 = None
    x_9 = torch.flatten(x_8, 1);  x_8 = None
    x_10 = self.L__self___fc(x_9);  x_9 = None
    return (x_10,)

mod = Repro()

def load_args(reader):
buf0 = reader.storage('aee8c913ace6d6ec115d6e3b12687aadf94e8582', 30105600)
reader.tensor(buf0, (50, 3, 224, 224), (150528, 1, 672, 3), is_leaf=True) # L_x_
load_args._version = 0

if name == 'main':
from torch._dynamo.repro.after_dynamo import run_repro
run_repro(mod, load_args, accuracy=False, command='minify',
save_dir='./torch_compile_debug/run_2023_11_06_10_51_32_497128-pid_3389471/minifier/checkpoints', autocast=False, backend=None)

Versions

PyTorch version: 2.2.0.dev20231106+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.27.5
Libc version: glibc-2.35

Python version: 3.8.10 (default, May 26 2023, 14:05:08) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.17.0-1033-oem-x86_64-with-glibc2.29
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 52 bits physical, 57 bits virtual
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 143
Model name: Intel(R) Xeon(R) Gold 6430L
Stepping: 7
CPU MHz: 2670.194
CPU max MHz: 3400.0000
CPU min MHz: 800.0000
BogoMIPS: 3800.00
L1d cache: 3 MiB
L1i cache: 2 MiB
L2 cache: 128 MiB
L3 cache: 120 MiB
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] torch==2.2.0.dev20231106+cpu
[pip3] torchaudio==2.2.0.dev20231105+cpu
[pip3] torchvision==0.17.0.dev20231105+cpu
[conda] Could not collect

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519

@AlexKoff88
Copy link
Author

@supriyar, @jerryzh

@jerryzh168
Copy link
Contributor

@leslie-fang-intel @jgong5 @Xia-Weiwen could you take a look?

@jerryzh168 jerryzh168 added oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 6, 2023
@leslie-fang-intel
Copy link
Collaborator

Sure, I will take a look soon.

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Nov 7, 2023

Hi @AlexKoff88, thanks for reporting the issue. You need to put the torch.compile under the torch.no_grad() context (same for fp32/bf16). Otherwise, the Inductor freezing pass will failed to work.

if config.freezing and not torch.is_grad_enabled():

Here is the script for your reference https://gist.github.com/leslie-fang-intel/0e9d7bcd222e7a9a4db0e7c5fb4a88bc.

And another suggestion is, since the warm up run of torch.compile will trigger the compilation and takes a lot of time, we had better remove the warm up run from the elapsed time count.

@leslie-fang-intel
Copy link
Collaborator

Hi @jingxu10, could you kindly help to modify the tutorial: https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html to tell user put the torch.compile and benchmark run under the torch.no_grad context?

@AlexKoff88
Copy link
Author

Hi @leslie-fang-intel, thanks for the prompt response. I am able to see the speedup from quantization on Xeon for batch size 50 after applying your suggestions. However, the situation is still different for batch size 1. I guess this is the runtime issue.

@leslie-fang-intel
Copy link
Collaborator

@AlexKoff88 glad to hear that. What's the speedup at BS50 and BS1 specifically?

@AlexKoff88
Copy link
Author

BS=50

fp32 elapsed:  16.489201068878174 count:  100
int8 elapsed:  9.073071241378784 count:  100

BS=1

fp32 elapsed:  1.8941051959991455 count:  100
int8 elapsed:  6.1039512157440186 count:  100

@leslie-fang-intel
Copy link
Collaborator

And another suggestion is, since the warm up run of torch.compile will trigger the compilation and takes a lot of time, we had better remove the warm up run from the elapsed time count.

Notice in your report script, you count the warmup run into the elapsed time, have you excluded it with above performance data?

Here is the the performance data I got with 4 cores and BS 1:

fp32 elapsed:  0.6051821708679199 count:  100
int8 elapsed:  0.1602315902709961 count:  100

@AlexKoff88
Copy link
Author

Yes, I excluded the warmup. The difference is that I use the all the cores available but the performance is much different from what I got. Here is the new script:

import time

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import torchvision.models as models
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer

print(torch._dynamo.list_backends())

# Create the Eager Model
model_name = "resnet18"
model = models.__dict__[model_name](pretrained=True)

# Set the model to eval mode
model = model.eval()

# Create the data, using the dummy data here as an example
traced_bs = 1
x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last)
example_inputs = (x,)

# Capture the FX Graph to be quantized
with torch.no_grad():
    # if you are using the PyTorch nightlies or building from source with the pytorch master,
    # use the API of `capture_pre_autograd_graph`
    # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be updated to use the official `torch.export` API when that is ready.
    exported_model = capture_pre_autograd_graph(model, example_inputs)
    # Note 2: if you are using the PyTorch 2.1 release binary or building from source with the PyTorch 2.1 release branch,
    # please use the API of `torch._dynamo.export` to capture the FX Graph.
    # exported_model, guards = torch._dynamo.export(
    #     model,
    #     *copy.deepcopy(example_inputs),
    #     aten_graph=True,
    # )

with torch.no_grad():
    optimized_model = torch.compile(model)
# Running some benchmark
res = optimized_model(*example_inputs)  # warmap

iters = 100
count = 0
start_t = time.time()

with torch.no_grad():
    for i in range(iters):
        res = optimized_model(*example_inputs)
        count += 1

print("fp32 elapsed: ", time.time() - start_t, "count: ", count)

quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepared_model = prepare_pt2e(exported_model, quantizer)


prepared_model(*example_inputs)

converted_model = convert_pt2e(prepared_model)

with torch.no_grad():
    optimized_model = torch.compile(converted_model)

# Running some benchmark
res = optimized_model(*example_inputs)  # warmap

count = 0
start_t = time.time()

with torch.no_grad():
    for i in range(iters):
        res = optimized_model(*example_inputs)
        count += 1

print("int8 elapsed: ", time.time() - start_t, "count: ", count)

@jingxu10 jingxu10 added the docathon-h2-2023 Issues for the docathon in H2 2023 label Nov 10, 2023
@leslie-fang-intel
Copy link
Collaborator

Hi @AlexKoff88, based on your script:

with torch.no_grad():
    optimized_model = torch.compile(converted_model)

# Running some benchmark
res = optimized_model(*example_inputs)  # warmap
  • The warm up run also need to under the no_grad context. Otherwise, I think you are still counting the compile time in the reported performance. Here is the performance data I got with 56 cores, CMD: TORCHINDUCTOR_FREEZING=1 numactl -C 56-111 -m 1 python test_113019.py:
    fp32 elapsed:  0.19747042655944824 count:  100
    int8 elapsed:  0.1437819004058838 count:  100
    
  • Since BS1 is small, I still suggest you use 4 cores to run this model.

@AlexKoff88
Copy link
Author

thanks, @leslie-fang-intel. Everything works now. I think it makes sense to prepare e2e example that shows the correct flow. Otherwise, it is easy to get incorrect results. I also wonder if it is possible to avoid using torch.no_grad context if the model in eval state? This can improve UX I guess.

@leslie-fang-intel
Copy link
Collaborator

Glad to hear that @AlexKoff88. I will close this issue now. Feel free to let me know if any other question.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docathon-h2-2023 Issues for the docathon in H2 2023 oncall: pt2 oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants