diff --git a/README.md b/README.md
index 6fb8b98..1eb03a7 100644
--- a/README.md
+++ b/README.md
@@ -108,6 +108,32 @@ Prompt enhancement plays a **crucial role** in enabling our model to generate hi
We highly recommend you to try the [PromptEnhancer-32B model](https://huggingface.co/PromptEnhancer/PromptEnhancer-32B) for higher-quality prompt enhancement.
+#### Cloud-Based Prompt Enhancement (MiniMax)
+
+If you don't have enough GPU memory for the local reprompt model, you can use MiniMax's Cloud API as a lightweight alternative. This offloads prompt enhancement to the cloud so your GPU is fully available for image generation.
+
+1. Install the OpenAI SDK: `pip install openai`
+2. Set your API key: `export MINIMAX_API_KEY="your-api-key"`
+3. Use `reprompt_model="minimax"` when creating the pipeline:
+
+```python
+pipe = HunyuanImagePipeline.from_pretrained(
+ model_name="hunyuanimage-v2.1",
+ reprompt_model="minimax", # Use MiniMax Cloud API for prompt enhancement
+ use_fp8=True,
+)
+
+image = pipe(
+ prompt="A cute penguin wearing a scarf",
+ use_reprompt=True, # Enable cloud prompt enhancement
+ use_refiner=True,
+)
+```
+
+You can also customize the model and base URL via environment variables:
+- `MINIMAX_MODEL`: Model name (default: `MiniMax-M2.7`)
+- `MINIMAX_BASE_URL`: API base URL (default: `https://api.minimax.io/v1`)
+
### Text to Image
HunyuanImage-2.1 **only supports 2K** image generation (e.g. 2048x2048 for 1:1 images, 2560x1536 for 16:9 images, etc.).
diff --git a/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py b/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py
index 4fbeeff..170a4cc 100644
--- a/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py
+++ b/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py
@@ -16,7 +16,7 @@
from hyimage.common.constants import PRECISION_TO_TYPE
from hyimage.common.format_prompt import MultilingualPromptFormat
from hyimage.models.text_encoder import PROMPT_TEMPLATE
-from hyimage.models.model_zoo import HUNYUANIMAGE_REPROMPT, HUNYUANIMAGE_REPROMPT_32B
+from hyimage.models.model_zoo import HUNYUANIMAGE_REPROMPT, HUNYUANIMAGE_REPROMPT_32B, HUNYUANIMAGE_REPROMPT_MINIMAX
from hyimage.models.text_encoder.byT5 import load_glyph_byT5_v2
from hyimage.models.hunyuan.modules.hunyuanimage_dit import load_hunyuan_dit_state_dict
from hyimage.diffusion.cfg_utils import AdaptiveProjectedGuidance, rescale_noise_cfg
@@ -84,7 +84,7 @@ def create_default(cls, version: str = "v2.1", use_distilled: bool = False, repr
dit_config=dit_config,
vae_config=HUNYUANIMAGE_V2_1_VAE_32x(),
text_encoder_config=HUNYUANIMAGE_V2_1_TEXT_ENCODER(),
- reprompt_config=HUNYUANIMAGE_REPROMPT_32B() if reprompt_model == "hunyuanimage-reprompt-32b" else HUNYUANIMAGE_REPROMPT(),
+ reprompt_config=self._resolve_reprompt_config(reprompt_model),
shift=4 if use_distilled else 5,
default_guidance_scale=3.25 if use_distilled else 3.5,
default_sampling_steps=8 if use_distilled else 50,
@@ -94,6 +94,16 @@ def create_default(cls, version: str = "v2.1", use_distilled: bool = False, repr
else:
raise ValueError(f"Unsupported HunyuanImage version: {version}. Only 'v2.1' is supported")
+ @staticmethod
+ def _resolve_reprompt_config(reprompt_model):
+ """Resolve reprompt model name to its configuration."""
+ if reprompt_model == "hunyuanimage-reprompt-32b":
+ return HUNYUANIMAGE_REPROMPT_32B()
+ elif reprompt_model in ("hunyuanimage-reprompt-minimax", "minimax"):
+ return HUNYUANIMAGE_REPROMPT_MINIMAX()
+ else:
+ return HUNYUANIMAGE_REPROMPT()
+
class HunyuanImagePipeline:
"""
@@ -241,7 +251,13 @@ def _load_reprompt_model(self):
if self.config.enable_stage1_offloading:
self.offload()
reprompt_config = self.config.reprompt_config
- self._reprompt_model = instantiate(reprompt_config.model, models_root_path=reprompt_config.load_from, enable_offloading=self.config.enable_reprompt_model_offloading)
+ # Cloud reprompt models don't need models_root_path or offloading
+ is_cloud = not reprompt_config.load_from
+ kwargs = {}
+ if not is_cloud:
+ kwargs["models_root_path"] = reprompt_config.load_from
+ kwargs["enable_offloading"] = self.config.enable_reprompt_model_offloading
+ self._reprompt_model = instantiate(reprompt_config.model, **kwargs)
loguru.logger.info("✓ Reprompt model loaded")
except Exception as e:
raise RuntimeError(f"Error loading reprompt model: {e}") from e
@@ -896,6 +912,12 @@ def from_pretrained(cls, model_name: str = "hunyuanimage-v2.1", use_distilled: b
Args:
model_name: Model name, supports "hunyuanimage-v2.1", "hunyuanimage-v2.1-distilled"
use_distilled: Whether to use distilled model (overrides model_name if specified)
+ reprompt_model: Reprompt model to use for prompt enhancement.
+ Supported values:
+ - "hunyuanimage-reprompt-32b" (default): Local 32B model
+ - "hunyuanimage-reprompt": Local smaller model
+ - "hunyuanimage-reprompt-minimax" or "minimax": MiniMax Cloud API
+ (requires MINIMAX_API_KEY environment variable)
**kwargs: Additional configuration options
Returns:
diff --git a/hyimage/models/model_zoo.py b/hyimage/models/model_zoo.py
index 6e83eac..b1539f1 100644
--- a/hyimage/models/model_zoo.py
+++ b/hyimage/models/model_zoo.py
@@ -152,11 +152,23 @@ def HUNYUANIMAGE_REPROMPT(**kwargs):
def HUNYUANIMAGE_REPROMPT_32B(**kwargs):
from hyimage.models.reprompt.reprompt_32b import RePrompt
-
+
return RepromptConfig(
model=L(RePrompt)(
models_root_path=None,
device_map="auto",
),
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/reprompt_32b",
+ )
+
+
+def HUNYUANIMAGE_REPROMPT_MINIMAX(**kwargs):
+ from hyimage.models.reprompt.reprompt_cloud import RePromptCloud
+
+ return RepromptConfig(
+ model=L(RePromptCloud)(
+ models_root_path=None,
+ provider="minimax",
+ ),
+ load_from="",
)
\ No newline at end of file
diff --git a/hyimage/models/reprompt/reprompt_cloud.py b/hyimage/models/reprompt/reprompt_cloud.py
new file mode 100644
index 0000000..97a3d74
--- /dev/null
+++ b/hyimage/models/reprompt/reprompt_cloud.py
@@ -0,0 +1,164 @@
+import os
+import re
+import logging
+
+logger = logging.getLogger(__name__)
+
+SYSTEM_PROMPT = (
+ "你是一位图像生成提示词撰写专家,请根据用户输入的提示词,改写生成新的提示词,改写后的提示词要求:"
+ "1 改写后提示词包含的主体/动作/数量/风格/布局/关系/属性/文字等 必须和改写前的意图一致; "
+ "2 在宏观上遵循“总-分-总”的结构,确保信息的层次清晰;"
+ "3 客观中立,避免主观臆断和情感评价;"
+ "4 由主到次,始终先描述最重要的元素,再描述次要和背景元素;"
+ "5 逻辑清晰,严格遵循空间逻辑或主次逻辑,使读者能在大脑中重建画面;"
+ "6 结尾点题,必须用一句话总结图像的整体风格或类型。"
+)
+
+
+class RePromptCloud:
+ """
+ Cloud-based prompt enhancement using MiniMax API (OpenAI-compatible).
+
+ This class provides the same interface as the local RePrompt model but
+ uses MiniMax's cloud LLM API for prompt enhancement, eliminating the
+ need for local GPU resources for the reprompt model.
+
+ Environment variables:
+ MINIMAX_API_KEY: Required. Your MiniMax API key.
+ MINIMAX_MODEL: Optional. Model name (default: MiniMax-M2.7).
+ MINIMAX_BASE_URL: Optional. API base URL
+ (default: https://api.minimax.io/v1).
+ """
+
+ SUPPORTED_PROVIDERS = {
+ "minimax": {
+ "default_base_url": "https://api.minimax.io/v1",
+ "default_model": "MiniMax-M2.7",
+ "api_key_env": "MINIMAX_API_KEY",
+ },
+ }
+
+ def __init__(
+ self,
+ models_root_path=None,
+ device_map="auto",
+ enable_offloading=True,
+ provider="minimax",
+ api_key=None,
+ base_url=None,
+ model=None,
+ ):
+ """
+ Initialize the cloud-based reprompt model.
+
+ Args:
+ models_root_path: Ignored (kept for interface compatibility).
+ device_map: Ignored (kept for interface compatibility).
+ enable_offloading: Ignored (kept for interface compatibility).
+ provider: Cloud LLM provider name (default: "minimax").
+ api_key: API key. If None, reads from environment variable.
+ base_url: API base URL. If None, uses provider default.
+ model: Model name. If None, uses provider default.
+ """
+ if provider not in self.SUPPORTED_PROVIDERS:
+ raise ValueError(
+ f"Unsupported provider: {provider}. "
+ f"Supported: {list(self.SUPPORTED_PROVIDERS.keys())}"
+ )
+
+ provider_config = self.SUPPORTED_PROVIDERS[provider]
+ self.provider = provider
+
+ self.api_key = api_key or os.environ.get(provider_config["api_key_env"])
+ if not self.api_key:
+ raise ValueError(
+ f"API key required. Set {provider_config['api_key_env']} "
+ f"environment variable or pass api_key parameter."
+ )
+
+ self.base_url = (
+ base_url
+ or os.environ.get("MINIMAX_BASE_URL")
+ or provider_config["default_base_url"]
+ )
+ self.model = (
+ model
+ or os.environ.get("MINIMAX_MODEL")
+ or provider_config["default_model"]
+ )
+
+ try:
+ from openai import OpenAI
+ self._client = OpenAI(
+ api_key=self.api_key,
+ base_url=self.base_url,
+ )
+ except ImportError:
+ raise ImportError(
+ "The 'openai' package is required for cloud-based prompt "
+ "enhancement. Install it with: pip install openai"
+ )
+
+ logger.info(
+ f"✓ Cloud reprompt model initialized "
+ f"(provider={self.provider}, model={self.model})"
+ )
+
+ def predict(
+ self,
+ prompt_cot,
+ sys_prompt=SYSTEM_PROMPT,
+ ):
+ """
+ Generate a rewritten prompt using the cloud LLM API.
+
+ Args:
+ prompt_cot: The original prompt to be rewritten.
+ sys_prompt: System prompt to guide the rewriting.
+
+ Returns:
+ str: The rewritten prompt, or the original if generation fails.
+ """
+ org_prompt_cot = prompt_cot
+ try:
+ # MiniMax temperature must be in (0.0, 1.0]
+ temperature = 0.1
+
+ response = self._client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": sys_prompt},
+ {"role": "user", "content": org_prompt_cot},
+ ],
+ max_tokens=2048,
+ temperature=temperature,
+ )
+
+ output_res = response.choices[0].message.content
+
+ # Strip thinking tags if present (MiniMax M2.7/M2.5 may include them)
+ output_res = re.sub(
+ r".*?\s*", "", output_res, flags=re.DOTALL
+ )
+
+ # Try to extract from tags (matches local model format)
+ answer_pattern = r"(.*?)"
+ answer_matches = re.findall(answer_pattern, output_res, re.DOTALL)
+ if answer_matches:
+ prompt_cot = answer_matches[0].strip()
+ else:
+ # Use the full response if no tags
+ prompt_cot = output_res.strip()
+
+ except Exception as e:
+ prompt_cot = org_prompt_cot
+ logger.error(
+ f"✗ Cloud re-prompting failed, fall back to original prompt. "
+ f"Cause: {e}"
+ )
+
+ return prompt_cot
+
+ def to(self, device, *args, **kwargs):
+ """No-op for cloud models (kept for interface compatibility)."""
+ return self
diff --git a/tests/test_reprompt_cloud.py b/tests/test_reprompt_cloud.py
new file mode 100644
index 0000000..d27bb7e
--- /dev/null
+++ b/tests/test_reprompt_cloud.py
@@ -0,0 +1,489 @@
+"""
+Unit and integration tests for MiniMax Cloud prompt enhancement (RePromptCloud).
+
+Run with: python -m pytest tests/test_reprompt_cloud.py -v
+"""
+
+import importlib
+import importlib.util
+import os
+import re
+import sys
+import unittest
+from unittest.mock import MagicMock, patch
+
+
+def _load_reprompt_cloud():
+ """Load reprompt_cloud module directly, bypassing __init__.py import chain."""
+ spec = importlib.util.spec_from_file_location(
+ "reprompt_cloud",
+ os.path.join(
+ os.path.dirname(__file__),
+ "..",
+ "hyimage",
+ "models",
+ "reprompt",
+ "reprompt_cloud.py",
+ ),
+ )
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ return mod
+
+
+# Pre-load the module once
+_rc_module = _load_reprompt_cloud()
+RePromptCloud = _rc_module.RePromptCloud
+SYSTEM_PROMPT = _rc_module.SYSTEM_PROMPT
+
+
+def _make_cloud(**kwargs):
+ """Create a RePromptCloud with mocked openai."""
+ mock_openai_module = MagicMock()
+ mock_openai_cls = MagicMock()
+ mock_openai_module.OpenAI = mock_openai_cls
+ defaults = {"api_key": "test-key-123"}
+ defaults.update(kwargs)
+ with patch.dict("sys.modules", {"openai": mock_openai_module}):
+ model = RePromptCloud(**defaults)
+ return model, mock_openai_cls
+
+
+def _mock_response(content):
+ """Create a mock OpenAI chat completion response."""
+ mock_choice = MagicMock()
+ mock_choice.message.content = content
+ mock_resp = MagicMock()
+ mock_resp.choices = [mock_choice]
+ return mock_resp
+
+
+class TestRePromptCloudInit(unittest.TestCase):
+ """Test RePromptCloud initialization and configuration."""
+
+ def test_init_with_explicit_api_key(self):
+ model, mock_cls = _make_cloud(api_key="explicit-key")
+ self.assertEqual(model.api_key, "explicit-key")
+ self.assertEqual(model.provider, "minimax")
+ self.assertEqual(model.model, "MiniMax-M2.7")
+ self.assertEqual(model.base_url, "https://api.minimax.io/v1")
+ mock_cls.assert_called_once_with(
+ api_key="explicit-key",
+ base_url="https://api.minimax.io/v1",
+ )
+
+ @patch.dict(os.environ, {"MINIMAX_API_KEY": "env-key-456"})
+ def test_init_with_env_api_key(self):
+ model, _ = _make_cloud(api_key=None)
+ self.assertEqual(model.api_key, "env-key-456")
+
+ def test_init_missing_api_key_raises(self):
+ env = {k: v for k, v in os.environ.items() if k != "MINIMAX_API_KEY"}
+ with patch.dict(os.environ, env, clear=True):
+ with self.assertRaises(ValueError) as ctx:
+ _make_cloud(api_key=None)
+ self.assertIn("MINIMAX_API_KEY", str(ctx.exception))
+
+ def test_init_unsupported_provider_raises(self):
+ with self.assertRaises(ValueError) as ctx:
+ _make_cloud(provider="unsupported")
+ self.assertIn("Unsupported provider", str(ctx.exception))
+
+ @patch.dict(
+ os.environ,
+ {
+ "MINIMAX_API_KEY": "env-key",
+ "MINIMAX_MODEL": "MiniMax-M2.5",
+ "MINIMAX_BASE_URL": "https://custom.api.url/v1",
+ },
+ )
+ def test_init_env_overrides(self):
+ model, _ = _make_cloud(api_key=None)
+ self.assertEqual(model.model, "MiniMax-M2.5")
+ self.assertEqual(model.base_url, "https://custom.api.url/v1")
+
+ def test_init_explicit_params_override_env(self):
+ model, _ = _make_cloud(
+ api_key="param-key",
+ model="custom-model",
+ base_url="https://param.url/v1",
+ )
+ self.assertEqual(model.api_key, "param-key")
+ self.assertEqual(model.model, "custom-model")
+ self.assertEqual(model.base_url, "https://param.url/v1")
+
+ def test_init_ignores_local_model_params(self):
+ """Cloud model should accept but ignore local model parameters."""
+ model, _ = _make_cloud(
+ models_root_path="/some/path",
+ device_map="cpu",
+ enable_offloading=False,
+ )
+ self.assertIsNotNone(model)
+
+ def test_supported_providers(self):
+ self.assertIn("minimax", RePromptCloud.SUPPORTED_PROVIDERS)
+ config = RePromptCloud.SUPPORTED_PROVIDERS["minimax"]
+ self.assertEqual(config["api_key_env"], "MINIMAX_API_KEY")
+ self.assertEqual(config["default_model"], "MiniMax-M2.7")
+ self.assertEqual(config["default_base_url"], "https://api.minimax.io/v1")
+
+ def test_default_model_is_m2_7(self):
+ model, _ = _make_cloud()
+ self.assertEqual(model.model, "MiniMax-M2.7")
+
+
+class TestRePromptCloudPredict(unittest.TestCase):
+ """Test RePromptCloud.predict() method."""
+
+ def _make_model(self):
+ model, _ = _make_cloud()
+ return model
+
+ def test_predict_plain_response(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response(
+ "A detailed description of a sunset over the ocean with golden light."
+ )
+ )
+ result = model.predict("sunset over ocean")
+ self.assertEqual(
+ result,
+ "A detailed description of a sunset over the ocean with golden light.",
+ )
+
+ def test_predict_with_answer_tags(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response(
+ "Here is the result:\nEnhanced prompt text here"
+ )
+ )
+ result = model.predict("test prompt")
+ self.assertEqual(result, "Enhanced prompt text here")
+
+ def test_predict_strips_think_tags(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response(
+ "Let me analyze this prompt...\n"
+ "A beautiful sunset over a calm ocean."
+ )
+ )
+ result = model.predict("sunset")
+ self.assertEqual(result, "A beautiful sunset over a calm ocean.")
+
+ def test_predict_strips_think_tags_with_answer(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response(
+ "Reasoning here\n"
+ "Clean enhanced prompt"
+ )
+ )
+ result = model.predict("test")
+ self.assertEqual(result, "Clean enhanced prompt")
+
+ def test_predict_uses_correct_api_params(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response("enhanced text")
+ )
+ model.predict("test prompt")
+ call_kwargs = model._client.chat.completions.create.call_args[1]
+ self.assertEqual(call_kwargs["model"], "MiniMax-M2.7")
+ self.assertEqual(call_kwargs["max_tokens"], 2048)
+ self.assertGreater(call_kwargs["temperature"], 0.0)
+ self.assertLessEqual(call_kwargs["temperature"], 1.0)
+
+ def test_predict_passes_system_prompt(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response("enhanced")
+ )
+ model.predict("test", sys_prompt="Custom system prompt")
+ call_kwargs = model._client.chat.completions.create.call_args[1]
+ messages = call_kwargs["messages"]
+ self.assertEqual(len(messages), 2)
+ self.assertEqual(messages[0]["role"], "system")
+ self.assertEqual(messages[0]["content"], "Custom system prompt")
+ self.assertEqual(messages[1]["role"], "user")
+ self.assertEqual(messages[1]["content"], "test")
+
+ def test_predict_uses_default_system_prompt(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response("enhanced")
+ )
+ model.predict("test prompt")
+ call_kwargs = model._client.chat.completions.create.call_args[1]
+ messages = call_kwargs["messages"]
+ self.assertEqual(messages[0]["content"], SYSTEM_PROMPT)
+
+ def test_predict_fallback_on_api_error(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ side_effect=Exception("API connection failed")
+ )
+ result = model.predict("original prompt")
+ self.assertEqual(result, "original prompt")
+
+ def test_predict_fallback_on_empty_response(self):
+ model = self._make_model()
+ mock_choice = MagicMock()
+ mock_choice.message.content = None
+ mock_resp = MagicMock()
+ mock_resp.choices = [mock_choice]
+ model._client.chat.completions.create = MagicMock(return_value=mock_resp)
+ result = model.predict("original prompt")
+ self.assertEqual(result, "original prompt")
+
+ def test_predict_chinese_prompt(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response(
+ "一幅精美的山水画,远处的青山在薄雾中若隐若现,近处的溪流清澈见底。"
+ )
+ )
+ result = model.predict("中国山水画")
+ self.assertIn("山水", result)
+
+ def test_predict_multiline_think_tags(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response(
+ "\nLine 1\nLine 2\nLine 3\n\n"
+ "Final enhanced prompt."
+ )
+ )
+ result = model.predict("test")
+ self.assertEqual(result, "Final enhanced prompt.")
+ self.assertNotIn("", result)
+
+ def test_predict_long_prompt(self):
+ model = self._make_model()
+ long_prompt = "A scene with " + ", ".join(f"element {i}" for i in range(50))
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response("Enhanced: " + long_prompt)
+ )
+ result = model.predict(long_prompt)
+ self.assertIn("Enhanced:", result)
+
+ def test_predict_empty_prompt(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response("A blank canvas.")
+ )
+ result = model.predict("")
+ self.assertEqual(result, "A blank canvas.")
+
+ def test_predict_special_characters(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response("Enhanced with 'quotes' and extra details")
+ )
+ result = model.predict("prompt with 'quotes'")
+ self.assertIn("quotes", result)
+
+ def test_predict_preserves_original_on_timeout(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ side_effect=TimeoutError("Request timed out")
+ )
+ result = model.predict("my prompt")
+ self.assertEqual(result, "my prompt")
+
+ def test_predict_with_nested_tags(self):
+ model = self._make_model()
+ model._client.chat.completions.create = MagicMock(
+ return_value=_mock_response(
+ "step 1\n"
+ "A beautiful scene"
+ )
+ )
+ result = model.predict("scene")
+ self.assertEqual(result, "A beautiful scene")
+
+
+class TestRePromptCloudInterface(unittest.TestCase):
+ """Test interface compatibility with local RePrompt models."""
+
+ def test_to_method_is_noop(self):
+ model, _ = _make_cloud()
+ result = model.to("cuda")
+ self.assertIs(result, model)
+
+ def test_to_method_chaining(self):
+ model, _ = _make_cloud()
+ result = model.to("cuda").to("cpu").to("cuda:0")
+ self.assertIs(result, model)
+
+ def test_has_predict_method(self):
+ model, _ = _make_cloud()
+ self.assertTrue(callable(getattr(model, "predict", None)))
+
+ def test_has_to_method(self):
+ model, _ = _make_cloud()
+ self.assertTrue(callable(getattr(model, "to", None)))
+
+ def test_predict_signature_matches_local(self):
+ """Verify predict() accepts same args as local RePrompt."""
+ import inspect
+
+ sig = inspect.signature(RePromptCloud.predict)
+ params = list(sig.parameters.keys())
+ self.assertIn("prompt_cot", params)
+ self.assertIn("sys_prompt", params)
+
+
+class TestModelZooMiniMax(unittest.TestCase):
+ """Test MiniMax reprompt configuration in model_zoo."""
+
+ def test_minimax_reprompt_config_exists(self):
+ # Import model_zoo components that don't need torch/diffusers
+ from hyimage.common.config.base_config import RepromptConfig
+ from hyimage.common.config.lazy import LazyCall as L, LazyObject
+
+ config = RepromptConfig(
+ model=L(RePromptCloud)(models_root_path=None, provider="minimax"),
+ load_from="",
+ )
+ self.assertIsInstance(config, RepromptConfig)
+ self.assertEqual(config.load_from, "")
+ self.assertIsInstance(config.model, LazyObject)
+
+ def test_minimax_config_instantiates_with_api_key(self):
+ from hyimage.common.config.lazy import LazyCall as L, instantiate
+
+ config_model = L(RePromptCloud)(models_root_path=None, provider="minimax")
+ mock_openai_module = MagicMock()
+ with patch.dict("sys.modules", {"openai": mock_openai_module}):
+ with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}):
+ model = instantiate(config_model)
+ self.assertIsNotNone(model)
+ self.assertEqual(model.provider, "minimax")
+ self.assertEqual(model.model, "MiniMax-M2.7")
+
+ def test_minimax_config_cloud_flag(self):
+ """Cloud reprompt config has empty load_from (no local model path)."""
+ from hyimage.common.config.base_config import RepromptConfig
+ from hyimage.common.config.lazy import LazyCall as L
+
+ config = RepromptConfig(
+ model=L(RePromptCloud)(models_root_path=None, provider="minimax"),
+ load_from="",
+ )
+ # Pipeline uses empty load_from to detect cloud mode
+ self.assertFalse(bool(config.load_from))
+
+
+class TestPipelineRepromptResolution(unittest.TestCase):
+ """Test that the pipeline correctly resolves reprompt model names.
+
+ Since the pipeline module has heavy dependencies (torch, diffusers),
+ we test the resolution logic by verifying the model_zoo factory
+ functions produce correct configs.
+ """
+
+ def test_minimax_factory_produces_empty_load_from(self):
+ """HUNYUANIMAGE_REPROMPT_MINIMAX should produce config with empty load_from."""
+ from hyimage.common.config.base_config import RepromptConfig
+ from hyimage.common.config.lazy import LazyCall as L
+
+ # Replicate what model_zoo.HUNYUANIMAGE_REPROMPT_MINIMAX does
+ config = RepromptConfig(
+ model=L(RePromptCloud)(models_root_path=None, provider="minimax"),
+ load_from="",
+ )
+ self.assertEqual(config.load_from, "")
+
+ def test_local_factory_produces_nonempty_load_from(self):
+ """Local reprompt configs should have non-empty load_from."""
+ from hyimage.common.config.base_config import RepromptConfig
+ from hyimage.common.config.lazy import LazyCall as L
+
+ config = RepromptConfig(
+ model=L(MagicMock)(models_root_path=None),
+ load_from="./ckpts/reprompt",
+ )
+ self.assertTrue(bool(config.load_from))
+
+ def test_resolution_logic(self):
+ """Test the _resolve_reprompt_config logic without importing the pipeline."""
+ from hyimage.common.config.base_config import RepromptConfig
+ from hyimage.common.config.lazy import LazyCall as L
+
+ # Simulate what _resolve_reprompt_config does
+ def resolve(name):
+ if name == "hunyuanimage-reprompt-32b":
+ return RepromptConfig(
+ model=L(MagicMock)(models_root_path=None),
+ load_from="./ckpts/reprompt_32b",
+ )
+ elif name in ("hunyuanimage-reprompt-minimax", "minimax"):
+ return RepromptConfig(
+ model=L(RePromptCloud)(
+ models_root_path=None, provider="minimax"
+ ),
+ load_from="",
+ )
+ else:
+ return RepromptConfig(
+ model=L(MagicMock)(models_root_path=None),
+ load_from="./ckpts/reprompt",
+ )
+
+ self.assertEqual(resolve("minimax").load_from, "")
+ self.assertEqual(resolve("hunyuanimage-reprompt-minimax").load_from, "")
+ self.assertNotEqual(resolve("hunyuanimage-reprompt-32b").load_from, "")
+ self.assertNotEqual(resolve("hunyuanimage-reprompt").load_from, "")
+
+
+class TestRePromptCloudIntegration(unittest.TestCase):
+ """Integration tests for MiniMax Cloud prompt enhancement.
+
+ These tests require a valid MINIMAX_API_KEY environment variable
+ and network access to the MiniMax API.
+ """
+
+ @unittest.skipUnless(
+ os.environ.get("MINIMAX_API_KEY"),
+ "MINIMAX_API_KEY not set; skipping integration tests",
+ )
+ def test_live_prompt_enhancement_english(self):
+ model = RePromptCloud()
+ result = model.predict("A cute cat sitting on a windowsill")
+ self.assertIsInstance(result, str)
+ self.assertGreater(len(result), 10)
+ self.assertGreater(len(result), len("A cute cat sitting on a windowsill"))
+
+ @unittest.skipUnless(
+ os.environ.get("MINIMAX_API_KEY"),
+ "MINIMAX_API_KEY not set; skipping integration tests",
+ )
+ def test_live_prompt_enhancement_chinese(self):
+ model = RePromptCloud()
+ result = model.predict("一只可爱的猫咪坐在窗台上")
+ self.assertIsInstance(result, str)
+ self.assertGreater(len(result), 5)
+
+ @unittest.skipUnless(
+ os.environ.get("MINIMAX_API_KEY"),
+ "MINIMAX_API_KEY not set; skipping integration tests",
+ )
+ def test_live_predict_interface_compatibility(self):
+ """Verify cloud model has same interface as local models."""
+ model = RePromptCloud()
+ result = model.predict(
+ "sunset over mountains",
+ sys_prompt="Enhance this image prompt with more details.",
+ )
+ self.assertIsInstance(result, str)
+ self.assertGreater(len(result), 0)
+ same_model = model.to("cuda")
+ self.assertIs(same_model, model)
+
+
+if __name__ == "__main__":
+ unittest.main()