Skip to content

Commit 3088115

Browse files
committed
Support Qwen & Llama
Signed-off-by: ftgreat <ftgreat@163.com>
1 parent 0ae8c6b commit 3088115

File tree

9 files changed

+509
-2
lines changed

9 files changed

+509
-2
lines changed

examples/Llama/bmtrain_mgpu.sh

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Defined by User
2+
export TRIGGER_FILE=bmtrain_mgpu.sh
3+
export SCRIPT_FILE=llama_pretrain.py
4+
5+
# ENVS
6+
export PROJ_HOME=$PWD
7+
export PRE_LOAD_DIR=$PROJ_HOME/checkpoints_in
8+
export NCCL_SOCKET_IFNAME=eth0
9+
export NCCL_IB_DISABLE=0
10+
export NCCL_IB_CUDA_SUPPORT=1
11+
export NCCL_IB_GID_INDEX=0
12+
export NCCL_IB_HCA=mlx5_2,mlx5_5
13+
export NCCL_DEBUG=debug
14+
export OMP_NUM_THREADS=4
15+
16+
echo "[INFO] $0: hostfile configfile model_name exp_name exp_version"
17+
set -u
18+
hostfile=$1
19+
configfile=$2
20+
model_name=$3
21+
exp_name=$4
22+
exp_version=$5
23+
set +u
24+
25+
# DIST
26+
export HOSTFILE=$hostfile
27+
export CONFIGFILE=$configfile
28+
export NODE_ADDR=$(ifconfig -a|grep inet|grep -v 127.0.0.1|grep -v inet6|awk '{print $2;}'|tr -d "addr:")
29+
export GPU_NUM_PER_NODE=$(awk -F" |=" '{ranks[$1]=$NF;}END{print ranks["'$NODE_ADDR'"];}' $HOSTFILE)
30+
export NODES_NUM=$(cat $HOSTFILE | wc -l)
31+
export MASTER_ADDR=$(head -n1 $HOSTFILE | awk '{print $1;}')
32+
export RANK=$(awk '{ranks[$1]=(FNR-1);}END{print ranks["'$NODE_ADDR'"];}' $HOSTFILE)
33+
export MASTER_PORT=23456
34+
35+
36+
## wandb
37+
export WANDB_MODE=offline
38+
39+
## EXP
40+
export MODEL_NAME=$model_name
41+
export EXP_NAME=$exp_name
42+
export WANDB_DIR=$PROJ_HOME/wandb/${EXP_NAME}/$exp_version
43+
mkdir -p $PROJ_HOME/checkpoints_out
44+
export SAVE_DIR=$PROJ_HOME/checkpoints_out/${EXP_NAME}/$exp_version
45+
mkdir -p $SAVE_DIR
46+
mkdir -p $WANDB_DIR
47+
## Backup ckpts & scripts into exp versions
48+
cp -r $PRE_LOAD_DIR/$MODEL_NAME $SAVE_DIR
49+
cp -r $PROJ_HOME/$TRIGGER_FILE $SAVE_DIR
50+
cp -r $hostfile $SAVE_DIR
51+
cp -r $configfile $SAVE_DIR
52+
53+
export EPOCH_NUM=1
54+
export BATCH_SIZE=6
55+
export GRADIENT_ACCUM_STEPS=1
56+
export LR=3.0e-4
57+
export LR=1.0e-5
58+
export LR=6.0e-5
59+
export WARMUP_RATE=0.008
60+
export WARMUP_RATE=0.02
61+
export WARMUP_RATE=0.1
62+
export WARMUP_RATE=0.2
63+
64+
## EXTRA OPTS
65+
OPTS=" --batch_size $BATCH_SIZE \
66+
--epochs $EPOCH_NUM \
67+
--gradient_accumulation_steps $GRADIENT_ACCUM_STEPS \
68+
--lr $LR \
69+
--warm_up $WARMUP_RATE \
70+
--weight_decay 0.1 \
71+
--adam_beta1 0.9 \
72+
--adam_beta2 0.95 \
73+
--save_dir $SAVE_DIR \
74+
--pre_load_dir $PRE_LOAD_DIR \
75+
--experiment_name $EXP_NAME \
76+
--model_name $MODEL_NAME \
77+
--wandb_dir $WANDB_DIR \
78+
--yaml_config $CONFIGFILE"
79+
80+
## Trigger job on Each Node when bmt or ddp.
81+
82+
mkdir -p $PRE_LOAD_DIR
83+
torchrun \
84+
--nproc_per_node $GPU_NUM_PER_NODE \
85+
--nnodes $NODES_NUM \
86+
--node_rank $RANK \
87+
--master_addr $MASTER_ADDR \
88+
--master_port $MASTER_PORT \
89+
$SCRIPT_FILE \
90+
--not_call_launch \
91+
$OPTS

examples/Llama/generate.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright © 2023 BAAI. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License")
4+
import os
5+
import torch
6+
from flagai.auto_model.auto_loader import AutoLoader
7+
from flagai.data.tokenizer import Tokenizer
8+
import transformers
9+
10+
state_dict = "./checkpoints_in/"
11+
model_name = 'Llama-3.1-8B'
12+
13+
loader = AutoLoader("llama3",
14+
model_dir=state_dict,
15+
model_name=model_name,
16+
device='cuda',
17+
use_cache=True)
18+
model = loader.get_model()
19+
tokenizer = loader.get_tokenizer()
20+
21+
model.eval()
22+
23+
model.cuda()
24+
25+
print("model loaded")
26+
27+
text = "Gravity is "
28+
29+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
30+
31+
# conduct text completion
32+
generated_ids = model.generate(
33+
**model_inputs,
34+
max_new_tokens=1024
35+
)
36+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
37+
38+
content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
39+
40+
print("content:", content)

examples/Llama/llama-pretrain.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
batch_size: 1
2+
gradient_accumulation_steps: 1
3+
lr: 1.0e-5
4+
warm_up: 0.01
5+
save_interval: 100
6+
log_interval: 1
7+
bmt_loss_scale: 1.0
8+
save_optim: True
9+
save_rng: True
10+
eps: 1.e-8
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Optional, Tuple
2+
import warnings
3+
4+
import torch
5+
from torch import nn
6+
import transformers
7+
8+
from transformers.cache_utils import Cache, DynamicCache
9+
from transformers.processing_utils import Unpack
10+
from transformers.modeling_outputs import (
11+
BaseModelOutputWithPast,
12+
CausalLMOutputWithPast,
13+
)
14+
from transformers.utils import TransformersKwargs, auto_docstring
15+
from transformers.utils.generic import check_model_inputs
16+
from transformers.masking_utils import create_causal_mask
17+
18+
19+
def forward(
20+
self,
21+
input_ids: Optional[torch.LongTensor] = None,
22+
attention_mask: Optional[torch.Tensor] = None,
23+
position_ids: Optional[torch.LongTensor] = None,
24+
past_key_values: Optional[Cache] = None,
25+
inputs_embeds: Optional[torch.FloatTensor] = None,
26+
cache_position: Optional[torch.LongTensor] = None,
27+
use_cache: Optional[bool] = None,
28+
**kwargs: Unpack[TransformersKwargs],
29+
) -> BaseModelOutputWithPast:
30+
if (input_ids is None) ^ (inputs_embeds is not None):
31+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
32+
33+
if inputs_embeds is None:
34+
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
35+
36+
if use_cache and past_key_values is None:
37+
past_key_values = DynamicCache()
38+
39+
if cache_position is None:
40+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
41+
cache_position: torch.Tensor = torch.arange(
42+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
43+
)
44+
45+
if position_ids is None:
46+
position_ids = cache_position.unsqueeze(0)
47+
48+
causal_mask = create_causal_mask(
49+
config=self.config,
50+
input_embeds=inputs_embeds,
51+
attention_mask=attention_mask,
52+
cache_position=cache_position,
53+
past_key_values=past_key_values,
54+
position_ids=position_ids,
55+
)
56+
57+
hidden_states = inputs_embeds
58+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
59+
60+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
61+
hidden_states = decoder_layer(
62+
hidden_states,
63+
attention_mask=causal_mask,
64+
position_ids=position_ids,
65+
past_key_value=past_key_values,
66+
use_cache=use_cache,
67+
cache_position=cache_position,
68+
position_embeddings=position_embeddings,
69+
**kwargs,
70+
)
71+
72+
hidden_states = self.norm(hidden_states)
73+
return BaseModelOutputWithPast(
74+
last_hidden_state=hidden_states,
75+
past_key_values=past_key_values,
76+
)
77+
78+
79+
def replace_llama_attn_with_bmt():
80+
print("replace_llama_attn_with_bmt")
81+
transformers.models.llama.modeling_llama.LlamaModel.forward = forward
82+

examples/Llama/llama_pretrain.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright © 2022 BAAI. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License")
4+
import os
5+
import torch
6+
from torch.utils.data import Dataset
7+
import gc
8+
import json
9+
10+
gc.collect()
11+
torch.cuda.empty_cache()
12+
13+
from transformers import AutoModelForCausalLM, AutoTokenizer
14+
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
15+
from llama_bmt_monkey_patch import (
16+
replace_llama_attn_with_bmt,
17+
)
18+
19+
from flagai.env_args import EnvArgs
20+
from flagai.env_trainer_v1 import EnvTrainer
21+
from flagai.data.dataset.indexed_dataset.build_index_mappings import _build_train_valid_test_datasets, _build_train_valid_test_weighted_datasets
22+
import bmtrain as bmt
23+
24+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25+
26+
# You can input all parameters by the command line.
27+
# For example: python train_env_trainer.py --epochs=300 --batch_size=4 --env_type=pytorch
28+
env_args = EnvArgs(
29+
env_type="bmtrain",
30+
experiment_name="llama3",
31+
batch_size=1,
32+
gradient_accumulation_steps=1,
33+
lr=2e-4,
34+
weight_decay=1e-3,
35+
epochs=10000,
36+
log_interval=1,
37+
eval_interval=5000,
38+
num_gpus=1,
39+
load_dir=None,
40+
pytorch_device=device,
41+
save_dir="checkpoints_out",
42+
checkpoint_activations=False,
43+
save_interval=100,
44+
fp16=True,
45+
training_script=__file__,
46+
)
47+
env_args = env_args.parse_args()
48+
#env_args.wandb = False
49+
50+
# overwrite
51+
if env_args.yaml_config:
52+
import yaml
53+
file_data = open(env_args.yaml_config, 'r', encoding="utf-8").read()
54+
data = yaml.load_all(file_data, Loader=yaml.SafeLoader)
55+
delattr(env_args, 'yaml_config')
56+
arg_dict = env_args.__dict__
57+
for subdata in data:
58+
for key, value in subdata.items():
59+
if isinstance(value, list):
60+
for v in value:
61+
arg_dict[key].append(v)
62+
else:
63+
arg_dict[key] = value
64+
trainer = EnvTrainer(env_args)
65+
66+
# Trainer as Trigger
67+
if not env_args.not_call_launch:
68+
import sys
69+
sys.exit(0)
70+
71+
print(f"Trainer effective env_args={env_args} local_rank={os.environ['LOCAL_RANK']}",
72+
flush=True)
73+
checkpoints = env_args.pre_load_dir
74+
model_name = env_args.model_name
75+
76+
print('*' * 20, "model_name", model_name, flush=True)
77+
78+
cache_dir = os.path.join(checkpoints, model_name)
79+
print('*' * 20, "cache_dir", cache_dir)
80+
tokenizer = AutoTokenizer.from_pretrained(cache_dir)
81+
print('*' * 20, "tokenizer", tokenizer)
82+
83+
# avoid sync loading models in case of Mem OOM
84+
if env_args.bmt_async_load:
85+
import time
86+
time.sleep(10 * 60 * (os.environ['LOCAL_RANK'] % 4))
87+
88+
config_file = os.path.join(cache_dir, 'config.json')
89+
with open(config_file, 'r') as f:
90+
model_args = json.load(f)
91+
92+
# bmt
93+
replace_llama_attn_with_bmt()
94+
95+
model = LlamaForCausalLM.from_pretrained(cache_dir)
96+
97+
## bmt_pre_load
98+
99+
trainer.pre_train(model)
100+
101+
print('*' * 20, "model", model, flush=True)
102+
103+
## Use Prebuilt DataSets
104+
data_prefix = '../indexed_dataset/data/demo_text_document'
105+
data_impl = 'mmap'
106+
splits_string = '90,10'
107+
train_valid_test_num_samples = [90, 10]
108+
seq_length = 1024
109+
seed = 2023
110+
skip_warmup = True
111+
112+
train_dataset, valid_dataset, _ = _build_train_valid_test_datasets(
113+
data_prefix, data_impl, splits_string, train_valid_test_num_samples,
114+
seq_length, seed, skip_warmup)
115+
print("Total train_dataset: ", len(train_dataset), flush=True)
116+
print("Total valid_dataset: ", len(valid_dataset), flush=True)
117+
118+
119+
def collate_fn(batch):
120+
121+
def padding(indice, max_length, pad_idx=0):
122+
pad_indice = [
123+
item.tolist() + [pad_idx] * max(0, max_length - len(item.tolist()))
124+
for item in indice
125+
]
126+
return torch.tensor(pad_indice)
127+
128+
input_ids = [data["input_ids"] for data in batch]
129+
max_length = max([len(t) for t in input_ids])
130+
input_ids = padding(input_ids, max_length)[:, :seq_length]
131+
132+
data = {"input_ids": input_ids, "labels": input_ids}
133+
return data
134+
135+
136+
trainer.do_train(train_dataset=train_dataset,
137+
valid_dataset=None,
138+
collate_fn=collate_fn,
139+
optimizer=None,
140+
rank_split=False)
141+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/bin/bash
2+
#
3+
# Defined by user
4+
export PROJ_HOME=$PWD
5+
6+
echo "[INFO] $0: hostfile configfile model_name exp_name"
7+
set -u
8+
hostfile=$1
9+
configfile=$2
10+
model_name=$3
11+
exp_name=$4
12+
set +u
13+
NODES_NUM=`cat $hostfile |wc -l`
14+
echo "NODES_NUM": $NODES_NUM
15+
if [ $NODES_NUM -ne 1 ];then
16+
echo "Make Sure One Node in hostfile"
17+
exit 0
18+
fi
19+
20+
exp_YYYYMMDDHH=$(date +"%Y%m%d%H")
21+
echo "exp_YYYYMMDDHH": $exp_YYYYMMDDHH
22+
23+
SAVE_DIR=$PROJ_HOME/checkpoints_out/${exp_name}/$exp_YYYYMMDDHH
24+
LOGFILE=$SAVE_DIR/$configfile.log.txt
25+
echo "LOGFILE": $LOGFILE
26+
27+
cd $PROJ_HOME;
28+
mkdir -p $SAVE_DIR;
29+
bash bmtrain_mgpu.sh $hostfile $configfile $model_name $exp_name $exp_YYYYMMDDHH 1>$LOGFILE 2>&1 &

0 commit comments

Comments
 (0)