Skip to content

When will support be provided for customizing the momentum value and eps value in BN? #4107

@PhilCuriosity

Description

@PhilCuriosity

When will this function support custom momentum values and eps values?

# TODO(Leslie): This function still fails to support custom momentum and eps value.

and
https://github.com/pytorch/ao/blob/main/torchao/quantization/pt2e/qat_utils.py

We conducted the test on yolo11. The mAP50-95 value of the yolo11s model decreased from 47 to 43.

demo:

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from torch.ao.quantization import disable_fake_quant, disable_observer

class ConvBnReluModel(nn.Module):
    def __init__(self, eps=1e-3, momentum=0.03):
        super().__init__()
        self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


def get_batch_norm_node_args(gm):
    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default:
            return tuple(node.args)
    raise RuntimeError("No aten.batch_norm.default node found")

torch.manual_seed(0)
device = 'cuda' 

model = ConvBnReluModel().train().to(device)
inputs = (torch.randn(2, 4, 8, 8).to(device),)
exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module()

print('before prepare_qat_pt2e')
print("\tbefore move to eval:", get_batch_norm_node_args(exported))
torch.ao.quantization.move_exported_model_to_eval(exported)
print("\tafter move to eval: ", get_batch_norm_node_args(exported))

quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
prepared = prepare_qat_pt2e(exported, quantizer)

print("after prepare_qat_pt2e:", get_batch_norm_node_args(prepared))
print("\tbefore move to eval:", get_batch_norm_node_args(prepared))
torch.ao.quantization.move_exported_model_to_eval(prepared)
torch.ao.quantization.allow_exported_model_train_eval(prepared)
print("\tafter move to eval: ", get_batch_norm_node_args(prepared))
prepared.to(device)
prepared.eval()

qat = convert_pt2e(prepared)
# print("\tafter convert_pt2e: ", get_batch_norm_node_args(qat))
 
with torch.no_grad():
    model.eval()
    torch.ao.quantization.move_exported_model_to_eval(exported)
    prepared.apply(disable_observer)
    prepared.apply(disable_fake_quant)
    prepared.eval()

    float_out = model(inputs[0])
    exported_out = exported(inputs[0])
    prepared_out = prepared(inputs[0])
    qat_out = qat(inputs[0])

print("float - exported max abs diff: \t", float((float_out - exported_out).abs().max()))
print("exported - prepared max abs diff:", float((exported_out - prepared_out).abs().max()))
print("prepared - qat max abs diff:\t", float((prepared_out - qat_out).abs().max()))
print("float - qat max abs diff: \t", float((float_out - qat_out).abs().max()))

log:

before prepare_qat_pt2e
        before move to eval: (conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.03, 0.001, True)
        after move to eval:  (conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True)
after prepare_qat_pt2e: (div_1, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True)
        before move to eval: (div_1, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True)
        after move to eval:  (div_1, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True)
float - exported max abs diff:   1.3814878463745117
exported - prepared max abs diff: 1.321423053741455
prepared - qat max abs diff:     1.610581636428833
float - qat max abs diff:        1.5505168437957764

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions