88"""
99
1010import math
11+ import logging
1112from typing import List , NamedTuple , Optional , Tuple , Union
1213
1314import torch
1920 is_sm_at_least_89 ,
2021)
2122
23+ logger : logging .Logger = logging .getLogger ()
24+
2225Tensor = torch .Tensor
2326
2427
@@ -237,6 +240,7 @@ def _normalize_granularity(
237240 ],
238241) -> Tuple [FP8Granularity , FP8Granularity ]:
239242 from torchao .quantization .granularity import (
243+ PerGroup ,
240244 PerRow ,
241245 PerTensor ,
242246 )
@@ -253,9 +257,12 @@ def _normalize_granularity(
253257 is_per_row = isinstance (granularity [0 ], PerRow ) and isinstance (
254258 granularity [1 ], PerRow
255259 )
260+ is_per_group = isinstance (granularity [0 ], PerGroup ) and isinstance (
261+ granularity [1 ], PerGroup
262+ )
256263 is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128 (granularity )
257264
258- if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128 ):
265+ if not (is_per_tensor or is_per_row or is_per_group or is_a_1_128_w_128_128 ):
259266 raise ValueError (f"Unsupported granularity types: { granularity } ." )
260267 if not isinstance (granularity [0 ], type (granularity [1 ])):
261268 raise ValueError (
@@ -281,6 +288,7 @@ def _check_hardware_support(
281288 ValueError: If invalid granularity type is provided
282289 """
283290 from torchao .quantization .granularity import (
291+ PerGroup ,
284292 PerRow ,
285293 PerTensor ,
286294 )
@@ -291,6 +299,9 @@ def _check_hardware_support(
291299 is_per_row = isinstance (granularities [0 ], PerRow ) and isinstance (
292300 granularities [1 ], PerRow
293301 )
302+ is_per_group = isinstance (granularities [0 ], PerGroup ) and isinstance (
303+ granularities [1 ], PerGroup
304+ )
294305 is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128 (granularities )
295306
296307 if is_per_tensor or is_per_row :
@@ -304,5 +315,10 @@ def _check_hardware_support(
304315 assert is_sm_at_least_89 (), (
305316 "Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
306317 )
318+ elif is_per_group :
319+ logger .warning (
320+ "PerGroup blockwise FP8 quantization: no hardware check performed. "
321+ "Ensure the target device supports blockwise FP8 operations."
322+ )
307323 else :
308324 raise ValueError (f"Invalid granularities { granularities } ." )
0 commit comments