-
Notifications
You must be signed in to change notification settings - Fork 334
Expand file tree
/
Copy pathmasks.py
More file actions
212 lines (183 loc) · 7.87 KB
/
masks.py
File metadata and controls
212 lines (183 loc) · 7.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from typing import Any
import torch
from comfy.model_management import is_device_cpu
from comfy_extras.nodes_mask import GrowMask
from .nodes_registry import comfy_node
@comfy_node(name="LTXVPreprocessMasks")
class LTXVPreprocessMasks:
"""
Preprocesses masks for use with the LTXVideo model's latent masking.
This node handles temporal mask processing by:
1. Validating mask dimensions against VAE downscaling factors
2. Optionally inverting masks
3. Handling the first frame mask separately (can be ignored)
4. Combining temporal masks using max pooling
5. Growing/shrinking masks with morphological operations
6. Clamping mask values to ensure proper opacity
The output is a set of masks ready for latent-space masking operations.
"""
@classmethod
def INPUT_TYPES(cls: "LTXVPreprocessMasks") -> dict[str, Any]: # noqa: N802
return {
"required": {
"masks": ("MASK",),
"vae": ("VAE",),
"invert_input_masks": (
"BOOLEAN",
{
"default": False,
"tooltip": "Invert the input masks before processing."
"Useful for masking vs unmasking operations.",
},
),
"ignore_first_mask": (
"BOOLEAN",
{
"default": True,
"tooltip": "Zero out the first mask, typically used when"
"it corresponds to the conditioning frame.",
},
),
"pooling_method": (
["max", "mean", "min"],
{
"default": "max",
"tooltip": "Method to combine temporal masks. Max preserves strongest values, "
"mean averages them, min takes weakest values.",
},
),
"grow_mask": (
"INT",
{
"default": 0,
"min": -16384,
"max": 16384,
"step": 1,
"tooltip": "Pixels to grow (positive) or shrink (negative) the mask."
"Uses morphological operations.",
},
),
"tapered_corners": (
"BOOLEAN",
{
"default": True,
"tooltip": "Use tapered corners in morphological operations for smoother mask edges.",
},
),
"clamp_min": (
"FLOAT",
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Minimum value to clamp masks to, ensuring minimum opacity.",
},
),
"clamp_max": (
"FLOAT",
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "Maximum value to clamp masks to.",
},
),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "preprocess_masks"
CATEGORY = "Lightricks/mask_operations"
DESCRIPTION = (
"Preprocess masks to be used for masking latents in the LTXVideo model."
)
def preprocess_masks(
self,
masks: torch.Tensor,
vae, # noqa: ANN001
pooling_method: str,
grow_mask: int,
tapered_corners: bool,
clamp_min: float,
clamp_max: float,
ignore_first_mask: bool,
invert_input_masks: bool,
) -> tuple[torch.Tensor]:
"""
Preprocess masks for temporal latent masking.
Args:
masks: Input masks tensor of shape (batch_size, H, W) where batch_size = 1 + (num_frames - 1)
The first mask is for the conditioning frame, remaining masks are for generated frames
vae: VAE model object containing downscale_index_formula attribute
pooling_method: Method to combine temporal masks ('max', 'mean', 'min')
grow_mask: Pixels to grow (positive) or shrink (negative) the mask
tapered_corners: Whether to use tapered corners in morphological operations
clamp_min: Minimum value to clamp masks to (ensures minimum opacity)
clamp_max: Maximum value to clamp masks to
ignore_first_mask: Whether to zero out the first mask (often used for conditioning)
invert_input_masks: Whether to invert all input masks (1 - mask)
Returns:
Tuple containing the processed masks tensor
Raises:
ValueError: If mask dimensions are not compatible with VAE downscaling factors
"""
if masks.ndim != 3:
raise ValueError("Masks must be of shape (batch_size, H, W).")
batch_size, H, W = masks.shape
# Get VAE downscaling factors for temporal and spatial dimensions
time_scale_factor, width_scale_factor, height_scale_factor = (
vae.downscale_index_formula
)
# Validate that mask dimensions are compatible with VAE downscaling
# The batch size minus 1 (for the first frame) must be divisible by time_scale_factor
if (batch_size - 1) % time_scale_factor != 0:
raise ValueError(
f"Masks batch size must have a multiple of {time_scale_factor} masks + 1."
)
if H % height_scale_factor != 0:
raise ValueError(
f"Masks height must have a multiple of {height_scale_factor}."
)
if W % width_scale_factor != 0:
raise ValueError(
f"Masks width must have a multiple of {width_scale_factor}."
)
# Optionally invert all masks (useful for masking vs unmasking operations)
if invert_input_masks:
masks = 1.0 - masks
# Handle first mask separately - often used as conditioning frame
if ignore_first_mask:
first_mask = masks[0:1, :, :] * 0 # Zero out the first mask
else:
first_mask = masks[0:1, :, :]
# Process remaining masks by grouping them according to time_scale_factor
# and taking the maximum value across each temporal group
remaining = masks[1:].view(-1, time_scale_factor, H, W)
# Apply selected pooling method
pooling_methods = {
"max": lambda x: x.max(dim=1)[0],
"mean": lambda x: x.mean(dim=1),
"min": lambda x: x.min(dim=1)[0],
}
combined = pooling_methods[pooling_method](remaining)
# Apply morphological operations to grow or shrink masks
if grow_mask != 0:
grow_mask_node = GrowMask()
# GrowMask node is not supported on GPU, so we need to move the combined tensor to CPU
# if it is not already on CPU
input_device = combined.device
if not is_device_cpu(input_device):
combined = combined.cpu()
(combined,) = grow_mask_node.expand_mask(
mask=combined, expand=grow_mask, tapered_corners=tapered_corners
)
# Move the combined tensor back to the original device
if not is_device_cpu(input_device):
combined = combined.to(input_device)
# Combine first mask with processed remaining masks and clamp values
# Clamping ensures masks have proper opacity values
processed_masks = torch.cat(
[first_mask, torch.clamp(combined, clamp_min, clamp_max)], dim=0
)
return (processed_masks,)