Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@ Prompt enhancement plays a **crucial role** in enabling our model to generate hi

We highly recommend you to try the [PromptEnhancer-32B model](https://huggingface.co/PromptEnhancer/PromptEnhancer-32B) for higher-quality prompt enhancement.

#### Cloud-Based Prompt Enhancement (MiniMax)

If you don't have enough GPU memory for the local reprompt model, you can use MiniMax's Cloud API as a lightweight alternative. This offloads prompt enhancement to the cloud so your GPU is fully available for image generation.

1. Install the OpenAI SDK: `pip install openai`
2. Set your API key: `export MINIMAX_API_KEY="your-api-key"`
3. Use `reprompt_model="minimax"` when creating the pipeline:

```python
pipe = HunyuanImagePipeline.from_pretrained(
model_name="hunyuanimage-v2.1",
reprompt_model="minimax", # Use MiniMax Cloud API for prompt enhancement
use_fp8=True,
)

image = pipe(
prompt="A cute penguin wearing a scarf",
use_reprompt=True, # Enable cloud prompt enhancement
use_refiner=True,
)
```

You can also customize the model and base URL via environment variables:
- `MINIMAX_MODEL`: Model name (default: `MiniMax-M2.7`)
- `MINIMAX_BASE_URL`: API base URL (default: `https://api.minimax.io/v1`)


### Text to Image
HunyuanImage-2.1 **only supports 2K** image generation (e.g. 2048x2048 for 1:1 images, 2560x1536 for 16:9 images, etc.).
Expand Down
28 changes: 25 additions & 3 deletions hyimage/diffusion/pipelines/hunyuanimage_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from hyimage.common.constants import PRECISION_TO_TYPE
from hyimage.common.format_prompt import MultilingualPromptFormat
from hyimage.models.text_encoder import PROMPT_TEMPLATE
from hyimage.models.model_zoo import HUNYUANIMAGE_REPROMPT, HUNYUANIMAGE_REPROMPT_32B
from hyimage.models.model_zoo import HUNYUANIMAGE_REPROMPT, HUNYUANIMAGE_REPROMPT_32B, HUNYUANIMAGE_REPROMPT_MINIMAX
from hyimage.models.text_encoder.byT5 import load_glyph_byT5_v2
from hyimage.models.hunyuan.modules.hunyuanimage_dit import load_hunyuan_dit_state_dict
from hyimage.diffusion.cfg_utils import AdaptiveProjectedGuidance, rescale_noise_cfg
Expand Down Expand Up @@ -84,7 +84,7 @@ def create_default(cls, version: str = "v2.1", use_distilled: bool = False, repr
dit_config=dit_config,
vae_config=HUNYUANIMAGE_V2_1_VAE_32x(),
text_encoder_config=HUNYUANIMAGE_V2_1_TEXT_ENCODER(),
reprompt_config=HUNYUANIMAGE_REPROMPT_32B() if reprompt_model == "hunyuanimage-reprompt-32b" else HUNYUANIMAGE_REPROMPT(),
reprompt_config=self._resolve_reprompt_config(reprompt_model),
shift=4 if use_distilled else 5,
default_guidance_scale=3.25 if use_distilled else 3.5,
default_sampling_steps=8 if use_distilled else 50,
Expand All @@ -94,6 +94,16 @@ def create_default(cls, version: str = "v2.1", use_distilled: bool = False, repr
else:
raise ValueError(f"Unsupported HunyuanImage version: {version}. Only 'v2.1' is supported")

@staticmethod
def _resolve_reprompt_config(reprompt_model):
"""Resolve reprompt model name to its configuration."""
if reprompt_model == "hunyuanimage-reprompt-32b":
return HUNYUANIMAGE_REPROMPT_32B()
elif reprompt_model in ("hunyuanimage-reprompt-minimax", "minimax"):
return HUNYUANIMAGE_REPROMPT_MINIMAX()
else:
return HUNYUANIMAGE_REPROMPT()


class HunyuanImagePipeline:
"""
Expand Down Expand Up @@ -241,7 +251,13 @@ def _load_reprompt_model(self):
if self.config.enable_stage1_offloading:
self.offload()
reprompt_config = self.config.reprompt_config
self._reprompt_model = instantiate(reprompt_config.model, models_root_path=reprompt_config.load_from, enable_offloading=self.config.enable_reprompt_model_offloading)
# Cloud reprompt models don't need models_root_path or offloading
is_cloud = not reprompt_config.load_from
kwargs = {}
if not is_cloud:
kwargs["models_root_path"] = reprompt_config.load_from
kwargs["enable_offloading"] = self.config.enable_reprompt_model_offloading
self._reprompt_model = instantiate(reprompt_config.model, **kwargs)
loguru.logger.info("✓ Reprompt model loaded")
except Exception as e:
raise RuntimeError(f"Error loading reprompt model: {e}") from e
Expand Down Expand Up @@ -896,6 +912,12 @@ def from_pretrained(cls, model_name: str = "hunyuanimage-v2.1", use_distilled: b
Args:
model_name: Model name, supports "hunyuanimage-v2.1", "hunyuanimage-v2.1-distilled"
use_distilled: Whether to use distilled model (overrides model_name if specified)
reprompt_model: Reprompt model to use for prompt enhancement.
Supported values:
- "hunyuanimage-reprompt-32b" (default): Local 32B model
- "hunyuanimage-reprompt": Local smaller model
- "hunyuanimage-reprompt-minimax" or "minimax": MiniMax Cloud API
(requires MINIMAX_API_KEY environment variable)
**kwargs: Additional configuration options

Returns:
Expand Down
14 changes: 13 additions & 1 deletion hyimage/models/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,23 @@ def HUNYUANIMAGE_REPROMPT(**kwargs):

def HUNYUANIMAGE_REPROMPT_32B(**kwargs):
from hyimage.models.reprompt.reprompt_32b import RePrompt

return RepromptConfig(
model=L(RePrompt)(
models_root_path=None,
device_map="auto",
),
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/reprompt_32b",
)


def HUNYUANIMAGE_REPROMPT_MINIMAX(**kwargs):
from hyimage.models.reprompt.reprompt_cloud import RePromptCloud

return RepromptConfig(
model=L(RePromptCloud)(
models_root_path=None,
provider="minimax",
),
load_from="",
)
164 changes: 164 additions & 0 deletions hyimage/models/reprompt/reprompt_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import os
import re
import logging

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = (
"你是一位图像生成提示词撰写专家,请根据用户输入的提示词,改写生成新的提示词,改写后的提示词要求:"
"1 改写后提示词包含的主体/动作/数量/风格/布局/关系/属性/文字等 必须和改写前的意图一致; "
"2 在宏观上遵循“总-分-总”的结构,确保信息的层次清晰;"
"3 客观中立,避免主观臆断和情感评价;"
"4 由主到次,始终先描述最重要的元素,再描述次要和背景元素;"
"5 逻辑清晰,严格遵循空间逻辑或主次逻辑,使读者能在大脑中重建画面;"
"6 结尾点题,必须用一句话总结图像的整体风格或类型。"
)


class RePromptCloud:
"""
Cloud-based prompt enhancement using MiniMax API (OpenAI-compatible).

This class provides the same interface as the local RePrompt model but
uses MiniMax's cloud LLM API for prompt enhancement, eliminating the
need for local GPU resources for the reprompt model.

Environment variables:
MINIMAX_API_KEY: Required. Your MiniMax API key.
MINIMAX_MODEL: Optional. Model name (default: MiniMax-M2.7).
MINIMAX_BASE_URL: Optional. API base URL
(default: https://api.minimax.io/v1).
"""

SUPPORTED_PROVIDERS = {
"minimax": {
"default_base_url": "https://api.minimax.io/v1",
"default_model": "MiniMax-M2.7",
"api_key_env": "MINIMAX_API_KEY",
},
}

def __init__(
self,
models_root_path=None,
device_map="auto",
enable_offloading=True,
provider="minimax",
api_key=None,
base_url=None,
model=None,
):
"""
Initialize the cloud-based reprompt model.

Args:
models_root_path: Ignored (kept for interface compatibility).
device_map: Ignored (kept for interface compatibility).
enable_offloading: Ignored (kept for interface compatibility).
provider: Cloud LLM provider name (default: "minimax").
api_key: API key. If None, reads from environment variable.
base_url: API base URL. If None, uses provider default.
model: Model name. If None, uses provider default.
"""
if provider not in self.SUPPORTED_PROVIDERS:
raise ValueError(
f"Unsupported provider: {provider}. "
f"Supported: {list(self.SUPPORTED_PROVIDERS.keys())}"
)

provider_config = self.SUPPORTED_PROVIDERS[provider]
self.provider = provider

self.api_key = api_key or os.environ.get(provider_config["api_key_env"])
if not self.api_key:
raise ValueError(
f"API key required. Set {provider_config['api_key_env']} "
f"environment variable or pass api_key parameter."
)

self.base_url = (
base_url
or os.environ.get("MINIMAX_BASE_URL")
or provider_config["default_base_url"]
)
self.model = (
model
or os.environ.get("MINIMAX_MODEL")
or provider_config["default_model"]
)

try:
from openai import OpenAI
self._client = OpenAI(
api_key=self.api_key,
base_url=self.base_url,
)
except ImportError:
raise ImportError(
"The 'openai' package is required for cloud-based prompt "
"enhancement. Install it with: pip install openai"
)

logger.info(
f"✓ Cloud reprompt model initialized "
f"(provider={self.provider}, model={self.model})"
)

def predict(
self,
prompt_cot,
sys_prompt=SYSTEM_PROMPT,
):
"""
Generate a rewritten prompt using the cloud LLM API.

Args:
prompt_cot: The original prompt to be rewritten.
sys_prompt: System prompt to guide the rewriting.

Returns:
str: The rewritten prompt, or the original if generation fails.
"""
org_prompt_cot = prompt_cot
try:
# MiniMax temperature must be in (0.0, 1.0]
temperature = 0.1

response = self._client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": org_prompt_cot},
],
max_tokens=2048,
temperature=temperature,
)

output_res = response.choices[0].message.content

# Strip thinking tags if present (MiniMax M2.7/M2.5 may include them)
output_res = re.sub(
r"<think>.*?</think>\s*", "", output_res, flags=re.DOTALL
)

# Try to extract from <answer> tags (matches local model format)
answer_pattern = r"<answer>(.*?)</answer>"
answer_matches = re.findall(answer_pattern, output_res, re.DOTALL)
if answer_matches:
prompt_cot = answer_matches[0].strip()
else:
# Use the full response if no <answer> tags
prompt_cot = output_res.strip()

except Exception as e:
prompt_cot = org_prompt_cot
logger.error(
f"✗ Cloud re-prompting failed, fall back to original prompt. "
f"Cause: {e}"
)

return prompt_cot

def to(self, device, *args, **kwargs):
"""No-op for cloud models (kept for interface compatibility)."""
return self
Loading