-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathextract_attention_maps.py
69 lines (53 loc) · 2.03 KB
/
extract_attention_maps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
import pickle
import argparse
import typing as t
from v1t import data
from v1t.utils import utils
from v1t.models.model import Model
from v1t.utils.scheduler import Scheduler
from v1t.utils.attention_rollout import extract_attention_maps
from torch.utils.data import DataLoader
def extract(ds: t.Dict[str, DataLoader], model: Model, device: torch.device = "cpu"):
results = {}
for mouse_id, mouse_ds in ds.items():
if mouse_id in ("S0", "S1"):
continue
results[mouse_id] = extract_attention_maps(
ds=mouse_ds, model=model, device=device
)
return results
def main(args):
if not os.path.isdir(args.output_dir):
raise FileNotFoundError(f"Cannot find {args.output_dir}.")
utils.get_device(args)
utils.set_random_seed(1234)
utils.load_args(args)
_, val_ds, test_ds = data.get_training_ds(
args,
data_dir=args.dataset,
mouse_ids=args.mouse_ids,
batch_size=args.batch_size,
device=args.device,
)
model = Model(args, ds=val_ds)
model.train(False)
scheduler = Scheduler(args, model=model, save_optimizer=False)
scheduler.restore(force=True)
results = {}
print(f"Extract attention rollout maps from validation set.")
results["validation"] = extract(ds=val_ds, model=model, device=args.device)
print(f"\nExtract attention rollout maps from test set.")
results["test"] = extract(ds=test_ds, model=model, device=args.device)
filename = os.path.join(args.output_dir, "attention_rollout_maps.pkl")
with open(filename, "wb") as file:
pickle.dump(results, file)
print(f"Saved attention maps to {filename}.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="../data/sensorium")
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--device", type=str, default=None)
main(parser.parse_args())