Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][pt2e] Fix conv-bn weight + bias per channel QAT #125208

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
92 changes: 85 additions & 7 deletions test/quantization/pt2e/test_quantize_pt2e_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,60 @@ def test_qat_conv_transpose_bn(self):
def test_qat_conv_transpose_bn_relu(self):
self._do_test_qat_conv_transpose_bn(has_relu=True)

def test_qat_conv_bn_per_channel_weight_bias(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = capture_pre_autograd_graph(m, example_inputs)
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m(*example_inputs)

# Expected graph:
# x -> q_tensor -> dq_tensor -> conv -> q_tensor -> dq_tensor -> output
# weight -> q_channel -> dq_channel /
# bias -> q_channel -> dq_channel /

(conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
conv_op = conv_node.target
conv_weight_dq_op = (
torch.ops.quantized_decomposed.dequantize_per_channel.default
)
node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 2,
}
node_list = [
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
ns.call_function(conv_weight_dq_op),
ns.call_function(conv_weight_dq_op),
ns.call_function(conv_op),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
]
self.checkGraphModuleNodes(
m,
expected_node_list=node_list,
expected_node_occurrence=node_occurrence,
)


@skipIfNoQNNPACK
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
Expand Down Expand Up @@ -952,22 +1006,45 @@ class ConvBnDerivedBiasQuantizer(Quantizer):
derived from the conv input activation and weight qparams.
"""

def __init__(self, is_per_channel: bool = False):
super().__init__()
self.is_per_channel = is_per_channel

def _derive_bias_qparams_from_act_and_weight_qparams(self, obs_or_fqs):
act_scale, _ = obs_or_fqs[0].calculate_qparams()
weight_scale, _ = obs_or_fqs[1].calculate_qparams()
bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32)
bias_zero_point = torch.tensor([0], dtype=torch.int32)
if self.is_per_channel:
bias_scale = act_scale * weight_scale
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
else:
bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32)
bias_zero_point = torch.tensor([0], dtype=torch.int32)
return bias_scale, bias_zero_point

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
if self.is_per_channel:
weight_qscheme = torch.per_channel_symmetric
weight_fq = FusedMovingAvgObsFakeQuantize.with_args(
observer=MovingAveragePerChannelMinMaxObserver,
)
else:
weight_qscheme = torch.per_tensor_affine
weight_fq = default_fake_quant
conv_node, _, getitem_node = _get_conv_bn_getitem_nodes(model)
act_and_weight_qspec = QuantizationSpec(
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=default_fake_quant,
)
weight_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=weight_qscheme,
observer_or_fake_quant_ctr=weight_fq,
)
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv_node.args[0], conv_node),
Expand All @@ -977,18 +1054,19 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
qscheme=weight_qscheme,
ch_axis=0 if self.is_per_channel else None,
)
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
conv_node.args[0]: act_and_weight_qspec,
conv_node.args[1]: act_and_weight_qspec,
conv_node.args[0]: act_qspec,
conv_node.args[1]: weight_qspec,
conv_node.args[2]: bias_qspec,
},
_annotated=True,
)
getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
output_qspec=act_and_weight_qspec,
output_qspec=act_qspec,
_annotated=True,
)
return model
Expand Down
34 changes: 24 additions & 10 deletions torch/ao/quantization/pt2e/qat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
def _get_quantized_conv_bn_example_inputs_kwargs(
is_per_channel: bool,
has_bias: bool,
bias_is_quantized: bool,
is_cuda: bool,
) -> Dict[str, Any]:
"""
Expand All @@ -68,8 +69,11 @@ def _get_quantized_conv_bn_example_inputs_kwargs(
# Per tensor quantization uses literals to represent scale and zero
# point, so there is no need to include them here as kwargs
if is_per_channel:
kwargs["scale"] = torch.tensor([1], dtype=torch.float)
kwargs["zero_point"] = torch.tensor([0], dtype=torch.int)
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias and bias_is_quantized:
kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float)
kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias:
kwargs["conv_bias"] = torch.randn(1)
if is_cuda:
Expand Down Expand Up @@ -157,7 +161,7 @@ def _qat_conv_bn_pattern_no_conv_bias(
return x
return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias)

def _append_qdq(x, is_per_channel, kwargs):
def _append_qdq(x, is_per_channel, is_bias, kwargs):
"""
Helper function to append q-dq ops after `x`, using dummy values for the qparams
and qmin/qmax. We use dummy values here because we match with `ignore_literals=True`
Expand All @@ -167,8 +171,10 @@ def _append_qdq(x, is_per_channel, kwargs):
"""
# Dummy args to be passed into q-dq ops
per_channel_axis = 0
scale = kwargs["scale"] if is_per_channel else 1.0
zp = kwargs["zero_point"] if is_per_channel else 0
scale_key = "bias_scale" if is_bias else "weight_scale"
zp_key = "bias_zero_point" if is_bias else "weight_zero_point"
scale = kwargs[scale_key] if is_per_channel else 1.0
zp = kwargs[zp_key] if is_per_channel else 0
qmin = -127
qmax = 127
dtype = torch.int8
Expand Down Expand Up @@ -215,11 +221,15 @@ def _quantized_qat_conv_bn_pattern(
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
scaled_weight = _append_qdq(scaled_weight, is_per_channel, kwargs)
scaled_weight = _append_qdq(
scaled_weight, is_per_channel, is_bias=False, kwargs=kwargs,
)
if has_bias:
zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
if bias_is_quantized:
zero_bias = _append_qdq(zero_bias, is_per_channel, kwargs)
zero_bias = _append_qdq(
zero_bias, is_per_channel, is_bias=True, kwargs=kwargs,
)
x = conv_fn(x, scaled_weight, zero_bias)
else:
x = conv_fn(x, scaled_weight, None)
Expand Down Expand Up @@ -252,11 +262,15 @@ def _folded_quantized_qat_conv_bn_pattern(
bn_running_var: torch.Tensor,
**kwargs,
) -> torch.Tensor:
conv_weight = _append_qdq(conv_weight, is_per_channel, kwargs)
conv_weight = _append_qdq(
conv_weight, is_per_channel, is_bias=False, kwargs=kwargs,
)
if has_bias:
bias = kwargs["conv_bias"]
if bias_is_quantized:
bias = _append_qdq(bias, is_per_channel, kwargs)
bias = _append_qdq(
bias, is_per_channel, is_bias=True, kwargs=kwargs,
)
else:
bias = None
x = conv_fn(x, conv_weight, bias)
Expand Down Expand Up @@ -739,7 +753,7 @@ def _fold_conv_bn_qat_helper(
# filter out one of the values for this flag to avoid having duplicate patterns
if not has_bias and bias_is_quantized:
continue
kwargs = _get_quantized_conv_bn_example_inputs_kwargs(is_per_channel, has_bias, is_cuda)
kwargs = _get_quantized_conv_bn_example_inputs_kwargs(is_per_channel, has_bias, bias_is_quantized, is_cuda)
match_pattern = _get_quantized_qat_conv_bn_pattern(
is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
)
Expand Down