@@ -1048,6 +1048,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
10481048 if (128 , 128 ) in result .keys ():
10491049 result [(128 , 128 )].append (
10501050 FmhaFwdTileSize (256 , 32 , 128 , 128 , 32 , 128 , 8 , 1 , 1 , 8 , 1 , 1 , 32 , 32 , 16 , 32 , 32 , 16 , - 1 )) # fmt: skip
1051+ elif dtype in cls ._DT_FP8BF16 :
1052+ if (128 , 128 ) in result .keys ():
1053+ result [(128 , 128 )].append (
1054+ FmhaFwdTileSize (256 , 64 , 128 , 128 , 64 , 128 , 8 , 1 , 1 , 8 , 1 , 1 , 32 , 32 , 32 , 32 , 32 , 32 , - 1 )) # fmt: skip
10511055 return result
10521056
10531057 @classmethod
@@ -1085,6 +1089,15 @@ def get_pipelines(
10851089 pipelines .append (FmhaFwdPipeline ("qr_async_trload_v3" , "row" , "t" , "t" , "f" , "f" ,
10861090 F_logits = logits , F_bias = "no" , F_lse = "f" , F_dropout = "f" , F_qscale = qscale , F_mask = mask , F_skip = "f" , F_trload = "t" , F_sink = "f" )) # fmt: skip
10871091
1092+ elif dtype in cls ._DT_FP8BF16 :
1093+ # no need lse/dropout kernels
1094+ # qr_async_trload_v3 only supports (generic) causal mask
1095+ for logits , qscale , mask in itertools .product (
1096+ ["t" , "f" ],
1097+ ["no" , "pertensor" ],
1098+ ["no" , "causal" ],
1099+ ):
1100+ pipelines .append (FmhaFwdPipeline ("qr_async_trload_v3" , "row" , "t" , "t" , "f" , "f" , F_logits = logits , F_bias = "no" , F_lse = "f" , F_dropout = "f" , F_qscale = qscale , F_mask = mask , F_skip = "f" , F_trload = "t" , F_sink = "f" )) # fmt: skip
10881101 return pipelines
10891102
10901103
0 commit comments