Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def register_module_override(self, module, param_name, config):


class Optimizer8bit(torch.optim.Optimizer):
_FSDP_WRAPPED_QUANT_STATE_KEY = "__bnb_optimizer_quant_state__"

def __init__(self, params, defaults, optim_bits=32, is_paged=False):
"""
Base 8-bit optimizer class.
Expand Down Expand Up @@ -152,6 +154,34 @@ def fill_qmap(self):
self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)

def state_dict(self):
"""Return optimizer state, wrapping quantization tensors for FSDP compatibility.

FSDP's full_optim_state_dict gathers all tensor states across ranks.
Quantization states (state1, state2, absmax, etc.) have different shapes
than model parameters, causing gather operations to fail. By wrapping
these tensors in a nested dict, FSDP skips them during gathering.
"""
state_dict = super().state_dict()

# Deep copy the state to avoid modifying the original optimizer state
# PyTorch's state_dict() only does a shallow copy
state_dict["state"] = {
k: {kk: vv for kk, vv in v.items()} if isinstance(v, dict) else v for k, v in state_dict["state"].items()
}

# Wrap quantization-specific tensors in a nested dict to hide from FSDP
for param_state in state_dict["state"].values():
if isinstance(param_state, dict):
quant_state = {}
keys_to_wrap = [k for k in param_state if k in self.non_castable_tensor_keys]
for key in keys_to_wrap:
quant_state[key] = param_state.pop(key)
if quant_state:
param_state[self._FSDP_WRAPPED_QUANT_STATE_KEY] = quant_state

return state_dict

def __setstate__(self, state):
super().__setstate__(state)

Expand All @@ -166,6 +196,13 @@ def load_state_dict(self, state_dict, move_to_device=True):
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)

# Unwrap quantization states that were wrapped for FSDP compatibility
for param_state in state_dict["state"].values():
if isinstance(param_state, dict) and self._FSDP_WRAPPED_QUANT_STATE_KEY in param_state:
quant_state = param_state.pop(self._FSDP_WRAPPED_QUANT_STATE_KEY)
param_state.update(quant_state)

# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict["param_groups"]
Expand Down