Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .laion_loader import AestheticDataset
73 changes: 73 additions & 0 deletions datasets/laion_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import torch
from PIL import Image, ImageFile
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from diffusers.utils.torch_utils import randn_tensor

ImageFile.LOAD_TRUNCATED_IMAGES = True
class AestheticDataset(torch.utils.data.Dataset):

def __init__(self, root_dir):
self.root_dir = root_dir
self.file_pairs = self._get_file_pairs()

def _get_file_pairs(self):
file_pairs = []
for subdir, _, files in os.walk(self.root_dir):
images = [f for f in files if f.endswith('.jpg')]
for image in images:
basename = os.path.splitext(image)[0]
text_file = basename + '.txt'
if text_file in files:
image_path = os.path.join(subdir, image)
text_path = os.path.join(subdir, text_file)
file_pairs.append((image_path, text_path))
return file_pairs

def __len__(self):
return len(self.file_pairs)

def __getitem__(self, idx):
image_path, text_path = self.file_pairs[idx]
# Load and transform image
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((320, 576)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
pixel_values = transform(image)
# Load text prompt
with open(text_path, 'r', encoding='utf-8') as f:
text_prompt = f.read().strip()
# Create condition
condition = (pixel_values + torch.randn_like(pixel_values) * 0.02)

return {"pixel_values": pixel_values.unsqueeze(0), "text_prompt": text_prompt, "condition": condition.unsqueeze(0)}


def visualize_data(dataloader):
for data in dataloader:
pixel_values = data["pixel_values"][0][0].permute(
1, 2, 0) # Convert to HWC format for matplotlib
text_prompt = data["text_prompt"][0]

plt.imshow(pixel_values)
plt.axis('off')
plt.text(0.5, 1.05, text_prompt, ha='center',
va='center', transform=plt.gca().transAxes)
plt.savefig('test.jpg')
break # Just show one data point for demonstration


if __name__ == '__main__':
# Create the dataset and dataloader
# Replace with the path to your dataset
root_dir = "/18940970966/laion-high-aesthetics-output"
iterable_dataset = AestheticDataset(root_dir)
dataloader = torch.utils.data.DataLoader(iterable_dataset, batch_size=4)

# Visualize one data point
visualize_data(dataloader)
140 changes: 140 additions & 0 deletions demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run SVD text to video generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from pipelines.pipeline_stable_video_diffusion_text import StableVideoDiffusionPipeline\n",
"from diffusers import UNetSpatioTemporalConditionModel\n",
"from diffusers.utils import load_image, export_to_video\n",
"from transformers import CLIPTextModel, CLIPTokenizer\n",
"from xtend import EmbeddingProjection\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(\n",
" 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'\n",
")\n",
"text_encoder = CLIPTextModel.from_pretrained(\n",
" 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', torch_dtype=torch.float16\n",
")\n",
"unet = UNetSpatioTemporalConditionModel.from_pretrained(\n",
" \"/18940970966/diffusion_re/NEW_CODE/SVD_Xtend/outputs/checkpoint-11000/\",\n",
" subfolder=\"unet\",\n",
" low_cpu_mem_usage=True,\n",
" torch_dtype=torch.float16,\n",
")\n",
"pipe = StableVideoDiffusionPipeline.from_pretrained(\n",
" \"stabilityai/stable-video-diffusion-img2vid-xt\",\n",
" unet=unet,\n",
" text_encoder=text_encoder,\n",
" tokenizer=tokenizer,\n",
" torch_dtype=torch.float16, variant=\"fp16\", local_files_only=True,\n",
")\n",
"pipe.to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipe.unet.embedding_projection = EmbeddingProjection(1024, 1024).cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from safetensors import safe_open\n",
"\n",
"tensors = {}\n",
"with safe_open(\"/18940970966/diffusion_re/NEW_CODE/SVD_Xtend/outputs/checkpoint-11000/unet/diffusion_pytorch_model.safetensors\", framework=\"pt\", device=0) as f:\n",
" for k in f.keys():\n",
" if 'embedding_projection' in k:\n",
" tensors[k.replace('embedding_projection.', '')] = f.get_tensor(k)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tensors.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipe.unet.embedding_projection.load_state_dict(tensors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image = load_image('/18940970966/diffusion_re/NEW_CODE/SVD_Xtend/bdd100k/images/track/train/00a0f008-a315437f/00a0f008-a315437f-0000002.jpg')\n",
"image = image.resize((1024, 576))\n",
"\n",
"generator = torch.manual_seed(123)\n",
"frames = pipe(image,\n",
" prompt='a car driving on the road',\n",
" negative_prompt='',\n",
" height=576, \n",
" width=1024, \n",
" num_frames=16,\n",
" max_guidance_scale=10,\n",
" min_guidance_scale=7,\n",
" decode_chunk_size=8, generator=generator, motion_bucket_id=0, fps=7, noise_aug_strength=0.02, num_inference_steps=50).frames[0]\n",
"\n",
"export_to_video(frames, \"generated.mp4\", fps=7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "my_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading