1111
1212import torch
1313import torch .cuda .amp as amp
14+ import torch .nn as nn
1415import torch .distributed as dist
1516from tqdm import tqdm
1617
2223 get_sampling_sigmas , retrieve_timesteps )
2324from .utils .fm_solvers_unipc import FlowUniPCMultistepScheduler
2425
26+ # def convert_linear_conv_to_fp8(module):
27+ # for name, child in module.named_children():
28+ # # 递归处理子模块
29+ # convert_linear_conv_to_fp8(child)
30+
31+ # # 判断是否为 Linear 或 Conv 层
32+ # if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
33+ # # 转换权重
34+ # if hasattr(child, 'weight') and child.weight is not None:
35+ # # 保留 Parameter 类型,仅修改数据
36+ # child.weight.data = child.weight.data.to(dtype=torch.float8_e4m3fn)
37+ # # 可选:转换偏置(根据需求开启)
38+
2539
2640class WanT2V :
2741
@@ -81,7 +95,9 @@ def __init__(
8195 device = self .device )
8296
8397 logging .info (f"Creating WanModel from { checkpoint_dir } " )
84- self .model = WanModel .from_pretrained (checkpoint_dir )
98+ self .model = WanModel .from_pretrained (checkpoint_dir ,torch_dtype = torch .float8_e4m3fn )
99+ #self.model = WanModel.from_pretrained(checkpoint_dir )
100+
85101 self .model .eval ().requires_grad_ (False )
86102
87103 if use_usp :
@@ -102,7 +118,9 @@ def __init__(
102118 dist .barrier ()
103119 if dit_fsdp :
104120 self .model = shard_fn (self .model )
121+ # convert_linear_conv_to_fp8(self.model)
105122 else :
123+ # convert_linear_conv_to_fp8(self.model)
106124 self .model .to (self .device )
107125
108126 self .sample_neg_prompt = config .sample_neg_prompt
@@ -152,6 +170,7 @@ def generate(self,
152170 - W: Frame width from size)
153171 """
154172 # preprocess
173+ offload_model = False
155174 F = frame_num
156175 target_shape = (self .vae .model .z_dim , (F - 1 ) // self .vae_stride [0 ] + 1 ,
157176 size [1 ] // self .vae_stride [1 ],
@@ -225,6 +244,16 @@ def noop_no_sync():
225244
226245 arg_c = {'context' : context , 'seq_len' : seq_len }
227246 arg_null = {'context' : context_null , 'seq_len' : seq_len }
247+
248+ # import gc
249+ # del self.text_encoder
250+ # del self.vae
251+ # gc.collect() # 立即触发垃圾回收
252+ # torch.cuda.empty_cache() # 清空CUDA缓存
253+ # torch.cuda.reset_peak_memory_stats()
254+
255+ # start_mem = torch.cuda.memory_allocated()
256+ #print(f"该阶段开始时显存占用:{start_mem / 1024**3:.2f} GB")
228257
229258 for _ , t in enumerate (tqdm (timesteps )):
230259 latent_model_input = latents
@@ -248,6 +277,9 @@ def noop_no_sync():
248277 return_dict = False ,
249278 generator = seed_g )[0 ]
250279 latents = [temp_x0 .squeeze (0 )]
280+
281+ # peak_mem_bytes = torch.cuda.max_memory_allocated()
282+ # print(f"该阶段最大显存占用:{peak_mem_bytes / 1024**3:.2f} GB")
251283
252284 x0 = latents
253285 if offload_model :
0 commit comments