Skip to content

Conversation

@ved1beta
Copy link
Contributor

@ved1beta ved1beta commented Jan 11, 2026

#89

The state_dict() method in Optimizer8bit wraps quantization states (like state1, state2, absmax1, etc.) in a nested dict with key bnb_optimizer_quant_state
This allows FSDP's full_optim_state_dict to gather optimizer states across ranks without trying to gather the quantization tensors which have different shapes than model parameters
The load_state_dict() method automatically unwraps these states when loading

tested with
torchrun --nproc_per_node=2 test.py

"""
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullOptimStateDictConfig, StateDictType
import bitsandbytes as bnb

# Initialize distributed
# For single GPU testing, set environment variables if not already set
if "RANK" not in os.environ:
    os.environ["RANK"] = "0"
    os.environ["LOCAL_RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

dist.init_process_group("nccl")

rank = dist.get_rank()
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)

# Create model and wrap with FSDP
class YourModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(128, 64)
    
    def forward(self, x):
        return self.linear(x)

model = YourModel().to(device)
model = FSDP(model)

# Use 8-bit optimizer
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-4)

# Training loop
model.train()
for step in range(3):
    x = torch.randn(2, 128, device=device)
    y = model(x).sum()
    y.backward()
    optimizer.step()
    optimizer.zero_grad()
    if rank == 0:
        print(f"Step {step} completed")

# This call now works!
# The optimizer's state_dict() method automatically wraps quantization states
# in a nested dict that FSDP can handle during full_optim_state_dict gathering
with FSDP.state_dict_type(
    model,
    StateDictType.FULL_STATE_DICT,
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
    optim_state = FSDP.full_optim_state_dict(model, optimizer)
    
    if rank == 0:
        print("\n✓ Successfully retrieved full optimizer state dict!")
        print(f"State dict contains {len(optim_state.get('state', {}))} parameter states")
        
        # Verify quantization states are wrapped (for debugging)
        if optim_state.get("state"):
            first_param_state = next(iter(optim_state["state"].values()))
            wrapped_key = bnb.optim.optimizer.Optimizer8bit._FSDP_WRAPPED_QUANT_STATE_KEY
            if wrapped_key in first_param_state:
                print(f"✓ Quantization states are properly wrapped under '{wrapped_key}'")
                print(f"  Wrapped keys: {list(first_param_state[wrapped_key].keys())}")

dist.destroy_process_group()



@ved1beta ved1beta marked this pull request as draft January 11, 2026 14:29
@ved1beta ved1beta marked this pull request as ready for review January 11, 2026 14:32
@matthewdouglas matthewdouglas added FSDP Optimizers Issues or feature requests relating to optimizers labels Jan 12, 2026
@github-actions
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas
Copy link
Member

LGTM, thanks!

@matthewdouglas matthewdouglas added this to the v0.49.2 milestone Jan 14, 2026
@matthewdouglas matthewdouglas linked an issue Jan 14, 2026 that may be closed by this pull request
@matthewdouglas matthewdouglas merged commit 31610c9 into bitsandbytes-foundation:main Jan 14, 2026
84 of 85 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

FSDP Optimizers Issues or feature requests relating to optimizers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8-bit optimizers dont work with FSDP

2 participants