-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpome.py
More file actions
42 lines (30 loc) · 1.51 KB
/
pome.py
File metadata and controls
42 lines (30 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import torch
import torch.nn.functional as F
import numpy as np
parser = argparse.ArgumentParser(description='POME')
parser.add_argument('--base_model', type=str)
parser.add_argument('--model', type=str)
parser.add_argument('--output_path', type=str)
parser.add_argument('--alpha', type=float)
parser.add_argument('--truncation', type=float)
parser.add_argument('--layer', nargs='+', type=str)
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(args.model, device_map='auto')
vanilla_model = AutoModelForCausalLM.from_pretrained(args.base_model, device_map='auto')
for (name, params), (_, params_v) in zip(model.named_parameters(), vanilla_model.named_parameters()):
if args.layer and any(layer_str in name for layer_str in args.layer):
U, S, Vh = torch.linalg.svd((params.data - params_v.data).to(torch.float64), full_matrices=False)
n = int(S.shape[0] * args.truncation)
s_norm = torch.norm(S)
S_sign = torch.sign(S) #
print("name: ", name)
S = S_sign * torch.sqrt(s_norm * s_norm/n) * args.alpha
S[n:] = 0
approx_weight = U @ torch.diag(S) @ Vh
params.data = (params_v.data + approx_weight).to(torch.float32)
print("savemodel")
model.save_pretrained(args.output_path)
tokenizer.save_pretrained(args.output_path)