Skip to content

Fix(lpips): load ImageNet backbone weights for pretrained models#4557

Merged
laggui merged 7 commits intotracel-ai:mainfrom
koreaygj:fix/lpips
Feb 25, 2026
Merged

Fix(lpips): load ImageNet backbone weights for pretrained models#4557
laggui merged 7 commits intotracel-ai:mainfrom
koreaygj:fix/lpips

Conversation

@koreaygj
Copy link
Copy Markdown
Contributor

Summary

  • Add TAR format support to burn-store for legacy PyTorch models (AlexNet, SqueezeNet)
  • Fix LPIPS to load ImageNet backbone weights separately from LPIPS linear weights
  • Fix key remapping to match PyTorch's original model structure

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs


Changes

LPIPS pretrained weights were not loading correctly. The original implementation only loaded LPIPS linear layer weights but not the ImageNet backbone weights (VGG16/AlexNet/SqueezeNet). Additionally, AlexNet and SqueezeNet backbone weights are stored in TAR format which was not supported.


Solution

  1. burn-store: Added TAR format support for legacy PyTorch models (pre-1.6)
    • Implemented TarSource in lazy_data.rs
    • Added TAR file detection and loading in reader.rs
  2. burn-train/lpips: Fixed pretrained weights loading
    • Now loads both ImageNet backbone weights and LPIPS linear layer weights separately
    • Fixed key remapping to match PyTorch's original model structure (features.X instead of net.sliceX)
    • Added backbone weight URLs from PyTorch official repository

Testing

# LPIPS pretrained tests (all pass)
cargo test -p burn-train --lib --features vision test_lpips_pretrained

# Results:
# - VGG: 0.4112574 ✅
# - AlexNet: 0.37245688 ✅
# - SqueezeNet: 0.1753163 ✅

# burn-store tests
cargo test -p burn-store --features pytorch

  Support loading PyTorch models saved in TAR format (pre-1.6),
  such as AlexNet and SqueezeNet from torchvision.
- not just use linear weights, add backbone weights
@koreaygj
Copy link
Copy Markdown
Contributor Author

Currently, burn-train --features vision tests are not included in CI, so LPIPS pretrained test failures were not caught. Adding vision feature tests to CI should be considered in a follow-up PR.

@codecov
Copy link
Copy Markdown

codecov bot commented Feb 22, 2026

Codecov Report

❌ Patch coverage is 77.46914% with 146 lines in your changes missing coverage. Please review.
✅ Project coverage is 62.04%. Comparing base (5923b1e) to head (f931b5d).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-store/src/pytorch/pickle_reader.rs 60.39% 80 Missing ⚠️
crates/burn-store/src/pytorch/lazy_data.rs 59.66% 48 Missing ⚠️
crates/burn-store/src/pytorch/reader.rs 86.81% 12 Missing ⚠️
...ates/burn-train/src/metric/vision/lpips/weights.rs 95.08% 6 Missing ⚠️

❌ Your patch check has failed because the patch coverage (77.46%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (62.04%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #4557      +/-   ##
==========================================
+ Coverage   61.21%   62.04%   +0.83%     
==========================================
  Files        1062     1074      +12     
  Lines      136012   138456    +2444     
==========================================
+ Hits        83258    85907    +2649     
+ Misses      52754    52549     -205     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, burn-train --features vision tests are not included in CI, so LPIPS pretrained test failures were not caught. Adding vision feature tests to CI should be considered in a follow-up PR.

Well that explains why it wasn't caught 😅 thanks

@koreaygj
Copy link
Copy Markdown
Contributor Author

Currently, burn-train --features vision tests are not included in CI, so LPIPS pretrained test failures were not caught. Adding vision feature tests to CI should be considered in a follow-up PR.

Well that explains why it wasn't caught 😅 thanks

Current CI xtask/src/commands/test.rs includes spesial feature tests for burn-dataset, burn-core and burn-vision, but burn-train is not included.

If you accept to change it, I will open PR.
I can addburn-train vision tests in xtask/src/commands/test.rs
like below!

  // burn-train vision
  helpers::custom_crates_tests(
      vec!["burn-train"],
      handle_test_args(&["--features", "vision"],
  args.release),
      None,
      None,
      "std vision",
  )?;

@laggui
Copy link
Copy Markdown
Member

laggui commented Feb 24, 2026

I can addburn-train vision tests in xtask/src/commands/test.rs

Yeah we should do that! Can even be part of this PR if you want since it's related.

@antimora antimora self-requested a review February 24, 2026 18:17
Copy link
Copy Markdown
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! TAR format support is a useful addition since a number of older torchvision models (AlexNet, SqueezeNet) on PyTorch Hub are in this format, and the two-step weight loading approach for LPIPS is the right design (separate ImageNet backbone weights + LPIPS linear layer weights).

That said, there are several issues that need to be addressed before this can be merged. The biggest one: there are no integration tests for the TAR format itself. burn-store has ~47 test functions covering ZIP and legacy formats with Python-generated fixtures, but zero tests for TAR loading. We need the same treatment here.

Summary of issues (details in inline comments):

  1. Missing integration tests for TAR format in burn-store (critical)
  2. ~440 lines of copy-pasted code between rebuild_tensor and rebuild_tensor_v2
  3. BFloat16 bug in TarSource element size mapping (will silently corrupt data)
  4. Silent default to F32 in multiple places instead of returning errors for unknown storage types
  5. Unused parameter, wasteful allocation in read_range, and debug println! left in tests

{
1
} else {
4 // Default to float (4 bytes)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: "BFloat16Storage".contains("Half") returns false, so BFloat16 falls through to 4. BFloat16 is 2 bytes. This will silently compute wrong offsets and corrupt tensor data.

The rebuild_tensor / rebuild_tensor_v2 code has the correct explicit mapping ("BFloat16Storage" => DType::BF16). Consider extracting a shared storage_type_to_dtype(name) -> DType helper and using dtype.size() here instead of a separate if/else chain with contains checks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved!

/// * `storages_data` - Raw storages blob with structure:
/// - Count pickle (number of storages)
/// - For each storage: metadata pickle + u64 num_elements + raw binary data
pub fn new(_tensors_data: &[u8], storages_data: Vec<u8>) -> std::io::Result<Self> {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_tensors_data is never used. It propagates up to LazyDataSource::from_tar and the callsite in reader.rs, making the API misleading. Remove it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved.

.unwrap_or_else(|poisoned| poisoned.into_inner());
let data = source.read(key)?;
let end = (offset + length).min(data.len());
Ok(data[offset..end].to_vec())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double allocation: source.read(key) allocates a Vec<u8> for the full storage, then this line allocates again for the slice. Since TarSource already holds the full blob in memory (storages_data), add a read_range method on TarSource that slices directly from self.storages_data[storage_offset + offset..].

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved.

// tuple[2] is storage type class
let stype = match &tuple[2] {
super::pickle_reader::Object::Class { name, .. } => name.clone(),
_ => "FloatStorage".to_string(),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silently defaulting to "FloatStorage" when the storage type is not a Class will mask real parsing errors. Return an error instead.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved.


/// Legacy _rebuild_tensor function for PyTorch < 1.6.
/// Same as rebuild_tensor_v2 but with fewer arguments: (storage, storage_offset, size, stride)
fn rebuild_tensor(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire function (~440 lines, through line 748) is a near-exact copy of rebuild_tensor_v2 below. The only difference is 4 args vs 5 args (v2 adds requires_grad and backward_hooks).

The TODO at line 936 acknowledges this. Please extract the shared logic: parse the storage args into a struct, build the closure once. Both functions should be thin wrappers. ~400 lines of duplication makes this harder to maintain and will cause bugs when one copy gets updated without the other.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved.

module_name: _,
name,
} => name.as_str(),
_ => "FloatStorage",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern: defaulting to "FloatStorage" when the object type is unexpected will silently produce wrong results. Return an error for unexpected types.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved.

@@ -792,6 +818,137 @@ fn load_legacy_pytorch_file_with_metadata(
Ok((tensors, metadata))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the biggest gap: there are no integration tests for the TAR format. The existing test suite in tests/reader/mod.rs has 47 tests covering ZIP and legacy formats with Python-generated fixtures. TAR needs the same treatment.

Since modern torch.save() cannot produce TAR files (this format predates PyTorch 0.1.10), you will need a Python script that manually constructs the TAR archive structure, similar to how create_legacy_with_offsets.py works.

At minimum, please add tests for:

  • TAR format detection (is_tar_file())
  • Loading a float32 tensor from TAR and verifying values
  • Loading multiple tensors (weight + bias) with correct shapes
  • Loading different dtypes (float32, float64, int64)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Added 8 TAR format tests:

  1. test_tar_format_detection - TAR file detection
  2. test_tar_float32_tensor - float32 tensor loading
  3. test_tar_float64_tensor - float64 tensor loading
  4. test_tar_int64_tensor - int64 tensor loading
  5. test_tar_multiple_tensors - multiple tensors (weight + bias)
  6. test_tar_multi_dtype - mixed dtypes
  7. test_tar_2d_tensor_shape - 2D shape verification
  8. test_tar_metadata - metadata verification

let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);
let distance = lpips.forward(image1, image2, Reduction::Mean);
let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];
println!("LPIPS VGG distance (black vs white): {}", distance_value);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug println! left in test code. Same at lines 787-789, 795-797, 815-817. Also the weight-printing block at lines 783-798 should be removed before merging.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

pub struct SqueezeFeatureExtractor<B: Backend> {
/// Conv1: 3 -> 64, kernel 3x3, stride 2
conv1: Conv2d<B>,
pub conv1: Conv2d<B>,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conv1 was changed from private to pub only to support the debug weight printing in the test. Once that debug block is removed, revert this back to private.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

/// Load ImageNet pretrained backbone weights.
fn load_backbone_weights<B: Backend>(
lpips: Lpips<B>,
_net: LpipsNet,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_net is unused since you match on the Lpips enum variant. Remove it. Same for load_lpips_weights at line 189.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

@antimora antimora added enhancement Enhance existing features store labels Feb 24, 2026
- add TAR format integration tests
- deduplicate rebuild_tensor functions
- fix BFloat16 bug and silent F32 defaults
- remove unused parameters and debug println!
- remove _net
- remove println
@koreaygj
Copy link
Copy Markdown
Contributor Author

Thanks for working on this! TAR format support is a useful addition since a number of older torchvision models (AlexNet, SqueezeNet) on PyTorch Hub are in this format, and the two-step weight loading approach for LPIPS is the right design (separate ImageNet backbone weights + LPIPS linear layer weights).

That said, there are several issues that need to be addressed before this can be merged. The biggest one: there are no integration tests for the TAR format itself. burn-store has ~47 test functions covering ZIP and legacy formats with Python-generated fixtures, but zero tests for TAR loading. We need the same treatment here.

Summary of issues (details in inline comments):

  1. Missing integration tests for TAR format in burn-store (critical)
  2. ~440 lines of copy-pasted code between rebuild_tensor and rebuild_tensor_v2
  3. BFloat16 bug in TarSource element size mapping (will silently corrupt data)
  4. Silent default to F32 in multiple places instead of returning errors for unknown storage types
  5. Unused parameter, wasteful allocation in read_range, and debug println! left in tests

Thanks for your review!
I fix all your feedback 😊

  • Missing integration tests for TAR format in burn-store (critical)
  • ~440 lines of copy-pasted code between rebuild_tensor and rebuild_tensor_v2
  • BFloat16 bug in TarSource element size mapping (will silently corrupt data)
  • Silent default to F32 in multiple places instead of returning errors for unknown storage types
  • Unused parameter, wasteful allocation in read_range, and debug println! left in tests

Copy link
Copy Markdown
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, all the feedback has been addressed. Thanks for the thorough rework!

@laggui laggui merged commit 3f5c1bb into tracel-ai:main Feb 25, 2026
11 checks passed
@koreaygj koreaygj deleted the fix/lpips branch February 26, 2026 01:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement Enhance existing features store

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants