-
Notifications
You must be signed in to change notification settings - Fork 467
Open
Description
When will this function support custom momentum values and eps values?
ao/torchao/quantization/pt2e/export_utils.py
Line 113 in a79d48f
| # TODO(Leslie): This function still fails to support custom momentum and eps value. |
and
https://github.com/pytorch/ao/blob/main/torchao/quantization/pt2e/qat_utils.py
We conducted the test on yolo11. The mAP50-95 value of the yolo11s model decreased from 47 to 43.
demo:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torch.ao.quantization import disable_fake_quant, disable_observer
class ConvBnReluModel(nn.Module):
def __init__(self, eps=1e-3, momentum=0.03):
super().__init__()
self.conv = nn.Conv2d(4, 4, 3, padding=1, bias=False)
self.bn = nn.BatchNorm2d(4, eps=eps, momentum=momentum)
self.act = nn.ReLU()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def get_batch_norm_node_args(gm):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.batch_norm.default:
return tuple(node.args)
raise RuntimeError("No aten.batch_norm.default node found")
torch.manual_seed(0)
device = 'cuda'
model = ConvBnReluModel().train().to(device)
inputs = (torch.randn(2, 4, 8, 8).to(device),)
exported = torch.export.export_for_training(copy.deepcopy(model), inputs).module()
print('before prepare_qat_pt2e')
print("\tbefore move to eval:", get_batch_norm_node_args(exported))
torch.ao.quantization.move_exported_model_to_eval(exported)
print("\tafter move to eval: ", get_batch_norm_node_args(exported))
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))
prepared = prepare_qat_pt2e(exported, quantizer)
print("after prepare_qat_pt2e:", get_batch_norm_node_args(prepared))
print("\tbefore move to eval:", get_batch_norm_node_args(prepared))
torch.ao.quantization.move_exported_model_to_eval(prepared)
torch.ao.quantization.allow_exported_model_train_eval(prepared)
print("\tafter move to eval: ", get_batch_norm_node_args(prepared))
prepared.to(device)
prepared.eval()
qat = convert_pt2e(prepared)
# print("\tafter convert_pt2e: ", get_batch_norm_node_args(qat))
with torch.no_grad():
model.eval()
torch.ao.quantization.move_exported_model_to_eval(exported)
prepared.apply(disable_observer)
prepared.apply(disable_fake_quant)
prepared.eval()
float_out = model(inputs[0])
exported_out = exported(inputs[0])
prepared_out = prepared(inputs[0])
qat_out = qat(inputs[0])
print("float - exported max abs diff: \t", float((float_out - exported_out).abs().max()))
print("exported - prepared max abs diff:", float((exported_out - prepared_out).abs().max()))
print("prepared - qat max abs diff:\t", float((prepared_out - qat_out).abs().max()))
print("float - qat max abs diff: \t", float((float_out - qat_out).abs().max()))
log:
before prepare_qat_pt2e
before move to eval: (conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.03, 0.001, True)
after move to eval: (conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True)
after prepare_qat_pt2e: (div_1, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True)
before move to eval: (div_1, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True)
after move to eval: (div_1, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True)
float - exported max abs diff: 1.3814878463745117
exported - prepared max abs diff: 1.321423053741455
prepared - qat max abs diff: 1.610581636428833
float - qat max abs diff: 1.5505168437957764
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels