diff --git a/comps/dataprep/multimodal_utils.py b/comps/dataprep/multimodal_utils.py new file mode 100644 index 000000000..900a904a6 --- /dev/null +++ b/comps/dataprep/multimodal_utils.py @@ -0,0 +1,469 @@ +import base64 +import json +import os +from pathlib import Path +import requests +from typing import List, Optional, Tuple, Union, Iterator, Any +import uuid + +import cv2 +import torch +import torch.nn.functional as F +import webvtt +import whisper +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra +from moviepy.editor import VideoFileClip +from torch import nn +from torchvision.io import ImageReadMode, read_image +import torchvision.transforms.functional as transform +from transformers import BridgeTowerProcessor, BridgeTowerPreTrainedModel, BridgeTowerModel +from transformers.modeling_outputs import SequenceClassifierOutput +from transformers.models.bridgetower.modeling_bridgetower import BridgeTowerTextModel + + +class BridgeTowerITCHead(nn.Module): + def __init__(self, hidden_size, embed_size): + super().__init__() + self.fc = nn.Linear(hidden_size, embed_size) + + def forward(self, x): + x = self.fc(x) + return x + + +class _BridgeTowerTextModelWrapper(nn.Module): + def __init__(self, config): + super().__init__() + self.text_model = BridgeTowerTextModel(config) + + def forward(self, **kwargs): + return self.text_model(**kwargs) + + +class BridgeTowerTextFeatureExtractor(BridgeTowerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bridgetower = _BridgeTowerTextModelWrapper(config.text_config) + self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ): + + outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + final_hidden_cls = outputs.hidden_states[-1][:,0,:] + final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2) + + return final_hidden_cls + + +class BridgeTowerForITC(BridgeTowerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bridgetower = BridgeTowerModel(config) + + self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) + self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]: + + assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC' + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bridgetower( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + pixel_values=pixel_values, + pixel_mask=pixel_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + image_embeds=image_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooler_output = outputs.pooler_output if return_dict else outputs[2] + + hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states + + final_hidden_txt = hidden_states_txt[-1] + final_hidden_img = hidden_states_img[-1] + + image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(final_hidden_img) + image_token_type_embeddings = self.bridgetower.token_type_embeddings( + torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device) + ).expand_as(image_embeds_with_ln) + + final_hidden_img = ( + self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + + image_token_type_embeddings + ) + + final_hidden_txt = F.normalize(self.itc_text_head(final_hidden_txt[:,0,:]), dim=-1, p=2) + final_hidden_img = F.normalize(self.itc_image_head(final_hidden_img[:,0,:]), dim=-1, p=2) + final_hidden_cross = F.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) + + logits = torch.stack([final_hidden_txt, final_hidden_img, final_hidden_cross], dim=-2) + + if not return_dict: + return tuple(logits) + + return SequenceClassifierOutput( + loss=None, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BridgeTowerEmbeddings(BaseModel, Embeddings): + """ BridgeTower embedding model """ + model_name: str = "BridgeTower/bridgetower-large-itm-mlm-itc" + device: str = "cpu" + text_model : Any + processor: Any + model: Any + + def __init__(self, **kwargs: Any): + """Initialize the BridgeTowerEmbeddings class""" + super().__init__(**kwargs) + self.text_model = BridgeTowerTextFeatureExtractor.from_pretrained(self.model_name).to(self.device) + self.processor = BridgeTowerProcessor.from_pretrained(self.model_name) + self.model = BridgeTowerForITC.from_pretrained(self.model_name).to(self.device) + + class Config: + """Configuration for this pydantic object.""" + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents using BridgeTower. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + encodings = self.processor.tokenizer(texts, return_tensors="pt").to(self.device) + with torch.no_grad(): + outputs = self.text_model(**encodings) + embeddings = outputs.cpu().numpy().tolist() + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed a query using BridgeTower. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] + + def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]: + """Embed a list of image-text pairs using BridgeTower. + + Args: + texts: The list of texts to embed. + images: The list of path-to-images to embed + batch_size: the batch size to process, default to 2 + Returns: + List of embeddings, one for each image-text pairs. + """ + + # the length of texts must be equal to the length of images + assert len(texts)==len(images), "the len of captions should be equal to the len of images" + + image_list = [] + text_list = [] + embeddings = [] + for path_to_img, text in zip(images, texts): + img = read_image(path_to_img, mode=ImageReadMode.RGB) + img = transform.to_pil_image(img) + image_list.append(img) + text_list.append(text) + if len(text_list) == batch_size: + batch = self.processor(image_list, text_list, return_tensors='pt', max_length=100, padding='max_length', truncation=True).to(self.device) + with torch.no_grad(): + batch_embeddings = self.model(**batch, output_hidden_states=True) + for i in range(len(text_list)): + embeddings.append(batch_embeddings.logits[i,2,:].detach().cpu().numpy().tolist()) + image_list = [] + text_list = [] + # embedding the remaining + if len(text_list) > 0: + batch = self.processor(image_list, text_list, return_tensors='pt', max_length=100, padding='max_length', truncation=True).to(self.device) + with torch.no_grad(): + batch_embeddings = self.model(**batch, output_hidden_states=True) + for i in range(len(text_list)): + embeddings.append(batch_embeddings.logits[i,2,:].detach().cpu().numpy().tolist()) + image_list = [] + text_list = [] + return embeddings + + +def create_upload_folder(upload_path): + """Create a directory to store uploaded video data""" + if not os.path.exists(upload_path): + Path(upload_path).mkdir(parents=True, exist_ok=True) + + +def load_json_file(file_path): + """Read contents of json file""" + with open(file_path, 'r') as file: + data = json.load(file) + return data + + +def clear_upload_folder(upload_path): + """Clear the upload directory""" + for root, dirs, files in os.walk(upload_path, topdown=False): + for file in files: + file_path = os.path.join(root, file) + os.remove(file_path) + for dir in dirs: + dir_path = os.path.join(root, dir) + os.rmdir(dir_path) + + +def generate_video_id(): + """Generates a unique identifier for a video file""" + return str(uuid.uuid4()) + + +def convert_video_to_audio(video_path: str, output_audio_path: str): + """Converts video to audio using MoviePy library that uses `ffmpeg` under the hood. + + :param video_path: file path of video file (.mp4) + :param output_audio_path: file path of audio file (.wav) to be created + """ + video_clip = VideoFileClip(video_path) + audio_clip = video_clip.audio + audio_clip.write_audiofile(output_audio_path) + video_clip.close() + audio_clip.close() + + +def load_whisper_model(model_name: str = "base"): + """Load a whisper model for generating video transcripts""" + return whisper.load_model(model_name) + + +def extract_transcript_from_audio(whisper_model, audio_path: str): + """Generate transcript from audio file + + :param whisper_model: a pre-loaded whisper model object + :param audio_path: file path of audio file (.wav) + """ + options = dict(task="translate", best_of=5, language='en') + return whisper_model.transcribe(audio_path, **options) + + +def format_timestamp_for_transcript(seconds: float, always_include_hours: bool = True, fractionalSeperator: str = '.'): + """Format timestamp for video transcripts""" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}" + + +def write_vtt(transcript: Iterator[dict], vtt_path: str): + """Write transcripts to a .vtt file""" + with open(vtt_path, 'a') as file: + file.write("WEBVTT\n\n") + for segment in transcript['segments']: + text = (segment['text']).replace('-->', '->') + file.write(f"{format_timestamp_for_transcript(segment['start'])} --> {format_timestamp_for_transcript(segment['end'])}\n") + file.write(f"{text.strip()}\n\n") + + +def delete_audio_file(audio_path: str): + """Delete audio file after extracting transcript""" + os.remove(audio_path) + + +def time_to_frame(time: float, fps: float): + """Convert time in seconds into frame number""" + return int(time * fps - 1) + + +def str2time(strtime: str): + """Get time in seconds from string""" + strtime = strtime.strip('"') + hrs, mins, seconds = [float(c) for c in strtime.split(':')] + + total_seconds = hrs * 60**2 + mins * 60 + seconds + + return total_seconds + + +def convert_img_to_base64(image): + "Convert image to base64 string" + _, buffer = cv2.imencode('.jpg', image) + encoded_string = base64.b64encode(buffer) + return encoded_string.decode() + + +def extract_frames_and_annotations_from_transcripts(video_id: str, video_path: str, vtt_path: str, output_dir: str): + """Extract frames (.jpg) and annotations (.json) from video file (.mp4) and captions file (.vtt)""" + # Set up location to store frames and annotations + os.makedirs(output_dir, exist_ok=True) + os.makedirs(os.path.join(output_dir, 'frames'), exist_ok=True) + + # Load video and get fps + vidcap = cv2.VideoCapture(video_path) + fps = vidcap.get(cv2.CAP_PROP_FPS) + + # read captions file + captions = webvtt.read(vtt_path) + + annotations = [] + for idx, caption in enumerate(captions): + start_time = str2time(caption.start) + end_time = str2time(caption.end) + + mid_time = (end_time + start_time) / 2 + text = caption.text.replace('\n', ' ') + + frame_no = time_to_frame(mid_time, fps) + mid_time_ms = mid_time * 1000 + vidcap.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms) + success, frame = vidcap.read() + + if success: + # Save frame for further processing + img_fname = f"frame_{idx}" + img_fpath = os.path.join(output_dir, 'frames', img_fname + '.jpg') + cv2.imwrite(img_fpath, frame) + + # Convert image to base64 encoded string + b64_img_str = convert_img_to_base64(frame) + + # Create annotations for frame from transcripts + annotations.append({ + 'video_id': video_id, + 'video_name' : os.path.basename(video_path), + 'b64_img_str': b64_img_str, + 'caption': text, + 'time': mid_time_ms, + 'frame_no': frame_no, + 'sub_video_id': idx, + }) + + # Save transcript annotations as json file for further processing + with open(os.path.join(output_dir, 'annotations.json'), 'w') as f: + json.dump(annotations, f) + + vidcap.release() + + +def use_lvm(endpoint: str, img_b64_string: str, prompt: str ="Provide a short description for this scene."): + """Generate image captions/descriptions using LVM microservice""" + inputs = {"image": img_b64_string, "prompt": prompt, "max_new_tokens": 32} + response = requests.post(url=endpoint, data=json.dumps(inputs)) + return response.json()["text"] + + +def extract_frames_and_generate_captions(video_id: str, video_path: str, lvm_endpoint: str, output_dir: str, key_frame_per_second: int = 1): + """Extract frames (.jpg) and annotations (.json) from video file (.mp4) by generating captions using LVM microservice""" + # Set up location to store frames and annotations + os.makedirs(output_dir, exist_ok=True) + os.makedirs(os.path.join(output_dir, 'frames'), exist_ok=True) + + # Load video and get fps + vidcap = cv2.VideoCapture(video_path) + fps = vidcap.get(cv2.CAP_PROP_FPS) + + annotations = [] + hop = round(fps / key_frame_per_second) + curr_frame = 0 + idx = -1 + + while True: + ret, frame = vidcap.read() + if not ret: + break + + if curr_frame % hop == 0: + idx += 1 + + mid_time = vidcap.get(cv2.CAP_PROP_POS_MSEC) + mid_time_ms = mid_time * 1000 + + frame_no = curr_frame + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Save frame for further processing + img_fname = f"frame_{idx}" + img_fpath = os.path.join(output_dir, 'frames', img_fname + '.jpg') + cv2.imwrite(img_fpath, frame) + + # Convert image to base64 encoded string + b64_img_str = convert_img_to_base64(frame) + + # Caption generation using LVM microservice + caption = use_lvm(lvm_endpoint, b64_img_str) + caption = caption.strip() + text = caption.replace('\n', ' ') + + + # Create annotations for frame from transcripts + annotations.append({ + 'video_id': video_id, + 'video_name' : os.path.basename(video_path), + 'b64_img_str': b64_img_str, + 'caption': text, + 'time': mid_time_ms, + 'frame_no': frame_no, + 'sub_video_id': idx, + }) + + curr_frame += 1 + + # Save caption annotations as json file for further processing + with open(os.path.join(output_dir, 'annotations.json'), 'w') as f: + json.dump(annotations, f) + + vidcap.release() \ No newline at end of file diff --git a/comps/dataprep/redis/README.md b/comps/dataprep/redis/README.md index 4617dfa25..5e7f0c1cf 100644 --- a/comps/dataprep/redis/README.md +++ b/comps/dataprep/redis/README.md @@ -4,7 +4,9 @@ For dataprep microservice, we provide two frameworks: `Langchain` and `LlamaInde We organized these two folders in the same way, so you can use either framework for dataprep microservice with the following constructions. -## 🚀1. Start Microservice with Python(Option 1) +Instructions for multimodal data preparation can be found in the `multimodal_langchain` directory. + +# 🚀1. Start Microservice with Python(Option 1) ### 1.1 Install Requirements diff --git a/comps/dataprep/redis/multimodal_langchain/README.md b/comps/dataprep/redis/multimodal_langchain/README.md new file mode 100644 index 000000000..02d017e36 --- /dev/null +++ b/comps/dataprep/redis/multimodal_langchain/README.md @@ -0,0 +1,190 @@ +# Dataprep Microservice for Multimodal Data with Redis + +This dataprep microservice accepts videos (mp4 files) from the user and ingests data into Redis vectorstore with the help of transcripts and captions. + +For videos without audio or recognizable speech, LVM is used to generate captions for video frames. To leverage LVM, please refer to this [readme](../../../lvms/README.md) to start the LVM microservice first before starting this microservice. + +# 🚀1. Start Microservice with Python(Option 1) + +## 1.1 Install Requirements + +```bash +apt update +apt install default-jre + +# Install ffmpeg static build +wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz +mkdir ffmpeg-git-amd64-static +tar -xvf ffmpeg-git-amd64-static.tar.xz -C ffmpeg-git-amd64-static --strip-components 1 +export PATH=$(pwd)/ffmpeg-git-amd64-static:$PATH +cp $(pwd)/ffmpeg-git-amd64-static/ffmpeg /usr/local/bin/ + +pip install -r requirements.txt +``` + +## 1.2 Start Redis Stack Server + +Please refer to this [readme](../../../vectorstores/langchain/redis/README.md). + +## 1.3 Setup Environment Variables + +```bash +export REDIS_URL="redis://${your_ip}:6379" +export INDEX_NAME=${your_index_name} +export PYTHONPATH=${path_to_comps} +``` + +## 1.4 Start LVM Microservice + +Please refer to this [readme](../../../lvms/README.md) to start the LVM microservice. + +After LVM is up, set up environment variables. + +```bash +export LVM_ENDPOINT="http://localhost:9399/v1/lvm" +``` + +## 1.5 Start Document Preparation Microservice for Redis with Python Script + +Start document preparation microservice for Redis with below command. + +```bash +python prepare_videodoc_redis.py +``` + +# 🚀2. Start Microservice with Docker (Option 2) + +## 2.1 Start Redis Stack Server + +Please refer to this [readme](../../../vectorstores/langchain/redis/README.md). + +## 2.2 Setup Environment Variables + +```bash +export EMBEDDING_MODEL_ID="BridgeTower/bridgetower-large-itm-mlm-itc" +export LVM_ENDPOINT="http://${your_ip}:9399/v1/lvm" +export REDIS_URL="redis://${your_ip}:6379" +export INDEX_NAME=${your_index_name} +export HUGGINGFACEHUB_API_TOKEN=${your_hf_api_token} +``` + +## 2.3 Build Docker Image + +```bash +cd ../../../../../ +docker build -t opea/dataprep-redis:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/dataprep/redis/multimodal_langchain/docker/Dockerfile . +``` + +## 2.4 Run Docker with CLI (Option A) + +```bash +docker run -d --name="dataprep-redis-server" -p 6007:6007 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e REDIS_URL=$REDIS_URL -e INDEX_NAME=$INDEX_NAME -e LVM_ENDPOINT=$LVM_ENDPOINT -e HUGGINGFACEHUB_API_TOKEN=$HUGGINGFACEHUB_API_TOKEN opea/dataprep-redis:latest +``` + +## 2.5 Run with Docker Compose (Option B - deprecated, will move to genAIExample in future) + +```bash +cd comps/dataprep/redis/multimodal_langchain/docker +docker compose -f docker-compose-dataprep-redis.yaml up -d +``` + +# 🚀3. Status Microservice + +```bash +docker container logs -f dataprep-redis-server +``` + +# 🚀4. Consume Microservice + +## 4.1 Consume videos_with_transcripts API + +Once document preparation microservice for Redis is started, user can use below command to invoke the microservice to convert videos and their transcripts to embeddings and save to the database. + +Make sure the file path after `files=@` is correct. + +- Single video-transcript pair upload + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./video1.mp4" \ + -F "files=@./video1.vtt" \ + http://dataprep-redis-service:6007/v1/dataprep/videos_with_transcripts +``` + +- Multiple video-transcript pair upload +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./video1.mp4" \ + -F "files=@./video1.vtt" \ + -F "files=@./video2.mp4" \ + -F "files=@./video2.vtt" \ + http://dataprep-redis-service:6007/v1/dataprep/videos_with_transcripts +``` + +## 4.2 Consume generate_transcripts API + +If transcripts are not available for videos, transcripts will be extracted from them. The user can use below command to invoke the microservice to convert videos and their extracted transcripts to embeddings and save to the database. + +- Single video upload + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./video1.mp4" \ + http://dataprep-redis-service:6007/v1/dataprep/generate_transcripts +``` + +- Multiple video upload + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./video1.mp4" \ + -F "files=@./video2.mp4" \ + http://dataprep-redis-service:6007/v1/dataprep/generate_transcripts +``` + +## 4.3 Consume generate_captions API + +If uploaded videos lack audio or recognizable speech, captions will be generated for frames using LVM. The user can use below command to invoke the microservice to convert videos and generated captions to embeddings and save to the database. + +- Single video upload + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./video1.mp4" \ + http://dataprep-redis-service:6007/v1/dataprep/generate_captions +``` + +- Multiple video upload + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./video1.mp4" \ + -F "files=@./video2.mp4" \ + http://dataprep-redis-service:6007/v1/dataprep/generate_captions +``` + +## 4.4 Consume get_videos API + +To get names of uploaded videos, use the following command. + +```bash +curl -X POST \ + -H "Content-Type: application/json" \ + http://dataprep-redis-service:6007/v1/dataprep/get_videos +``` + +## 4.5 Consume delete_videos API + +To delete uploaded videos and clear the database, use the following command. + +```bash +curl -X POST \ + -H "Content-Type: application/json" \ + http://dataprep-redis-service:6007/v1/dataprep/delete_videos +``` diff --git a/comps/dataprep/redis/multimodal_langchain/__init__.py b/comps/dataprep/redis/multimodal_langchain/__init__.py new file mode 100644 index 000000000..916f3a44b --- /dev/null +++ b/comps/dataprep/redis/multimodal_langchain/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/comps/dataprep/redis/multimodal_langchain/config.py b/comps/dataprep/redis/multimodal_langchain/config.py new file mode 100644 index 000000000..99643a741 --- /dev/null +++ b/comps/dataprep/redis/multimodal_langchain/config.py @@ -0,0 +1,71 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os + +# Models +EMBED_MODEL = os.getenv("EMBED_MODEL", "BridgeTower/bridgetower-large-itm-mlm-itc") +WHISPER_MODEL = os.getenv("WHISPER_MODEL", "large-v2") + +# Redis Connection Information +REDIS_HOST = os.getenv("REDIS_HOST", "localhost") +REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) + +# Lvm Microservice Information +LVM_ENDPOINT=os.getenv("LVM_ENDPOINT", "http://localhost:9399/v1/lvm") + + +def get_boolean_env_var(var_name, default_value=False): + """Retrieve the boolean value of an environment variable. + + Args: + var_name (str): The name of the environment variable to retrieve. + default_value (bool): The default value to return if the variable + is not found. + + Returns: + bool: The value of the environment variable, interpreted as a boolean. + """ + true_values = {"true", "1", "t", "y", "yes"} + false_values = {"false", "0", "f", "n", "no"} + + # Retrieve the environment variable's value + value = os.getenv(var_name, "").lower() + + # Decide the boolean value based on the content of the string + if value in true_values: + return True + elif value in false_values: + return False + else: + return default_value + + +def format_redis_conn_from_env(): + redis_url = os.getenv("REDIS_URL", None) + if redis_url: + return redis_url + else: + using_ssl = get_boolean_env_var("REDIS_SSL", False) + start = "rediss://" if using_ssl else "redis://" + + # if using RBAC + password = os.getenv("REDIS_PASSWORD", None) + username = os.getenv("REDIS_USERNAME", "default") + if password is not None: + start += f"{username}:{password}@" + + return start + f"{REDIS_HOST}:{REDIS_PORT}" + + +REDIS_URL = format_redis_conn_from_env() + +# Vector Index Configuration +INDEX_NAME = os.getenv("INDEX_NAME", "mm-rag-redis") + +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(current_file_path) +REDIS_SCHEMA = os.getenv("REDIS_SCHEMA", "schema.yml") +TIMEOUT_SECONDS = int(os.getenv("TIMEOUT_SECONDS", 600)) +schema_path = os.path.join(parent_dir, REDIS_SCHEMA) +INDEX_SCHEMA = schema_path \ No newline at end of file diff --git a/comps/dataprep/redis/multimodal_langchain/docker/Dockerfile b/comps/dataprep/redis/multimodal_langchain/docker/Dockerfile new file mode 100644 index 000000000..d498a5629 --- /dev/null +++ b/comps/dataprep/redis/multimodal_langchain/docker/Dockerfile @@ -0,0 +1,48 @@ + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +FROM python:3.11-slim + +ENV LANG=C.UTF-8 + +ARG ARCH="cpu" + +RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \ + build-essential \ + libgl1-mesa-glx \ + libjemalloc-dev \ + default-jre \ + wget \ + vim + +# Install ffmpeg static build +RUN cd /root && wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz && \ + mkdir ffmpeg-git-amd64-static && tar -xvf ffmpeg-git-amd64-static.tar.xz -C ffmpeg-git-amd64-static --strip-components 1 && \ + export PATH=/root/ffmpeg-git-amd64-static:$PATH && \ + cp /root/ffmpeg-git-amd64-static/ffmpeg /usr/local/bin/ + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +USER user + +COPY comps /home/user/comps + +RUN pip install --no-cache-dir --upgrade pip setuptools && \ + if [ ${ARCH} = "cpu" ]; then pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu; fi && \ + pip install --no-cache-dir -r /home/user/comps/dataprep/redis/multimodal_langchain/requirements.txt + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +USER root + +RUN mkdir -p /home/user/comps/dataprep/redis/multimodal_langchain/uploaded_files && chown -R user /home/user/comps/dataprep/redis/multimodal_langchain/uploaded_files + +USER user + +WORKDIR /home/user/comps/dataprep/redis/multimodal_langchain + +ENTRYPOINT ["python", "prepare_videodoc_redis.py"] + diff --git a/comps/dataprep/redis/multimodal_langchain/docker/docker-compose-dataprep-redis.yaml b/comps/dataprep/redis/multimodal_langchain/docker/docker-compose-dataprep-redis.yaml new file mode 100644 index 000000000..5e73c6ebb --- /dev/null +++ b/comps/dataprep/redis/multimodal_langchain/docker/docker-compose-dataprep-redis.yaml @@ -0,0 +1,30 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +version: "3" +services: + redis-vector-db: + image: redis/redis-stack:7.2.0-v9 + container_name: redis-vector-db + ports: + - "6379:6379" + - "8001:8001" + dataprep-redis: + image: opea/dataprep-redis:latest + container_name: dataprep-redis-server + ports: + - "6007:6007" + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + REDIS_URL: ${REDIS_URL} + INDEX_NAME: ${INDEX_NAME} + LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY} + LVM_ENDPOINT: ${LVM_ENDPOINT} + restart: unless-stopped + +networks: + default: + driver: bridge diff --git a/comps/dataprep/redis/multimodal_langchain/prepare_videodoc_redis.py b/comps/dataprep/redis/multimodal_langchain/prepare_videodoc_redis.py new file mode 100644 index 000000000..35e18a8ed --- /dev/null +++ b/comps/dataprep/redis/multimodal_langchain/prepare_videodoc_redis.py @@ -0,0 +1,495 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import shutil +import uuid +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Type, Union + +from config import EMBED_MODEL, WHISPER_MODEL, INDEX_NAME, INDEX_SCHEMA, REDIS_URL, LVM_ENDPOINT +from fastapi import File, HTTPException, UploadFile +from langchain_core.embeddings import Embeddings +from langchain_core.utils import get_from_dict_or_env +from langchain_community.vectorstores import Redis +from langchain_community.vectorstores.redis.base import _generate_field_schema, _prepare_metadata +from langchain_community.utilities.redis import _array_to_buffer +from langsmith import traceable + +from comps import opea_microservices, register_microservice +from comps.dataprep.multimodal_utils import ( + BridgeTowerEmbeddings, + create_upload_folder, + load_json_file, + clear_upload_folder, + generate_video_id, + convert_video_to_audio, + load_whisper_model, + extract_transcript_from_audio, + write_vtt, + delete_audio_file, + extract_frames_and_annotations_from_transcripts, + extract_frames_and_generate_captions +) + + +device = "cpu" +upload_folder = "./uploaded_files/" + + +class MultimodalRedis(Redis): + """ Redis vector database to process multimodal data""" + + @classmethod + def from_text_image_pairs_return_keys( + cls: Type[Redis], + texts: List[str], + images: List[str], + embedding: Embeddings = BridgeTowerEmbeddings, + metadatas: Optional[List[dict]] = None, + index_name: Optional[str] = None, + index_schema: Optional[Union[Dict[str, str], str, os.PathLike]] = None, + vector_schema: Optional[Dict[str, Union[str, int]]] = None, + **kwargs: Any, + ): + """ + Args: + texts (List[str]): List of texts to add to the vectorstore. + images (List[str]): List of path-to-images to add to the vectorstore. + embedding (Embeddings): Embeddings to use for the vectorstore. + metadatas (Optional[List[dict]], optional): Optional list of metadata + dicts to add to the vectorstore. Defaults to None. + index_name (Optional[str], optional): Optional name of the index to + create or add to. Defaults to None. + index_schema (Optional[Union[Dict[str, str], str, os.PathLike]], optional): + Optional fields to index within the metadata. Overrides generated + schema. Defaults to None. + vector_schema (Optional[Dict[str, Union[str, int]]], optional): Optional + vector schema to use. Defaults to None. + **kwargs (Any): Additional keyword arguments to pass to the Redis client. + + Returns: + Tuple[Redis, List[str]]: Tuple of the Redis instance and the keys of + the newly created documents. + + Raises: + ValueError: If the number of metadatas does not match the number of texts. + """ + # the length of texts must be equal to the length of images + assert len(texts)==len(images), "the len of captions should be equal to the len of images" + + redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL") + + if "redis_url" in kwargs: + kwargs.pop("redis_url") + + # flag to use generated schema + if "generate" in kwargs: + kwargs.pop("generate") + + # see if the user specified keys + keys = None + if "keys" in kwargs: + keys = kwargs.pop("keys") + + # Name of the search index if not given + if not index_name: + index_name = uuid.uuid4().hex + + # type check for metadata + if metadatas: + if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore # noqa: E501 + raise ValueError("Number of metadatas must match number of texts") + if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)): + raise ValueError("Metadatas must be a list of dicts") + generated_schema = _generate_field_schema(metadatas[0]) + + if not index_schema: + index_schema = generated_schema + + # Create instance + instance = cls( + redis_url, + index_name, + embedding, + index_schema=index_schema, + vector_schema=vector_schema, + **kwargs, + ) + # Add data to Redis + keys = instance.add_text_image_pairs(texts, images, metadatas, keys=keys) + return instance, keys + + def add_text_image_pairs( + self, + texts: Iterable[str], + images: Iterable[str], + metadatas: Optional[List[dict]] = None, + embeddings: Optional[List[List[float]]] = None, + batch_size: int = 2, + clean_metadata: bool = True, + **kwargs: Any, + ) -> List[str]: + + """Add more embeddings of text-image pairs to the vectorstore. + + Args: + texts (Iterable[str]): Iterable of strings/text to add to the vectorstore. + images: Iterable[str]: Iterable of strings/text of path-to-image to add to the vectorstore. + metadatas (Optional[List[dict]], optional): Optional list of metadatas. + Defaults to None. + embeddings (Optional[List[List[float]]], optional): Optional pre-generated + embeddings. Defaults to None. + keys (List[str]) or ids (List[str]): Identifiers of entries. + Defaults to None. + batch_size (int, optional): Batch size to use for writes. Defaults to 1000. + + Returns: + List[str]: List of ids added to the vectorstore + """ + ids = [] + # Get keys or ids from kwargs + # Other vectorstores use ids + keys_or_ids = kwargs.get("keys", kwargs.get("ids")) + + # type check for metadata + if metadatas: + if isinstance(metadatas, list) and len(metadatas) != len(texts): # type: ignore # noqa: E501 + raise ValueError("Number of metadatas must match number of texts") + if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)): + raise ValueError("Metadatas must be a list of dicts") + + if not embeddings: + embeddings = self._embeddings.embed_image_text_pairs(list(texts), list(images), batch_size=batch_size) + self._create_index_if_not_exist(dim=len(embeddings[0])) + + # Write data to redis + pipeline = self.client.pipeline(transaction=False) + for i, text in enumerate(texts): + # Use provided values by default or fallback + key = keys_or_ids[i] if keys_or_ids else str(uuid.uuid4().hex) + if not key.startswith(self.key_prefix + ":"): + key = self.key_prefix + ":" + key + metadata = metadatas[i] if metadatas else {} + metadata = _prepare_metadata(metadata) if clean_metadata else metadata + pipeline.hset( + key, + mapping={ + self._schema.content_key: text, + self._schema.content_vector_key: _array_to_buffer( + embeddings[i], self._schema.vector_dtype + ), + **metadata, + }, + ) + ids.append(key) + + # Write batch + if i % batch_size == 0: + pipeline.execute() + + # Cleanup final batch + pipeline.execute() + return ids + + +def prepare_data_and_metadata_from_annotation(annotation, path_to_frames, title, description, num_transcript_concat_for_ingesting=2, num_transcript_concat_for_inference=7): + text_list = [] + image_list = [] + metadatas = [] + for i, frame in enumerate(annotation): + frame_index = frame['sub_video_id'] + path_to_frame = os.path.join(path_to_frames, f"frame_{frame_index}.jpg") + lb_ingesting = max(0, i-num_transcript_concat_for_ingesting) + ub_ingesting = min(len(annotation), i+num_transcript_concat_for_ingesting+1) + caption_for_ingesting = ' '.join([annotation[j]['caption'] for j in range(lb_ingesting, ub_ingesting)]) + + lb_inference = max(0, i-num_transcript_concat_for_inference) + ub_inference = min(len(annotation), i+num_transcript_concat_for_inference+1) + caption_for_inference = ' '.join([annotation[j]['caption'] for j in range(lb_inference, ub_inference)]) + + video_id = frame['video_id'] + b64_img_str = frame['b64_img_str'] + time_of_frame = frame['time'] + embedding_type = 'pair' + source_video = frame['video_name'] + + text_list.append(caption_for_ingesting) + image_list.append(path_to_frame) + metadatas.append({ + 'content' : caption_for_ingesting, + 'b64_img_str': b64_img_str, + 'video_id': video_id, + 'source_video' : source_video, + 'time_of_frame_ms' : float(time_of_frame), + 'embedding_type' : embedding_type, + 'title' : title, + 'description' : description, + 'transcript_for_inference' : caption_for_inference, + }) + + return text_list, image_list, metadatas + + +def ingest_multimodal(title, title_for_embedding, description, data_folder, embeddings): + """ + Ingest text image pairs to Redis from the data/ directory that consists of frames and annotations + """ + data_folder = os.path.abspath(data_folder) + annotation_file_path = os.path.join(data_folder, 'annotations.json') + path_to_frames = os.path.join(data_folder, 'frames') + + annotation = load_json_file(annotation_file_path) + + #prepare data to ingest + text_list, image_list, metadatas = prepare_data_and_metadata_from_annotation(annotation, path_to_frames, title, description) + + MultimodalRedis.from_text_image_pairs_return_keys( + texts=[f"From {title_for_embedding}. " + text for text in text_list], + images=image_list, + embedding=embeddings, + metadatas=metadatas, + index_name=INDEX_NAME, + index_schema=INDEX_SCHEMA, + redis_url=REDIS_URL, + ) + + +def drop_index(index_name, redis_url=REDIS_URL): + print(f"dropping index {index_name}") + try: + assert Redis.drop_index(index_name=index_name, delete_documents=True, redis_url=redis_url) + print(f"index {index_name} deleted") + except Exception as e: + print(f"index {index_name} delete failed: {e}") + return False + return True + + +@register_microservice(name="opea_service@prepare_videodoc_redis", endpoint="/v1/dataprep/generate_transcripts", host="0.0.0.0", port=6007) +@traceable(run_type="tool") +async def ingest_videos( + files: List[UploadFile] = File(None) +): + """Upload videos with speech, generate transcripts using whisper and ingest into redis""" + + if files: + video_files = [] + for file in files: + if os.path.splitext(file.filename)[1] == ".mp4": + video_files.append(file) + else: + raise HTTPException(status_code=400, detail=f"File {file.filename} is not an mp4 file. Please upload mp4 files only.") + + # Load whisper model + whisper_model = load_whisper_model(model_name=WHISPER_MODEL) + + # Load embeddings model + embeddings = BridgeTowerEmbeddings(model_name=EMBED_MODEL, device=device) + + for video_file in video_files: + print(f"Processing video {video_file.filename}") + + # Assign unique identifier to video + video_id = generate_video_id() + + # Create video file name by appending identifier + video_name = os.path.splitext(video_file.filename)[0] + video_file_name = f"{video_name}_{video_id}.mp4" + video_dir_name = os.path.splitext(video_file_name)[0] + + # Save video file in upload_directory + with open(os.path.join(upload_folder, video_file_name), 'wb') as f: + shutil.copyfileobj(video_file.file, f) + + # Convert mp4 to temporary wav file + audio_file = video_dir_name + ".wav" + convert_video_to_audio(os.path.join(upload_folder, video_file_name), os.path.join(upload_folder, audio_file)) + + # Extract transcript from audio + transcripts = extract_transcript_from_audio(whisper_model, os.path.join(upload_folder, audio_file)) + + # Save transcript as vtt file and delete audio file + vtt_file = video_file_name + ".vtt" + write_vtt(transcripts, os.path.join(upload_folder, vtt_file)) + delete_audio_file(os.path.join(upload_folder, audio_file)) + + # Store frames and caption annotations in a new directory + extract_frames_and_annotations_from_transcripts(video_id, os.path.join(upload_folder, video_file_name), os.path.join(upload_folder, vtt_file), os.path.join(upload_folder, video_dir_name)) + + # Delete temporary vtt file + os.remove(os.path.join(upload_folder, vtt_file)) + + # Ingest multimodal data into redis + ingest_multimodal(video_file_name, video_name, video_name, os.path.join(upload_folder, video_dir_name), embeddings) + + # Delete temporary video directory containing frames and annotations + shutil.rmtree(os.path.join(upload_folder, video_dir_name)) + + print(f"Processed video {video_file.filename}") + + return {"status": 200, "message": "Data preparation succeeded"} + + raise HTTPException(status_code=400, detail="Must provide atleast one video (.mp4) file.") + + +@register_microservice(name="opea_service@prepare_videodoc_redis", endpoint="/v1/dataprep/generate_captions", host="0.0.0.0", port=6007) +@traceable(run_type="tool") +async def ingest_videos( + files: List[UploadFile] = File(None) +): + """Upload videos without speech (only background music or no audio), generate captions using lvm microservice and ingest into redis""" + + if files: + video_files = [] + for file in files: + if os.path.splitext(file.filename)[1] == ".mp4": + video_files.append(file) + else: + raise HTTPException(status_code=400, detail=f"File {file.filename} is not an mp4 file. Please upload mp4 files only.") + + # Load embeddings model + embeddings = BridgeTowerEmbeddings(model_name=EMBED_MODEL, device=device) + + for video_file in video_files: + print(f"Processing video {video_file.filename}") + + # Assign unique identifier to video + video_id = generate_video_id() + + # Create video file name by appending identifier + video_name = os.path.splitext(video_file.filename)[0] + video_file_name = f"{video_name}_{video_id}.mp4" + video_dir_name = os.path.splitext(video_file_name)[0] + + # Save video file in upload_directory + with open(os.path.join(upload_folder, video_file_name), 'wb') as f: + shutil.copyfileobj(video_file.file, f) + + # Store frames and caption annotations in a new directory + extract_frames_and_generate_captions(video_id, os.path.join(upload_folder, video_file_name), LVM_ENDPOINT, os.path.join(upload_folder, video_dir_name)) + + # Ingest multimodal data into redis + ingest_multimodal(video_file_name, video_name, video_name, os.path.join(upload_folder, video_dir_name), embeddings) + + # Delete temporary video directory containing frames and annotations + shutil.rmtree(os.path.join(upload_folder, video_dir_name)) + + print(f"Processed video {video_file.filename}") + + return {"status": 200, "message": "Data preparation succeeded"} + + raise HTTPException(status_code=400, detail="Must provide atleast one video (.mp4) file.") + + + +@register_microservice(name="opea_service@prepare_videodoc_redis", endpoint="/v1/dataprep/videos_with_transcripts", host="0.0.0.0", port=6007) +@traceable(run_type="tool") +async def ingest_videos( + files: List[UploadFile] = File(None) +): + + if files: + video_files, video_file_names = [], [] + captions_files, captions_file_names = [], [] + for file in files: + if os.path.splitext(file.filename)[1] == ".mp4": + video_files.append(file) + video_file_names.append(file.filename) + elif os.path.splitext(file.filename)[1] == ".vtt": + captions_files.append(file) + captions_file_names.append(file.filename) + else: + print(f"Skipping file {file.filename} because of unsupported format.") + + # Check if every video file has a captions file + for video_file_name in video_file_names: + file_prefix = os.path.splitext(video_file_name)[0] + if (file_prefix + ".vtt") not in captions_file_names: + raise HTTPException(status_code=400, detail=f"No captions file {file_prefix}.vtt found for {video_file_name}") + + if len(video_files) == 0: + return HTTPException(status_code=400, detail="The uploaded files have unsupported formats. Please upload atleast one video file (.mp4) with captions (.vtt)") + + # Load embeddings model + embeddings = BridgeTowerEmbeddings(model_name=EMBED_MODEL, device=device) + + for video_file in video_files: + print(f"Processing video {video_file.filename}") + + # Assign unique identifier to video + video_id = generate_video_id() + + # Create video file name by appending identifier + video_name = os.path.splitext(video_file.filename)[0] + video_file_name = f"{video_name}_{video_id}.mp4" + video_dir_name = os.path.splitext(video_file_name)[0] + + # Save video file in upload_directory + with open(os.path.join(upload_folder, video_file_name), 'wb') as f: + shutil.copyfileobj(video_file.file, f) + + # Save captions file in upload directory + vtt_file_name = os.path.splitext(video_file.filename)[0] + ".vtt" + vtt_idx = None + for idx, caption_file in enumerate(captions_files): + if caption_file.filename == vtt_file_name: + vtt_idx = idx + break + vtt_file = video_dir_name + ".vtt" + with open(os.path.join(upload_folder, vtt_file), 'wb') as f: + shutil.copyfileobj(captions_files[vtt_idx].file, f) + + # Store frames and caption annotations in a new directory + extract_frames_and_annotations_from_transcripts(video_id, os.path.join(upload_folder, video_file_name), os.path.join(upload_folder, vtt_file), os.path.join(upload_folder, video_dir_name)) + + # Delete temporary vtt file + os.remove(os.path.join(upload_folder, vtt_file)) + + # Ingest multimodal data into redis + ingest_multimodal(video_file_name, video_dir_name, video_dir_name, os.path.join(upload_folder, video_dir_name), embeddings) + + # Delete temporary video directory containing frames and annotations + shutil.rmtree(os.path.join(upload_folder, video_dir_name)) + + print(f"Processed video {video_file.filename}") + + return {"status": 200, "message": "Data preparation succeeded"} + + raise HTTPException(status_code=400, detail="Must provide atleast one pair consisting of video (.mp4) and captions (.vtt)") + + +@register_microservice( + name="opea_service@prepare_videodoc_redis", endpoint="/v1/dataprep/get_videos", host="0.0.0.0", port=6007 +) +@traceable(run_type="tool") +async def rag_get_file_structure(): + """Returns list of names of uploaded videos saved on the server""" + + if not Path(upload_folder).exists(): + print("No file uploaded, return empty list.") + return [] + + uploaded_videos = os.listdir(upload_folder) + return uploaded_videos + + +@register_microservice( + name="opea_service@prepare_videodoc_redis", endpoint="/v1/dataprep/delete_videos", host="0.0.0.0", port=6007 +) +@traceable(run_type="tool") +async def delete_videos(): + """Delete all uploaded videos along with redis index""" + index_deleted = drop_index(index_name=INDEX_NAME) + + if not index_deleted: + raise HTTPException(status_code=409, detail="Uploaded videos could not be deleted. Index does not exist") + + clear_upload_folder(upload_folder) + print("Successfully deleted all uploaded videos.") + return {"status": True} + + +if __name__ == "__main__": + create_upload_folder(upload_folder) + opea_microservices["opea_service@prepare_videodoc_redis"].start() diff --git a/comps/dataprep/redis/multimodal_langchain/requirements.txt b/comps/dataprep/redis/multimodal_langchain/requirements.txt new file mode 100644 index 000000000..7a58c6d53 --- /dev/null +++ b/comps/dataprep/redis/multimodal_langchain/requirements.txt @@ -0,0 +1,21 @@ +diffusers +docarray[full] +fastapi +langchain==0.1.12 +langchain_benchmarks +langsmith +moviepy +opencv-python +openai-whisper +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-sdk +Pillow +prometheus-fastapi-instrumentator +pydantic==2.8.2 +python-multipart +redis +transformers +shortuuid +uvicorn +webvtt-py diff --git a/comps/dataprep/redis/multimodal_langchain/schema.yml b/comps/dataprep/redis/multimodal_langchain/schema.yml new file mode 100644 index 000000000..f0d1bd567 --- /dev/null +++ b/comps/dataprep/redis/multimodal_langchain/schema.yml @@ -0,0 +1,20 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +text: +- name: content +- name: b64_img_str +- name: video_id +- name: source_video +- name: embedding_type +- name: title +- name: description +- name: transcript_for_inference +numeric: +- name: time_of_frame_ms +vector: +- name: content_vector + algorithm: HNSW + datatype: FLOAT32 + dims: 512 + distance_metric: COSINE