Skip to content

Commit b71503a

Browse files
committed
Updated how prompt seeds are generated
These are now returned from the get_seeds function which decides whether if should be the same as image seeds or generated separately. Fixes #535
1 parent 78d599c commit b71503a

4 files changed

Lines changed: 92 additions & 37 deletions

File tree

sd_dynamic_prompts/dynamic_prompting.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,24 +464,25 @@ def process(
464464
else:
465465
negative_generator = generator
466466

467-
all_seeds = None
468-
if num_images and not unlink_seed_from_prompt:
469-
p.all_seeds, p.all_subseeds = get_seeds(
467+
if num_images:
468+
image_seeds, image_subseeds, prompt_seeds = get_seeds(
470469
p,
471470
num_images,
472471
use_fixed_seed,
473472
is_combinatorial,
474473
combinatorial_batches,
474+
unlink_seed_from_prompt,
475475
)
476-
all_seeds = p.all_seeds
476+
p.all_seeds = image_seeds
477+
p.all_subseeds = image_subseeds
477478

478479
all_prompts, all_negative_prompts = generate_prompts(
479480
generator,
480481
negative_generator,
481482
original_prompt,
482483
original_negative_prompt,
483484
num_images,
484-
all_seeds,
485+
prompt_seeds,
485486
)
486487

487488
except GeneratorException as e:
@@ -493,12 +494,13 @@ def process(
493494
p.n_iter = math.ceil(updated_count / p.batch_size)
494495

495496
if num_images != updated_count:
496-
p.all_seeds, p.all_subseeds = get_seeds(
497+
p.all_seeds, p.all_subseeds, _ = get_seeds(
497498
p,
498499
updated_count,
499500
use_fixed_seed,
500501
is_combinatorial,
501502
combinatorial_batches,
503+
unlink_seed_from_prompt,
502504
)
503505

504506
if updated_count > 1:

sd_dynamic_prompts/helpers.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,29 @@
11
from __future__ import annotations
22

33
import logging
4+
import random
45
from pathlib import Path
56

67
from dynamicprompts.generators.promptgenerator import PromptGenerator
78

89
logger = logging.getLogger(__name__)
910

1011

12+
def get_fixed_seed(seed):
13+
# Copied from auto1111 modules/processing.py
14+
if seed is None or seed == "" or seed == -1:
15+
return int(random.randrange(4294967294))
16+
17+
return seed
18+
19+
1120
def get_seeds(
1221
p,
1322
num_seeds,
1423
use_fixed_seed,
1524
is_combinatorial=False,
1625
combinatorial_batches=1,
26+
unlink_seed_from_prompt=False,
1727
):
1828
if p.subseed_strength != 0:
1929
seed = int(p.all_seeds[0])
@@ -24,22 +34,27 @@ def get_seeds(
2434

2535
if use_fixed_seed:
2636
if is_combinatorial:
27-
all_seeds = []
28-
all_subseeds = [subseed] * num_seeds
37+
image_seeds = []
38+
image_subseeds = [subseed] * num_seeds
2939
for i in range(combinatorial_batches):
30-
all_seeds.extend([seed + i] * (num_seeds // combinatorial_batches))
40+
image_seeds.extend([seed + i] * (num_seeds // combinatorial_batches))
3141
else:
32-
all_seeds = [seed] * num_seeds
33-
all_subseeds = [subseed] * num_seeds
42+
image_seeds = [seed] * num_seeds
43+
image_subseeds = [subseed] * num_seeds
3444
else:
3545
if p.subseed_strength == 0:
36-
all_seeds = [seed + i for i in range(num_seeds)]
46+
image_seeds = [seed + i for i in range(num_seeds)]
3747
else:
38-
all_seeds = [seed] * num_seeds
48+
image_seeds = [seed] * num_seeds
3949

40-
all_subseeds = [subseed + i for i in range(num_seeds)]
50+
image_subseeds = [subseed + i for i in range(num_seeds)]
51+
52+
if unlink_seed_from_prompt:
53+
prompt_seeds = [get_fixed_seed(None) for _ in range(num_seeds)]
54+
else:
55+
prompt_seeds = image_seeds
4156

42-
return all_seeds, all_subseeds
57+
return image_seeds, image_subseeds, prompt_seeds
4358

4459

4560
def should_freeze_prompt(p):

tests/prompts/test_frozenprompt_generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55

66
def test_repeats_correctly():
7-
generator = FrozenPromptGenerator(RandomPromptGenerator())
7+
generator = FrozenPromptGenerator(
8+
RandomPromptGenerator(unlink_seed_from_prompt=True),
9+
)
810
template = "{A|B|C|D|E|F|G|H|I|J|K}"
911
prompts = generator.generate(template, 10)
1012

@@ -15,5 +17,4 @@ def test_repeats_correctly():
1517

1618
assert len(prompts2) == 10
1719
assert len(set(prompts2)) == 1
18-
1920
assert prompts[0] != prompts2[0]

tests/prompts/test_helpers.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,67 +22,104 @@ def processing():
2222
def test_get_seeds_with_fixed_seed(processing):
2323
num_seeds = 10
2424

25-
seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True)
26-
assert seeds == [processing.seed] * num_seeds
27-
assert subseeds == [processing.subseed] * num_seeds
25+
image_seeds, image_subseeds, _ = get_seeds(
26+
processing,
27+
num_seeds,
28+
use_fixed_seed=True,
29+
)
30+
assert image_seeds == [processing.seed] * num_seeds
31+
assert image_subseeds == [processing.subseed] * num_seeds
2832

2933
processing.subseed_strength = 0.5
3034

31-
seeds, subseeds = get_seeds(processing, num_seeds, use_fixed_seed=True)
32-
assert seeds == [processing.all_seeds[0]] * num_seeds
33-
assert subseeds == [processing.all_subseeds[0]] * num_seeds
35+
image_seeds, image_subseeds, _ = get_seeds(
36+
processing,
37+
num_seeds,
38+
use_fixed_seed=True,
39+
)
40+
assert image_seeds == [processing.all_seeds[0]] * num_seeds
41+
assert image_subseeds == [processing.all_subseeds[0]] * num_seeds
3442

3543

3644
def test_get_seeds_with_fixed_seed_batched_combinatorial(processing):
3745
num_seeds = 10
3846
combinatorial_batches = 3
39-
seeds, subseeds = get_seeds(
47+
image_seeds, image_subseeds, _ = get_seeds(
4048
processing,
4149
num_seeds,
4250
use_fixed_seed=True,
4351
is_combinatorial=True,
4452
combinatorial_batches=combinatorial_batches,
4553
)
4654
seed0 = processing.seed
47-
assert seeds == (
55+
assert image_seeds == (
4856
[seed0] * (num_seeds // 3)
4957
+ [seed0 + 1] * (num_seeds // 3)
5058
+ [seed0 + 2] * (num_seeds // 3)
5159
)
52-
assert subseeds == [processing.subseed] * num_seeds
60+
assert image_subseeds == [processing.subseed] * num_seeds
5361

5462
processing.subseed_strength = 0.5
5563

56-
seeds, subseeds = get_seeds(
64+
image_seeds, image_subseeds, _ = get_seeds(
5765
processing,
5866
num_seeds,
5967
use_fixed_seed=True,
6068
is_combinatorial=True,
6169
combinatorial_batches=combinatorial_batches,
6270
)
6371
seed0 = processing.all_seeds[0]
64-
assert seeds == (
72+
assert image_seeds == (
6573
[seed0] * (num_seeds // 3)
6674
+ [seed0 + 1] * (num_seeds // 3)
6775
+ [seed0 + 2] * (num_seeds // 3)
6876
)
69-
assert subseeds == [processing.all_subseeds[0]] * num_seeds
77+
assert image_subseeds == [processing.all_subseeds[0]] * num_seeds
7078

7179

7280
def test_get_seeds_with_random_seed(processing):
7381
num_seeds = 10
7482

75-
seed, subseed = processing.seed, processing.subseed
76-
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
77-
assert seeds == list(range(seed, seed + num_seeds))
78-
assert subseeds == list(range(subseed, subseed + num_seeds))
83+
image_seeds, image_subseeds = processing.seed, processing.subseed
84+
seeds, subseeds, _ = get_seeds(
85+
processing,
86+
num_seeds=num_seeds,
87+
use_fixed_seed=False,
88+
)
89+
assert seeds == list(range(image_seeds, image_seeds + num_seeds))
90+
assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds))
7991

8092
processing.subseed_strength = 0.5
8193

82-
seed, subseed = processing.all_seeds[0], processing.all_subseeds[0]
83-
seeds, subseeds = get_seeds(processing, num_seeds=num_seeds, use_fixed_seed=False)
84-
assert seeds == [seed] * num_seeds
85-
assert subseeds == list(range(subseed, subseed + num_seeds))
94+
image_seeds, image_subseeds = processing.all_seeds[0], processing.all_subseeds[0]
95+
seeds, subseeds, _ = get_seeds(
96+
processing,
97+
num_seeds=num_seeds,
98+
use_fixed_seed=False,
99+
)
100+
assert seeds == [image_seeds] * num_seeds
101+
assert subseeds == list(range(image_subseeds, image_subseeds + num_seeds))
102+
103+
104+
@pytest.mark.parametrize("use_fixed_seed", [True, False])
105+
def test_get_with_unlinked_seed(processing, use_fixed_seed):
106+
num_seeds = 10
107+
108+
image_seeds, _, prompt_seeds = get_seeds(
109+
processing,
110+
num_seeds,
111+
use_fixed_seed=use_fixed_seed,
112+
unlink_seed_from_prompt=False,
113+
)
114+
assert image_seeds == prompt_seeds
115+
116+
image_seeds, _, prompt_seeds = get_seeds(
117+
processing,
118+
num_seeds,
119+
use_fixed_seed=use_fixed_seed,
120+
unlink_seed_from_prompt=True,
121+
)
122+
assert image_seeds != prompt_seeds
86123

87124

88125
def test_load_magicprompt_models():

0 commit comments

Comments
 (0)