diff --git a/deploy/.env.example b/deploy/.env.example new file mode 100644 index 000000000..6b0f4de4f --- /dev/null +++ b/deploy/.env.example @@ -0,0 +1,3 @@ +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_REGION= diff --git a/deploy/README.md b/deploy/README.md new file mode 100644 index 000000000..2e53469ce --- /dev/null +++ b/deploy/README.md @@ -0,0 +1,10 @@ +``` +# First time setup +cd deploy +uv venv +source .venv/bin/activate +uv pip install -e . + +# Subsequent usage +python deploy/models/omniparser/deploy.py start +``` diff --git a/deploy/deploy/models/omniparser/.dockerignore b/deploy/deploy/models/omniparser/.dockerignore new file mode 100644 index 000000000..213bee701 --- /dev/null +++ b/deploy/deploy/models/omniparser/.dockerignore @@ -0,0 +1,20 @@ +__pycache__ +*.pyc +*.pyo +*.pyd +.Python +env +pip-log.txt +pip-delete-this-directory.txt +.tox +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.log +.pytest_cache +.env +.venv +.DS_Store diff --git a/deploy/deploy/models/omniparser/Dockerfile b/deploy/deploy/models/omniparser/Dockerfile new file mode 100644 index 000000000..f14ea7ac8 --- /dev/null +++ b/deploy/deploy/models/omniparser/Dockerfile @@ -0,0 +1,59 @@ +FROM nvidia/cuda:12.3.1-devel-ubuntu22.04 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + git-lfs \ + wget \ + libgl1 \ + libglib2.0-0 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* \ + && git lfs install + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \ + bash miniconda.sh -b -p /opt/conda && \ + rm miniconda.sh +ENV PATH="/opt/conda/bin:$PATH" + +RUN conda create -n omni python=3.12 && \ + echo "source activate omni" > ~/.bashrc +ENV CONDA_DEFAULT_ENV=omni +ENV PATH="/opt/conda/envs/omni/bin:$PATH" + +WORKDIR /app + +RUN git clone https://github.com/microsoft/OmniParser.git && \ + cd OmniParser && \ + git lfs install && \ + git lfs pull + +WORKDIR /app/OmniParser + +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + pip uninstall -y opencv-python opencv-python-headless && \ + pip install --no-cache-dir opencv-python-headless==4.8.1.78 && \ + pip install -r requirements.txt && \ + pip install huggingface_hub fastapi uvicorn + +# Download V2 weights +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + mkdir -p /app/OmniParser/weights && \ + cd /app/OmniParser && \ + rm -rf weights/icon_detect weights/icon_caption weights/icon_caption_florence && \ + for folder in icon_caption icon_detect; do \ + huggingface-cli download microsoft/OmniParser-v2.0 --local-dir weights --repo-type model --include "$folder/*"; \ + done && \ + mv weights/icon_caption weights/icon_caption_florence + +# Pre-download OCR models during build +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + cd /app/OmniParser && \ + python3 -c "import easyocr; reader = easyocr.Reader(['en']); print('Downloaded EasyOCR model')" && \ + python3 -c "from paddleocr import PaddleOCR; ocr = PaddleOCR(lang='en', use_angle_cls=False, use_gpu=False, show_log=False); print('Downloaded PaddleOCR model')" + +CMD ["python3", "/app/OmniParser/omnitool/omniparserserver/omniparserserver.py", \ + "--som_model_path", "/app/OmniParser/weights/icon_detect/model.pt", \ + "--caption_model_path", "/app/OmniParser/weights/icon_caption_florence", \ + "--device", "cuda", \ + "--BOX_TRESHOLD", "0.05", \ + "--host", "0.0.0.0", \ + "--port", "8000"] diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py new file mode 100644 index 000000000..c0cac4f49 --- /dev/null +++ b/deploy/deploy/models/omniparser/client.py @@ -0,0 +1,128 @@ +"""Client module for interacting with the OmniParser server.""" + +import base64 +import fire +import requests + +from loguru import logger +from PIL import Image, ImageDraw + + +def image_to_base64(image_path: str) -> str: + """Convert an image file to base64 string. + + Args: + image_path: Path to the image file + + Returns: + str: Base64 encoded string of the image + """ + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def plot_results( + original_image_path: str, + som_image_base64: str, + parsed_content_list: list[dict[str, list[float]]], +) -> None: + """Plot parsing results on the original image. + + Args: + original_image_path: Path to the original image + som_image_base64: Base64 encoded SOM image + parsed_content_list: List of parsed content with bounding boxes + """ + # Open original image + image = Image.open(original_image_path) + width, height = image.size + + # Create drawable image + draw = ImageDraw.Draw(image) + + # Draw bounding boxes and labels + for item in parsed_content_list: + # Get normalized coordinates and convert to pixel coordinates + x1, y1, x2, y2 = item["bbox"] + x1 = int(x1 * width) + y1 = int(y1 * height) + x2 = int(x2 * width) + y2 = int(y2 * height) + + label = item["content"] + + # Draw rectangle + draw.rectangle([(x1, y1), (x2, y2)], outline="red", width=2) + + # Draw label background + text_bbox = draw.textbbox((x1, y1), label) + draw.rectangle( + [text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2], + fill="white", + ) + + # Draw label text + draw.text((x1, y1), label, fill="red") + + # Show image + image.show() + + +def parse_image( + image_path: str, + server_url: str, +) -> None: + """Parse an image using the OmniParser server. + + Args: + image_path: Path to the image file + server_url: URL of the OmniParser server + """ + # Remove trailing slash from server_url if present + server_url = server_url.rstrip("/") + + # Convert image to base64 + base64_image = image_to_base64(image_path) + + # Prepare request + url = f"{server_url}/parse/" + payload = {"base64_image": base64_image} + + try: + # First, check if the server is available + probe_url = f"{server_url}/probe/" + probe_response = requests.get(probe_url) + probe_response.raise_for_status() + logger.info("Server is available") + + # Make request to API + response = requests.post(url, json=payload) + response.raise_for_status() + + # Parse response + result = response.json() + som_image_base64 = result["som_image_base64"] + parsed_content_list = result["parsed_content_list"] + + # Plot results + plot_results(image_path, som_image_base64, parsed_content_list) + + # Print latency + logger.info(f"API Latency: {result['latency']:.2f} seconds") + + except requests.exceptions.ConnectionError: + logger.error(f"Error: Could not connect to server at {server_url}") + logger.error("Please check if the server is running and the URL is correct") + except requests.exceptions.RequestException as e: + logger.error(f"Error making request to API: {e}") + except Exception as e: + logger.error(f"Error: {e}") + + +def main() -> None: + """Main entry point for the client application.""" + fire.Fire(parse_image) + + +if __name__ == "__main__": + main() diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py new file mode 100644 index 000000000..b951378bb --- /dev/null +++ b/deploy/deploy/models/omniparser/deploy.py @@ -0,0 +1,785 @@ +"""Deployment module for OmniParser on AWS EC2.""" + +import os +import subprocess +import time + +from botocore.exceptions import ClientError +from loguru import logger +from pydantic_settings import BaseSettings +import boto3 +import fire +import paramiko + + +CLEANUP_ON_FAILURE = False + + +class Config(BaseSettings): + """Configuration settings for deployment.""" + + AWS_ACCESS_KEY_ID: str + AWS_SECRET_ACCESS_KEY: str + AWS_REGION: str + + PROJECT_NAME: str = "omniparser" + REPO_URL: str = "https://github.com/microsoft/OmniParser.git" + AWS_EC2_AMI: str = "ami-06835d15c4de57810" + AWS_EC2_DISK_SIZE: int = 128 # GB + AWS_EC2_INSTANCE_TYPE: str = "g4dn.xlarge" # (T4 16GB $0.526/hr x86_64) + AWS_EC2_USER: str = "ubuntu" + PORT: int = 8000 # FastAPI port + COMMAND_TIMEOUT: int = 600 # 10 minutes + + class Config: + """Pydantic configuration class.""" + + env_file = ".env" + env_file_encoding = "utf-8" + + @property + def CONTAINER_NAME(self) -> str: + """Get the container name.""" + return f"{self.PROJECT_NAME}-container" + + @property + def AWS_EC2_KEY_NAME(self) -> str: + """Get the EC2 key pair name.""" + return f"{self.PROJECT_NAME}-key" + + @property + def AWS_EC2_KEY_PATH(self) -> str: + """Get the path to the EC2 key file.""" + return f"./{self.AWS_EC2_KEY_NAME}.pem" + + @property + def AWS_EC2_SECURITY_GROUP(self) -> str: + """Get the EC2 security group name.""" + return f"{self.PROJECT_NAME}-SecurityGroup" + + +config = Config() + + +def create_key_pair( + key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = config.AWS_EC2_KEY_PATH +) -> str | None: + """Create an EC2 key pair. + + Args: + key_name: Name of the key pair + key_path: Path where to save the key file + + Returns: + str | None: Key name if successful, None otherwise + """ + ec2_client = boto3.client("ec2", region_name=config.AWS_REGION) + try: + key_pair = ec2_client.create_key_pair(KeyName=key_name) + private_key = key_pair["KeyMaterial"] + + with open(key_path, "w") as key_file: + key_file.write(private_key) + os.chmod(key_path, 0o400) # Set read-only permissions + + logger.info(f"Key pair {key_name} created and saved to {key_path}") + return key_name + except ClientError as e: + logger.error(f"Error creating key pair: {e}") + return None + + +def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: + """Get existing security group or create a new one. + + Args: + ports: List of ports to open in the security group + + Returns: + str | None: Security group ID if successful, None otherwise + """ + ec2 = boto3.client("ec2", region_name=config.AWS_REGION) + + ip_permissions = [ + { + "IpProtocol": "tcp", + "FromPort": port, + "ToPort": port, + "IpRanges": [{"CidrIp": "0.0.0.0/0"}], + } + for port in ports + ] + + try: + response = ec2.describe_security_groups( + GroupNames=[config.AWS_EC2_SECURITY_GROUP] + ) + security_group_id = response["SecurityGroups"][0]["GroupId"] + logger.info( + f"Security group '{config.AWS_EC2_SECURITY_GROUP}' already exists: " + f"{security_group_id}" + ) + + for ip_permission in ip_permissions: + try: + ec2.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=[ip_permission] + ) + logger.info(f"Added inbound rule for port {ip_permission['FromPort']}") + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidPermission.Duplicate": + logger.info( + f"Rule for port {ip_permission['FromPort']} already exists" + ) + else: + logger.error( + f"Error adding rule for port {ip_permission['FromPort']}: {e}" + ) + + return security_group_id + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidGroup.NotFound": + try: + response = ec2.create_security_group( + GroupName=config.AWS_EC2_SECURITY_GROUP, + Description="Security group for OmniParser deployment", + TagSpecifications=[ + { + "ResourceType": "security-group", + "Tags": [{"Key": "Name", "Value": config.PROJECT_NAME}], + } + ], + ) + security_group_id = response["GroupId"] + logger.info( + f"Created security group '{config.AWS_EC2_SECURITY_GROUP}' " + f"with ID: {security_group_id}" + ) + + ec2.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=ip_permissions + ) + logger.info(f"Added inbound rules for ports {ports}") + + return security_group_id + except ClientError as e: + logger.error(f"Error creating security group: {e}") + return None + else: + logger.error(f"Error describing security groups: {e}") + return None + + +def deploy_ec2_instance( + ami: str = config.AWS_EC2_AMI, + instance_type: str = config.AWS_EC2_INSTANCE_TYPE, + project_name: str = config.PROJECT_NAME, + key_name: str = config.AWS_EC2_KEY_NAME, + disk_size: int = config.AWS_EC2_DISK_SIZE, +) -> tuple[str | None, str | None]: + """Deploy a new EC2 instance or return existing one. + + Args: + ami: AMI ID to use for the instance + instance_type: EC2 instance type + project_name: Name tag for the instance + key_name: Name of the key pair to use + disk_size: Size of the root volume in GB + + Returns: + tuple[str | None, str | None]: Instance ID and public IP if successful + """ + ec2 = boto3.resource("ec2") + ec2_client = boto3.client("ec2") + + # Check for existing instances first + instances = ec2.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [config.PROJECT_NAME]}, + { + "Name": "instance-state-name", + "Values": ["running", "pending", "stopped"], + }, + ] + ) + + existing_instance = None + for instance in instances: + existing_instance = instance + if instance.state["Name"] == "running": + logger.info( + f"Instance already running: ID - {instance.id}, " + f"IP - {instance.public_ip_address}" + ) + break + elif instance.state["Name"] == "stopped": + logger.info(f"Starting existing stopped instance: ID - {instance.id}") + ec2_client.start_instances(InstanceIds=[instance.id]) + instance.wait_until_running() + instance.reload() + logger.info( + f"Instance started: ID - {instance.id}, " + f"IP - {instance.public_ip_address}" + ) + break + + # If we found an existing instance, ensure we have its key + if existing_instance: + if not os.path.exists(config.AWS_EC2_KEY_PATH): + logger.warning( + f"Key file {config.AWS_EC2_KEY_PATH} not found for existing instance." + ) + logger.warning( + "You'll need to use the original key file to connect to this instance." + ) + logger.warning( + "Consider terminating the instance with 'deploy.py stop' and starting " + "fresh." + ) + return None, None + return existing_instance.id, existing_instance.public_ip_address + + # No existing instance found, create new one with new key pair + security_group_id = get_or_create_security_group_id() + if not security_group_id: + logger.error( + "Unable to retrieve security group ID. Instance deployment aborted." + ) + return None, None + + # Create new key pair + try: + if os.path.exists(config.AWS_EC2_KEY_PATH): + logger.info(f"Removing existing key file {config.AWS_EC2_KEY_PATH}") + os.remove(config.AWS_EC2_KEY_PATH) + + try: + ec2_client.delete_key_pair(KeyName=key_name) + logger.info(f"Deleted existing key pair {key_name}") + except ClientError: + pass # Key pair doesn't exist, which is fine + + if not create_key_pair(key_name): + logger.error("Failed to create key pair") + return None, None + except Exception as e: + logger.error(f"Error managing key pair: {e}") + return None, None + + # Create new instance + ebs_config = { + "DeviceName": "/dev/sda1", + "Ebs": { + "VolumeSize": disk_size, + "VolumeType": "gp3", + "DeleteOnTermination": True, + }, + } + + new_instance = ec2.create_instances( + ImageId=ami, + MinCount=1, + MaxCount=1, + InstanceType=instance_type, + KeyName=key_name, + SecurityGroupIds=[security_group_id], + BlockDeviceMappings=[ebs_config], + TagSpecifications=[ + { + "ResourceType": "instance", + "Tags": [{"Key": "Name", "Value": project_name}], + }, + ], + )[0] + + new_instance.wait_until_running() + new_instance.reload() + logger.info( + f"New instance created: ID - {new_instance.id}, " + f"IP - {new_instance.public_ip_address}" + ) + return new_instance.id, new_instance.public_ip_address + + +def configure_ec2_instance( + instance_id: str | None = None, + instance_ip: str | None = None, + max_ssh_retries: int = 20, + ssh_retry_delay: int = 20, + max_cmd_retries: int = 20, + cmd_retry_delay: int = 30, +) -> tuple[str | None, str | None]: + """Configure an EC2 instance with necessary dependencies and Docker setup. + + This function either configures an existing EC2 instance specified by instance_id + and instance_ip, or deploys and configures a new instance. It installs Docker and + other required dependencies, and sets up the environment for running containers. + + Args: + instance_id: Optional ID of an existing EC2 instance to configure. + If None, a new instance will be deployed. + instance_ip: Optional IP address of an existing EC2 instance. + Required if instance_id is provided. + max_ssh_retries: Maximum number of SSH connection attempts. + Defaults to 20 attempts. + ssh_retry_delay: Delay in seconds between SSH connection attempts. + Defaults to 20 seconds. + max_cmd_retries: Maximum number of command execution retries. + Defaults to 20 attempts. + cmd_retry_delay: Delay in seconds between command execution retries. + Defaults to 30 seconds. + + Returns: + tuple[str | None, str | None]: A tuple containing: + - The instance ID (str) or None if configuration failed + - The instance's public IP address (str) or None if configuration failed + + Raises: + RuntimeError: If command execution fails + paramiko.SSHException: If SSH connection fails + Exception: For other unexpected errors during configuration + """ + if not instance_id: + ec2_instance_id, ec2_instance_ip = deploy_ec2_instance() + else: + ec2_instance_id = instance_id + ec2_instance_ip = instance_ip + + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + ssh_retries = 0 + while ssh_retries < max_ssh_retries: + try: + ssh_client.connect( + hostname=ec2_instance_ip, username=config.AWS_EC2_USER, pkey=key + ) + break + except Exception as e: + ssh_retries += 1 + logger.error(f"SSH connection attempt {ssh_retries} failed: {e}") + if ssh_retries < max_ssh_retries: + logger.info(f"Retrying SSH connection in {ssh_retry_delay} seconds...") + time.sleep(ssh_retry_delay) + else: + logger.error("Maximum SSH connection attempts reached. Aborting.") + return None, None + + commands = [ + "sudo apt-get update", + "sudo apt-get install -y ca-certificates curl gnupg", + "sudo install -m 0755 -d /etc/apt/keyrings", + ( + "curl -fsSL https://download.docker.com/linux/ubuntu/gpg | " + "sudo dd of=/etc/apt/keyrings/docker.gpg" + ), + "sudo chmod a+r /etc/apt/keyrings/docker.gpg", + ( + 'echo "deb [arch="$(dpkg --print-architecture)" ' + "signed-by=/etc/apt/keyrings/docker.gpg] " + "https://download.docker.com/linux/ubuntu " + '"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | ' + "sudo tee /etc/apt/sources.list.d/docker.list > /dev/null" + ), + "sudo apt-get update", + ( + "sudo apt-get install -y docker-ce docker-ce-cli containerd.io " + "docker-buildx-plugin docker-compose-plugin" + ), + "sudo systemctl start docker", + "sudo systemctl enable docker", + "sudo usermod -a -G docker ${USER}", + "sudo docker system prune -af --volumes", + f"sudo docker rm -f {config.PROJECT_NAME}-container || true", + ] + + for command in commands: + logger.info(f"Executing command: {command}") + cmd_retries = 0 + while cmd_retries < max_cmd_retries: + stdin, stdout, stderr = ssh_client.exec_command(command) + exit_status = stdout.channel.recv_exit_status() + + if exit_status == 0: + logger.info("Command executed successfully") + break + else: + error_message = stderr.read() + if "Could not get lock" in str(error_message): + cmd_retries += 1 + logger.warning( + f"dpkg is locked, retrying in {cmd_retry_delay} seconds... " + f"Attempt {cmd_retries}/{max_cmd_retries}" + ) + time.sleep(cmd_retry_delay) + else: + logger.error( + f"Error in command: {command}, Exit Status: {exit_status}, " + f"Error: {error_message}" + ) + break + + ssh_client.close() + return ec2_instance_id, ec2_instance_ip + + +def execute_command(ssh_client: paramiko.SSHClient, command: str) -> None: + """Execute a command and handle its output safely.""" + logger.info(f"Executing: {command}") + stdin, stdout, stderr = ssh_client.exec_command( + command, + timeout=config.COMMAND_TIMEOUT, + # get_pty=True + ) + + # Stream output in real-time + while not stdout.channel.exit_status_ready(): + if stdout.channel.recv_ready(): + try: + line = stdout.channel.recv(1024).decode("utf-8", errors="replace") + if line.strip(): # Only log non-empty lines + logger.info(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stdout: {e}") + + if stdout.channel.recv_stderr_ready(): + try: + line = stdout.channel.recv_stderr(1024).decode( + "utf-8", errors="replace" + ) + if line.strip(): # Only log non-empty lines + logger.error(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stderr: {e}") + + exit_status = stdout.channel.recv_exit_status() + + # Capture any remaining output + try: + remaining_stdout = stdout.read().decode("utf-8", errors="replace") + if remaining_stdout.strip(): + logger.info(remaining_stdout.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stdout: {e}") + + try: + remaining_stderr = stderr.read().decode("utf-8", errors="replace") + if remaining_stderr.strip(): + logger.error(remaining_stderr.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stderr: {e}") + + if exit_status != 0: + error_msg = f"Command failed with exit status {exit_status}: {command}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + logger.info(f"Successfully executed: {command}") + + +class Deploy: + """Class handling deployment operations for OmniParser.""" + + @staticmethod + def start() -> None: + """Start a new deployment of OmniParser on EC2.""" + try: + instance_id, instance_ip = configure_ec2_instance() + assert instance_ip, f"invalid {instance_ip=}" + + # Trigger driver installation via login shell + Deploy.ssh(non_interactive=True) + + # Get the directory containing deploy.py + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Define files to copy + files_to_copy = { + "Dockerfile": os.path.join(current_dir, "Dockerfile"), + ".dockerignore": os.path.join(current_dir, ".dockerignore"), + } + + # Copy files to instance + for filename, filepath in files_to_copy.items(): + if os.path.exists(filepath): + logger.info(f"Copying {filename} to instance...") + subprocess.run( + [ + "scp", + "-i", + config.AWS_EC2_KEY_PATH, + "-o", + "StrictHostKeyChecking=no", + filepath, + f"{config.AWS_EC2_USER}@{instance_ip}:~/{filename}", + ], + check=True, + ) + else: + logger.warning(f"File not found: {filepath}") + + # Connect to instance and execute commands + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + logger.info(f"Connecting to {instance_ip}...") + ssh_client.connect( + hostname=instance_ip, + username=config.AWS_EC2_USER, + pkey=key, + timeout=30, + ) + + setup_commands = [ + "rm -rf OmniParser", # Clean up any existing repo + f"git clone {config.REPO_URL}", + "cp Dockerfile .dockerignore OmniParser/", + ] + + # Execute setup commands + for command in setup_commands: + logger.info(f"Executing setup command: {command}") + execute_command(ssh_client, command) + + # Build and run Docker container + docker_commands = [ + # Remove any existing container + "sudo docker rm -f {config.CONTAINER_NAME} || true", + # Remove any existing image + "sudo docker rmi {config.PROJECT_NAME} || true", + # Build new image + ( + "cd OmniParser && sudo docker build --progress=plain " + "-t {config.PROJECT_NAME} ." + ), + # Run new container + ( + "sudo docker run -d -p 8000:8000 --gpus all --name " + "{config.CONTAINER_NAME} {config.PROJECT_NAME}" + ), + ] + + # Execute Docker commands + for command in docker_commands: + logger.info(f"Executing Docker command: {command}") + execute_command(ssh_client, command) + + # Wait for container to start and check its logs + logger.info("Waiting for container to start...") + time.sleep(10) # Give container time to start + execute_command(ssh_client, "docker logs {config.CONTAINER_NAME}") + + # Wait for server to become responsive + logger.info("Waiting for server to become responsive...") + max_retries = 30 + retry_delay = 10 + server_ready = False + + for attempt in range(max_retries): + try: + # Check if server is responding + check_command = f"curl -s http://localhost:{config.PORT}/probe/" + execute_command(ssh_client, check_command) + server_ready = True + break + except Exception as e: + logger.warning( + f"Server not ready (attempt {attempt + 1}/{max_retries}): " + f"{e}" + ) + if attempt < max_retries - 1: + logger.info( + f"Waiting {retry_delay} seconds before next attempt..." + ) + time.sleep(retry_delay) + + if not server_ready: + raise RuntimeError("Server failed to start properly") + + # Final status check + execute_command(ssh_client, "docker ps | grep {config.CONTAINER_NAME}") + + server_url = f"http://{instance_ip}:{config.PORT}" + logger.info(f"Deployment complete. Server running at: {server_url}") + + # Verify server is accessible from outside + try: + import requests + + response = requests.get(f"{server_url}/probe/", timeout=10) + if response.status_code == 200: + logger.info("Server is accessible from outside!") + else: + logger.warning( + f"Server responded with status code: {response.status_code}" + ) + except Exception as e: + logger.warning(f"Could not verify external access: {e}") + + except Exception as e: + logger.error(f"Error during deployment: {e}") + # Get container logs for debugging + try: + execute_command(ssh_client, "docker logs {config.CONTAINER_NAME}") + except Exception as exc: + logger.warning(f"{exc=}") + pass + raise + + finally: + ssh_client.close() + + except Exception as e: + logger.error(f"Deployment failed: {e}") + if CLEANUP_ON_FAILURE: + # Attempt cleanup on failure + try: + Deploy.stop() + except Exception as cleanup_error: + logger.error(f"Cleanup after failure also failed: {cleanup_error}") + raise + + logger.info("Deployment completed successfully!") + + @staticmethod + def status() -> None: + """Check the status of deployed instances.""" + ec2 = boto3.resource("ec2") + instances = ec2.instances.filter( + Filters=[{"Name": "tag:Name", "Values": [config.PROJECT_NAME]}] + ) + + for instance in instances: + public_ip = instance.public_ip_address + if public_ip: + server_url = f"http://{public_ip}:{config.PORT}" + logger.info( + f"Instance ID: {instance.id}, State: {instance.state['Name']}, " + f"URL: {server_url}" + ) + else: + logger.info( + f"Instance ID: {instance.id}, State: {instance.state['Name']}, " + f"URL: Not available (no public IP)" + ) + + @staticmethod + def ssh(non_interactive: bool = False) -> None: + """SSH into the running instance. + + Args: + non_interactive: If True, run in non-interactive mode + """ + # Get instance IP + ec2 = boto3.resource("ec2") + instances = ec2.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [config.PROJECT_NAME]}, + {"Name": "instance-state-name", "Values": ["running"]}, + ] + ) + + instance = next(iter(instances), None) + if not instance: + logger.error("No running instance found") + return + + ip = instance.public_ip_address + if not ip: + logger.error("Instance has no public IP") + return + + # Check if key file exists + if not os.path.exists(config.AWS_EC2_KEY_PATH): + logger.error(f"Key file not found: {config.AWS_EC2_KEY_PATH}") + return + + if non_interactive: + # Simulate full login by forcing all initialization scripts + ssh_command = [ + "ssh", + "-o", + "StrictHostKeyChecking=no", # Automatically accept new host keys + "-o", + "UserKnownHostsFile=/dev/null", # Prevent writing to known_hosts + "-i", + config.AWS_EC2_KEY_PATH, + f"{config.AWS_EC2_USER}@{ip}", + "-t", # Allocate a pseudo-terminal + "-tt", # Force pseudo-terminal allocation + "bash --login -c 'exit'", # Force full login shell and exit immediately + ] + else: + # Build and execute SSH command + ssh_command = ( + f"ssh -i {config.AWS_EC2_KEY_PATH} -o StrictHostKeyChecking=no " + f"{config.AWS_EC2_USER}@{ip}" + ) + logger.info(f"Connecting with: {ssh_command}") + os.system(ssh_command) + return + + # Execute the SSH command for non-interactive mode + try: + subprocess.run(ssh_command, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"SSH connection failed: {e}") + + @staticmethod + def stop( + project_name: str = config.PROJECT_NAME, + security_group_name: str = config.AWS_EC2_SECURITY_GROUP, + ) -> None: + """Terminates the EC2 instance and deletes the associated security group. + + Args: + project_name (str): The project name used to tag the instance. + Defaults to config.PROJECT_NAME. + security_group_name (str): The name of the security group to delete. + Defaults to config.AWS_EC2_SECURITY_GROUP. + """ + ec2_resource = boto3.resource("ec2") + ec2_client = boto3.client("ec2") + + # Terminate EC2 instances + instances = ec2_resource.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [project_name]}, + { + "Name": "instance-state-name", + "Values": [ + "pending", + "running", + "shutting-down", + "stopped", + "stopping", + ], + }, + ] + ) + + for instance in instances: + logger.info(f"Terminating instance: ID - {instance.id}") + instance.terminate() + instance.wait_until_terminated() + logger.info(f"Instance {instance.id} terminated successfully.") + + # Delete security group + try: + ec2_client.delete_security_group(GroupName=security_group_name) + logger.info(f"Deleted security group: {security_group_name}") + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidGroup.NotFound": + logger.info( + f"Security group {security_group_name} does not exist or already " + "deleted." + ) + else: + logger.error(f"Error deleting security group: {e}") + + +if __name__ == "__main__": + fire.Fire(Deploy) diff --git a/deploy/pyproject.toml b/deploy/pyproject.toml new file mode 100644 index 000000000..835b62424 --- /dev/null +++ b/deploy/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "deploy" +version = "0.1.0" +authors = [ + { name="Richard Abrich", email="richard@openadapt.ai" }, +] +description = "Deployment tools for OpenAdapt models" +requires-python = ">=3.10" +dependencies = [ + "boto3>=1.36.22", + "fire>=0.7.0", + "loguru>=0.7.0", + "paramiko>=3.5.1", + "pillow>=11.1.0", + "pydantic>=2.10.6", + "pydantic-settings>=2.7.1", + "requests>=2.32.3", +]