Skip to content

Commit

Permalink
Fix PT2E Dynamic Quant regression
Browse files Browse the repository at this point in the history
ghstack-source-id: a85e47669198f4c75108c28e185a397073734609
Pull Request resolved: #125207
  • Loading branch information
leslie-fang-intel committed Apr 30, 2024
1 parent 89a5517 commit 613d159
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
58 changes: 55 additions & 3 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,18 @@ def test_dynamic_qlinear_qat_cpu(self):
(torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_dynamic_qlinear_input_dim_exceeds_2(self):
r"""
This testcase will quantize a single Linear Moduel.
"""
for bias in [True, False]:
self._qlinear_cpu_test_helper(
(torch.randn((2, 3, 4)),), bias=bias, is_dynamic=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
Expand Down Expand Up @@ -1546,7 +1558,13 @@ def test_qlinear_gelu_int8_mixed_bf16(self):
(torch.randn((2, 4)),), gelu, int8_mixed_bf16=True
)

def _qlinear_dequant_promotion_cpu_test_helper(self, inputs, int8_mixed_bf16=False):
def _qlinear_dequant_promotion_cpu_test_helper(
self,
inputs,
int8_mixed_bf16=False,
is_dynamic=False,
matcher_check_fn=None,
):
class M(torch.nn.Module):
def __init__(
self,
Expand All @@ -1564,7 +1582,7 @@ def forward(self, x):

mod = M().eval()

def matcher_check_fn():
def default_matcher_check_fn():
# 1. Dequant pattern matcher for dequant promotion * 1
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
# 2. dequant-linear pattern matched in quantization weight prepack * 3
Expand All @@ -1579,7 +1597,10 @@ def matcher_check_fn():
inputs,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
matcher_check_fn=matcher_check_fn
if matcher_check_fn is not None
else default_matcher_check_fn,
is_dynamic=is_dynamic,
)

@skipIfNoDynamoSupport
Expand Down Expand Up @@ -1662,6 +1683,37 @@ def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self):
(torch.randn((2, 3, 4)),), int8_mixed_bf16=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qlinear_dequant_promotion_dynamic_cpu(self):
r"""
This testcase test if dequant node before linear is promoted correctly:
X
|
Linear1(X)
/ \
Linear2(X) Linear3(X)
\ /
Add
|
Y
"""

def matcher_check_fn():
# 1. Dequant pattern matcher for dequant promotion * 1
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
# 2. dequant-linear pattern matched in quantization weight prepack * 3
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3
)

self._qlinear_dequant_promotion_cpu_test_helper(
(torch.randn((2, 4)),),
matcher_check_fn=matcher_check_fn,
is_dynamic=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
Expand Down
29 changes: 20 additions & 9 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,7 @@ def _inner(match):
dequant_pattern_end_node = match.output_node()
if dequant_pattern_end_node.target not in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
prims.convert_element_type.default,
aten.reshape.default,
]:
Expand All @@ -1254,7 +1255,11 @@ def _inner(match):
)

if (
dequant_node.target is quantized_decomposed.dequantize_per_tensor.default
dequant_node.target
in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]
and len(list(dequant_pattern_end_node.users)) > 1
):
# If dequant pattern has more than 1 users, then do dequant promoted
Expand Down Expand Up @@ -1319,6 +1324,7 @@ def clone_to_new_node(graph, source_node, user_node):
dequant_pattern_end_node = match.output_node()
assert dequant_pattern_end_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
prims.convert_element_type.default,
aten.reshape.default,
]
Expand All @@ -1328,7 +1334,10 @@ def clone_to_new_node(graph, source_node, user_node):
# * OPT(prims.convert_element_type.default) (to_bf16)
# * dequantize_per_tensor
def _find_first_node_in_dequant_pattern(_node):
if _node.target is quantized_decomposed.dequantize_per_tensor.default:
if _node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
return _node
else:
Expand All @@ -1341,10 +1350,10 @@ def _find_first_node_in_dequant_pattern(_node):
dequant_pattern_end_node
)

assert (
dequant_pattern_start_node.target
is quantized_decomposed.dequantize_per_tensor.default
)
assert dequant_pattern_start_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]

# Clone the dequant pattern for each user node
graph = match.graph
Expand Down Expand Up @@ -1993,9 +2002,9 @@ def _generate_qlinear_weight_prepack_patterns(

def _register_dequant_promotion():
dequant_pattern_cases = itertools.product(
[torch.float32, torch.bfloat16], [True, False]
[torch.float32, torch.bfloat16], [True, False], [True, False]
)
for dtype, input_dim_exceeds_two in dequant_pattern_cases:
for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases:
# 4 dequantization patterns will be matched based on the dtype and input dimension size.
# Case 1: int8-mixed-fp32, input dim size is 2
# Case 2: int8-mixed-fp32, input dim size exceeds 2
Expand All @@ -2019,7 +2028,9 @@ def _register_dequant_promotion():
_register_dequant_promotion_pass(
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(),
get_dequantize_per_tensor_activation_pattern(
is_tensor_overload=is_tensor_overload
),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
Expand Down
12 changes: 7 additions & 5 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,16 +1244,18 @@ def quantized_decomposed_quantize_per_tensor_tensor(
zero_point_loader = zero_point.make_loader()

def inner_fn(idx):
input = input_loader(idx)
_input = input_loader(idx)
_scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
_zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
if scale.dtype != torch.float32:
_scale = ops.to_dtype(_scale, torch.float32)
if input.dtype != scale.dtype:
_input = ops.to_dtype(_input, scale.dtype)
if zero_point.dtype != torch.float32:
_zero_point = ops.to_dtype(_zero_point, torch.float32)
val = ops.round(input * ops.reciprocal(_scale)) + _zero_point
val = ops.round(_input * ops.reciprocal(_scale))
if scale.dtype != torch.float32:
val = ops.to_dtype(val, torch.float32)
qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
clamped = ops.minimum(ops.maximum(val + _zero_point, qmin), qmax)
return ops.to_dtype(clamped, dtype)

return Pointwise.create(
Expand Down

0 comments on commit 613d159

Please sign in to comment.