-
Notifications
You must be signed in to change notification settings - Fork 417
Expand file tree
/
Copy pathgenerate.py
More file actions
executable file
·40 lines (30 loc) · 994 Bytes
/
generate.py
File metadata and controls
executable file
·40 lines (30 loc) · 994 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import os
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.data.tokenizer import Tokenizer
import transformers
state_dict = "./checkpoints_in/"
model_name = 'Llama-3.1-8B'
loader = AutoLoader("llama3",
model_dir=state_dict,
model_name=model_name,
device='cuda',
use_cache=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
model.eval()
model.cuda()
print("model loaded")
text = "Gravity is "
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=1024
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
print("content:", content)