diff --git a/README.md b/README.md index 6fb8b98..1eb03a7 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,32 @@ Prompt enhancement plays a **crucial role** in enabling our model to generate hi We highly recommend you to try the [PromptEnhancer-32B model](https://huggingface.co/PromptEnhancer/PromptEnhancer-32B) for higher-quality prompt enhancement. +#### Cloud-Based Prompt Enhancement (MiniMax) + +If you don't have enough GPU memory for the local reprompt model, you can use MiniMax's Cloud API as a lightweight alternative. This offloads prompt enhancement to the cloud so your GPU is fully available for image generation. + +1. Install the OpenAI SDK: `pip install openai` +2. Set your API key: `export MINIMAX_API_KEY="your-api-key"` +3. Use `reprompt_model="minimax"` when creating the pipeline: + +```python +pipe = HunyuanImagePipeline.from_pretrained( + model_name="hunyuanimage-v2.1", + reprompt_model="minimax", # Use MiniMax Cloud API for prompt enhancement + use_fp8=True, +) + +image = pipe( + prompt="A cute penguin wearing a scarf", + use_reprompt=True, # Enable cloud prompt enhancement + use_refiner=True, +) +``` + +You can also customize the model and base URL via environment variables: +- `MINIMAX_MODEL`: Model name (default: `MiniMax-M2.7`) +- `MINIMAX_BASE_URL`: API base URL (default: `https://api.minimax.io/v1`) + ### Text to Image HunyuanImage-2.1 **only supports 2K** image generation (e.g. 2048x2048 for 1:1 images, 2560x1536 for 16:9 images, etc.). diff --git a/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py b/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py index 4fbeeff..170a4cc 100644 --- a/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py +++ b/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py @@ -16,7 +16,7 @@ from hyimage.common.constants import PRECISION_TO_TYPE from hyimage.common.format_prompt import MultilingualPromptFormat from hyimage.models.text_encoder import PROMPT_TEMPLATE -from hyimage.models.model_zoo import HUNYUANIMAGE_REPROMPT, HUNYUANIMAGE_REPROMPT_32B +from hyimage.models.model_zoo import HUNYUANIMAGE_REPROMPT, HUNYUANIMAGE_REPROMPT_32B, HUNYUANIMAGE_REPROMPT_MINIMAX from hyimage.models.text_encoder.byT5 import load_glyph_byT5_v2 from hyimage.models.hunyuan.modules.hunyuanimage_dit import load_hunyuan_dit_state_dict from hyimage.diffusion.cfg_utils import AdaptiveProjectedGuidance, rescale_noise_cfg @@ -84,7 +84,7 @@ def create_default(cls, version: str = "v2.1", use_distilled: bool = False, repr dit_config=dit_config, vae_config=HUNYUANIMAGE_V2_1_VAE_32x(), text_encoder_config=HUNYUANIMAGE_V2_1_TEXT_ENCODER(), - reprompt_config=HUNYUANIMAGE_REPROMPT_32B() if reprompt_model == "hunyuanimage-reprompt-32b" else HUNYUANIMAGE_REPROMPT(), + reprompt_config=self._resolve_reprompt_config(reprompt_model), shift=4 if use_distilled else 5, default_guidance_scale=3.25 if use_distilled else 3.5, default_sampling_steps=8 if use_distilled else 50, @@ -94,6 +94,16 @@ def create_default(cls, version: str = "v2.1", use_distilled: bool = False, repr else: raise ValueError(f"Unsupported HunyuanImage version: {version}. Only 'v2.1' is supported") + @staticmethod + def _resolve_reprompt_config(reprompt_model): + """Resolve reprompt model name to its configuration.""" + if reprompt_model == "hunyuanimage-reprompt-32b": + return HUNYUANIMAGE_REPROMPT_32B() + elif reprompt_model in ("hunyuanimage-reprompt-minimax", "minimax"): + return HUNYUANIMAGE_REPROMPT_MINIMAX() + else: + return HUNYUANIMAGE_REPROMPT() + class HunyuanImagePipeline: """ @@ -241,7 +251,13 @@ def _load_reprompt_model(self): if self.config.enable_stage1_offloading: self.offload() reprompt_config = self.config.reprompt_config - self._reprompt_model = instantiate(reprompt_config.model, models_root_path=reprompt_config.load_from, enable_offloading=self.config.enable_reprompt_model_offloading) + # Cloud reprompt models don't need models_root_path or offloading + is_cloud = not reprompt_config.load_from + kwargs = {} + if not is_cloud: + kwargs["models_root_path"] = reprompt_config.load_from + kwargs["enable_offloading"] = self.config.enable_reprompt_model_offloading + self._reprompt_model = instantiate(reprompt_config.model, **kwargs) loguru.logger.info("✓ Reprompt model loaded") except Exception as e: raise RuntimeError(f"Error loading reprompt model: {e}") from e @@ -896,6 +912,12 @@ def from_pretrained(cls, model_name: str = "hunyuanimage-v2.1", use_distilled: b Args: model_name: Model name, supports "hunyuanimage-v2.1", "hunyuanimage-v2.1-distilled" use_distilled: Whether to use distilled model (overrides model_name if specified) + reprompt_model: Reprompt model to use for prompt enhancement. + Supported values: + - "hunyuanimage-reprompt-32b" (default): Local 32B model + - "hunyuanimage-reprompt": Local smaller model + - "hunyuanimage-reprompt-minimax" or "minimax": MiniMax Cloud API + (requires MINIMAX_API_KEY environment variable) **kwargs: Additional configuration options Returns: diff --git a/hyimage/models/model_zoo.py b/hyimage/models/model_zoo.py index 6e83eac..b1539f1 100644 --- a/hyimage/models/model_zoo.py +++ b/hyimage/models/model_zoo.py @@ -152,11 +152,23 @@ def HUNYUANIMAGE_REPROMPT(**kwargs): def HUNYUANIMAGE_REPROMPT_32B(**kwargs): from hyimage.models.reprompt.reprompt_32b import RePrompt - + return RepromptConfig( model=L(RePrompt)( models_root_path=None, device_map="auto", ), load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/reprompt_32b", + ) + + +def HUNYUANIMAGE_REPROMPT_MINIMAX(**kwargs): + from hyimage.models.reprompt.reprompt_cloud import RePromptCloud + + return RepromptConfig( + model=L(RePromptCloud)( + models_root_path=None, + provider="minimax", + ), + load_from="", ) \ No newline at end of file diff --git a/hyimage/models/reprompt/reprompt_cloud.py b/hyimage/models/reprompt/reprompt_cloud.py new file mode 100644 index 0000000..97a3d74 --- /dev/null +++ b/hyimage/models/reprompt/reprompt_cloud.py @@ -0,0 +1,164 @@ +import os +import re +import logging + +logger = logging.getLogger(__name__) + +SYSTEM_PROMPT = ( + "你是一位图像生成提示词撰写专家,请根据用户输入的提示词,改写生成新的提示词,改写后的提示词要求:" + "1 改写后提示词包含的主体/动作/数量/风格/布局/关系/属性/文字等 必须和改写前的意图一致; " + "2 在宏观上遵循“总-分-总”的结构,确保信息的层次清晰;" + "3 客观中立,避免主观臆断和情感评价;" + "4 由主到次,始终先描述最重要的元素,再描述次要和背景元素;" + "5 逻辑清晰,严格遵循空间逻辑或主次逻辑,使读者能在大脑中重建画面;" + "6 结尾点题,必须用一句话总结图像的整体风格或类型。" +) + + +class RePromptCloud: + """ + Cloud-based prompt enhancement using MiniMax API (OpenAI-compatible). + + This class provides the same interface as the local RePrompt model but + uses MiniMax's cloud LLM API for prompt enhancement, eliminating the + need for local GPU resources for the reprompt model. + + Environment variables: + MINIMAX_API_KEY: Required. Your MiniMax API key. + MINIMAX_MODEL: Optional. Model name (default: MiniMax-M2.7). + MINIMAX_BASE_URL: Optional. API base URL + (default: https://api.minimax.io/v1). + """ + + SUPPORTED_PROVIDERS = { + "minimax": { + "default_base_url": "https://api.minimax.io/v1", + "default_model": "MiniMax-M2.7", + "api_key_env": "MINIMAX_API_KEY", + }, + } + + def __init__( + self, + models_root_path=None, + device_map="auto", + enable_offloading=True, + provider="minimax", + api_key=None, + base_url=None, + model=None, + ): + """ + Initialize the cloud-based reprompt model. + + Args: + models_root_path: Ignored (kept for interface compatibility). + device_map: Ignored (kept for interface compatibility). + enable_offloading: Ignored (kept for interface compatibility). + provider: Cloud LLM provider name (default: "minimax"). + api_key: API key. If None, reads from environment variable. + base_url: API base URL. If None, uses provider default. + model: Model name. If None, uses provider default. + """ + if provider not in self.SUPPORTED_PROVIDERS: + raise ValueError( + f"Unsupported provider: {provider}. " + f"Supported: {list(self.SUPPORTED_PROVIDERS.keys())}" + ) + + provider_config = self.SUPPORTED_PROVIDERS[provider] + self.provider = provider + + self.api_key = api_key or os.environ.get(provider_config["api_key_env"]) + if not self.api_key: + raise ValueError( + f"API key required. Set {provider_config['api_key_env']} " + f"environment variable or pass api_key parameter." + ) + + self.base_url = ( + base_url + or os.environ.get("MINIMAX_BASE_URL") + or provider_config["default_base_url"] + ) + self.model = ( + model + or os.environ.get("MINIMAX_MODEL") + or provider_config["default_model"] + ) + + try: + from openai import OpenAI + self._client = OpenAI( + api_key=self.api_key, + base_url=self.base_url, + ) + except ImportError: + raise ImportError( + "The 'openai' package is required for cloud-based prompt " + "enhancement. Install it with: pip install openai" + ) + + logger.info( + f"✓ Cloud reprompt model initialized " + f"(provider={self.provider}, model={self.model})" + ) + + def predict( + self, + prompt_cot, + sys_prompt=SYSTEM_PROMPT, + ): + """ + Generate a rewritten prompt using the cloud LLM API. + + Args: + prompt_cot: The original prompt to be rewritten. + sys_prompt: System prompt to guide the rewriting. + + Returns: + str: The rewritten prompt, or the original if generation fails. + """ + org_prompt_cot = prompt_cot + try: + # MiniMax temperature must be in (0.0, 1.0] + temperature = 0.1 + + response = self._client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": sys_prompt}, + {"role": "user", "content": org_prompt_cot}, + ], + max_tokens=2048, + temperature=temperature, + ) + + output_res = response.choices[0].message.content + + # Strip thinking tags if present (MiniMax M2.7/M2.5 may include them) + output_res = re.sub( + r".*?\s*", "", output_res, flags=re.DOTALL + ) + + # Try to extract from tags (matches local model format) + answer_pattern = r"(.*?)" + answer_matches = re.findall(answer_pattern, output_res, re.DOTALL) + if answer_matches: + prompt_cot = answer_matches[0].strip() + else: + # Use the full response if no tags + prompt_cot = output_res.strip() + + except Exception as e: + prompt_cot = org_prompt_cot + logger.error( + f"✗ Cloud re-prompting failed, fall back to original prompt. " + f"Cause: {e}" + ) + + return prompt_cot + + def to(self, device, *args, **kwargs): + """No-op for cloud models (kept for interface compatibility).""" + return self diff --git a/tests/test_reprompt_cloud.py b/tests/test_reprompt_cloud.py new file mode 100644 index 0000000..d27bb7e --- /dev/null +++ b/tests/test_reprompt_cloud.py @@ -0,0 +1,489 @@ +""" +Unit and integration tests for MiniMax Cloud prompt enhancement (RePromptCloud). + +Run with: python -m pytest tests/test_reprompt_cloud.py -v +""" + +import importlib +import importlib.util +import os +import re +import sys +import unittest +from unittest.mock import MagicMock, patch + + +def _load_reprompt_cloud(): + """Load reprompt_cloud module directly, bypassing __init__.py import chain.""" + spec = importlib.util.spec_from_file_location( + "reprompt_cloud", + os.path.join( + os.path.dirname(__file__), + "..", + "hyimage", + "models", + "reprompt", + "reprompt_cloud.py", + ), + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# Pre-load the module once +_rc_module = _load_reprompt_cloud() +RePromptCloud = _rc_module.RePromptCloud +SYSTEM_PROMPT = _rc_module.SYSTEM_PROMPT + + +def _make_cloud(**kwargs): + """Create a RePromptCloud with mocked openai.""" + mock_openai_module = MagicMock() + mock_openai_cls = MagicMock() + mock_openai_module.OpenAI = mock_openai_cls + defaults = {"api_key": "test-key-123"} + defaults.update(kwargs) + with patch.dict("sys.modules", {"openai": mock_openai_module}): + model = RePromptCloud(**defaults) + return model, mock_openai_cls + + +def _mock_response(content): + """Create a mock OpenAI chat completion response.""" + mock_choice = MagicMock() + mock_choice.message.content = content + mock_resp = MagicMock() + mock_resp.choices = [mock_choice] + return mock_resp + + +class TestRePromptCloudInit(unittest.TestCase): + """Test RePromptCloud initialization and configuration.""" + + def test_init_with_explicit_api_key(self): + model, mock_cls = _make_cloud(api_key="explicit-key") + self.assertEqual(model.api_key, "explicit-key") + self.assertEqual(model.provider, "minimax") + self.assertEqual(model.model, "MiniMax-M2.7") + self.assertEqual(model.base_url, "https://api.minimax.io/v1") + mock_cls.assert_called_once_with( + api_key="explicit-key", + base_url="https://api.minimax.io/v1", + ) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "env-key-456"}) + def test_init_with_env_api_key(self): + model, _ = _make_cloud(api_key=None) + self.assertEqual(model.api_key, "env-key-456") + + def test_init_missing_api_key_raises(self): + env = {k: v for k, v in os.environ.items() if k != "MINIMAX_API_KEY"} + with patch.dict(os.environ, env, clear=True): + with self.assertRaises(ValueError) as ctx: + _make_cloud(api_key=None) + self.assertIn("MINIMAX_API_KEY", str(ctx.exception)) + + def test_init_unsupported_provider_raises(self): + with self.assertRaises(ValueError) as ctx: + _make_cloud(provider="unsupported") + self.assertIn("Unsupported provider", str(ctx.exception)) + + @patch.dict( + os.environ, + { + "MINIMAX_API_KEY": "env-key", + "MINIMAX_MODEL": "MiniMax-M2.5", + "MINIMAX_BASE_URL": "https://custom.api.url/v1", + }, + ) + def test_init_env_overrides(self): + model, _ = _make_cloud(api_key=None) + self.assertEqual(model.model, "MiniMax-M2.5") + self.assertEqual(model.base_url, "https://custom.api.url/v1") + + def test_init_explicit_params_override_env(self): + model, _ = _make_cloud( + api_key="param-key", + model="custom-model", + base_url="https://param.url/v1", + ) + self.assertEqual(model.api_key, "param-key") + self.assertEqual(model.model, "custom-model") + self.assertEqual(model.base_url, "https://param.url/v1") + + def test_init_ignores_local_model_params(self): + """Cloud model should accept but ignore local model parameters.""" + model, _ = _make_cloud( + models_root_path="/some/path", + device_map="cpu", + enable_offloading=False, + ) + self.assertIsNotNone(model) + + def test_supported_providers(self): + self.assertIn("minimax", RePromptCloud.SUPPORTED_PROVIDERS) + config = RePromptCloud.SUPPORTED_PROVIDERS["minimax"] + self.assertEqual(config["api_key_env"], "MINIMAX_API_KEY") + self.assertEqual(config["default_model"], "MiniMax-M2.7") + self.assertEqual(config["default_base_url"], "https://api.minimax.io/v1") + + def test_default_model_is_m2_7(self): + model, _ = _make_cloud() + self.assertEqual(model.model, "MiniMax-M2.7") + + +class TestRePromptCloudPredict(unittest.TestCase): + """Test RePromptCloud.predict() method.""" + + def _make_model(self): + model, _ = _make_cloud() + return model + + def test_predict_plain_response(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response( + "A detailed description of a sunset over the ocean with golden light." + ) + ) + result = model.predict("sunset over ocean") + self.assertEqual( + result, + "A detailed description of a sunset over the ocean with golden light.", + ) + + def test_predict_with_answer_tags(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response( + "Here is the result:\nEnhanced prompt text here" + ) + ) + result = model.predict("test prompt") + self.assertEqual(result, "Enhanced prompt text here") + + def test_predict_strips_think_tags(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response( + "Let me analyze this prompt...\n" + "A beautiful sunset over a calm ocean." + ) + ) + result = model.predict("sunset") + self.assertEqual(result, "A beautiful sunset over a calm ocean.") + + def test_predict_strips_think_tags_with_answer(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response( + "Reasoning here\n" + "Clean enhanced prompt" + ) + ) + result = model.predict("test") + self.assertEqual(result, "Clean enhanced prompt") + + def test_predict_uses_correct_api_params(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response("enhanced text") + ) + model.predict("test prompt") + call_kwargs = model._client.chat.completions.create.call_args[1] + self.assertEqual(call_kwargs["model"], "MiniMax-M2.7") + self.assertEqual(call_kwargs["max_tokens"], 2048) + self.assertGreater(call_kwargs["temperature"], 0.0) + self.assertLessEqual(call_kwargs["temperature"], 1.0) + + def test_predict_passes_system_prompt(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response("enhanced") + ) + model.predict("test", sys_prompt="Custom system prompt") + call_kwargs = model._client.chat.completions.create.call_args[1] + messages = call_kwargs["messages"] + self.assertEqual(len(messages), 2) + self.assertEqual(messages[0]["role"], "system") + self.assertEqual(messages[0]["content"], "Custom system prompt") + self.assertEqual(messages[1]["role"], "user") + self.assertEqual(messages[1]["content"], "test") + + def test_predict_uses_default_system_prompt(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response("enhanced") + ) + model.predict("test prompt") + call_kwargs = model._client.chat.completions.create.call_args[1] + messages = call_kwargs["messages"] + self.assertEqual(messages[0]["content"], SYSTEM_PROMPT) + + def test_predict_fallback_on_api_error(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + side_effect=Exception("API connection failed") + ) + result = model.predict("original prompt") + self.assertEqual(result, "original prompt") + + def test_predict_fallback_on_empty_response(self): + model = self._make_model() + mock_choice = MagicMock() + mock_choice.message.content = None + mock_resp = MagicMock() + mock_resp.choices = [mock_choice] + model._client.chat.completions.create = MagicMock(return_value=mock_resp) + result = model.predict("original prompt") + self.assertEqual(result, "original prompt") + + def test_predict_chinese_prompt(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response( + "一幅精美的山水画,远处的青山在薄雾中若隐若现,近处的溪流清澈见底。" + ) + ) + result = model.predict("中国山水画") + self.assertIn("山水", result) + + def test_predict_multiline_think_tags(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response( + "\nLine 1\nLine 2\nLine 3\n\n" + "Final enhanced prompt." + ) + ) + result = model.predict("test") + self.assertEqual(result, "Final enhanced prompt.") + self.assertNotIn("", result) + + def test_predict_long_prompt(self): + model = self._make_model() + long_prompt = "A scene with " + ", ".join(f"element {i}" for i in range(50)) + model._client.chat.completions.create = MagicMock( + return_value=_mock_response("Enhanced: " + long_prompt) + ) + result = model.predict(long_prompt) + self.assertIn("Enhanced:", result) + + def test_predict_empty_prompt(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response("A blank canvas.") + ) + result = model.predict("") + self.assertEqual(result, "A blank canvas.") + + def test_predict_special_characters(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response("Enhanced with 'quotes' and extra details") + ) + result = model.predict("prompt with 'quotes'") + self.assertIn("quotes", result) + + def test_predict_preserves_original_on_timeout(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + side_effect=TimeoutError("Request timed out") + ) + result = model.predict("my prompt") + self.assertEqual(result, "my prompt") + + def test_predict_with_nested_tags(self): + model = self._make_model() + model._client.chat.completions.create = MagicMock( + return_value=_mock_response( + "step 1\n" + "A beautiful scene" + ) + ) + result = model.predict("scene") + self.assertEqual(result, "A beautiful scene") + + +class TestRePromptCloudInterface(unittest.TestCase): + """Test interface compatibility with local RePrompt models.""" + + def test_to_method_is_noop(self): + model, _ = _make_cloud() + result = model.to("cuda") + self.assertIs(result, model) + + def test_to_method_chaining(self): + model, _ = _make_cloud() + result = model.to("cuda").to("cpu").to("cuda:0") + self.assertIs(result, model) + + def test_has_predict_method(self): + model, _ = _make_cloud() + self.assertTrue(callable(getattr(model, "predict", None))) + + def test_has_to_method(self): + model, _ = _make_cloud() + self.assertTrue(callable(getattr(model, "to", None))) + + def test_predict_signature_matches_local(self): + """Verify predict() accepts same args as local RePrompt.""" + import inspect + + sig = inspect.signature(RePromptCloud.predict) + params = list(sig.parameters.keys()) + self.assertIn("prompt_cot", params) + self.assertIn("sys_prompt", params) + + +class TestModelZooMiniMax(unittest.TestCase): + """Test MiniMax reprompt configuration in model_zoo.""" + + def test_minimax_reprompt_config_exists(self): + # Import model_zoo components that don't need torch/diffusers + from hyimage.common.config.base_config import RepromptConfig + from hyimage.common.config.lazy import LazyCall as L, LazyObject + + config = RepromptConfig( + model=L(RePromptCloud)(models_root_path=None, provider="minimax"), + load_from="", + ) + self.assertIsInstance(config, RepromptConfig) + self.assertEqual(config.load_from, "") + self.assertIsInstance(config.model, LazyObject) + + def test_minimax_config_instantiates_with_api_key(self): + from hyimage.common.config.lazy import LazyCall as L, instantiate + + config_model = L(RePromptCloud)(models_root_path=None, provider="minimax") + mock_openai_module = MagicMock() + with patch.dict("sys.modules", {"openai": mock_openai_module}): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}): + model = instantiate(config_model) + self.assertIsNotNone(model) + self.assertEqual(model.provider, "minimax") + self.assertEqual(model.model, "MiniMax-M2.7") + + def test_minimax_config_cloud_flag(self): + """Cloud reprompt config has empty load_from (no local model path).""" + from hyimage.common.config.base_config import RepromptConfig + from hyimage.common.config.lazy import LazyCall as L + + config = RepromptConfig( + model=L(RePromptCloud)(models_root_path=None, provider="minimax"), + load_from="", + ) + # Pipeline uses empty load_from to detect cloud mode + self.assertFalse(bool(config.load_from)) + + +class TestPipelineRepromptResolution(unittest.TestCase): + """Test that the pipeline correctly resolves reprompt model names. + + Since the pipeline module has heavy dependencies (torch, diffusers), + we test the resolution logic by verifying the model_zoo factory + functions produce correct configs. + """ + + def test_minimax_factory_produces_empty_load_from(self): + """HUNYUANIMAGE_REPROMPT_MINIMAX should produce config with empty load_from.""" + from hyimage.common.config.base_config import RepromptConfig + from hyimage.common.config.lazy import LazyCall as L + + # Replicate what model_zoo.HUNYUANIMAGE_REPROMPT_MINIMAX does + config = RepromptConfig( + model=L(RePromptCloud)(models_root_path=None, provider="minimax"), + load_from="", + ) + self.assertEqual(config.load_from, "") + + def test_local_factory_produces_nonempty_load_from(self): + """Local reprompt configs should have non-empty load_from.""" + from hyimage.common.config.base_config import RepromptConfig + from hyimage.common.config.lazy import LazyCall as L + + config = RepromptConfig( + model=L(MagicMock)(models_root_path=None), + load_from="./ckpts/reprompt", + ) + self.assertTrue(bool(config.load_from)) + + def test_resolution_logic(self): + """Test the _resolve_reprompt_config logic without importing the pipeline.""" + from hyimage.common.config.base_config import RepromptConfig + from hyimage.common.config.lazy import LazyCall as L + + # Simulate what _resolve_reprompt_config does + def resolve(name): + if name == "hunyuanimage-reprompt-32b": + return RepromptConfig( + model=L(MagicMock)(models_root_path=None), + load_from="./ckpts/reprompt_32b", + ) + elif name in ("hunyuanimage-reprompt-minimax", "minimax"): + return RepromptConfig( + model=L(RePromptCloud)( + models_root_path=None, provider="minimax" + ), + load_from="", + ) + else: + return RepromptConfig( + model=L(MagicMock)(models_root_path=None), + load_from="./ckpts/reprompt", + ) + + self.assertEqual(resolve("minimax").load_from, "") + self.assertEqual(resolve("hunyuanimage-reprompt-minimax").load_from, "") + self.assertNotEqual(resolve("hunyuanimage-reprompt-32b").load_from, "") + self.assertNotEqual(resolve("hunyuanimage-reprompt").load_from, "") + + +class TestRePromptCloudIntegration(unittest.TestCase): + """Integration tests for MiniMax Cloud prompt enhancement. + + These tests require a valid MINIMAX_API_KEY environment variable + and network access to the MiniMax API. + """ + + @unittest.skipUnless( + os.environ.get("MINIMAX_API_KEY"), + "MINIMAX_API_KEY not set; skipping integration tests", + ) + def test_live_prompt_enhancement_english(self): + model = RePromptCloud() + result = model.predict("A cute cat sitting on a windowsill") + self.assertIsInstance(result, str) + self.assertGreater(len(result), 10) + self.assertGreater(len(result), len("A cute cat sitting on a windowsill")) + + @unittest.skipUnless( + os.environ.get("MINIMAX_API_KEY"), + "MINIMAX_API_KEY not set; skipping integration tests", + ) + def test_live_prompt_enhancement_chinese(self): + model = RePromptCloud() + result = model.predict("一只可爱的猫咪坐在窗台上") + self.assertIsInstance(result, str) + self.assertGreater(len(result), 5) + + @unittest.skipUnless( + os.environ.get("MINIMAX_API_KEY"), + "MINIMAX_API_KEY not set; skipping integration tests", + ) + def test_live_predict_interface_compatibility(self): + """Verify cloud model has same interface as local models.""" + model = RePromptCloud() + result = model.predict( + "sunset over mountains", + sys_prompt="Enhance this image prompt with more details.", + ) + self.assertIsInstance(result, str) + self.assertGreater(len(result), 0) + same_model = model.to("cuda") + self.assertIs(same_model, model) + + +if __name__ == "__main__": + unittest.main()