Fix(lpips): load ImageNet backbone weights for pretrained models#4557
Fix(lpips): load ImageNet backbone weights for pretrained models#4557laggui merged 7 commits intotracel-ai:mainfrom
Conversation
Support loading PyTorch models saved in TAR format (pre-1.6), such as AlexNet and SqueezeNet from torchvision.
- not just use linear weights, add backbone weights
|
Currently, |
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (77.46%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #4557 +/- ##
==========================================
+ Coverage 61.21% 62.04% +0.83%
==========================================
Files 1062 1074 +12
Lines 136012 138456 +2444
==========================================
+ Hits 83258 85907 +2649
+ Misses 52754 52549 -205 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
laggui
left a comment
There was a problem hiding this comment.
Currently,
burn-train --features visiontests are not included in CI, so LPIPS pretrained test failures were not caught. Addingvision feature teststo CI should be considered in a follow-up PR.
Well that explains why it wasn't caught 😅 thanks
Current CI If you accept to change it, I will open PR. // burn-train vision
helpers::custom_crates_tests(
vec!["burn-train"],
handle_test_args(&["--features", "vision"],
args.release),
None,
None,
"std vision",
)?; |
Yeah we should do that! Can even be part of this PR if you want since it's related. |
antimora
left a comment
There was a problem hiding this comment.
Thanks for working on this! TAR format support is a useful addition since a number of older torchvision models (AlexNet, SqueezeNet) on PyTorch Hub are in this format, and the two-step weight loading approach for LPIPS is the right design (separate ImageNet backbone weights + LPIPS linear layer weights).
That said, there are several issues that need to be addressed before this can be merged. The biggest one: there are no integration tests for the TAR format itself. burn-store has ~47 test functions covering ZIP and legacy formats with Python-generated fixtures, but zero tests for TAR loading. We need the same treatment here.
Summary of issues (details in inline comments):
- Missing integration tests for TAR format in
burn-store(critical) - ~440 lines of copy-pasted code between
rebuild_tensorandrebuild_tensor_v2 - BFloat16 bug in
TarSourceelement size mapping (will silently corrupt data) - Silent default to F32 in multiple places instead of returning errors for unknown storage types
- Unused parameter, wasteful allocation in
read_range, and debugprintln!left in tests
| { | ||
| 1 | ||
| } else { | ||
| 4 // Default to float (4 bytes) |
There was a problem hiding this comment.
Bug: "BFloat16Storage".contains("Half") returns false, so BFloat16 falls through to 4. BFloat16 is 2 bytes. This will silently compute wrong offsets and corrupt tensor data.
The rebuild_tensor / rebuild_tensor_v2 code has the correct explicit mapping ("BFloat16Storage" => DType::BF16). Consider extracting a shared storage_type_to_dtype(name) -> DType helper and using dtype.size() here instead of a separate if/else chain with contains checks.
| /// * `storages_data` - Raw storages blob with structure: | ||
| /// - Count pickle (number of storages) | ||
| /// - For each storage: metadata pickle + u64 num_elements + raw binary data | ||
| pub fn new(_tensors_data: &[u8], storages_data: Vec<u8>) -> std::io::Result<Self> { |
There was a problem hiding this comment.
_tensors_data is never used. It propagates up to LazyDataSource::from_tar and the callsite in reader.rs, making the API misleading. Remove it.
| .unwrap_or_else(|poisoned| poisoned.into_inner()); | ||
| let data = source.read(key)?; | ||
| let end = (offset + length).min(data.len()); | ||
| Ok(data[offset..end].to_vec()) |
There was a problem hiding this comment.
Double allocation: source.read(key) allocates a Vec<u8> for the full storage, then this line allocates again for the slice. Since TarSource already holds the full blob in memory (storages_data), add a read_range method on TarSource that slices directly from self.storages_data[storage_offset + offset..].
| // tuple[2] is storage type class | ||
| let stype = match &tuple[2] { | ||
| super::pickle_reader::Object::Class { name, .. } => name.clone(), | ||
| _ => "FloatStorage".to_string(), |
There was a problem hiding this comment.
Silently defaulting to "FloatStorage" when the storage type is not a Class will mask real parsing errors. Return an error instead.
|
|
||
| /// Legacy _rebuild_tensor function for PyTorch < 1.6. | ||
| /// Same as rebuild_tensor_v2 but with fewer arguments: (storage, storage_offset, size, stride) | ||
| fn rebuild_tensor( |
There was a problem hiding this comment.
This entire function (~440 lines, through line 748) is a near-exact copy of rebuild_tensor_v2 below. The only difference is 4 args vs 5 args (v2 adds requires_grad and backward_hooks).
The TODO at line 936 acknowledges this. Please extract the shared logic: parse the storage args into a struct, build the closure once. Both functions should be thin wrappers. ~400 lines of duplication makes this harder to maintain and will cause bugs when one copy gets updated without the other.
| module_name: _, | ||
| name, | ||
| } => name.as_str(), | ||
| _ => "FloatStorage", |
There was a problem hiding this comment.
Same concern: defaulting to "FloatStorage" when the object type is unexpected will silently produce wrong results. Return an error for unexpected types.
| @@ -792,6 +818,137 @@ fn load_legacy_pytorch_file_with_metadata( | |||
| Ok((tensors, metadata)) | |||
There was a problem hiding this comment.
This is the biggest gap: there are no integration tests for the TAR format. The existing test suite in tests/reader/mod.rs has 47 tests covering ZIP and legacy formats with Python-generated fixtures. TAR needs the same treatment.
Since modern torch.save() cannot produce TAR files (this format predates PyTorch 0.1.10), you will need a Python script that manually constructs the TAR archive structure, similar to how create_legacy_with_offsets.py works.
At minimum, please add tests for:
- TAR format detection (
is_tar_file()) - Loading a float32 tensor from TAR and verifying values
- Loading multiple tensors (weight + bias) with correct shapes
- Loading different dtypes (float32, float64, int64)
There was a problem hiding this comment.
Thanks for the suggestion! Added 8 TAR format tests:
test_tar_format_detection- TAR file detectiontest_tar_float32_tensor- float32 tensor loadingtest_tar_float64_tensor- float64 tensor loadingtest_tar_int64_tensor- int64 tensor loadingtest_tar_multiple_tensors- multiple tensors (weight + bias)test_tar_multi_dtype- mixed dtypestest_tar_2d_tensor_shape- 2D shape verificationtest_tar_metadata- metadata verification
| let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device); | ||
| let distance = lpips.forward(image1, image2, Reduction::Mean); | ||
| let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0]; | ||
| println!("LPIPS VGG distance (black vs white): {}", distance_value); |
There was a problem hiding this comment.
Debug println! left in test code. Same at lines 787-789, 795-797, 815-817. Also the weight-printing block at lines 783-798 should be removed before merging.
| pub struct SqueezeFeatureExtractor<B: Backend> { | ||
| /// Conv1: 3 -> 64, kernel 3x3, stride 2 | ||
| conv1: Conv2d<B>, | ||
| pub conv1: Conv2d<B>, |
There was a problem hiding this comment.
conv1 was changed from private to pub only to support the debug weight printing in the test. Once that debug block is removed, revert this back to private.
| /// Load ImageNet pretrained backbone weights. | ||
| fn load_backbone_weights<B: Backend>( | ||
| lpips: Lpips<B>, | ||
| _net: LpipsNet, |
There was a problem hiding this comment.
_net is unused since you match on the Lpips enum variant. Remove it. Same for load_lpips_weights at line 189.
- add TAR format integration tests - deduplicate rebuild_tensor functions - fix BFloat16 bug and silent F32 defaults - remove unused parameters and debug println!
- remove _net - remove println
Thanks for your review!
|
antimora
left a comment
There was a problem hiding this comment.
Looks good, all the feedback has been addressed. Thanks for the thorough rework!
Summary
TARformat support toburn-storefor legacyPyTorchmodels (AlexNet,SqueezeNet)LPIPSto loadImageNetbackbone weights separately fromLPIPSlinear weightsPyTorch's original model structureChecklist
cargo run-checkscommand has been executed.Related Issues/PRs
Changes
LPIPSpretrained weights were not loading correctly. The original implementation only loadedLPIPSlinear layer weights but not theImageNetbackbone weights (VGG16/AlexNet/SqueezeNet). Additionally,AlexNetandSqueezeNetbackbone weights are stored inTARformat which was not supported.Solution
burn-store: AddedTARformat support for legacyPyTorchmodels (pre-1.6)TarSourceinlazy_data.rsTARfile detection and loading inreader.rsburn-train/lpips: Fixed pretrained weights loadingImageNetbackbone weights andLPIPSlinear layer weights separatelyPyTorch's original model structure (features.Xinstead ofnet.sliceX)URLsfromPyTorchofficial repositoryTesting