Skip to content

Commit f517072

Browse files
authored
Merge pull request #88 from gluttony-10/main
Merge the code with online quantization
2 parents 45e300c + 13f3f35 commit f517072

4 files changed

Lines changed: 129 additions & 19 deletions

File tree

README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,14 @@ cd BAGEL
109109
conda create -n bagel python=3.10 -y
110110
conda activate bagel
111111
pip install -r requirements.txt
112+
pip install flash_attn==2.5.8 --no-build-isolation
112113
```
113114

114115
2️⃣ Download pretrained checkpoint
115116
```python
116117
from huggingface_hub import snapshot_download
117118

118-
save_dir = "/path/to/save/BAGEL-7B-MoT"
119+
save_dir = "models/BAGEL-7B-MoT"
119120
repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
120121
cache_dir = save_dir + "/cache"
121122

@@ -129,14 +130,22 @@ snapshot_download(cache_dir=cache_dir,
129130

130131
```
131132

132-
3️⃣ Go to [`inference.ipynb`](inference.ipynb) to start playing with BAGEL!
133-
134-
4️⃣ Use Gradio WebUI to start playing with BAGEL!
133+
3️⃣ Use Gradio WebUI to start playing with BAGEL!
135134
```bash
136-
pip install gradio
135+
# For 32GB+ VRAM GPU or multi GPUs.
137136
python app.py
138137
```
139138

139+
```bash
140+
# For 12~32GB VRAM GPU, recommend using NF4 quantization. And use Chinese interface.
141+
python app.py --mode 2 --zh
142+
```
143+
144+
```bash
145+
# For 22~32GB VRAM GPU, not recommended to use INT8 quantization.
146+
python app.py --mode 3
147+
```
148+
140149
## 🔥 Train & Eval
141150

142151
### Train

app.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,22 @@
1919
from modeling.qwen2 import Qwen2Tokenizer
2020

2121
import argparse
22+
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
2223

2324
parser = argparse.ArgumentParser()
2425
parser.add_argument("--server_name", type=str, default="127.0.0.1")
2526
parser.add_argument("--server_port", type=int, default=7860)
2627
parser.add_argument("--share", action="store_true")
2728
parser.add_argument("--model_path", type=str, default="models/BAGEL-7B-MoT")
29+
parser.add_argument("--mode", type=int, default=1)
30+
parser.add_argument("--zh", action="store_true")
2831
args = parser.parse_args()
2932

3033
# Model Initialization
3134
model_path = args.model_path #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT to models/BAGEL-7B-MoT
3235

36+
model_path = args.model_path
37+
3338
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
3439
llm_config.qk_norm = True
3540
llm_config.tie_word_embeddings = False
@@ -94,17 +99,35 @@
9499
for k in same_device_modules:
95100
if k in device_map:
96101
device_map[k] = first_device
97-
98-
model = load_checkpoint_and_dispatch(
99-
model,
100-
checkpoint=os.path.join(model_path, "ema.safetensors"),
101-
device_map=device_map,
102-
offload_buffers=True,
103-
offload_folder="offload",
104-
dtype=torch.bfloat16,
105-
force_hooks=True,
106-
).eval()
107102

103+
if args.mode == 1:
104+
model = load_checkpoint_and_dispatch(
105+
model,
106+
checkpoint=os.path.join(model_path, "ema.safetensors"),
107+
device_map=device_map,
108+
offload_buffers=True,
109+
offload_folder="offload",
110+
dtype=torch.bfloat16,
111+
force_hooks=True,
112+
).eval()
113+
elif args.mode == 2:
114+
bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4")
115+
model = load_and_quantize_model(
116+
model,
117+
weights_location=os.path.join(model_path, "ema.safetensors"),
118+
bnb_quantization_config=bnb_quantization_config,
119+
device_map=device_map,
120+
offload_folder="offload",
121+
).eval()
122+
else:
123+
bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, torch_dtype=torch.float32)
124+
model = load_and_quantize_model(
125+
model,
126+
weights_location=os.path.join(model_path, "ema.safetensors"),
127+
bnb_quantization_config=bnb_quantization_config,
128+
device_map=device_map,
129+
offload_folder="offload",
130+
).eval()
108131

109132
# Inferencer Preparing
110133
inferencer = InterleaveInferencer(
@@ -366,7 +389,7 @@ def process_text_to_image(prompt, show_thinking, cfg_text_scale,
366389
with gr.Row():
367390
edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
368391
value="text_channel", label="CFG Renorm Type",
369-
info="If the genrated image is blurry, use 'global")
392+
info="If the genrated image is blurry, use 'global'")
370393
edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
371394
label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
372395

@@ -513,10 +536,73 @@ def process_understanding(image, prompt, show_thinking, do_sample,
513536
</div>
514537
""")
515538

539+
UI_TRANSLATIONS = {
540+
"📝 Text to Image":"📝 文生图",
541+
"Prompt":"提示词",
542+
"Thinking":"思考模式",
543+
"Inference Hyperparameters":"推理参数",
544+
"Seed":"种子",
545+
"0 for random seed, positive for reproducible results":"0为随机种子,正数表示可重复结果",
546+
"Image Ratio":"图片比例",
547+
"The longer size is fixed to 1024":"长边固定为1024",
548+
"CFG Text Scale":"CFG 文本规模",
549+
"Controls how strongly the model follows the text prompt (4.0-8.0)":"控制模型是否遵循文本提示(4.0-8.0)",
550+
"CFG Interval":"CFG 间隔",
551+
"Start of CFG application interval (end is fixed at 1.0)":"CFG 应用间隔的开始(结束固定为1.0)",
552+
"CFG Renorm Type":"CFG 重新归一化类型",
553+
"If the genrated image is blurry, use 'global'":"如果生成的图像模糊,请使用'global'",
554+
"CFG Renorm Min":"CFG 重新归一化最小值",
555+
"1.0 disables CFG-Renorm":"1.0 禁用 CFG 重新归一化",
556+
"Timesteps":"时间步数",
557+
"Total denoising steps":"总去噪步数",
558+
"Timestep Shift":"时间偏移",
559+
"Higher values for layout, lower for details":"布局更高,细节更低",
560+
"Sampling":"采样",
561+
"Enable sampling for text generation":"为文本生成启用采样",
562+
"Max Think Tokens":"最大思考标记数",
563+
"Maximum number of tokens for thinking":"思考的最大标记数",
564+
"Temperature":"温度",
565+
"Controls randomness in text generation":"控制文本生成的随机性",
566+
"Thinking Process":"思考过程",
567+
"Generated Image":"生成图像",
568+
"Generate":"开始生成",
569+
"🖌️ Image Edit":"🖌️ 图像编辑",
570+
"Input Image":"图像输入",
571+
"Result":"结果",
572+
"Controls how strongly the model follows the text prompt":"控制模型是否遵循文本提示的强度",
573+
"CFG Image Scale":"CFG图像规模",
574+
"Controls how much the model preserves input image details":"控制模型是否保留输入图像细节的强度",
575+
"Submit":"开始生成",
576+
"🖼️ Image Understanding":"🖼️ 图像理解",
577+
"Controls randomness in text generation (0=deterministic, 1=creative)":"控制文本生成的随机性(0=确定,1= creative)",
578+
"Max New Tokens":"最大新标记",
579+
"Maximum length of generated text, including potential thinking":"生成文本的最大长度,包括可能的思考",
580+
}
581+
582+
def apply_localization(block):
583+
def process_component(component):
584+
if not component:
585+
return
586+
587+
for attr in ['label', 'info', 'placeholder']:
588+
if hasattr(component, attr):
589+
text = getattr(component, attr)
590+
if text in UI_TRANSLATIONS:
591+
setattr(component, attr, UI_TRANSLATIONS[text])
592+
593+
if hasattr(component, 'children'):
594+
for child in component.children:
595+
process_component(child)
596+
597+
process_component(block)
598+
return block
599+
516600
if __name__ == "__main__":
601+
if args.zh:
602+
demo = apply_localization(demo)
517603
demo.launch(
518604
server_name=args.server_name,
519605
server_port=args.server_port,
520606
share=args.share,
521607
inbrowser=True,
522-
)
608+
)

modeling/bagel/bagel.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from .qwen2_navit import NaiveCache
2121
from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
2222

23+
from tqdm import tqdm
24+
2325

2426
class BagelConfig(PretrainedConfig):
2527
def __init__(
@@ -387,6 +389,8 @@ def forward_cache_update_vit(
387389
packed_vit_token_embed = self.connector(packed_vit_token_embed)
388390
pos_emb = self.vit_pos_embed(packed_vit_position_ids)
389391
packed_vit_token_embed = packed_vit_token_embed + pos_emb
392+
if packed_vit_token_embed.dtype != packed_sequence.dtype:
393+
packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype)
390394
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
391395

392396
extra_inputs = {}
@@ -516,6 +520,8 @@ def forward_cache_update_vae(
516520
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
517521
packed_timestep_embeds = self.time_embedder(packed_timesteps)
518522
packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
523+
if packed_latent.dtype != packed_sequence.dtype:
524+
packed_latent = packed_latent.to(packed_sequence.dtype)
519525
packed_sequence[packed_vae_token_indexes] = packed_latent
520526

521527
extra_inputs = {}
@@ -675,7 +681,7 @@ def generate_image(
675681
dts = timesteps[:-1] - timesteps[1:]
676682
timesteps = timesteps[:-1]
677683

678-
for i, t in enumerate(timesteps):
684+
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
679685

680686
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
681687
if t > cfg_interval[0] and t <= cfg_interval[1]:
@@ -762,6 +768,8 @@ def _forward_flow(
762768
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
763769
packed_timestep_embeds = self.time_embedder(timestep)
764770
x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
771+
if x_t.dtype != packed_sequence.dtype:
772+
x_t = x_t.to(packed_sequence.dtype)
765773
packed_sequence[packed_vae_token_indexes] = x_t
766774

767775
extra_inputs = {}

requirements.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ sentencepiece==0.1.99
1313
torch==2.5.1
1414
torchvision==0.20.1
1515
transformers==4.49.0
16-
flash_attn==2.5.8
16+
#flash_attn==2.5.8
1717
accelerate>=0.34.0
1818
wandb
19+
gradio
20+
setuptools
21+
wheel
22+
ninja
23+
bitsandbytes
24+
triton ; sys_platform != 'win32'
25+
triton-windows ; sys_platform == 'win32'

0 commit comments

Comments
 (0)