diff --git a/python/paddle/distributed/auto_parallel/fully_shard_fusion.py b/python/paddle/distributed/auto_parallel/fully_shard_fusion.py index 6cc3ff9bfbaea6..74c660411b1ca2 100644 --- a/python/paddle/distributed/auto_parallel/fully_shard_fusion.py +++ b/python/paddle/distributed/auto_parallel/fully_shard_fusion.py @@ -26,18 +26,9 @@ alignment, get_current_device_type, ) - -# Global registry for fsdp_context -_g_fsdp_context = None - - -def register_fsdp_context(context): - global _g_fsdp_context - _g_fsdp_context = context - - -def get_fsdp_context(): - return _g_fsdp_context +from paddle.distributed.fsdp._fsdp_context import ( + register_fsdp_context, +) class BufferState(Enum): diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index da6c6e1ec33530..cd6d0c3aa49399 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -138,7 +138,12 @@ def _dygraph_clip(self, params_grads): and getattr(p, 'is_firstly_shared', True) ) - if not_shared_enable: + from paddle.distributed.fsdp._fsdp_context import ( + get_fsdp_context, + ) + + fsdp_context = get_fsdp_context() + if not_shared_enable and fsdp_context is None: if p.is_distributed: if g.dtype == paddle.float16: sum_square_dist_fp16.append(sum_square) @@ -280,7 +285,12 @@ class HybridParallelOptimizer: def __init__(self, optimizer, hcg, strategy): # Note: Only sharding stage 1 is considered in HybridParallelOptimizer. # The sharding stage2 and stage3 optimizers are invoked in other api. - if hcg.get_sharding_parallel_world_size() > 1: + from paddle.distributed.fsdp._fsdp_context import ( + get_fsdp_context, + ) + + fsdp_context = get_fsdp_context() + if hcg.get_sharding_parallel_world_size() > 1 and fsdp_context is None: split_param = strategy.hybrid_configs[ 'sharding_configs' ].split_param diff --git a/python/paddle/distributed/fsdp/_fsdp_context.py b/python/paddle/distributed/fsdp/_fsdp_context.py new file mode 100644 index 00000000000000..9f7074b0e2de43 --- /dev/null +++ b/python/paddle/distributed/fsdp/_fsdp_context.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Shared FSDP context module. +This module provides a unified global registry for fsdp_context, +used by both fsdp.fully_shard_fusion and auto_parallel.fully_shard_fusion. +""" + +# Global registry for fsdp_context +_g_fsdp_context = None + + +def register_fsdp_context(context): + global _g_fsdp_context + _g_fsdp_context = context + + +def get_fsdp_context(): + return _g_fsdp_context diff --git a/python/paddle/distributed/fsdp/fully_shard.py b/python/paddle/distributed/fsdp/fully_shard.py index 96adc7302065dd..53dcdc4eabe81e 100644 --- a/python/paddle/distributed/fsdp/fully_shard.py +++ b/python/paddle/distributed/fsdp/fully_shard.py @@ -22,9 +22,8 @@ import paddle.distributed as dist import paddle from paddle.distributed.auto_parallel.fully_shard import FullyShardAuto -from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_fully_shard import ( - FullyShard, -) + +from .fully_shard_fusion import FullyShardFusion def in_auto_parallel_mode() -> bool: @@ -56,7 +55,8 @@ def _fully_shard_manual_parallel( ignored_params, enable_tensor_fusion_and_overlap, ): - return FullyShard(module) + FullyShardFusion(module) + return module def _fully_shard_auto_parallel( diff --git a/python/paddle/distributed/fsdp/fully_shard_fusion.py b/python/paddle/distributed/fsdp/fully_shard_fusion.py new file mode 100644 index 00000000000000..0a1357d02d83ff --- /dev/null +++ b/python/paddle/distributed/fsdp/fully_shard_fusion.py @@ -0,0 +1,740 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from dataclasses import dataclass, field +from enum import Enum + +import numpy as np + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.fleet.utils.tensor_fusion_helper import ( + align, + alignment, + get_current_device_type, +) +from paddle.distributed.fsdp._fsdp_context import ( + register_fsdp_context, +) + + +class BufferState(Enum): + # Buffer status for lazy double buffer mechanism + # + # State transitions: + # FREED ──all_gather──> USING ──computation done──> READY ──release──> FREED + # ^ │ + # │ (reuse) │ + # └────────────────────────────┘ + + FREED = 1 # Released, buffer data is sharded, tmp_buffer not allocated + USING = 2 # Unsharded and actively in use + READY = 3 # Unsharded, marked for lazy release, can be reused + SYNCING = 4 # Communication in progress + + +@dataclass +class BufferGroup: + params: list = field(default_factory=list) + dtype: object = None + trainable: bool = None + fsdp_unit_id: int = None + is_tie: bool = False + is_expert_param: bool = False + fsdp_group: object = None + params_buffer: 'TensorFusionBuffer' = None + grads_buffer: 'TensorFusionBuffer' = None + params_use_sum: int = 0 + params_use_cnt: int = 0 + grads_use_sum: int = 0 + grads_use_cnt: int = 0 + + +class TensorFusionBuffer: + def __init__( + self, + unique_key, + params, + fsdp_group, + dtype, + is_params=False, + main_grad_dtype=None, + ): + # Calculate total buffer size needed (with padding) + self.unique_key = unique_key + self.fsdp_degree = fsdp_group.nranks + self.dtype = dtype + self.main_grad_dtype = ( + main_grad_dtype if main_grad_dtype is not None else dtype + ) + self.total_buffer_size = 0 + self.param_offsets = {} + self.tmp_data_buffer = None + self.comm_task = None + self.trainable = params[0].trainable + + for param in params: + self.param_offsets[param.name] = self.total_buffer_size + self.total_buffer_size += self.get_padded_size(param) + + if is_params: + # Create fused params_buffer + # TODO(lizhenxing): Build full params_buffer on CPU and only move shards to GPU to minimize mem peaks + self.data_buffer = paddle.zeros( + shape=[self.total_buffer_size], + dtype=dtype, + ) + # Use BufferState enum instead of is_shard boolean, initial state is FREED (sharded) + self.status = BufferState.FREED + + for param in params: + offset = self.param_offsets[param.name] + stop_gradient = param.stop_gradient + _shape = param.shape + param.stop_gradient = True + param.flatten_() + paddle.assign( + param, + self.data_buffer._slice( + offset, + offset + param._numel(), + ), + ) + param._clear_data() + param.stop_gradient = stop_gradient + param.get_tensor()._set_dims(_shape) + paddle.device.cuda.empty_cache() + + curr_rank = paddle.distributed.get_rank(fsdp_group) + total_nums = self.data_buffer.shape[0] + piece_len = (total_nums + self.fsdp_degree - 1) // self.fsdp_degree + start = curr_rank * piece_len + end = min(start + piece_len, total_nums) + self.data_buffer = paddle.slice( + self.data_buffer, [0], [start], [end] + ).clone() + paddle.device.cuda.empty_cache() + + # Init params_buffer attr + self.data_buffer.name = "fuse_params_" + str(unique_key) + self.data_buffer.stop_gradient = params[0].stop_gradient + self.data_buffer.optimize_attr = params[0].optimize_attr + else: + # Create fused grads_buffer with shard + self.data_buffer = paddle.zeros( + shape=[self.total_buffer_size // self.fsdp_degree], + dtype=self.main_grad_dtype, + ) + + # Register get_main_grad method for each param, returns view_slice of grad_buffer + for param in params: + if param.trainable: + param._fusion_buffer = self + param._param_offsets = self.param_offsets + + def get_grad_from_tmp_buf(param): + tmp_buffer = param._fusion_buffer.get_tmp_buffer() + offset = param._param_offsets[param.name] + main_grad = paddle._C_ops.view_slice( + tmp_buffer, + offset, + offset + param._numel(), + ) + return main_grad + + param.get_main_grad = get_grad_from_tmp_buf.__get__(param) + + def get_padded_size(self, param): + size = np.prod(param.shape) + align_size = ( + alignment[get_current_device_type()] + // align[param.dtype] + * self.fsdp_degree + ) + return ((size + align_size - 1) // align_size) * align_size + + def get_tmp_buffer(self): + # Reuse tmp_buffer if exists, else create + if self.tmp_data_buffer is None: + self.tmp_data_buffer = paddle.zeros( + shape=[self.total_buffer_size], dtype=self.dtype + ) + return self.tmp_data_buffer + + def clear_tmp_buffer(self): + if self.tmp_data_buffer is not None: + self.tmp_data_buffer._clear_data() + self.tmp_data_buffer = None + # paddle.device.cuda.empty_cache() + + +class FSDPBufferManager: + def __init__(self, model, fsdp_unit_layers=None, moe_layers_name=None): + self.model = model + self.hcg = fleet.get_hybrid_communicate_group() + self.dp_group = self.hcg.get_data_parallel_group() + self._fsdp_group = self.hcg.get_sharding_parallel_group() + self.main_grad_dtype = paddle.float32 + + paddle.device.cuda.empty_cache() + topk = None + if hasattr(self.model, 'config') and hasattr( + self.model.config, 'num_experts_per_tok' + ): + topk = self.model.config.num_experts_per_tok + + # Layer types to wrap as FSDP sharding layers + # Note: 'Qwen3VLTextDecoderLayer' is temporary; fleet models all use 'TransformerLayer' + self.fsdp_unit_layers = fsdp_unit_layers or [ + 'TransformerLayer', + 'Qwen3VLTextDecoderLayer', + 'Qwen3MoeDecoderLayer', + ] + self.moe_layers_name = moe_layers_name or [ + 'Qwen3MoeMLP', + 'StandardMLPExpert', + ] + + # Get tie_param_name if using tie_weights + self.tie_param_name = None + # Note: need add get_input_embeddings in fleet modeling + # if hasattr(self.model, "get_input_embeddings"): + # self.tie_param_name = self.model.get_input_embeddings().weight.name + + # Create buffer_groups + grouped_params, group_is_expert = self._build_groups() + self.buffer_groups = [] + self.param_to_buffer_id = {} + + # Create params_buffer, grads_buffer with groups + for gid, params in grouped_params.items(): + is_expert = group_is_expert.get(gid, False) + # Use EP group for expert params, DP group for regular params + + params_buffer = TensorFusionBuffer( + gid, + params, + self._fsdp_group, + params[0].dtype, + is_params=True, + ) + + if not params[0].stop_gradient: + grads_buffer = TensorFusionBuffer( + gid, + params, + self._fsdp_group, + params[0].dtype, + main_grad_dtype=self.main_grad_dtype, + ) + else: + grads_buffer = None + + if is_expert: + _params_use_sum = topk + _grads_use_sum = topk + else: + _params_use_sum = len(params) + _grads_use_sum = len(params) + self.buffer_groups.append( + BufferGroup( + params=params, + dtype=params[0].dtype, + trainable=params[0].trainable, + is_expert_param=is_expert, + fsdp_group=self._fsdp_group, + params_buffer=params_buffer, + grads_buffer=grads_buffer, + params_use_sum=_params_use_sum, + params_use_cnt=0, + grads_use_sum=_grads_use_sum, + grads_use_cnt=0, + ) + ) + + for param in params: + self.param_to_buffer_id[param.name] = gid + + def _build_groups(self): + parameters = self.model.parameters() + grouped_params = OrderedDict() + group_is_expert = {} + curr_gid = 0 + + param_to_unit_id = {} + for unit_id, m in enumerate(self.model.modules()): + if type(m).__name__ in self.fsdp_unit_layers: + for p in m.parameters(): + param_to_unit_id[p.name] = unit_id + if type(m).__name__ in self.moe_layers_name: + for p in m.parameters(): + p.is_moe_param = True + + param_groups = [] + for param in parameters: + name = param.name + is_expert = getattr(param, "is_moe_param", False) + if is_expert: + continue + is_tie = ( + self.tie_param_name is not None and name == self.tie_param_name + ) + + param_attrs = { + "dtype": param.dtype, + "trainable": param.trainable, + "fsdp_unit_id": param_to_unit_id.get(name), + "is_tie": is_tie, + "is_expert_param": is_expert, + } + + found_group = False + for param_group in param_groups: + if ( + param_group.dtype == param_attrs["dtype"] + and param_group.trainable == param_attrs["trainable"] + and param_group.fsdp_unit_id == param_attrs["fsdp_unit_id"] + and param_group.is_tie == param_attrs["is_tie"] + and param_group.is_expert_param + == param_attrs["is_expert_param"] + ): + param_group.params.append(param) + found_group = True + break + + # Create new group if no matching + if not found_group: + param_groups.append(BufferGroup(params=[param], **param_attrs)) + + def group_sort_key(group): + priority = 0 if group.is_tie else (1 if not group.trainable else 2) + return ( + priority, + group.fsdp_unit_id + if group.fsdp_unit_id is not None + else float('inf'), + ) + + sorted_groups = sorted(param_groups, key=group_sort_key) + + # For each sorted parameter group, buffer them by execution order + for param_group in sorted_groups: + cur_params = param_group.params + if len(cur_params) == 0: + continue + for p in cur_params: + grouped_params.setdefault(curr_gid, []).append(p) + group_is_expert[curr_gid] = param_group.is_expert_param + curr_gid += 1 + + return grouped_params, group_is_expert + + +class FSDPCommManager: + def __init__( + self, + buffer_manager, + enable_overlap=True, + double_buffer_limit=2, + ): + self.buffer_manager = buffer_manager + self.enable_overlap = enable_overlap + self.grad_reduce_queue = [] + + # for double buffer mechanism config + self.double_buffer_limit = double_buffer_limit + self.buffer_cnt_in_using = 0 + self.need_zero_grads = True + + def _release_one_buffer_if_needed(self): + # Release a buffer with the READY status if needed + while self.buffer_cnt_in_using >= self.double_buffer_limit: + found = False + for gid_idx, group in enumerate(self.buffer_manager.buffer_groups): + if group.params_buffer.status == BufferState.READY: + group.params_buffer.status = BufferState.FREED + group.params_buffer.clear_tmp_buffer() + self.buffer_cnt_in_using -= 1 + found = True + break + if not found: + break + + def _next_buffer_id(self, gid, is_backward): + # Get next buffer id for prefetch + if is_backward: + next_gid = gid - 1 + # Search backward for trainable buffer_groups + while ( + next_gid >= 0 + and not self.buffer_manager.buffer_groups[ + next_gid + ].params_buffer.trainable + ): + next_gid -= 1 + return max(next_gid, 0) + else: + return min(gid + 1, len(self.buffer_manager.buffer_groups) - 1) + + def all_gather_params(self, params, is_backward=False): + if len(params) == 0: + return + for param in params: + if hasattr(param, "is_moe_param"): + continue + gid = self.buffer_manager.param_to_buffer_id[param.name] + group = self.buffer_manager.buffer_groups[gid] + group.params_use_cnt += 1 + params_buffer = group.params_buffer + # Use group-specific fsdp_group + fsdp_group = group.fsdp_group or self.buffer_manager._fsdp_group + + # Double buffer: reuse buffer if status is READY + if params_buffer.status == BufferState.READY: + # Reuse: READY -> USING, no need to all_gather again + params_buffer.status = BufferState.USING + + # Overlap prefetch comm + if self.enable_overlap: + prefetch_count = 2 + curr_next_gid = gid + for _ in range(prefetch_count): + next_gid = self._next_buffer_id(curr_next_gid, is_backward) + next_group = self.buffer_manager.buffer_groups[next_gid] + next_params_buffer = next_group.params_buffer + next_fsdp_group = ( + next_group.fsdp_group or self.buffer_manager._fsdp_group + ) + if next_params_buffer.status == BufferState.FREED: + # Check double_buffer_limit before prefetch + self._release_one_buffer_if_needed() + next_params_buffer.status = BufferState.SYNCING + tmp_buffer_prefetch = ( + next_params_buffer.get_tmp_buffer() + ) + next_params_buffer.comm_task = ( + paddle.distributed.all_gather( + tmp_buffer_prefetch, + next_params_buffer.data_buffer, + group=next_fsdp_group, + sync_op=False, + ) + ) + self.buffer_cnt_in_using += 1 + + # Wait for async comm to complete: SYNCING -> USING + if params_buffer.status == BufferState.SYNCING: + params_buffer.status = BufferState.USING + params_buffer.comm_task.wait() + params_buffer.comm_task = None + + tmp_buffer = params_buffer.get_tmp_buffer() + # Do all_gather in sync: FREED -> USING + if params_buffer.status == BufferState.FREED: + fsdp_group.process_group.all_gather( + params_buffer.data_buffer, tmp_buffer + ).wait() + params_buffer.status = BufferState.USING + self.buffer_cnt_in_using += 1 + + # Bind the unsharded param to the real param + offset = params_buffer.param_offsets[param.name] + tmp_param = paddle._C_ops.view_slice( + tmp_buffer, + offset, + offset + param._numel(), + ) + tmp_param.get_tensor()._set_dims(param.shape) + param.get_tensor()._share_data_with(tmp_param.get_tensor()) + + def shard_params(self, params, is_backward=False): + affected_gids = set() + for param in params: + if hasattr(param, "is_moe_param"): + continue + gid = self.buffer_manager.param_to_buffer_id.get(param.name) + group = self.buffer_manager.buffer_groups[gid] + stop_gradient = param.stop_gradient + _shape = param.shape + param._clear_data() + param.stop_gradient = stop_gradient + param.get_tensor()._set_dims(_shape) + + affected_gids.add(gid) + + for gid in affected_gids: + group = self.buffer_manager.buffer_groups[gid] + if group.params_buffer.status == BufferState.USING: + group.params_buffer.status = BufferState.READY + + def reduce_scatter_grads(self, param): + if self.need_zero_grads: + self.need_zero_grads = False + for group in self.buffer_manager.buffer_groups: + if group.grads_buffer is not None: + group.grads_buffer.data_buffer.zero_() + gid = self.buffer_manager.param_to_buffer_id.get(param.name) + group = self.buffer_manager.buffer_groups[gid] + group.grads_use_cnt += 1 + fsdp_group = group.fsdp_group or self.buffer_manager._fsdp_group + param.main_grad = None + + if group.grads_use_cnt == group.grads_use_sum: + group.grads_use_cnt = 0 + + # reduce_scatter from tmp_grad_buffer into grads_buffer + grads_buffer = group.grads_buffer + + # Grad queue mechanism: wait and release completed reduce_scatter async tasks + self._wait_for_grad_comm() + + tmp_buffer = grads_buffer.get_tmp_buffer() + shard_size = grads_buffer.data_buffer.shape[0] + grad_buffer_shard = tmp_buffer._slice(0, shard_size) + if self.enable_overlap: + # Comm grads async and check all comm_task before optimizer update + grads_buffer.comm_task = paddle.distributed.reduce_scatter( + grad_buffer_shard, + tmp_buffer, + op=paddle.distributed.ReduceOp.SUM, + group=fsdp_group, + sync_op=False, + ) + + # Add async task to queue + self.grad_reduce_queue.append(grads_buffer) + else: + paddle.distributed.reduce_scatter( + grad_buffer_shard, + tmp_buffer, + op=paddle.distributed.ReduceOp.SUM, + group=fsdp_group, + sync_op=False, + ).wait() + grads_buffer.data_buffer.add_(grad_buffer_shard) + grads_buffer.clear_tmp_buffer() + + def _wait_for_grad_comm(self, queue_limit=2): + # Wait for async reduce_scatter tasks to complete and release resources + # queue_limit: max queue size, default use 2, 0 means wait for all + while len(self.grad_reduce_queue) > queue_limit: + grads_buffer = self.grad_reduce_queue.pop(0) + if grads_buffer.comm_task is not None: + grads_buffer.comm_task.wait() + grads_buffer.comm_task = None + tmp_buffer = grads_buffer.get_tmp_buffer() + shard_size = grads_buffer.data_buffer.shape[0] + grad_buffer_shard = tmp_buffer._slice(0, shard_size) + grads_buffer.data_buffer.add_(grad_buffer_shard) + grads_buffer.clear_tmp_buffer() + + def finish_grads_sync(self): + # Wait for all async reduce_scatter tasks, call before optimizer.step() + self._wait_for_grad_comm(queue_limit=0) + + def reset_params_buffer_status(self): + for group in self.buffer_manager.buffer_groups: + params_buffer = group.params_buffer + if params_buffer.status in (BufferState.READY, BufferState.USING): + # Clear stale tmp_buffer to force re-all_gather with updated data_buffer + params_buffer.clear_tmp_buffer() + params_buffer.status = BufferState.FREED + if self.buffer_cnt_in_using > 0: + self.buffer_cnt_in_using -= 1 + + +class FusionBackwardHook(PyLayer): + @staticmethod + def forward(ctx, *inputs, layer, comm_manager, recursive=False): + ctx.layer = layer + ctx.comm_manager = comm_manager + ctx.recursive = recursive + return inputs if len(inputs) > 1 else inputs[0] + + @staticmethod + def backward(ctx, *args): + trainable_params = [] + + for param in ctx.layer.parameters(include_sublayers=ctx.recursive): + if param.trainable: + trainable_params.append(param) + + ctx.comm_manager.all_gather_params(trainable_params, is_backward=True) + return args + + +class FusionForwardHook(PyLayer): + @staticmethod + def forward(ctx, *inputs, layer, comm_manager, recursive=False): + ctx.layer = layer + ctx.comm_manager = comm_manager + ctx.recursive = recursive + return inputs if len(inputs) > 1 else inputs[0] + + @staticmethod + def backward(ctx, *args): + ctx.comm_manager.shard_params( + ctx.layer.parameters(include_sublayers=ctx.recursive), + is_backward=True, + ) + return args + + +class FullyShardFusion: + def __init__(self, model, fsdp_unit_layers=None, moe_layers_name=None): + self.model = model + self.buffer_manager = FSDPBufferManager( + self.model, fsdp_unit_layers, moe_layers_name + ) + self.comm_manager = FSDPCommManager(self.buffer_manager) + self.register_tensor_fusion_hooks(self.model) + register_fsdp_context(self) + + def comm_sync_and_reset_status(self): + self.comm_manager.finish_grads_sync() + self.comm_manager.reset_params_buffer_status() + self.comm_manager.need_zero_grads = True + # Reset main_grad for all trainable parameters + for param in self.model.parameters(): + if param.trainable: + param.main_grad = None + + def register_tensor_fusion_hooks(self, model): + def _pre_forward_hook(sublayers, recursive=False): + comm_manager = self.comm_manager + + @paddle.autograd.no_grad() + def all_gather_comm(*_): + comm_manager.all_gather_params( + sublayers.parameters(include_sublayers=recursive) + ) + + return all_gather_comm + + def _post_forward_hook(sublayers, recursive=False): + comm_manager = self.comm_manager + + @paddle.autograd.no_grad() + def shard_comm(*_): + comm_manager.shard_params( + sublayers.parameters(include_sublayers=recursive) + ) + + return shard_comm + + def _update_main_grad_hook(param): + comm_manager = self.comm_manager + + @paddle.autograd.no_grad() + def comm_hook(grad): + if grad is not None and grad._is_initialized(): + # Share mem with grads_tmp_buffer + _main_grad = param.get_main_grad() + _main_grad.get_tensor()._set_dims(grad.shape) + param.main_grad = _main_grad + param.main_grad.copy_(grad) + grad._clear_data() + comm_manager.shard_params([param], is_backward=True) + comm_manager.reduce_scatter_grads(param) + + return comm_hook + + def _post_backward_hook(param): + param.main_grad = None + if hasattr(param, "get_main_grad"): + param._register_grad_hook(_update_main_grad_hook(param)) + + for param in model.parameters(): + if param.trainable: + _post_backward_hook(param) + + def _register_recursive(layer): + is_unit = ( + type(layer).__name__ in self.buffer_manager.fsdp_unit_layers + ) + + if is_unit: + # For FSDP Unit, register recursive hooks and stop recursion + layer.register_forward_pre_hook( + _pre_forward_hook(layer, recursive=True) + ) + layer.register_forward_post_hook( + _post_forward_hook(layer, recursive=True) + ) + self._register_fusion_layer_hooks(layer, recursive=True) + return + + if layer.parameters(include_sublayers=False): + layer.register_forward_pre_hook( + _pre_forward_hook(layer, recursive=False) + ) + layer.register_forward_post_hook( + _post_forward_hook(layer, recursive=False) + ) + self._register_fusion_layer_hooks(layer, recursive=False) + + for child in layer.children(): + _register_recursive(child) + + _register_recursive(model) + + def _register_fusion_layer_hooks(self, layer, recursive=False): + def _forward_post_hook(layer, inputs, outputs): + if isinstance(outputs, dict): + for key, value in outputs.items(): + if ( + isinstance(value, paddle.Tensor) + and not value.stop_gradient + ): + outputs[key] = FusionBackwardHook.apply( + value, + layer=layer, + comm_manager=self.comm_manager, + recursive=recursive, + ) + return outputs + elif isinstance(outputs, tuple): + result = FusionBackwardHook.apply( + *outputs, + layer=layer, + comm_manager=self.comm_manager, + recursive=recursive, + ) + if not isinstance(result, tuple): + result = (result,) + return result + else: + return FusionBackwardHook.apply( + outputs, + layer=layer, + comm_manager=self.comm_manager, + recursive=recursive, + ) + + def _forward_pre_hook(layer, inputs): + return FusionForwardHook.apply( + *inputs, + layer=layer, + comm_manager=self.comm_manager, + recursive=recursive, + ) + + layer.register_forward_post_hook(_forward_post_hook) + + # Register an additional hook for tie_weights shard_params + for param in layer.parameters(include_sublayers=False): + if param.name == self.comm_manager.buffer_manager.tie_param_name: + layer.register_forward_pre_hook(_forward_pre_hook) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index a1fcb180826a59..44b46e3352b217 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -1694,7 +1694,7 @@ def _apply_optimize( paddle.static.default_startup_program(), ): auto_dp = paddle.distributed.auto_parallel.auto_dp_utils.in_auto_dp_mode() - from paddle.distributed.auto_parallel.fully_shard_fusion import ( + from paddle.distributed.fsdp._fsdp_context import ( get_fsdp_context, )