Skip to content

Commit 55dcc33

Browse files
committed
Add block swap debug prints and cleanup unused code
1 parent 9087791 commit 55dcc33

File tree

6 files changed

+26
-258
lines changed

6 files changed

+26
-258
lines changed

nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class WanVideoBlockSwap:
3939
def INPUT_TYPES(s):
4040
return {
4141
"required": {
42-
"blocks_to_swap": ("INT", {"default": 20, "min": 0, "max": 40, "step": 1, "tooltip": "Number of double blocks to swap"}),
42+
"blocks_to_swap": ("INT", {"default": 20, "min": 0, "max": 40, "step": 1, "tooltip": "Number of transformer blocks to swap, the 14B model has 40, while the 1.3B model has 30 blocks"}),
4343
"offload_img_emb": ("BOOLEAN", {"default": False, "tooltip": "Offload img_emb to offload_device"}),
4444
"offload_txt_emb": ("BOOLEAN", {"default": False, "tooltip": "Offload time_emb to offload_device"}),
4545
},

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "ComfyUI-WanVideoWrapper"
33
description = "ComfyUI diffusers wrapper nodes for WanVideo"
4-
version = "1.0.3"
4+
version = "1.0.4"
55
license = {file = "LICENSE"}
66
dependencies = ["accelerate >= 1.2.1", "diffusers >= 0.31.0", "ftfy"]
77

utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,11 @@ def print_memory(device):
2121
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
2222
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")
2323
#memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False)
24-
#log.info(f"Memory Summary:\n{memory_summary}")
24+
#log.info(f"Memory Summary:\n{memory_summary}")
25+
26+
def get_module_memory_mb(module):
27+
memory = 0
28+
for param in module.parameters():
29+
if param.data is not None:
30+
memory += param.nelement() * param.element_size()
31+
return memory / (1024 * 1024) # Convert to MB

wanvideo/modules/clip.py

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import torchvision.transforms as T
1010

1111
from .attention import attention
12-
from .tokenizers import HuggingfaceTokenizer
13-
from .xlm_roberta import XLMRoberta
1412

1513
__all__ = [
1614
'XLMRobertaCLIP',
@@ -155,60 +153,6 @@ def forward(self, x):
155153
x = x + self.mlp(self.norm2(x))
156154
return x
157155

158-
159-
class AttentionPool(nn.Module):
160-
161-
def __init__(self,
162-
dim,
163-
mlp_ratio,
164-
num_heads,
165-
activation='gelu',
166-
proj_dropout=0.0,
167-
norm_eps=1e-5):
168-
assert dim % num_heads == 0
169-
super().__init__()
170-
self.dim = dim
171-
self.mlp_ratio = mlp_ratio
172-
self.num_heads = num_heads
173-
self.head_dim = dim // num_heads
174-
self.proj_dropout = proj_dropout
175-
self.norm_eps = norm_eps
176-
177-
# layers
178-
gain = 1.0 / math.sqrt(dim)
179-
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
180-
self.to_q = nn.Linear(dim, dim)
181-
self.to_kv = nn.Linear(dim, dim * 2)
182-
self.proj = nn.Linear(dim, dim)
183-
self.norm = LayerNorm(dim, eps=norm_eps)
184-
self.mlp = nn.Sequential(
185-
nn.Linear(dim, int(dim * mlp_ratio)),
186-
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
187-
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
188-
189-
def forward(self, x):
190-
"""
191-
x: [B, L, C].
192-
"""
193-
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
194-
195-
# compute query, key, value
196-
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
197-
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
198-
199-
# compute attention
200-
x = flash_attention(q, k, v, version=2)
201-
x = x.reshape(b, 1, c)
202-
203-
# output
204-
x = self.proj(x)
205-
x = F.dropout(x, self.proj_dropout, self.training)
206-
207-
# mlp
208-
x = x + self.mlp(self.norm(x))
209-
return x[:, 0]
210-
211-
212156
class VisionTransformer(nn.Module):
213157

214158
def __init__(self,
@@ -275,9 +219,6 @@ def __init__(self,
275219
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
276220
elif pool_type == 'token_fc':
277221
self.head = nn.Linear(dim, out_dim)
278-
elif pool_type == 'attn_pool':
279-
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
280-
proj_dropout, norm_eps)
281222

282223
def forward(self, x, interpolation=False, use_31_block=False):
283224
b = x.size(0)
@@ -303,31 +244,6 @@ def forward(self, x, interpolation=False, use_31_block=False):
303244
return x
304245

305246

306-
class XLMRobertaWithHead(XLMRoberta):
307-
308-
def __init__(self, **kwargs):
309-
self.out_dim = kwargs.pop('out_dim')
310-
super().__init__(**kwargs)
311-
312-
# head
313-
mid_dim = (self.dim + self.out_dim) // 2
314-
self.head = nn.Sequential(
315-
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
316-
nn.Linear(mid_dim, self.out_dim, bias=False))
317-
318-
def forward(self, ids):
319-
# xlm-roberta
320-
x = super().forward(ids)
321-
322-
# average pooling
323-
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
324-
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
325-
326-
# head
327-
x = self.head(x)
328-
return x
329-
330-
331247
class XLMRobertaCLIP(nn.Module):
332248

333249
def __init__(self,

wanvideo/modules/model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from tqdm import tqdm
1717

18-
from ...utils import log
18+
from ...utils import log, get_module_memory_mb
1919

2020
def poly1d(coefficients, x):
2121
result = torch.zeros_like(x)
@@ -547,12 +547,27 @@ def block_swap(self, blocks_to_swap, offload_txt_emb=False, offload_img_emb=Fals
547547
self.blocks_to_swap = blocks_to_swap
548548
self.offload_img_emb = offload_img_emb
549549
self.offload_txt_emb = offload_txt_emb
550+
551+
total_offload_memory = 0
552+
total_main_memory = 0
550553

551554
for b, block in tqdm(enumerate(self.blocks), total=len(self.blocks), desc="Initializing block swap"):
555+
block_memory = get_module_memory_mb(block)
556+
552557
if b > self.blocks_to_swap:
553558
block.to(self.main_device)
559+
total_main_memory += block_memory
554560
else:
555561
block.to(self.offload_device)
562+
total_offload_memory += block_memory
563+
564+
#print(f"Block {b}: {block_memory:.2f}MB on {block.parameters().__next__().device}")
565+
log.info("----------------------")
566+
log.info(f"Block swap memory summary:")
567+
log.info(f"Transformer blocks on {self.offload_device}: {total_offload_memory:.2f}MB")
568+
log.info(f"Transformer blocks on {self.main_device}: {total_main_memory:.2f}MB")
569+
log.info(f"Total Memory: {(total_offload_memory + total_main_memory):.2f}MB")
570+
log.info("----------------------")
556571

557572
def forward(
558573
self,

wanvideo/modules/xlm_roberta.py

Lines changed: 0 additions & 170 deletions
This file was deleted.

0 commit comments

Comments
 (0)