-
Notifications
You must be signed in to change notification settings - Fork 372
Expand file tree
/
Copy pathconditioning_saver.py
More file actions
69 lines (54 loc) · 2.37 KB
/
conditioning_saver.py
File metadata and controls
69 lines (54 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from datetime import datetime
from pathlib import Path
import comfy.utils
import folder_paths
import torch
from comfy_api.latest import io, ui
from .nodes_registry import comfy_node
@comfy_node(name="LTXVSaveConditioning")
class LTXVSaveConditioning(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVSaveConditioning",
display_name="🅛🅣🅧 LTXV Save Conditioning",
category="lightricks/LTXV",
inputs=[
io.Conditioning.Input("conditioning"),
io.String.Input("filename", default="conditioning"),
io.Combo.Input("dtype", options=["bfloat16", "float16"]),
],
is_output_node=True,
)
@classmethod
def execute(cls, conditioning: list, filename: str, dtype: str) -> io.NodeOutput:
if not conditioning or len(conditioning) == 0:
raise ValueError("Conditioning is empty")
embeddings_folder = Path(folder_paths.get_folder_paths("embeddings")[0])
embeddings_folder.mkdir(parents=True, exist_ok=True)
sanitized_filename = "".join(
c for c in filename if c.isalnum() or c in ("_", "-", ".")
)
if not sanitized_filename:
sanitized_filename = "conditioning"
output_path = embeddings_folder / f"{sanitized_filename}.safetensors"
target_dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float16
tensors_to_save: dict[str, torch.Tensor] = {}
for idx, (cond_tensor, cond_options) in enumerate(conditioning):
tensor_converted = cond_tensor.to(dtype=target_dtype).contiguous()
tensors_to_save[f"conditioning_data_{idx}"] = tensor_converted
if "attention_mask" in cond_options:
mask = cond_options["attention_mask"].contiguous()
tensors_to_save[f"attention_mask_{idx}"] = mask
metadata = {
"num_conditionings": str(len(conditioning)),
"dtype": dtype,
"created_at": str(datetime.now()),
}
comfy.utils.save_torch_file(
tensors_to_save, str(output_path), metadata=metadata
)
file_size_mb = output_path.stat().st_size / (1024 * 1024)
return io.NodeOutput(
ui=ui.PreviewText(f"Saved: {output_path.name} ({file_size_mb:.2f} MB)")
)