-
Notifications
You must be signed in to change notification settings - Fork 372
Expand file tree
/
Copy pathvanish_nodes.py
More file actions
146 lines (124 loc) · 4.73 KB
/
vanish_nodes.py
File metadata and controls
146 lines (124 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.nn.functional as F
from comfy_api.latest import io
from .nodes_registry import comfy_node
@comfy_node(name="LTXVDilateVideoMask")
class LTXVDilateVideoMask(io.ComfyNode):
"""Dilates a video mask spatially and/or temporally using max-pooling and thresholds the result."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVDilateVideoMask",
category="Lightricks/mask_operations",
description=(
"Dilates a video mask spatially and/or temporally using "
"separable max-pooling and thresholds the result."
),
inputs=[
io.Int.Input(
"spatial_radius",
default=1,
min=0,
max=30,
tooltip="Half-size of the spatial dilation kernel. Kernel = 2*radius+1.",
),
io.Int.Input(
"temporal_radius",
default=0,
min=0,
max=10,
tooltip="Half-size of the temporal dilation kernel. Kernel = 2*radius+1.",
),
io.Mask.Input(
"mask",
optional=True,
tooltip="Video mask to dilate. Either this or image_as_mask must be provided.",
),
io.Image.Input(
"image_as_mask",
optional=True,
tooltip="Image to use as mask (channel-averaged). Either this or mask must be provided.",
),
],
outputs=[
io.Mask.Output("mask"),
],
)
@classmethod
def execute(
cls,
spatial_radius: int,
temporal_radius: int,
mask: torch.Tensor | None = None,
image_as_mask: torch.Tensor | None = None,
) -> io.NodeOutput:
if mask is None and image_as_mask is None:
raise ValueError("Either 'mask' or 'image_as_mask' must be provided.")
if mask is None:
mask = image_as_mask.mean(dim=-1)
if mask.ndim == 4:
mask = mask[:, :, :, 0]
s_kernel = spatial_radius * 2 + 1
t_kernel = temporal_radius * 2 + 1
# Separable dilation: 2D spatial + 1D temporal (much faster than 3D pooling)
if s_kernel > 1:
mask = mask.unsqueeze(1) # (B, 1, H, W)
mask = F.max_pool2d(
mask, kernel_size=s_kernel, stride=1, padding=spatial_radius
)
mask = mask.squeeze(1)
if t_kernel > 1:
B, H, W = mask.shape
mask = mask.permute(1, 2, 0).reshape(H * W, 1, B)
mask = F.max_pool1d(
mask, kernel_size=t_kernel, stride=1, padding=temporal_radius
)
mask = mask.reshape(H, W, B).permute(2, 0, 1)
mask = (mask > 0.5).float()
return io.NodeOutput(mask)
_BG_COLOR_RGB = (102, 255, 0)
@comfy_node(name="LTXVInpaintPreprocess")
class LTXVInpaintPreprocess(io.ComfyNode):
"""Composites images with a green (#66FF00) background where mask is active.
If the mask has a single frame it is broadcast to match the video length.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVInpaintPreprocess",
category="Lightricks/image_processing",
description=(
"Composites images with a green background where mask is "
"active, for inpainting conditioning."
),
inputs=[
io.Image.Input(
"images",
tooltip="Video frames to composite onto the green background.",
),
io.Mask.Input(
"mask",
tooltip="Mask indicating regions to replace with green. Single-frame masks are broadcast.",
),
],
outputs=[
io.Image.Output("image"),
],
)
@classmethod
def execute(
cls,
images: torch.Tensor,
mask: torch.Tensor,
) -> io.NodeOutput:
if mask.ndim == 4:
mask = mask[:, :, :, 0]
if mask.shape[0] == 1 and images.shape[0] > 1:
mask = mask.expand(images.shape[0], -1, -1)
min_frames = min(mask.shape[0], images.shape[0])
mask = mask[:min_frames]
images = images[:min_frames]
mask_4d = mask.unsqueeze(-1) # (B, H, W, 1) for broadcasting
bg_color = torch.tensor(_BG_COLOR_RGB).float().to(images.device) / 255
result = images * (1 - mask_4d) + bg_color.view(1, 1, 1, 3) * mask_4d
return io.NodeOutput(result)