feat(dinomaly): Add context-aware recentering option from dinomaly2#3435
feat(dinomaly): Add context-aware recentering option from dinomaly2#3435rajeshgangireddy wants to merge 8 commits intoopen-edge-platform:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds an optional Dinomaly2-style “Context-Aware Recentering” mode to Dinomaly and includes a Lightning hook to improve Rich progress-bar epoch display when training is configured with max_steps.
Changes:
- Added
use_context_recenteringflag to toggle CLS-token-based recentering vs. existing Dinomaly behavior. - Implemented patch-feature recentering (subtract CLS token) in the encoder feature pipeline, and adjusted downstream token stripping accordingly.
- Added a Lightning
on_train_startworkaround to estimate epochs for Rich progress display whenmax_stepsis used.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/anomalib/models/image/dinomaly/torch_model.py |
Adds use_context_recentering option, validates incompatibility with remove_class_token, and applies recentering in feature extraction. |
src/anomalib/models/image/dinomaly/lightning_model.py |
Plumbs the new flag into the Lightning module and monkey-patches Rich progress-bar epoch description for step-based training. |
examples/configs/model/dinomaly.yaml |
Exposes use_context_recentering in the example config. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
… for accurate epoch display
There was a problem hiding this comment.
Pull request overview
Adds an optional Dinomaly2-style “Context-Aware Recentering” mode to Dinomaly and includes a Lightning/Rich progress-bar compatibility tweak for step-based training (max_steps).
Changes:
- Introduces
use_context_recenteringflag and enforces incompatibility withremove_class_token. - Applies CLS-token subtraction to patch tokens prior to reconstruction when the flag is enabled.
- Adds a
on_train_starthook to adjust RichProgressBar epoch display when Lightning usesmax_epochs=-1for step-based runs.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
src/anomalib/models/image/dinomaly/torch_model.py |
Adds use_context_recentering flag and modifies token handling to recenter patch features. |
src/anomalib/models/image/dinomaly/lightning_model.py |
Wires the new flag through the Lightning module and patches Rich progress-bar epoch display for max_steps runs. |
examples/configs/model/dinomaly.yaml |
Exposes use_context_recentering in the example config. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if use_context_recentering and remove_class_token: | ||
| msg = ( | ||
| "use_context_recentering=True requires access " | ||
| "to the class token and is incompatible with remove_class_token=True" | ||
| ) |
There was a problem hiding this comment.
The new use_context_recentering option adds a new feature-processing path and an incompatibility constraint with remove_class_token, but there are currently no unit tests covering DinomalyModel. Since this repo has unit tests for other image torch models, please add tests that (1) assert the incompatible flag combination raises, and (2) validate output tensor shapes/paths for use_context_recentering vs baseline to prevent regressions.
| """Fix Rich progress bar epoch display when using step-based training. | ||
|
|
||
| Lightning internally sets ``max_epochs=-1`` when only ``max_steps`` is | ||
| provided. The Rich progress bar then displays "Epoch X/-2" because it | ||
| computes ``max_epochs - 1``. This hook calculates the estimated number | ||
| of epochs from ``max_steps`` and ``num_training_batches`` so the | ||
| progress bar shows "Epoch X/M", where ``M`` is the estimated last | ||
| epoch index (``estimated_epochs - 1``). | ||
|
|
||
| Note: | ||
| This relies on Lightning's ``RichProgressBar._get_train_description`` | ||
| as implemented in Lightning 2.x. If the internals change in future | ||
| versions, this hook may need updating. | ||
| """ | ||
| if self.trainer.max_epochs is not None and self.trainer.max_epochs < 0: | ||
| try: | ||
| from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar | ||
| except ImportError: | ||
| return | ||
|
|
||
| progress_bar = getattr(self.trainer, "progress_bar_callback", None) | ||
| if isinstance(progress_bar, RichProgressBar) and hasattr( | ||
| progress_bar, | ||
| "_get_train_description", | ||
| ): | ||
| num_batches = self.trainer.num_training_batches | ||
| max_steps = self.trainer.max_steps | ||
| if ( | ||
| max_steps > 0 | ||
| and isinstance(num_batches, (int, float)) | ||
| and num_batches > 0 | ||
| and math.isfinite(num_batches) | ||
| ): | ||
| est_max_epochs = math.ceil(max_steps / num_batches) | ||
| else: | ||
| est_max_epochs = None | ||
|
|
||
| val_desc = getattr(progress_bar, "validation_description", "Validation") | ||
|
|
||
| class _FixedRichProgressBar(RichProgressBar): | ||
| """RichProgressBar subclass with corrected epoch display for step-based training.""" | ||
|
|
||
| def _get_train_description(self, current_epoch: int) -> str: # noqa: PLR6301 | ||
| desc = f"Epoch {current_epoch}" | ||
| if est_max_epochs is not None: | ||
| desc += f"/{est_max_epochs - 1}" | ||
| if len(val_desc) > len(desc): | ||
| desc = f"{desc:{len(val_desc)}}" | ||
| return desc | ||
|
|
||
| progress_bar.__class__ = _FixedRichProgressBar | ||
|
|
There was a problem hiding this comment.
This hook monkey-patches the trainer’s RichProgressBar by reassigning progress_bar.__class__ and overriding the private method _get_train_description. This is brittle (depends on Lightning internals that can change without warning) and makes behavior harder to reason about/debug. Consider moving this to a dedicated callback in anomalib (so it’s explicit/configurable) or avoid patching private methods by deriving max_epochs earlier (e.g., when building the Trainer) instead of mutating the callback at runtime.
| """Fix Rich progress bar epoch display when using step-based training. | |
| Lightning internally sets ``max_epochs=-1`` when only ``max_steps`` is | |
| provided. The Rich progress bar then displays "Epoch X/-2" because it | |
| computes ``max_epochs - 1``. This hook calculates the estimated number | |
| of epochs from ``max_steps`` and ``num_training_batches`` so the | |
| progress bar shows "Epoch X/M", where ``M`` is the estimated last | |
| epoch index (``estimated_epochs - 1``). | |
| Note: | |
| This relies on Lightning's ``RichProgressBar._get_train_description`` | |
| as implemented in Lightning 2.x. If the internals change in future | |
| versions, this hook may need updating. | |
| """ | |
| if self.trainer.max_epochs is not None and self.trainer.max_epochs < 0: | |
| try: | |
| from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar | |
| except ImportError: | |
| return | |
| progress_bar = getattr(self.trainer, "progress_bar_callback", None) | |
| if isinstance(progress_bar, RichProgressBar) and hasattr( | |
| progress_bar, | |
| "_get_train_description", | |
| ): | |
| num_batches = self.trainer.num_training_batches | |
| max_steps = self.trainer.max_steps | |
| if ( | |
| max_steps > 0 | |
| and isinstance(num_batches, (int, float)) | |
| and num_batches > 0 | |
| and math.isfinite(num_batches) | |
| ): | |
| est_max_epochs = math.ceil(max_steps / num_batches) | |
| else: | |
| est_max_epochs = None | |
| val_desc = getattr(progress_bar, "validation_description", "Validation") | |
| class _FixedRichProgressBar(RichProgressBar): | |
| """RichProgressBar subclass with corrected epoch display for step-based training.""" | |
| def _get_train_description(self, current_epoch: int) -> str: # noqa: PLR6301 | |
| desc = f"Epoch {current_epoch}" | |
| if est_max_epochs is not None: | |
| desc += f"/{est_max_epochs - 1}" | |
| if len(val_desc) > len(desc): | |
| desc = f"{desc:{len(val_desc)}}" | |
| return desc | |
| progress_bar.__class__ = _FixedRichProgressBar | |
| """Hook called at the beginning of training. | |
| This method previously adjusted the Rich progress bar epoch display for | |
| step-based training by monkey-patching Lightning's ``RichProgressBar``. | |
| That behavior has been removed to avoid relying on private Lightning | |
| internals and mutating callback classes at runtime. | |
| """ | |
| # Intentionally left without side effects to avoid brittle monkey-patching | |
| # of Lightning internals while preserving the public hook API. | |
| return None |
ashwinvaidya17
left a comment
There was a problem hiding this comment.
Thanks for adding this. Just one minor comment
| self._initialize_trainable_modules(self.trainable_modules) | ||
|
|
||
| def on_train_start(self) -> None: | ||
| """Fix Rich progress bar epoch display when using step-based training. |
There was a problem hiding this comment.
I feel this isn't limited to just this model. Maybe all models can benefit from this method if they ever use step based training. Should we move this to the base class?
| from .checkpoint import ModelCheckpoint | ||
| from .graph import GraphLogger | ||
| from .model_loader import LoadModelCallback | ||
| from .rich_progress_bar import MaxStepsProgressCallback | ||
| from .tiler_configuration import TilerConfigurationCallback | ||
| from .timer import TimerCallback | ||
|
|
There was a problem hiding this comment.
The module docstring lists which callbacks are exported, but it isn’t updated to include MaxStepsProgressCallback even though it is now part of the public API via __all__. Please update the docstring list so it matches the actual exports.
| """Max-steps progress callback for step-based training. | ||
|
|
||
| Lightning internally sets ``max_epochs=-1`` when only ``max_steps`` is provided. | ||
| The Rich progress bar then displays "Epoch X/-2" because it computes | ||
| ``max_epochs - 1``. This callback estimates the total number of epochs from | ||
| ``max_steps`` and ``num_training_batches`` so the progress bar shows a | ||
| meaningful "Epoch X/M" instead. |
There was a problem hiding this comment.
The module docstring says the progress bar will show "Epoch X/M", but the implementation appends est_max_epochs - 1 (e.g. "Epoch 0/9" for 10 epochs). Please align the docstring with the actual display convention or adjust the display logic to match the documented output.
| from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar | ||
| except ImportError: | ||
| return |
There was a problem hiding this comment.
RichProgressBar is imported from lightning.pytorch.callbacks.progress.rich_progress, which is an internal module path and more likely to change across Lightning versions. Prefer importing from the public API (e.g. lightning.pytorch.callbacks) with a fallback to the internal path if needed, to reduce breakage risk.
| from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar | |
| except ImportError: | |
| return | |
| from lightning.pytorch.callbacks import RichProgressBar | |
| except ImportError: | |
| try: | |
| from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar | |
| except ImportError: | |
| return |
| if use_context_recentering and remove_class_token: | ||
| msg = ( | ||
| "use_context_recentering=True requires access " | ||
| "to the class token and is incompatible with remove_class_token=True" | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
There was a problem hiding this comment.
The new use_context_recentering behavior (including the ValueError when combined with remove_class_token=True) isn’t covered by tests. Please add unit tests that (1) assert the incompatibility check raises, and (2) verify the recentering path produces patch-only tokens and preserves the expected spatial reshape downstream.


📝 Description
DAinomaly 2 was released a few months ago. Dinomaly2 extends Dinomaly with context-Aware Recentering. The other concepts like foundation Transformers, Noisy Bottleneck, Unfocused Linear Attention, Loose Reconstruction, are identical to what's already implemented.
Context-Aware Recentering subtracts the CLS (class) token from each patch feature before reconstruction:
This conditions feature reconstruction on class-specific context, preventing the decoder from reconstructing patterns that are "normal" in one object class but anomalous in another.
The code is not public, but for single class UAD scenario, the changes are simple (at least I assume)
In this PR, added an extra 'use_context_recentering' that toggles between dinomaly2 (if true) and dinomaly implementations.
Addtionaly added a logic to show the right number of epochs when max_steps is passed (Dinomaly paper uses 5000 steps )
There could be slighly improvoement (~1 to 2%) in pixel level scores.
Example numbers for Bottle category MvTecAD with max_steps = 5000
Need to test on more categories to get a clear picture.
✨ Changes
Select what type of change your PR is:
✅ Checklist
Before you submit your pull request, please make sure you have completed the following steps:
For more information about code review checklists, see the Code Review Checklist.