Skip to content

[GPTQ][ddp] enabling DDP for GPTQ#2333

Open
HDCharles wants to merge 14 commits intomainfrom
94_ddp_api
Open

[GPTQ][ddp] enabling DDP for GPTQ#2333
HDCharles wants to merge 14 commits intomainfrom
94_ddp_api

Conversation

@HDCharles
Copy link
Collaborator

@HDCharles HDCharles commented Feb 6, 2026

After the changes in vllm-project/compressed-tensors#572 vllm-project/compressed-tensors#534 #2340 we're ready to start rolling out DDP implementations of various modifiers

API:

The Api we've landed on attempts to maintain the normal flow with minimal changes necessary to enable DDP:

  1. the user will call torchrun --nproc_per_node==<num_threads> script.py to start the script
  2. the user will initialize the distributed context, (they can use the helper init_dist to do this)
  3. the user will load the model using the new context manager, setting the device map as outlined here. (For most users this will be "auto_offload")
  4. (optional) the user can partition the dataset at load time using get_rank_partition or just load as normal and oneshot will partition the data later (will load 1 copy of dataset into cpu memory for each rank which may be onerous)
from compressed_tensors.offload import load_offloaded_model, init_dist
init_dist()
with load_offloaded_model(): 
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto_offload")
...
ds = load_dataset(
    DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)

Implementation

Adding the DDP process to GPTQ has relatively straightforward though optimizing it for speed was a bit trickier. There are 4 steps

  1. assigning each module to a rank which it will be compressed by
  2. for each module assigned to a rank, having all hessian information sent by other ranks to the assigned rank
  3. each rank compresses the modules that it was assigned
  4. broadcast the final quantized values to all ranks

Step 1 required the largest optimization, without any load balancing, we ran into situations where 1 rank could be doing twice as much work as another. Thus we implemented basic load balancing and time estimation that seems to be working well in practice. The other major optimization was using asynchronous ops for thread to thread communication. Before these optimizations, 2 thread GPTQ was as fast as 1 thread GPTQ for llama3-8B, afterward it results in a 27% speedup despite being a relatively small model.

TODO insert benchmarks here

GPTQ Changes

while validating numerical accuracy of the DDP technique, we noticed that accuracy improved significantly for each thread added. After some debugging we realized this was because the existing hessian calculation was causing an accumulation of floating point errors. By rewriting the hessian calculation to sum the intermediate hessians and only divide by num_samples at the end, we improved the GSM8K evaluation from (.67, .66) to (.71, .71). You can repro these results here

TODO remove test code and add an example script

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @HDCharles, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces initial support for running GPTQ quantization with Distributed Data Parallel (DDP). The core changes involve distributing the quantization workload across multiple processes, where each process handles a portion of the model's modules. It also includes necessary adjustments for dataset partitioning, model loading, and saving in a distributed environment, aiming to enable more efficient quantization of large language models.

Highlights

  • Distributed GPTQ Quantization: The GPTQModifier has been updated to support Distributed Data Parallel (DDP) quantization, allowing multiple ranks to collectively quantize model modules by distributing the workload, reducing Hessians, and broadcasting results.
  • Dataset Partitioning for DDP: A new utility function get_rank_partition was added to src/llmcompressor/datasets/utils.py to enable partitioning of datasets across DDP ranks, ensuring each rank processes a distinct subset of calibration samples.
  • Distributed Model Loading and Saving: Modifications were made to save_pretrained_wrapper to handle DDP environments, ensuring only the main rank saves the model and temporarily disabling offloading during the save process. A new ct_offload context manager was introduced to patch from_pretrained for distributed loading.
  • Dynamic Device Assignment: The get_main_device utility now dynamically assigns CUDA or XPU devices based on the current DDP rank, improving device utilization in distributed setups.
  • DDP Test Case: A new test file test_ddp.py has been added to demonstrate and validate the GPTQ quantization process within a DDP environment, including distributed model loading, dataset preparation, and quantized model saving.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/llmcompressor/datasets/init.py
    • Exports the newly added get_rank_partition function.
  • src/llmcompressor/datasets/utils.py
    • Adds get_rank_partition function to divide a dataset split into partitions for each DDP rank, ensuring the last rank gets any remainder.
  • src/llmcompressor/modifiers/quantization/gptq/base.py
    • Imports torch.distributed as dist for DDP functionalities.
    • Refactors compress_modules to support both non-distributed and distributed quantization workflows.
    • Introduces a distributed logic where ranks collectively quantize modules: each rank processes a subset of modules, reduces Hessians, and then broadcasts quantization parameters.
    • Extracts the single module compression logic into a new compress_single_module method for reusability.
  • src/llmcompressor/transformers/compression/compressed_tensors_utils.py
    • Imports OffloadCache for managing offloaded tensors.
    • Modifies save_pretrained_wrapper to ensure only rank 0 performs the model saving in a distributed setup.
    • Adds a context manager OffloadCache.disable_onloading() around the original_save_pretrained call to prevent materializing the entire model on device during saving.
  • src/llmcompressor/utils/dev.py
    • Updates get_main_device to use the specific CUDA/XPU device corresponding to the current DDP rank, rather than always 0.
  • test_ddp.py
    • Adds a new test file for GPTQ with DDP.
    • Includes init_dist for DDP setup and is_ddp utility.
    • Defines convert_to_ct_offload to handle model offloading based on device_map configurations.
    • Implements patch_from_pretrained to modify from_pretrained behavior for distributed loading, adjusting device_map for each rank.
    • Introduces ct_offload context manager to wrap from_pretrained calls, manage DDP initialization, and apply patching/cleanup.
    • Demonstrates loading a model, partitioning a dataset using get_rank_partition, applying GPTQModifier, and saving the quantized model in a DDP environment.
Activity
  • This pull request is marked as a Proof of Concept (PoC) for GPTQ with DDP, indicating it's an initial implementation.
  • The author notes that there are 'some specifics to work through as apis are updated in compressed tensors', suggesting potential future refinements or dependencies.
  • A test plan is provided, instructing users to run torchrun --nproc_per_node=2 test_ddp.py to validate the changes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@HDCharles HDCharles added enhancement New feature or request gptq For any PR / issue related to GPTQ support labels Feb 6, 2026
@mergify
Copy link
Contributor

mergify bot commented Feb 6, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Proof of Concept for GPTQ with Distributed Data Parallel (DDP). The changes are mainly in the GPTQ modifier to handle distributed computation of Hessians and quantization. My review has identified a critical bug in the non-distributed path that would lead to incomplete quantization, as well as a high-severity issue in the new distributed logic that could cause a runtime error. I've also provided suggestions to improve code clarity and remove redundant or temporary code sections.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The key for module is already removed from self._num_samples on line 283 using pop(). This second call to pop() is redundant and can be removed.

@mergify mergify bot removed the quality-failed label Feb 9, 2026
@mergify
Copy link
Contributor

mergify bot commented Feb 9, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@mergify
Copy link
Contributor

mergify bot commented Feb 9, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@mergify
Copy link
Contributor

mergify bot commented Feb 11, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@mergify
Copy link
Contributor

mergify bot commented Feb 13, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@mergify
Copy link
Contributor

mergify bot commented Feb 13, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@mergify
Copy link
Contributor

mergify bot commented Feb 17, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

@HDCharles HDCharles changed the title [GPTQ][ddp] PoC for GPTQ with DDP [GPTQ][ddp] enabling DDP for GPTQ Feb 18, 2026
@HDCharles HDCharles added ready When a PR is ready for review dist Work pertaining to distributed work labels Feb 18, 2026
some specifics to work through as apis are updated in compressed tensors

Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
Summary

Signed-off-by: HDCharles <[email protected]>
@mergify mergify bot added the documentation Improvements or additions to documentation label Feb 18, 2026
module=module,
quant_args=quant_args,
hessians_dict=self._hessians,
hessian=self._hessians[module] / self._num_samples[module],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider passing num_samples as an arg to quantize_weight

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm unsure why this was implemented by passing entire dicts originally. Seems like i'd rather make the function more explicit on what its acting on i.e. what i have here.

Copy link
Collaborator

@kylesayrs kylesayrs Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that not passing a dictionary is cleaner, but it comes at a memory cost since we cannot "move" the hessian memory. This is an instance where I feel like (better behavior) is preferable to (cleaner code)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke offline, we agreed to pop the value from the dict.

Summary

Signed-off-by: HDCharles <[email protected]>
@mergify
Copy link
Contributor

mergify bot commented Feb 18, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just small programming nits, otherwise looks excellent

Comment on lines +340 to +343
if rank == target_rank:
wait_for_comms(pending_comms)
self._hessians[module] = H
self._num_samples[module] = n
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand this logic. Why is this step needed? Can't you just write to the memory address directly using

for module in module_list:
    h_comm = dist.reduce(
        self._hessians[module],
        op=dist.ReduceOp.SUM,
        dst=target_rank,
        async_op=True
    )

    pending_comms.append(h_comm)

wait_for_comms(pending_comms)

This way seems to maximize throughput more than how it's written now, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you do a similar approach when broadcasting.


# Broadcast each tensor asynchronously
# note: update in place, since compress_module_list updated the offload
for tensor in to_broadcast:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of splitting into two for loops? Why not just write

for module in module_list:
    for attr in _GPTQ_Q_PARAMS:
        if (tensor := getattr(module, attr, None) is not None:
            pending_comms.append(dist.broadcast(tensor, ...))

module=module,
quant_args=quant_args,
hessians_dict=self._hessians,
hessian=self._hessians[module] / self._num_samples[module],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke offline, we agreed to pop the value from the dict.

T = TypeVar("T", bound=Hashable)


def greedy_bin_packing(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful

Comment on lines +72 to +73
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to keep these?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dist Work pertaining to distributed work documentation Improvements or additions to documentation enhancement New feature or request gptq For any PR / issue related to GPTQ support quality-failed ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments