-
Notifications
You must be signed in to change notification settings - Fork 92
Question about MaskCLIP implementation in FeatUp: Differences from official MaskCLIP (residual connections and MLP) #91
Description
Hi maintainers,
Thank you for your work on FeatUp! I've been comparing the MaskCLIP implementation in FeatUp with the official MaskCLIP code, and noticed significant differences in how the value (v) features are processed—specifically regarding residual connections and the FFN (MLP) layer. I'm writing to understand the design considerations behind these changes.
1. Official vs. FeatUp
Official MaskCLIP's forward method:
def forward(self, x, return_qkv=False):
q, k, v = None, None, None
if return_qkv:
y = self.norm1(x)
y = F.linear(y, self.attn.attn.in_proj_weight, self.attn.attn.in_proj_bias)
N, L, C = y.shape
y = y.view(N, L, 3, C//3).permute(2, 0, 1, 3).reshape(3*N, L, C//3)
y = F.linear(y, self.attn.attn.out_proj.weight, self.attn.attn.out_proj.bias)
q, k, v = y.tensor_split(3, dim=0)
# Official: Residual connection for v
v += x
# Official: FFN (MLP) with residual for v
v = self.ffn(self.norm2(v), identity=v)
x = self.attn(self.norm1(x), identity=x)
x = self.ffn(self.norm2(x), identity=x)
return x, q, k, vFeatUp's forward_v method:
def forward_v(self, x: torch.Tensor):
v_in_proj_weight = self.attn.in_proj_weight[-self.attn.embed_dim:]
v_in_proj_bias = self.attn.in_proj_bias[-self.attn.embed_dim:]
# No residual, no FFN
v_in = F.linear(self.ln_1(x), v_in_proj_weight, v_in_proj_bias)
v_out = F.linear(v_in, self.attn.out_proj.weight, self.attn.out_proj.bias)
return v_out
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x2. Key Differences Highlighted
From the code, two critical deviations from the official implementation stand out:
-
Residual connections for
v:
Official MaskCLIP explicitly adds a residual connection tov(v += x) and retains the residual in FFN (self.ffn(..., identity=v)).
FeatUp’sforward_vremoves all residual logic—v_outis purely the result of linear projections, with no addition of the original inputx. -
FFN (MLP) application to
v:
Official MaskCLIP processesvthroughself.ffnafter normalization (self.norm2(v)).
FeatUp’sforward_vskips the FFN entirely, returningv_outdirectly after linear projections.
3. Questions
I hope you don’t mind me asking—while comparing FeatUp’s forward_v with the official MaskCLIP implementation, I noticed that FeatUp removes residual connections for v and skips the FFN processing, which are retained in the official version.
I’d be really grateful if you could share a bit about the thinking behind these differences. Were there specific considerations that led to these adjustments?