Skip to content

Commit ac5bbb9

Browse files
zlin888meta-codesync[bot]
authored andcommitted
support Per Group in Float8DynamicActivationFloat8WeightConfig (#4182)
Summary: Pull Request resolved: #4182 as title Differential Revision: D97987011
1 parent 6f56403 commit ac5bbb9

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

torchao/float8/inference.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
import math
11+
import logging
1112
from typing import List, NamedTuple, Optional, Tuple, Union
1213

1314
import torch
@@ -19,6 +20,8 @@
1920
is_sm_at_least_89,
2021
)
2122

23+
logger: logging.Logger = logging.getLogger()
24+
2225
Tensor = torch.Tensor
2326

2427

@@ -237,6 +240,7 @@ def _normalize_granularity(
237240
],
238241
) -> Tuple[FP8Granularity, FP8Granularity]:
239242
from torchao.quantization.granularity import (
243+
PerGroup,
240244
PerRow,
241245
PerTensor,
242246
)
@@ -253,9 +257,12 @@ def _normalize_granularity(
253257
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
254258
granularity[1], PerRow
255259
)
260+
is_per_group = isinstance(granularity[0], PerGroup) and isinstance(
261+
granularity[1], PerGroup
262+
)
256263
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)
257264

258-
if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
265+
if not (is_per_tensor or is_per_row or is_per_group or is_a_1_128_w_128_128):
259266
raise ValueError(f"Unsupported granularity types: {granularity}.")
260267
if not isinstance(granularity[0], type(granularity[1])):
261268
raise ValueError(
@@ -281,6 +288,7 @@ def _check_hardware_support(
281288
ValueError: If invalid granularity type is provided
282289
"""
283290
from torchao.quantization.granularity import (
291+
PerGroup,
284292
PerRow,
285293
PerTensor,
286294
)
@@ -291,6 +299,9 @@ def _check_hardware_support(
291299
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
292300
granularities[1], PerRow
293301
)
302+
is_per_group = isinstance(granularities[0], PerGroup) and isinstance(
303+
granularities[1], PerGroup
304+
)
294305
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)
295306

296307
if is_per_tensor or is_per_row:
@@ -304,5 +315,10 @@ def _check_hardware_support(
304315
assert is_sm_at_least_89(), (
305316
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
306317
)
318+
elif is_per_group:
319+
logger.warning(
320+
"PerGroup blockwise FP8 quantization: no hardware check performed. "
321+
"Ensure the target device supports blockwise FP8 operations."
322+
)
307323
else:
308324
raise ValueError(f"Invalid granularities {granularities}.")

0 commit comments

Comments
 (0)