Skip to content

Conversation

@raimbekovm
Copy link
Contributor

@raimbekovm raimbekovm commented Dec 24, 2025

What does this PR do?

Fixes incorrect default interpolation in ViT, EfficientNet, and PVT image processors. The original implementations use BICUBIC but HuggingFace defaulted to BILINEAR/NEAREST.

Changes:

  • ViT: BILINEAR → BICUBIC
  • ViTFast: BILINEAR → BICUBIC
  • EfficientNet: NEAREST → BICUBIC
  • PVT: BILINEAR → BICUBIC

Verified against timm - all three use BICUBIC: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/transforms_factory.py#L75

Note: This is a breaking change for users relying on previous defaults.

Fixes part of #28180

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case. Verify interpolation of image processors #28180
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@NielsRogge

The original Vision Transformer implementation uses BICUBIC interpolation
for image preprocessing, but the HuggingFace image processor defaulted to
BILINEAR. This change aligns the default with the original implementation.

Changes:
- Update default resample from BILINEAR to BICUBIC in ViTImageProcessor
- Update default resample in ViTImageProcessorFast
- Update docstrings to reflect BICUBIC

Reference: https://github.com/huggingface/pytorch-image-models
Fixes part of huggingface#28180
These models copy the resize method from ViT, so they need to be updated
to match the new BICUBIC default interpolation.
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: efficientnet, imagegpt, layoutlmv2, layoutlmv3, pvt, segformer, vit

@raimbekovm
Copy link
Contributor Author

Quick clarification on the file changes:

The second commit updates 6 additional models because they copy ViT's resize method. After changing ViT from BILINEAR to BICUBIC, these copies needed to stay in sync.

This PR complements the ongoing interpolation fixes (#28180) by @lukepayyapilli for other models.

Let me know if any changes are needed!

@CodersAcademy006
Copy link

This makes sense and aligns ViT preprocessing with expected defaults (e.g. TIMM). Since this changes a default behavior, it might be worth explicitly calling it out as a breaking change for users relying on the previous interpolation. A small regression test asserting the default resampling value could also help prevent accidental future changes.

@Rocketknight1
Copy link
Member

Hey, did you check this against the original implementation in each case?

After verification against original implementations:
- ImageGPT: OpenAI original uses BILINEAR
- Segformer: MMSegmentation uses BILINEAR by default
- LayoutLMv2: Microsoft/Detectron2 uses BILINEAR
- LayoutLMv3: Microsoft/Detectron2 uses BILINEAR

These models were incorrectly changed to BICUBIC. Only ViT, EfficientNet,
and PVT should use BICUBIC (verified against timm).
@raimbekovm
Copy link
Contributor Author

Yes, verified against original implementations.

BICUBIC confirmed (timm): ViT, EfficientNet, PVT

Reverted to BILINEAR: ImageGPT, Segformer, LayoutLMv2, LayoutLMv3 - these use BILINEAR in their original repos (OpenAI, MMSegmentation, Microsoft/Detectron2).

Pushed a fix, PR now only changes the 3 models that actually need BICUBIC.

@raimbekovm raimbekovm changed the title 🚨 Fix ViT image processor default interpolation to BICUBIC Fix default interpolation to BICUBIC for ViT, EfficientNet, PVT Jan 12, 2026
@Rocketknight1
Copy link
Member

cc @NielsRogge for review in that case!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants