From ca2f3d395a24c730edb5857178449fe24c497472 Mon Sep 17 00:00:00 2001 From: mirageciel Date: Wed, 3 Apr 2024 15:04:38 +0900 Subject: [PATCH] Chore: to get just one text input from user --- ai/TextReID/lib/data/metrics/evaluation.py | 35 +++++++++++----------- ai/TextReID/lib/engine/inference.py | 10 +++---- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/ai/TextReID/lib/data/metrics/evaluation.py b/ai/TextReID/lib/data/metrics/evaluation.py index f9c80a4156..6e2eedfc9f 100644 --- a/ai/TextReID/lib/data/metrics/evaluation.py +++ b/ai/TextReID/lib/data/metrics/evaluation.py @@ -111,7 +111,9 @@ def evaluation( image_ids.append(image_id) pids.append(pid) image_global.append(prediction[0]) - if len(prediction) == 2: # text query를 하나만 넣었으므로, text emgedding이 없는 부분이 있을 것임 + if len(prediction) == 2: + # text query를 하나만 넣었으므로, text emgedding은 배치의 제일 처음 이미지에만 들어감 + # 왜냐하면 유사도 검사 시 배치 별로 검사를 했으니까 text_global.append(prediction[1]) pids = list(map(int, pids)) @@ -131,20 +133,19 @@ def evaluation( writer = SummaryWriter() # top 10 results 반환 - for i in range(4): - sorted_indices = torch.argsort(similarity[i], descending=True) - sorted_values = similarity[i][sorted_indices] - top_k = 10 - images = [] - similarities = "" - print(cap[i]) - for index, value in zip(sorted_indices[:top_k], sorted_values[:top_k]): - image_id, pid = dataset.get_id_info(idx) - img, caption, idx, query = dataset.__getitem__(index) - images.append(img) - print(f"Index: {index}, Similarity: {value}, pid: {pid}") - similarities += str(value) + "\t" - grid_img = make_grid(images, nrow=10) - writer.add_image(f"Image Grid for Query {i}", grid_img) - writer.add_text(f"Captions for Query {i}", cap[i]) + sorted_indices = torch.argsort(similarity[0], descending=True) + sorted_values = similarity[0][sorted_indices] + top_k = 10 + images = [] + # similarities = "" + print(cap[0]) + for index, value in zip(sorted_indices[:top_k], sorted_values[:top_k]): + image_id, pid = dataset.get_id_info(idx) + img, caption, idx, query = dataset.__getitem__(index) + images.append(img) + print(f"Index: {index}, Similarity: {value}, pid: {pid}") + # similarities += str(value) + "\t" + grid_img = make_grid(images, nrow=10) + writer.add_image("Query <%s>"%cap[0], grid_img) + # writer.add_text(f"Captions for Query {i}", cap[i]) writer.close() \ No newline at end of file diff --git a/ai/TextReID/lib/engine/inference.py b/ai/TextReID/lib/engine/inference.py index e0930c5bca..8b09f5d3fd 100644 --- a/ai/TextReID/lib/engine/inference.py +++ b/ai/TextReID/lib/engine/inference.py @@ -18,16 +18,16 @@ def compute_on_dataset(model, data_loader, cap, device): model.eval() results_dict = defaultdict(list) + caption = input("\nText Query Input: ") + cap.append(caption) + caption = encode(caption) + caption = Caption([torch.tensor(caption)]) for batch in tqdm(data_loader): images, captions, image_ids = batch images = images.to(device) # captions = [captions[0].to(device)] # 첫 번째 캡션만 사용 - caption = input("\nText Query Input: ") - cap.append(caption) - caption = encode(caption) - caption = torch.tensor(caption) - captions = [Caption([caption]).to(device)] # captions = [caption.to(device) for caption in captions] + captions = [caption.to(device)] with torch.no_grad(): output = model(images, captions) for result in output: