Skip to content

Commit

Permalink
Chore: to get just one text input from user
Browse files Browse the repository at this point in the history
  • Loading branch information
chaews0327 committed Apr 3, 2024
1 parent ef0a67f commit ca2f3d3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
35 changes: 18 additions & 17 deletions ai/TextReID/lib/data/metrics/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def evaluation(
image_ids.append(image_id)
pids.append(pid)
image_global.append(prediction[0])
if len(prediction) == 2: # text query를 하나만 넣었으므로, text emgedding이 없는 부분이 있을 것임
if len(prediction) == 2:
# text query를 하나만 넣었으므로, text emgedding은 배치의 제일 처음 이미지에만 들어감
# 왜냐하면 유사도 검사 시 배치 별로 검사를 했으니까
text_global.append(prediction[1])

pids = list(map(int, pids))
Expand All @@ -131,20 +133,19 @@ def evaluation(

writer = SummaryWriter()
# top 10 results 반환
for i in range(4):
sorted_indices = torch.argsort(similarity[i], descending=True)
sorted_values = similarity[i][sorted_indices]
top_k = 10
images = []
similarities = ""
print(cap[i])
for index, value in zip(sorted_indices[:top_k], sorted_values[:top_k]):
image_id, pid = dataset.get_id_info(idx)
img, caption, idx, query = dataset.__getitem__(index)
images.append(img)
print(f"Index: {index}, Similarity: {value}, pid: {pid}")
similarities += str(value) + "\t"
grid_img = make_grid(images, nrow=10)
writer.add_image(f"Image Grid for Query {i}", grid_img)
writer.add_text(f"Captions for Query {i}", cap[i])
sorted_indices = torch.argsort(similarity[0], descending=True)
sorted_values = similarity[0][sorted_indices]
top_k = 10
images = []
# similarities = ""
print(cap[0])
for index, value in zip(sorted_indices[:top_k], sorted_values[:top_k]):
image_id, pid = dataset.get_id_info(idx)
img, caption, idx, query = dataset.__getitem__(index)
images.append(img)
print(f"Index: {index}, Similarity: {value}, pid: {pid}")
# similarities += str(value) + "\t"
grid_img = make_grid(images, nrow=10)
writer.add_image("Query <%s>"%cap[0], grid_img)
# writer.add_text(f"Captions for Query {i}", cap[i])
writer.close()
10 changes: 5 additions & 5 deletions ai/TextReID/lib/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
def compute_on_dataset(model, data_loader, cap, device):
model.eval()
results_dict = defaultdict(list)
caption = input("\nText Query Input: ")
cap.append(caption)
caption = encode(caption)
caption = Caption([torch.tensor(caption)])
for batch in tqdm(data_loader):
images, captions, image_ids = batch
images = images.to(device)
# captions = [captions[0].to(device)] # 첫 번째 캡션만 사용
caption = input("\nText Query Input: ")
cap.append(caption)
caption = encode(caption)
caption = torch.tensor(caption)
captions = [Caption([caption]).to(device)]
# captions = [caption.to(device) for caption in captions]
captions = [caption.to(device)]
with torch.no_grad():
output = model(images, captions)
for result in output:
Expand Down

0 comments on commit ca2f3d3

Please sign in to comment.