-
Notifications
You must be signed in to change notification settings - Fork 25
/
evaluate.py
409 lines (368 loc) · 16.1 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
import copy
import os
import torch
import argparse
from transformers import StoppingCriteria, StoppingCriteriaList
from math import ceil
from PIL import Image
import numpy as np
import torch.backends.cudnn as cudnn
from timechat.common.logger import setup_logger
from timechat.common.config import Config
from timechat.common.dist_utils import get_rank
from timechat.common.registry import registry
from timechat.conversation.conversation_video_batch import Chat, Conversation, default_conversation, SeparatorStyle, \
conv_llava_llama_2
import decord
decord.bridge.set_bridge('torch')
import logging
from torchvision.transforms.functional import InterpolationMode
from torchvision import transforms
import pdb
import json
from pathlib import Path
import time
import datetime
from tqdm import tqdm
import random
random.seed(1234)
from utils.format_dvc import format_dvc_output
from utils.format_tvg import format_tvg_output
from utils.format_vhd import format_vhd_output
def read_txt(path):
with open(path, "r") as fin:
data = fin.readline().strip()
return data
def load_data(args, anno_path, split=None):
'''
anno data example:
{"annotations":
[
{
"image_id": "xHr8X2Wpmno.mp4"
...
},
...
]
}
'''
file_path = os.path.join(anno_path, f'{split}.caption_coco_format.json')
with open(file_path, 'r') as f:
data = json.load(f)["annotations"]
if args.debug:
data = data[:10]
return data
def merge_seg_caps(results):
"""merge mulple generated captions from a same video into paragraph."""
merge_results = {}
for jterm in results:
vname = jterm["vname"]
cap = jterm["generated_cap"]
postfix = vname.split(".mp4")[-1]
start_time, end_time = float(postfix.split("_")[-2]), float(postfix.split("_")[-1])
vid = vname.split(".mp4")[0] + ".mp4"
if vid not in merge_results:
merge_results[vid] = []
merge_results[vid].append({"timestamp": [start_time, end_time], "caption": cap})
return merge_results
def save_result(args, output_dir, results, split_name='test', format=False):
Path(output_dir).mkdir(parents=True, exist_ok=True)
file_name = f'{args.dataset}_{split_name}_f{args.num_frames}_result.json'
if args.timestamp:
if args.timestamp_file != '':
file_name = f'{args.dataset}_{split_name}_f{args.num_frames}_result_with_pred_timestamp.json'
else:
file_name = f'{args.dataset}_{split_name}_f{args.num_frames}_result_with_gt_timestamp.json'
if args.debug:
file_name = 'debug_' + file_name
if format:
file_name = 'fmt_' + file_name
with open(os.path.join(output_dir, file_name), 'w') as f:
json.dump(results, f)
return
def get_timestamp_from_file(timestamp_file):
timestamp = {}
with open(timestamp_file, 'r') as f:
data = json.load(f)
for vid, vlist in data.items():
timestamp[vid] = []
for vterm in vlist:
timestamp[vid].append(vterm["timestamp"])
return timestamp
def format_dvc(datas):
fmt_datas = {}
timestamp_count = []
cnt = 0
for i, jterm in enumerate(datas):
vid = jterm["vname"]
caption = jterm["generated_cap"]
timestamps, sents = format_dvc_output(caption)
if len(timestamps) == 0:
cnt += 1
print(vid, caption)
fmt_datas[vid] = []
for j in range(len(timestamps)):
fmt_datas[vid].append({"timestamp": timestamps[j], "caption": sents[j]})
timestamp_count.append(len(timestamps))
print(f"predict avg {sum(timestamp_count) / len(timestamp_count)} events per video")
print(f'parse failed number: {cnt}')
return fmt_datas
def format_tvg(datas):
fmt_datas = {}
cnt = 0
for i, jterm in enumerate(datas):
vid = jterm["vname"]
query = jterm["query"]
gcap = jterm["generated_cap"]
qid = int(jterm["id"])
timestamps = format_tvg_output(gcap)
if len(timestamps) == 0:
cnt += 1
print(vid, query + "\n", gcap, "\n")
fmt_datas[qid] = {"timestamp": timestamps, "query": query, "vid": vid}
print(f'parse failed number: {cnt}')
return fmt_datas
def format_vhd(datas, gts):
vid2gts = {}
for jterm in gts:
vid2gts[jterm["image_id"]] = jterm
fmt_datas = []
cnt = 0
for i, jterm in enumerate(datas):
vid = jterm["vname"]
query = jterm["query"]
gcap = jterm["generated_cap"]
qid = jterm["id"]
highlights, clipscores = format_vhd_output(gcap, vid2gts[vid])
if len(highlights) == 0:
cnt += 1
print(vid, query + "\n", gcap + "\n")
# pdb.set_trace()
else:
# print(gcap)
# print(timestamps)
pass
result = {}
result["qid"] = qid
result["query"] = query
result["vid"] = vid
result["pred_saliency_scores"] = clipscores
fmt_datas.append(result)
print(f'parse failed number: {cnt}')
return fmt_datas
def generate(chat, gr_videos, user_messages, num_beams, temperature, top_p, n_frms, chat_states=None, img_lists=None):
N = len(user_messages)
if chat_states is None:
chat_states = []
for i in range(N):
if args.model_type == 'vicuna':
chat_state = default_conversation.copy()
else:
chat_state = conv_llava_llama_2.copy()
chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
chat_states.append(chat_state)
if img_lists is None:
img_lists = [[] for i in range(N)]
llm_message = chat.upload_video_without_audio(gr_videos, chat_states, img_lists, n_frms=n_frms)
for user_message, chat_state in zip(user_messages, chat_states):
chat.ask(user_message, chat_state)
responses = chat.answer(convs=chat_states,
img_lists=img_lists,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
max_new_tokens=512,
max_length=3000)[0]
return responses, chat_states, img_lists
def main(args):
num_beams = 1
temperature = args.temperature
top_p = args.top_p
n_frms = args.num_frames
eval_start_time = time.time()
prompt = read_txt(args.prompt_file)
# load model
device = torch.device(f"cuda:{args.gpu_id}")
args.options = []
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_config.ckpt = args.timechat_model_path
if args.no_lora:
model_config.lora = False
# set after init_distributed_mode() to only log on master.
setup_logger()
cfg.pretty_print()
message = '\n' + '\n'.join([f'{k:<25}: {v}' for k, v in vars(args).items()])
logging.info(message)
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
model.eval()
vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
# load data
video_path = args.video_path
anno_path = args.anno_path
anno_data = load_data(args, anno_path, split=args.split)
if args.timestamp_file != '':
pred_timestamps = get_timestamp_from_file(args.timestamp_file)
vids = []
vnames = []
captions = []
qids = []
if args.sample_num > 0:
# sample part data to evaluate
anno_data = random.sample(anno_data, args.sample_num)
for jterm in anno_data:
vname = jterm["image_id"].split("/")[-1]
vid_path = os.path.join(video_path, vname)
if args.timestamp:
duration = int(jterm["duration"])
if args.timestamp_file == '': # input the gt timestamps
timestamp = jterm["segments"]
else: # input the pred timestamps
timestamp = pred_timestamps[vname]
for (start_time, end_time) in timestamp:
# process anno timestamp error
if start_time >= end_time or end_time > duration or start_time >= duration:
continue
vids.append(vid_path)
vnames.append(vname + "_" + str(start_time) + "_" + str(end_time))
# image_emb, _ = model.encode_img(video)
# img_lists.append([image_emb])
else:
vids.append(vid_path)
vnames.append(vname)
captions.append(jterm["caption"])
qids.append(jterm["id"])
results = []
bz = args.batch_size
# evaluate using batch
epoch = ceil(len(vnames) / bz)
for i in tqdm(range(epoch)):
sid = i * bz
eid = min((i + 1) * bz, len(vnames))
prompts = []
# load video
paths = vids[sid:eid]
image_ids = qids[sid:eid]
for pi in range(len(paths)):
final_prompt = copy.deepcopy(prompt)
if args.asr:
max_num_asr = 15 # only use max to 20 asr
asr_path = os.path.join(args.asr_path, vnames[pi].split('.')[0] + '.txt')
if not os.path.exists(asr_path):
final_asr = 'None.'
else:
with open(asr_path, 'r') as f:
asrs = f.readlines()
final_asr = ''
stride = len(asrs) // max_num_asr
stride = stride if stride > 0 else 1
for idx in range(1, len(asrs), stride):
asr = asrs[idx]
asr = asr.strip()
if not asr.endswith('.'):
asr = asr + '.'
asr = asr.split('\t')
real_timestamp_start, real_timestamp_end, caption = float(asr[0]), float(asr[1]), asr[2]
asr = f"{real_timestamp_start:.1f} - {real_timestamp_end:.1f} seconds, {caption} "
final_asr += asr
if final_asr == '':
final_asr = 'None.'
final_prompt = f'Transcribed speech: {final_asr} Based on the video content and possible transcribed speech, {final_prompt}'
# final_prompt = f'{final_prompt} Transcribed speech: {final_asr}'
if args.task in ["tvg", "vhd"]:
idx = sid + pi
prompts.append(final_prompt.format(args.dataset, captions[idx].strip('.')))
else:
prompts.append(final_prompt)
outputs, chat_states, img_lists = generate(chat, paths, prompts, num_beams, temperature, top_p, n_frms)
if args.post_check:
post_check_prompt = read_txt(args.post_check_prompt_file)
post_check_prompts = [post_check_prompt] * len(paths)
outputs, chat_states, img_lists = generate(chat, paths, post_check_prompts, num_beams, temperature, top_p,
n_frms, chat_states, img_lists)
for j, (output, chat_state) in enumerate(zip(outputs, chat_states)):
if args.task in ["tvg", "vhd"]:
results.append({
"vname": vnames[sid + j],
"generated_cap": output,
"query": captions[sid + j],
"id": qids[sid + j],
"prompt": chat_state.get_prompt()
})
else:
results.append({
"vname": vnames[sid + j],
"generated_cap": output,
"prompt": chat_state.get_prompt()
})
if i < 5:
print(chat_state.get_prompt())
print(results[-1]["generated_cap"])
print('*' * 50)
# with open(output_file, 'a') as f:
# print(json.dumps(results[-1]), file=f, flush=True)
if args.timestamp:
results = merge_seg_caps(results)
# save results
save_result(args, args.output_dir, results, args.split)
# format results to calculate metrics
if args.task == "dvc":
fmt_results = format_dvc(results)
elif args.task == "tvg":
fmt_results = format_tvg(results)
elif args.task == "vhd":
fmt_results = format_vhd(results, anno_data)
else:
print(f"Not support formatting samples for task {args.task}")
# save format results
save_result(args, args.output_dir, fmt_results, args.split, format=True)
total_time = time.time() - eval_start_time
# convert seconds to date
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Evaluate time {}'.format(total_time_str))
with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(cfg.to_dict(), indent=4) + "\n")
f.write(message + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--cfg_path', type=str, default='eval_configs/timechat.yaml')
parser.add_argument('--anno_path', type=str, default='data/YouCook2-BB/YouCook2_asr_denseCap/')
parser.add_argument('--video_path', type=str, default='data/YouCook2-BB/YouCook2_asr_denseCap/youcook2_6fps_224/')
parser.add_argument('--model_type', type=str)
parser.add_argument('--task',
default='dvc') # dvc for dense video captioning; tvg for temporal video grounding; vhd for video highlight detection
parser.add_argument('--dataset', default='youcook')
parser.add_argument('--output_dir', default='debug')
parser.add_argument('--split', default='val')
parser.add_argument('--num_frames', type=int, default=8)
parser.add_argument('--top_p', type=float, default=0.8)
parser.add_argument('--temperature', type=float, default=1)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--gpu_id', default='0')
parser.add_argument('--timestamp', action='store_true', help='input the gt/predicted timestamps to the model')
parser.add_argument('--timestamp_file', type=str, default='', help='the predcited timestamps file')
parser.add_argument('--debug', action='store_true', help='the debug mode will only use 10 data samples')
parser.add_argument('--prompt_file', default='prompts/dvc_description.txt')
parser.add_argument('--timechat_model_path',
default='ckpt/timechat/train_stage2_llama2_7b_time64k_valley72k_bz32_f96_epoch3_open_i_instruct_qformer_lora_bind_time_ws32_mfp96_mtl2048/20231026060/checkpoint_2.pth')
parser.add_argument('--sample_num', type=int, default=-1, help='fast inference by sampling N instances to evaluate')
parser.add_argument('--example_output', action='store_true', help='output the example results')
parser.add_argument('--no_lora', action='store_true')
parser.add_argument('--post_check', action='store_true', help='post check the format of generated captions')
parser.add_argument('--post_check_prompt_file', type=str, default='prompts/dvc_post_check.txt')
parser.add_argument('--asr', action='store_true')
parser.add_argument('--asr_path', type=str,
default='data/YouCook2-BB/YouCook2_asr_denseCap/whisper_outputs_with_time/small.en.cleaned/')
args = parser.parse_args()
main(args)