1919from modeling .qwen2 import Qwen2Tokenizer
2020
2121import argparse
22+ from accelerate .utils import BnbQuantizationConfig , load_and_quantize_model
2223
2324parser = argparse .ArgumentParser ()
2425parser .add_argument ("--server_name" , type = str , default = "127.0.0.1" )
2526parser .add_argument ("--server_port" , type = int , default = 7860 )
2627parser .add_argument ("--share" , action = "store_true" )
2728parser .add_argument ("--model_path" , type = str , default = "models/BAGEL-7B-MoT" )
29+ parser .add_argument ("--mode" , type = int , default = 1 )
30+ parser .add_argument ("--zh" , action = "store_true" )
2831args = parser .parse_args ()
2932
3033# Model Initialization
3134model_path = args .model_path #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT to models/BAGEL-7B-MoT
3235
36+ model_path = args .model_path
37+
3338llm_config = Qwen2Config .from_json_file (os .path .join (model_path , "llm_config.json" ))
3439llm_config .qk_norm = True
3540llm_config .tie_word_embeddings = False
9499 for k in same_device_modules :
95100 if k in device_map :
96101 device_map [k ] = first_device
97-
98- model = load_checkpoint_and_dispatch (
99- model ,
100- checkpoint = os .path .join (model_path , "ema.safetensors" ),
101- device_map = device_map ,
102- offload_buffers = True ,
103- offload_folder = "offload" ,
104- dtype = torch .bfloat16 ,
105- force_hooks = True ,
106- ).eval ()
107102
103+ if args .mode == 1 :
104+ model = load_checkpoint_and_dispatch (
105+ model ,
106+ checkpoint = os .path .join (model_path , "ema.safetensors" ),
107+ device_map = device_map ,
108+ offload_buffers = True ,
109+ offload_folder = "offload" ,
110+ dtype = torch .bfloat16 ,
111+ force_hooks = True ,
112+ ).eval ()
113+ elif args .mode == 2 :
114+ bnb_quantization_config = BnbQuantizationConfig (load_in_4bit = True , bnb_4bit_compute_dtype = torch .bfloat16 , bnb_4bit_use_double_quant = False , bnb_4bit_quant_type = "nf4" )
115+ model = load_and_quantize_model (
116+ model ,
117+ weights_location = os .path .join (model_path , "ema.safetensors" ),
118+ bnb_quantization_config = bnb_quantization_config ,
119+ device_map = device_map ,
120+ offload_folder = "offload" ,
121+ ).eval ()
122+ else :
123+ bnb_quantization_config = BnbQuantizationConfig (load_in_8bit = True , torch_dtype = torch .float32 )
124+ model = load_and_quantize_model (
125+ model ,
126+ weights_location = os .path .join (model_path , "ema.safetensors" ),
127+ bnb_quantization_config = bnb_quantization_config ,
128+ device_map = device_map ,
129+ offload_folder = "offload" ,
130+ ).eval ()
108131
109132# Inferencer Preparing
110133inferencer = InterleaveInferencer (
@@ -366,7 +389,7 @@ def process_text_to_image(prompt, show_thinking, cfg_text_scale,
366389 with gr .Row ():
367390 edit_cfg_renorm_type = gr .Dropdown (choices = ["global" , "local" , "text_channel" ],
368391 value = "text_channel" , label = "CFG Renorm Type" ,
369- info = "If the genrated image is blurry, use 'global" )
392+ info = "If the genrated image is blurry, use 'global' " )
370393 edit_cfg_renorm_min = gr .Slider (minimum = 0.0 , maximum = 1.0 , value = 0.0 , step = 0.1 , interactive = True ,
371394 label = "CFG Renorm Min" , info = "1.0 disables CFG-Renorm" )
372395
@@ -513,10 +536,73 @@ def process_understanding(image, prompt, show_thinking, do_sample,
513536</div>
514537""" )
515538
539+ UI_TRANSLATIONS = {
540+ "📝 Text to Image" :"📝 文生图" ,
541+ "Prompt" :"提示词" ,
542+ "Thinking" :"思考模式" ,
543+ "Inference Hyperparameters" :"推理参数" ,
544+ "Seed" :"种子" ,
545+ "0 for random seed, positive for reproducible results" :"0为随机种子,正数表示可重复结果" ,
546+ "Image Ratio" :"图片比例" ,
547+ "The longer size is fixed to 1024" :"长边固定为1024" ,
548+ "CFG Text Scale" :"CFG 文本规模" ,
549+ "Controls how strongly the model follows the text prompt (4.0-8.0)" :"控制模型是否遵循文本提示(4.0-8.0)" ,
550+ "CFG Interval" :"CFG 间隔" ,
551+ "Start of CFG application interval (end is fixed at 1.0)" :"CFG 应用间隔的开始(结束固定为1.0)" ,
552+ "CFG Renorm Type" :"CFG 重新归一化类型" ,
553+ "If the genrated image is blurry, use 'global'" :"如果生成的图像模糊,请使用'global'" ,
554+ "CFG Renorm Min" :"CFG 重新归一化最小值" ,
555+ "1.0 disables CFG-Renorm" :"1.0 禁用 CFG 重新归一化" ,
556+ "Timesteps" :"时间步数" ,
557+ "Total denoising steps" :"总去噪步数" ,
558+ "Timestep Shift" :"时间偏移" ,
559+ "Higher values for layout, lower for details" :"布局更高,细节更低" ,
560+ "Sampling" :"采样" ,
561+ "Enable sampling for text generation" :"为文本生成启用采样" ,
562+ "Max Think Tokens" :"最大思考标记数" ,
563+ "Maximum number of tokens for thinking" :"思考的最大标记数" ,
564+ "Temperature" :"温度" ,
565+ "Controls randomness in text generation" :"控制文本生成的随机性" ,
566+ "Thinking Process" :"思考过程" ,
567+ "Generated Image" :"生成图像" ,
568+ "Generate" :"开始生成" ,
569+ "🖌️ Image Edit" :"🖌️ 图像编辑" ,
570+ "Input Image" :"图像输入" ,
571+ "Result" :"结果" ,
572+ "Controls how strongly the model follows the text prompt" :"控制模型是否遵循文本提示的强度" ,
573+ "CFG Image Scale" :"CFG图像规模" ,
574+ "Controls how much the model preserves input image details" :"控制模型是否保留输入图像细节的强度" ,
575+ "Submit" :"开始生成" ,
576+ "🖼️ Image Understanding" :"🖼️ 图像理解" ,
577+ "Controls randomness in text generation (0=deterministic, 1=creative)" :"控制文本生成的随机性(0=确定,1= creative)" ,
578+ "Max New Tokens" :"最大新标记" ,
579+ "Maximum length of generated text, including potential thinking" :"生成文本的最大长度,包括可能的思考" ,
580+ }
581+
582+ def apply_localization (block ):
583+ def process_component (component ):
584+ if not component :
585+ return
586+
587+ for attr in ['label' , 'info' , 'placeholder' ]:
588+ if hasattr (component , attr ):
589+ text = getattr (component , attr )
590+ if text in UI_TRANSLATIONS :
591+ setattr (component , attr , UI_TRANSLATIONS [text ])
592+
593+ if hasattr (component , 'children' ):
594+ for child in component .children :
595+ process_component (child )
596+
597+ process_component (block )
598+ return block
599+
516600if __name__ == "__main__" :
601+ if args .zh :
602+ demo = apply_localization (demo )
517603 demo .launch (
518604 server_name = args .server_name ,
519605 server_port = args .server_port ,
520606 share = args .share ,
521607 inbrowser = True ,
522- )
608+ )
0 commit comments