Skip to content

Commit

Permalink
Feat : 2차 탐색 ai 모델 도입
Browse files Browse the repository at this point in the history
Feat : 2차 탐색 ai 모델 도입
  • Loading branch information
begong313 authored May 15, 2024
2 parents 6279410 + cc60f17 commit 924ee67
Show file tree
Hide file tree
Showing 27 changed files with 429 additions and 165 deletions.
5 changes: 3 additions & 2 deletions ai/TextReID/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def main():
detect(args)

def findByText(root="./", config_file="configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml", checkpoint_file="output/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048/best.pth",
local_rank=0, opts=[], load_result=False, search_num=0,query="",data_dir = "./datasets/", save_folder = "./output/output.json"):
local_rank=0, opts=[], load_result=False, search_num=0,query="",data_dir = "./datasets/", save_folder = "./output/output.json", result_num = 10):
# 매개변수를 Namespace 객체로 묶기
args = Namespace(
root=root,
Expand All @@ -100,7 +100,8 @@ def findByText(root="./", config_file="configs/cuhkpedes/moco_gru_cliprn50_ls_bs
search_num=search_num,
query = query,
data_dir = data_dir,
save_folder = save_folder
save_folder = save_folder,
top_k = result_num
)
detect(args)

Expand Down
87 changes: 80 additions & 7 deletions ai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
sys.path.append(str(Path(__file__).parent)+"/yolov5_crowdhuman")
sys.path.append(str(Path(__file__).parent)+"/yolov8")
sys.path.append(str(Path(__file__).parent)+"/TextReID")
sys.path.append(str(Path(__file__).parent)+"/imageSearch")

from TextReID.test_net import findByText
from yolov5_crowdhuman.detect import run_detection
from yolov8.run import run_Yolo
from imageSearch.imgSearch_server import run_Image_to_Image

app = FastAPI(port = 8080)

Expand All @@ -46,17 +48,31 @@ class TotalInput(BaseModel):
step : str
query : str



class DetectResult(BaseModel):
searchId : int
missingPeopleId : int
query : str
data : list

class SecondInput(BaseModel):
missingPeopleId : int
firstSearchId : int
secondSearchId : int
topK:int
queryImagePath : List[str]

class SecondDetectResult(BaseModel):
data : list
secondSearchId : int
# @app.get('/')
# async def root():
# run_Yolo([CCTVInfo(id=1,longitude=1,latitude=1)],'/home/jongbin/Desktop/test/2-results','2021-09-01T000000')
# @app.post("/test")
# async def test(input : TotalInput):

@app.post('/run', response_model=DetectResult)
async def test(input :TotalInput):
print(input)
async def firstDetection(input :TotalInput):
if input.query == none:
raise HTTPException(status_code=400, detail="Query cannot be None")
yolo_save_path = f"/home/jongbin/Desktop/yolo/{input.searchId}" #경로는 각자 환경에 맞게 조장하시오
run_Yolo(input.cctvId,yolo_save_path,input.startTime) #todo start time 따라 input다르게 만들기
result_dir = await runTextReID(input, yolo_save_path) #text-re-id돌리고 결과 json파일 받아오기
Expand All @@ -65,15 +81,40 @@ async def test(input :TotalInput):
with open(result_json_dir, 'r') as file:
result = json.load(file)

return DetectResult(searchId= input.searchId, missingPeopleId= input.missingPeopleId,query = input.query, data = result[1:])
return DetectResult(searchId= input.searchId, missingPeopleId= input.missingPeopleId, data = result[1:])
#2차탐색
@app.post("/second", response_model=DetectResult)
async def secondDetection(input:SecondInput):
print(input)
img_download_url = "/home/jongbin/Desktop/imgDown"
data_path = f"/home/jongbin/Desktop/yolo/{input.firstSearchId}"
result = []
for img_path in input.queryImagePath:
#img_path는 s3주소 특정폴더에 다운로드하고 경로가져오기
local_image_path = os.path.join(img_download_url, os.path.basename(img_path))
download_image_from_s3(img_path, local_image_path)
print("dfasdfsad",local_image_path)
data_to_save = run_Image_to_Image(data_path,10, local_image_path)
for output in data_to_save['output']:
#해당 주소에 있는 이미지 s3업로드하고 이미지 주소 result에 넣기
local_output_path = output['output_dir']
s3_key = f"missingPeopleId={input.missingPeopleId}/searchHistoryId={input.secondSearchId}/step=second/{local_output_path}"
s3_key = s3_key.replace(' ', '-').replace(':', '').replace('/', '+')
s3_url = upload_image_to_s3(local_output_path, s3_key)
a = {"img_path" : s3_url, "cctvId" : 1, "similarity" :0 }
result.append(a)

print(result)
return DetectResult(searchId= input.secondSearchId, missingPeopleId= 1, data = result)

async def runTextReID(input : TotalInput, yolo_save_path:str):
root_path = os.getcwd() + "/TextReID"
print(root_path)
## 저장경로 지정
home_path = os.path.expanduser("~")
result_dir = os.path.join(home_path, "Desktop", "result", str(input.searchId) ,"output.json")
findByText(root_path, search_num=input.searchId, query = input.query, data_dir = yolo_save_path, save_folder = result_dir)
# findByText(root_path, search_num=input.searchId, query = input.query, data_dir = yolo_save_path, save_folder = result_dir)
findByText(root_path, search_num=input.searchId, query = "a man wearing a white shirt and black long pants. he has short hair.", data_dir = yolo_save_path, save_folder = result_dir)
return result_dir


Expand Down Expand Up @@ -119,3 +160,35 @@ async def uploadS3(json_file_path:str, missingPeopleId:int, searchId:int, step:s

return updated_json_path

def download_image_from_s3(s3_url, download_path):
"""
S3 URL로부터 이미지를 다운로드하여 지정된 경로에 저장합니다.
"""
# S3 URL을 분석하여 버킷 이름과 키를 추출
s3_bucket, s3_key = parse_s3_url(s3_url)
s3_client.download_file(s3_bucket, s3_key, download_path)

def parse_s3_url(s3_url):
"""
S3 URL에서 버킷 이름과 키를 추출합니다.
"""
if not s3_url.startswith("https://"):
raise ValueError("S3 URL은 'https://'로 시작해야 합니다.")

# URL의 앞부분 제거
s3_url = s3_url.replace("https://", "")

# 버킷 이름 추출
s3_bucket = s3_url.split(".s3.amazonaws.com")[0]

# 객체 키 추출
s3_key = s3_url.split(".s3.amazonaws.com/")[1]

return s3_bucket, s3_key

def upload_image_to_s3(local_path, s3_key):
"""
로컬 경로의 이미지를 S3 버킷의 지정된 키에 업로드합니다.
"""
s3_client.upload_file(local_path, bucket_name, s3_key)
return f"https://{bucket_name}.s3.amazonaws.com/{s3_key}"
Empty file added ai/imageSearch/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion ai/imageSearch/imgSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# ===== 사전학습 모델 세팅 및 특징 추출기 세팅
start_time = time.time()

BATCH = 16
BATCH = 8
model_ckpt = "microsoft/swinv2-large-patch4-window12to16-192to256-22kto1k-ft"
# model_ckpt = "microsoft/swinv2-tiny-patch4-window16-256"
processor = ViTImageProcessor.from_pretrained(model_ckpt)
Expand Down
119 changes: 63 additions & 56 deletions ai/imageSearch/imgSearch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,20 @@
from transformers import ViTImageProcessor, AutoModel
import argparse
import numpy as np

from datasets import load_dataset, DatasetDict

# ===== 입력 옵션 파싱
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='YOLO 검출 결과 이미지가 저장된 폴더 경로')
parser.add_argument('--topk', type=int, default=10, help='유사도 상위 #개 출력')
parser.add_argument('--query', type=str, help='query img 경로')

args = parser.parse_args()

# parser = argparse.ArgumentParser()
# parser.add_argument('--data_path', type=str, help='YOLO 검출 결과 이미지가 저장된 폴더 경로')
# parser.add_argument('--topk', type=int, default=10, help='유사도 상위 #개 출력')
# parser.add_argument('--query', type=str, help='query img 경로')

# ===== 사전학습 모델 세팅 및 특징 추출기 세팅
start_time = time.time()

BATCH = 16
# args = parser.parse_args()
model_ckpt = "microsoft/swinv2-large-patch4-window12to16-192to256-22kto1k-ft"
# model_ckpt = "microsoft/swinv2-tiny-patch4-window16-256"
processor = ViTImageProcessor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)


# ===== 데이터셋 로드 및 경로 추가
def load_image_paths(data_dir):
image_paths = []
for root, dirs, files in os.walk(data_dir):
for file in files:
if file.lower().endswith('.jpg'):
full_path = os.path.join(root, file)
image_paths.append(full_path)
return image_paths

dataset = load_dataset("imagefolder", data_dir=args.data_path)
image_paths = load_image_paths(args.data_path)

# 각 데이터셋 구성 요소에 대해 이미지 경로 추가
for split in dataset.keys():
dataset[split] = dataset[split].add_column("image_path", image_paths[:len(dataset[split])])


# ===== 임베딩 추출 (img > feature > embeddings)
def extract_embeddings(example):
if isinstance(example['image'], str):
Expand All @@ -59,51 +34,83 @@ def extract_embeddings(example):

return {'embeddings': features.squeeze()}

def run_Image_to_Image(data_path:str,topk:int,query:str):
start_time = time.time()

BATCH = 8
# model_ckpt = "microsoft/swinv2-large-patch4-window12to16-192to256-22kto1k-ft"
# # model_ckpt = "microsoft/swinv2-tiny-patch4-window16-256"
# processor = ViTImageProcessor.from_pretrained(model_ckpt)
# model = AutoModel.from_pretrained(model_ckpt)

# ===== 데이터셋에 임베딩 추출 함수 적용
dataset = dataset['train'].map(
extract_embeddings, batched=True, batch_size=BATCH
)

# ===== 데이터셋 로드 및 경로 추가
dataset = load_dataset("imagefolder", data_dir=data_path)
image_paths = load_image_paths(data_path)
# 각 데이터셋 구성 요소에 대해 이미지 경로 추가
for split in dataset.keys():
dataset[split] = dataset[split].add_column("image_path", image_paths[:len(dataset[split])])
print(split)
# ===== 데이터셋에 임베딩 추출 함수 적용
dataset = dataset['train'].map(
extract_embeddings, batched=True, batch_size=BATCH
)


# ===== 임베딩을 기반으로 Faiss 인덱스 생성
dataset.add_faiss_index(column='embeddings')
# ===== 결과 저장
scores, retrieved_examples = get_neighbors(query, dataset, topk)



sec = time.time() - start_time
times = str(datetime.timedelta(seconds=sec))
short = times.split(".")[0]
print(f"The 2nd search has ended. \nThe time required: {short} sec\n")
return save_results(query, scores, retrieved_examples)

# ===== 사전학습 모델 세팅 및 특징 추출기 세팅
def load_image_paths(data_dir):
image_paths = []
for root, dirs, files in os.walk(data_dir):
for file in files:
if file.lower().endswith('.jpg'):
full_path = os.path.join(root, file)
image_paths.append(full_path)
return image_paths


# ===== 임베딩을 기반으로 Faiss 인덱스 생성
dataset.add_faiss_index(column='embeddings')


# ===== 유사 이미지 검색
def get_neighbors(query_image_path, top_k=10):
def get_neighbors(query_image_path, dataset, top_k=10):
query_image = Image.open(query_image_path)
inputs = processor(images=query_image, return_tensors="pt")
outputs = model(**inputs)
qi_embedding = outputs.last_hidden_state[:, 0].detach().numpy().squeeze()
scores, retrieved_examples = dataset.get_nearest_examples('embeddings', qi_embedding, k=top_k)
return scores, retrieved_examples


# ===== 결과 저장을 위한 JSON 형식 정의
def default(o):
if isinstance(o, np.float32):
return float(o)

def save_results(query_image_path, scores, retrieved_examples, sorted_indices):
def save_results(query_image_path, scores, retrieved_examples):
data_to_save = {
"query_dir": query_image_path,
"output": [
{"output_dir": retrieved_examples['image_path'][i], "score": scores[i]}
for i in sorted_indices
{"output_dir": retrieved_examples['image_path'][i], "score": float(scores[i])}
for i in range(len(scores))
]
}
with open('results.json', 'w', encoding='utf-8') as f:
json.dump(data_to_save, f, ensure_ascii=False, indent=4, default=default)


# ===== 결과 저장
scores, retrieved_examples = get_neighbors(args.query, args.topk)
sorted_indices = np.argsort(scores)[::-1]

save_results(args.query, scores, retrieved_examples, sorted_indices)

sec = time.time() - start_time
times = str(datetime.timedelta(seconds=sec))
short = times.split(".")[0]
print(f"The 2nd search has ended. \nThe time required: {short} sec\n")
return data_to_save

# ===== 이미지 그리드 생성
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
3 changes: 3 additions & 0 deletions ai/imageSearch/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
python3 imgSearch.py --data_path /home/jongbin/Desktop/yolo/164 --topk 20 --query /home/jongbin/Desktop/yolo/164/1_2024-05-14_16-07-56_155.jpg

/home/jongbin/Desktop/yolo/164
1 change: 1 addition & 0 deletions ai/yolov5_crowdhuman/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def detect(opt,save_img=False):
date_obj += nfps * timedelta(seconds=10)
else:
date_obj += nfps * timedelta(seconds=2)
opt.cctv_id = "1"
img_name = opt.cctv_id+"_"+date_obj.strftime('%Y-%m-%d_%H-%M-%S ')

if len(det):
Expand Down
Loading

0 comments on commit 924ee67

Please sign in to comment.