Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions sd_dynamic_prompts/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,24 +464,26 @@ def process(
else:
negative_generator = generator

all_seeds = None
if num_images and not unlink_seed_from_prompt:
p.all_seeds, p.all_subseeds = get_seeds(
prompt_seeds = p.all_seeds
if num_images:
image_seeds, image_subseeds, prompt_seeds = get_seeds(
p,
num_images,
use_fixed_seed,
is_combinatorial,
combinatorial_batches,
unlink_seed_from_prompt,
)
all_seeds = p.all_seeds
p.all_seeds = image_seeds
p.all_subseeds = image_subseeds

all_prompts, all_negative_prompts = generate_prompts(
generator,
negative_generator,
original_prompt,
original_negative_prompt,
num_images,
all_seeds,
prompt_seeds,
)

except GeneratorException as e:
Expand All @@ -493,12 +495,13 @@ def process(
p.n_iter = math.ceil(updated_count / p.batch_size)

if num_images != updated_count:
p.all_seeds, p.all_subseeds = get_seeds(
p.all_seeds, p.all_subseeds, _ = get_seeds(
p,
updated_count,
use_fixed_seed,
is_combinatorial,
combinatorial_batches,
unlink_seed_from_prompt,
)

if updated_count > 1:
Expand Down
33 changes: 24 additions & 9 deletions sd_dynamic_prompts/helpers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from __future__ import annotations

import logging
import random
from pathlib import Path

from dynamicprompts.generators.promptgenerator import PromptGenerator

logger = logging.getLogger(__name__)


def get_fixed_seed(seed):
# Copied from auto1111 modules/processing.py
if seed is None or seed == "" or seed == -1:
return int(random.randrange(4294967294))

return seed


def get_seeds(
p,
num_seeds,
use_fixed_seed,
is_combinatorial=False,
combinatorial_batches=1,
unlink_seed_from_prompt=False,
):
if p.subseed_strength != 0:
seed = int(p.all_seeds[0])
Expand All @@ -24,22 +34,27 @@ def get_seeds(

if use_fixed_seed:
if is_combinatorial:
all_seeds = []
all_subseeds = [subseed] * num_seeds
image_seeds = []
image_subseeds = [subseed] * num_seeds
for i in range(combinatorial_batches):
all_seeds.extend([seed + i] * (num_seeds // combinatorial_batches))
image_seeds.extend([seed + i] * (num_seeds // combinatorial_batches))
else:
all_seeds = [seed] * num_seeds
all_subseeds = [subseed] * num_seeds
image_seeds = [seed] * num_seeds
image_subseeds = [subseed] * num_seeds
else:
if p.subseed_strength == 0:
all_seeds = [seed + i for i in range(num_seeds)]
image_seeds = [seed + i for i in range(num_seeds)]
else:
all_seeds = [seed] * num_seeds
image_seeds = [seed] * num_seeds

all_subseeds = [subseed + i for i in range(num_seeds)]
image_subseeds = [subseed + i for i in range(num_seeds)]

if unlink_seed_from_prompt:
prompt_seeds = [get_fixed_seed(None) for _ in range(num_seeds)]
else:
prompt_seeds = image_seeds

return all_seeds, all_subseeds
return image_seeds, image_subseeds, prompt_seeds


def should_freeze_prompt(p):
Expand Down
5 changes: 3 additions & 2 deletions tests/prompts/test_frozenprompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@


def test_repeats_correctly():
generator = FrozenPromptGenerator(RandomPromptGenerator())
generator = FrozenPromptGenerator(
RandomPromptGenerator(unlink_seed_from_prompt=True),
)
template = "{A|B|C|D|E|F|G|H|I|J|K}"
prompts = generator.generate(template, 10)

Expand All @@ -15,5 +17,4 @@ def test_repeats_correctly():

assert len(prompts2) == 10
assert len(set(prompts2)) == 1

assert prompts[0] != prompts2[0]
77 changes: 57 additions & 20 deletions tests/prompts/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,67 +22,104 @@ def processing():
def test_get_seeds_with_fixed_seed(processing):
num_seeds = 10

seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True)
assert seeds == [processing.seed] * num_seeds
assert subseeds == [processing.subseed] * num_seeds
image_seeds, image_subseeds, _ = get_seeds(
processing,
num_seeds,
use_fixed_seed=True,
)
assert image_seeds == [processing.seed] * num_seeds
assert image_subseeds == [processing.subseed] * num_seeds

processing.subseed_strength = 0.5

seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True)
assert seeds == [processing.all_seeds[0]] * num_seeds
assert subseeds == [processing.all_subseeds[0]] * num_seeds
image_seeds, image_subseeds, _ = get_seeds(
processing,
num_seeds,
use_fixed_seed=True,
)
assert image_seeds == [processing.all_seeds[0]] * num_seeds
assert image_subseeds == [processing.all_subseeds[0]] * num_seeds


def test_get_seeds_with_fixed_seed_batched_combinatorial(processing):
num_seeds = 10
combinatorial_batches = 3
seeds, subseeds = get_seeds(
image_seeds, image_subseeds, _ = get_seeds(
processing,
num_seeds,
use_fixed_seed=True,
is_combinatorial=True,
combinatorial_batches=combinatorial_batches,
)
seed0 = processing.seed
assert seeds == (
assert image_seeds == (
[seed0] * (num_seeds // 3)
+ [seed0 + 1] * (num_seeds // 3)
+ [seed0 + 2] * (num_seeds // 3)
)
assert subseeds == [processing.subseed] * num_seeds
assert image_subseeds == [processing.subseed] * num_seeds

processing.subseed_strength = 0.5

seeds, subseeds = get_seeds(
image_seeds, image_subseeds, _ = get_seeds(
processing,
num_seeds,
use_fixed_seed=True,
is_combinatorial=True,
combinatorial_batches=combinatorial_batches,
)
seed0 = processing.all_seeds[0]
assert seeds == (
assert image_seeds == (
[seed0] * (num_seeds // 3)
+ [seed0 + 1] * (num_seeds // 3)
+ [seed0 + 2] * (num_seeds // 3)
)
assert subseeds == [processing.all_subseeds[0]] * num_seeds
assert image_subseeds == [processing.all_subseeds[0]] * num_seeds


def test_get_seeds_with_random_seed(processing):
num_seeds = 10

seed, subseed = processing.seed, processing.subseed
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
assert seeds == list(range(seed, seed + num_seeds))
assert subseeds == list(range(subseed, subseed + num_seeds))
image_seeds, image_subseeds = processing.seed, processing.subseed
seeds, subseeds, _ = get_seeds(
processing,
num_seeds=num_seeds,
use_fixed_seed=False,
)
assert seeds == list(range(image_seeds, image_seeds + num_seeds))
assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds))

processing.subseed_strength = 0.5

seed, subseed = processing.all_seeds[0], processing.all_subseeds[0]
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
assert seeds == [seed] * num_seeds
assert subseeds == list(range(subseed, subseed + num_seeds))
image_seeds, image_subseeds = processing.all_seeds[0], processing.all_subseeds[0]
seeds, subseeds, _ = get_seeds(
processing,
num_seeds=num_seeds,
use_fixed_seed=False,
)
assert seeds == [image_seeds] * num_seeds
assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds))


@pytest.mark.parametrize("use_fixed_seed", [True, False])
def test_get_with_unlinked_seed(processing, use_fixed_seed):
num_seeds = 10

image_seeds, _, prompt_seeds = get_seeds(
processing,
num_seeds,
use_fixed_seed=use_fixed_seed,
unlink_seed_from_prompt=False,
)
assert image_seeds == prompt_seeds

image_seeds, _, prompt_seeds = get_seeds(
processing,
num_seeds,
use_fixed_seed=use_fixed_seed,
unlink_seed_from_prompt=True,
)
assert image_seeds != prompt_seeds


def test_load_magicprompt_models():
Expand Down