-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimclr_architecture.py
More file actions
420 lines (349 loc) · 14.7 KB
/
simclr_architecture.py
File metadata and controls
420 lines (349 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
class ResNetBackbone(nn.Module):
"""
ResNet backbone with modifications for spectrogram inputs.
Adapted to handle the characteristics of underwater acoustic spectrograms.
"""
def __init__(self, base_model='resnet18', pretrained=False):
super(ResNetBackbone, self).__init__()
# Load the base model
if base_model == 'resnet18':
base = models.resnet18(pretrained=pretrained)
elif base_model == 'resnet34':
base = models.resnet34(pretrained=pretrained)
elif base_model == 'resnet50':
base = models.resnet50(pretrained=pretrained)
else:
raise ValueError(f"Unsupported base model: {base_model}")
# Modify the first conv layer to accept single-channel spectrograms
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# If using pretrained weights, adapt the first conv layer
if pretrained:
self.conv1.weight.data = torch.mean(base.conv1.weight.data, dim=1, keepdim=True)
# Use the rest of the ResNet model
self.bn1 = base.bn1
self.relu = base.relu
self.maxpool = base.maxpool
self.layer1 = base.layer1
self.layer2 = base.layer2
self.layer3 = base.layer3
self.layer4 = base.layer4
self.avgpool = base.avgpool
# Get the output dimension for the projection head
self.feature_dim = base.fc.in_features
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
class FrequencyAttention(nn.Module):
"""
Attention mechanism that focuses on relevant frequency bands.
Helps the model attend to specific frequency ranges important for different acoustic signals.
"""
def __init__(self, in_channels, reduction_ratio=8):
super(FrequencyAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((None, 1)) # Pool along time dimension
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, h, w = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class TimeAttention(nn.Module):
"""
Attention mechanism that focuses on relevant time segments.
Helps the model attend to specific temporal patterns in acoustic signals.
"""
def __init__(self, in_channels, reduction_ratio=8):
super(TimeAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, None)) # Pool along frequency dimension
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, h, w = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class DualAttentionModule(nn.Module):
"""
Combines frequency and time attention to focus on relevant parts of the spectrogram.
"""
def __init__(self, in_channels, reduction_ratio=8):
super(DualAttentionModule, self).__init__()
self.freq_attention = FrequencyAttention(in_channels, reduction_ratio)
self.time_attention = TimeAttention(in_channels, reduction_ratio)
def forward(self, x):
x = self.freq_attention(x)
x = self.time_attention(x)
return x
class MultiScaleModule(nn.Module):
"""
Processes the input at multiple scales to capture both fine-grained and longer-term patterns.
Important for handling different types of acoustic events (transients vs. whale calls).
"""
def __init__(self, in_channels, out_channels):
super(MultiScaleModule, self).__init__()
# Different kernel sizes for capturing patterns at different scales
self.branch1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(out_channels // 4),
nn.ReLU(inplace=True)
)
self.branch2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 4, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels // 4),
nn.ReLU(inplace=True)
)
self.branch3 = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 4, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(out_channels // 4),
nn.ReLU(inplace=True)
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(out_channels // 4),
nn.ReLU(inplace=True)
)
def forward(self, x):
branch1 = self.branch1(x)
branch2 = self.branch2(x)
branch3 = self.branch3(x)
branch4 = self.branch4(x)
return torch.cat([branch1, branch2, branch3, branch4], 1)
class EnhancedBackbone(nn.Module):
"""
Enhanced backbone network with multi-scale processing and attention mechanisms.
Specifically designed for underwater acoustic spectrograms.
"""
def __init__(self, base_model='resnet18', pretrained=False):
super(EnhancedBackbone, self).__init__()
# Base ResNet backbone
self.backbone = ResNetBackbone(base_model, pretrained)
# Add multi-scale modules after each ResNet block
self.multi_scale1 = MultiScaleModule(64, 64)
self.multi_scale2 = MultiScaleModule(128, 128)
self.multi_scale3 = MultiScaleModule(256, 256)
# Add attention modules
self.attention1 = DualAttentionModule(64)
self.attention2 = DualAttentionModule(128)
self.attention3 = DualAttentionModule(256)
# Feature dimension remains the same as the base backbone
self.feature_dim = self.backbone.feature_dim
def forward(self, x):
# Initial layers
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
# Layer 1 with multi-scale and attention
x = self.backbone.layer1(x)
x = self.multi_scale1(x)
x = self.attention1(x)
# Layer 2 with multi-scale and attention
x = self.backbone.layer2(x)
x = self.multi_scale2(x)
x = self.attention2(x)
# Layer 3 with multi-scale and attention
x = self.backbone.layer3(x)
x = self.multi_scale3(x)
x = self.attention3(x)
# Final layers
x = self.backbone.layer4(x)
x = self.backbone.avgpool(x)
x = torch.flatten(x, 1)
return x
class ProjectionHead(nn.Module):
"""
Projection head for SimCLR.
Maps representations to the space where contrastive loss is applied.
"""
def __init__(self, input_dim, hidden_dim=512, output_dim=128):
super(ProjectionHead, self).__init__()
# Multi-layer projection head as recommended in SimCLR paper
self.projection = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, output_dim),
nn.BatchNorm1d(output_dim)
)
def forward(self, x):
return self.projection(x)
class UnderwaterAcousticSimCLR(nn.Module):
"""
Complete SimCLR model for underwater acoustic spectrograms.
Combines the enhanced backbone with the projection head.
"""
def __init__(self, base_model='resnet18', pretrained=False, projection_dim=128):
super(UnderwaterAcousticSimCLR, self).__init__()
# Enhanced backbone with multi-scale processing and attention
self.backbone = EnhancedBackbone(base_model, pretrained)
# Projection head
self.projection_head = ProjectionHead(
input_dim=self.backbone.feature_dim,
hidden_dim=self.backbone.feature_dim,
output_dim=projection_dim
)
def forward(self, x):
features = self.backbone(x)
projections = self.projection_head(features)
return features, projections
class NTXentLoss(nn.Module):
"""
Normalized Temperature-scaled Cross Entropy Loss from SimCLR paper.
"""
def __init__(self, temperature=0.5, batch_size=256):
super(NTXentLoss, self).__init__()
self.temperature = temperature
self.batch_size = batch_size
self.criterion = nn.CrossEntropyLoss(reduction="sum")
self.similarity_f = nn.CosineSimilarity(dim=2)
def forward(self, z_i, z_j):
"""
Calculate NT-Xent loss for batch of paired samples.
Args:
z_i, z_j: Batch of paired embeddings [N, D]
"""
# Concatenate embeddings from the two augmented views
representations = torch.cat([z_i, z_j], dim=0) # [2*N, D]
# Calculate similarity matrix
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
# Remove diagonal (self-similarity)
sim_ij = torch.diag(similarity_matrix, self.batch_size)
sim_ji = torch.diag(similarity_matrix, -self.batch_size)
positives = torch.cat([sim_ij, sim_ji], dim=0)
# Remove diagonal (self-similarity)
mask = (~torch.eye(2 * self.batch_size, dtype=bool, device=representations.device))
negatives = similarity_matrix[mask].view(2 * self.batch_size, -1)
# Compute loss
logits = torch.cat([positives.unsqueeze(1), negatives], dim=1) / self.temperature
labels = torch.zeros(2 * self.batch_size, dtype=torch.long, device=representations.device)
loss = self.criterion(logits, labels)
loss /= (2 * self.batch_size)
return loss
# Data augmentation functions for underwater acoustic spectrograms
def time_shift(spectrogram, max_shift_percent=0.2):
"""
Randomly shift the spectrogram in time.
Handles varying onset times of signals.
"""
_, width = spectrogram.shape
shift_amount = int(width * np.random.uniform(-max_shift_percent, max_shift_percent))
if shift_amount > 0:
shifted = torch.cat([torch.zeros_like(spectrogram[:, :shift_amount]), spectrogram[:, :-shift_amount]], dim=1)
elif shift_amount < 0:
shifted = torch.cat([spectrogram[:, -shift_amount:], torch.zeros_like(spectrogram[:, :shift_amount])], dim=1)
else:
shifted = spectrogram
return shifted
def time_mask(spectrogram, max_mask_percent=0.2, num_masks=2):
"""
Apply random masking in time dimension.
Simulates intermittent signals and improves robustness.
"""
_, width = spectrogram.shape
masked = spectrogram.clone()
for _ in range(num_masks):
mask_width = int(width * np.random.uniform(0, max_mask_percent))
mask_start = np.random.randint(0, width - mask_width)
masked[:, mask_start:mask_start + mask_width] = 0
return masked
def freq_mask(spectrogram, max_mask_percent=0.2, num_masks=2):
"""
Apply random masking in frequency dimension.
Improves robustness to frequency-selective noise.
"""
height, _ = spectrogram.shape
masked = spectrogram.clone()
for _ in range(num_masks):
mask_height = int(height * np.random.uniform(0, max_mask_percent))
mask_start = np.random.randint(0, height - mask_height)
masked[mask_start:mask_start + mask_height, :] = 0
return masked
def amplitude_scale(spectrogram, min_factor=0.5, max_factor=1.5):
"""
Randomly scale the amplitude of the spectrogram.
Handles variations in signal strength.
"""
scale_factor = np.random.uniform(min_factor, max_factor)
return spectrogram * scale_factor
def add_gaussian_noise(spectrogram, max_noise_percent=0.1):
"""
Add random Gaussian noise to the spectrogram.
Improves robustness to background noise.
"""
noise_level = np.random.uniform(0, max_noise_percent)
noise = torch.randn_like(spectrogram) * noise_level * torch.mean(spectrogram)
return spectrogram + noise
def freq_shift(spectrogram, max_shift_percent=0.2):
"""
Randomly shift the spectrogram in frequency.
Handles variations in pitch/frequency.
"""
height, _ = spectrogram.shape
shift_amount = int(height * np.random.uniform(-max_shift_percent, max_shift_percent))
if shift_amount > 0:
shifted = torch.cat([torch.zeros_like(spectrogram[:shift_amount, :]), spectrogram[:-shift_amount, :]], dim=0)
elif shift_amount < 0:
shifted = torch.cat([spectrogram[-shift_amount:, :], torch.zeros_like(spectrogram[:shift_amount, :])], dim=0)
else:
shifted = spectrogram
return shifted
def apply_augmentations(spectrogram, p=0.5):
"""
Apply a random combination of augmentations to the spectrogram.
"""
augmented = spectrogram.clone()
# Time-domain augmentations
if np.random.random() < p:
augmented = time_shift(augmented)
if np.random.random() < p:
augmented = time_mask(augmented)
# Frequency-domain augmentations
if np.random.random() < p:
augmented = freq_shift(augmented)
if np.random.random() < p:
augmented = freq_mask(augmented)
# Intensity augmentations
if np.random.random() < p:
augmented = amplitude_scale(augmented)
if np.random.random() < p:
augmented = add_gaussian_noise(augmented)
return augmented
# Configuration for the SimCLR model
simclr_config = {
'base_model': 'resnet18',
'pretrained': False,
'projection_dim': 128,
'batch_size': 64,
'temperature': 0.5,
'learning_rate': 0.0003,
'weight_decay': 1e-4,
'epochs': 100,
'augmentation_probability': 0.5
}