-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathadd_impl_for_deit.patch
More file actions
112 lines (104 loc) · 4.69 KB
/
add_impl_for_deit.patch
File metadata and controls
112 lines (104 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
Subject: [PATCH] Refactor augment.py imports and update main.py with new arguments and attention methods
---
Index: augment.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/augment.py b/augment.py
--- a/augment.py (revision 7e160fe43f0252d17191b71cbb5826254114ea5b)
+++ b/augment.py (revision 01e623e7e4d8cb54df2ea4706ec80caed6c27267)
@@ -9,7 +9,7 @@
import torch
from torchvision import transforms
-from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
+from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
import numpy as np
from torchvision import datasets, transforms
Index: main.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/main.py b/main.py
--- a/main.py (revision 7e160fe43f0252d17191b71cbb5826254114ea5b)
+++ b/main.py (revision 01e623e7e4d8cb54df2ea4706ec80caed6c27267)
@@ -23,9 +23,6 @@
from samplers import RASampler
from augment import new_data_aug_generator
-import models
-import models_v2
-
import utils
@@ -45,7 +42,11 @@
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)')
-
+ parser.add_argument("--method", default=None, help="method for attention (int_attention, idx_softmax_only, exaq_attention, quant_only)")
+ parser.add_argument("--inp-quant-bit", type=int, default=8)
+ parser.add_argument("--quant-bit", type=int, default=5)
+ parser.add_argument("--zero-thr", type=float, default=6.6)
+ parser.add_argument("--bitwidth", type=int, default=3, help="bitwidth for EXAQ (default: 3)")
parser.add_argument('--model-ema', action='store_true')
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
parser.set_defaults(model_ema=True)
@@ -206,7 +207,7 @@
cudnn.benchmark = True
- dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
+ dataset_train, args.nb_classes = build_dataset(is_train=False, args=args)
dataset_val, _ = build_dataset(is_train=False, args=args)
if args.distributed:
@@ -262,7 +263,7 @@
print(f"Creating model: {args.model}")
model = create_model(
args.model,
- pretrained=False,
+ pretrained=True,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
@@ -410,6 +411,43 @@
loss_scaler.load_state_dict(checkpoint['scaler'])
lr_scheduler.step(args.start_epoch)
if args.eval:
+ import torch.nn.functional as F
+ from functools import partial
+ import pysimulation
+
+ if args.method in ("int_attention", "intattention", "int-attention"):
+ F.scaled_dot_product_attention = partial(
+ pysimulation.int_attention,
+ inp_quant_bit=args.inp_quant_bit,
+ quant_bit=args.quant_bit,
+ zero_thr=args.zero_thr,
+ )
+ elif args.method in ("idx_softmax_only", "idx_softmax", "idxsoftmaxonly"):
+ # idx_softmax_only expects (query,key,value, ... , quant_bit, zero_thr)
+ F.scaled_dot_product_attention = partial(
+ pysimulation.idx_softmax_only,
+ inp_quant_bit=args.inp_quant_bit,
+ quant_bit=args.quant_bit,
+ zero_thr=args.zero_thr,
+ )
+ elif args.method in ("exaq_attention", "exaq", "exaq_attention"):
+ F.scaled_dot_product_attention = partial(
+ pysimulation.exaq_attention,
+ bitwidth=args.bitwidth,
+ )
+ elif args.method in ("quant_only", "quantonly", "quant_only"):
+ F.scaled_dot_product_attention = partial(
+ pysimulation.quant_only,
+ inp_quant_bit=args.inp_quant_bit,
+ )
+ elif args.method is None:
+ # No method specified: leave default PyTorch implementation
+ print(f"[INFO] No attention method specified, leaving default attention.")
+ else:
+ # Unknown method: leave default PyTorch implementation
+ print(f"[WARN] Unknown attention method '{args.method}', leaving default attention.")
+
+ print(f"Using {args.method} for scaled dot product attention")
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return