diff --git a/upload-ami/src/upload_ami/delete_deprecated_images.py b/upload-ami/src/upload_ami/delete_deprecated_images.py index 06a0693..997fa40 100644 --- a/upload-ami/src/upload_ami/delete_deprecated_images.py +++ b/upload-ami/src/upload_ami/delete_deprecated_images.py @@ -33,7 +33,9 @@ def delete_deprecated_images(ec2: EC2Client, dry_run: bool) -> None: if current_time >= image["DeprecationTime"]: assert "ImageId" in image assert "Name" in image - logger.info(f"Deleting image {image['Name']} : {image['ImageId']}. DeprecationTime: {image['DeprecationTime']}") + logger.info( + f"Deleting image {image['Name']} : {image['ImageId']}. DeprecationTime: {image['DeprecationTime']}" + ) try: ec2.deregister_image(ImageId=image["ImageId"], DryRun=dry_run) except botocore.exceptions.ClientError as e: diff --git a/upload-ami/src/upload_ami/smoke_test.py b/upload-ami/src/upload_ami/smoke_test.py index c42101b..fe83e81 100644 --- a/upload-ami/src/upload_ami/smoke_test.py +++ b/upload-ami/src/upload_ami/smoke_test.py @@ -1,13 +1,19 @@ +from re import sub import boto3 +import botocore.exceptions import time import argparse import logging +import subprocess +import os from mypy_boto3_ec2 import EC2Client from mypy_boto3_ec2.literals import InstanceTypeType -def smoke_test(image_id: str, run_id: str, cancel: bool) -> None: +def smoke_test( + *, image_id: str, instance_type: InstanceTypeType | None, run_id: str, cancel: bool +) -> None: ec2: EC2Client = boto3.client("ec2") images = ec2.describe_images(Owners=["self"], ImageIds=[image_id]) @@ -15,59 +21,89 @@ def smoke_test(image_id: str, run_id: str, cancel: bool) -> None: image = images["Images"][0] assert "Architecture" in image architecture = image["Architecture"] - instance_type: InstanceTypeType - if architecture == "x86_64": + if architecture == "x86_64" and instance_type is None: instance_type = "t3.nano" - elif architecture == "arm64": + elif architecture == "arm64" and instance_type is None: instance_type = "t4g.nano" else: raise Exception("Unknown architecture: " + architecture) logging.info("Starting instance") - run_instances = ec2.run_instances( - ImageId=image_id, - InstanceType=instance_type, - MinCount=1, - MaxCount=1, - ClientToken=image_id + run_id if run_id else image_id, - InstanceMarketOptions={"MarketType": "spot"}, - ) + + key_pairs = ec2.describe_key_pairs(KeyNames=[image_id]) + if len(key_pairs["KeyPairs"]) > 0: + logging.warn(f"Deleting existing key pair from previous run {image_id}") + ec2.delete_key_pair(KeyName=image_id) + key_pair = ec2.create_key_pair(KeyName=image_id, KeyType="ed25519", KeyFormat="pem") + private_key = key_pair["KeyMaterial"] + + try: + run_instances = ec2.run_instances( + ImageId=image_id, + InstanceType=instance_type, + MinCount=1, + MaxCount=1, + KeyName=image_id, + ClientToken=image_id + run_id if run_id else image_id, + InstanceMarketOptions={"MarketType": "spot"}, + ) + except botocore.exceptions.ClientError as error: + if error.response["Error"]["Code"] == "IdempotentInstanceTerminated": + logging.warn(error) + else: + raise error instance = run_instances["Instances"][0] assert "InstanceId" in instance instance_id = instance["InstanceId"] + assert "State" in instance + assert "Name" in instance["State"] try: - if not cancel and instance["State"]["Name"] != "terminated": - # This basically waits for DHCP to have finished; as it uses ARP to check if the instance is healthy - logging.info(f"Waiting for instance {instance_id} to be running") - ec2.get_waiter("instance_running").wait(InstanceIds=[instance_id]) - logging.info(f"Waiting for instance {instance_id} to be healthy") - ec2.get_waiter("instance_status_ok").wait(InstanceIds=[instance_id]) - tries = 5 + if cancel: + return + assert "PublicIpAddress" in instance + # This basically waits for DHCP to have finished; as it uses ARP to check if the instance is healthy + logging.info(f"Waiting for instance {instance_id} to be running") + ec2.get_waiter("instance_running").wait(InstanceIds=[instance_id]) + logging.info(f"Waiting for instance {instance_id} to be healthy") + ec2.get_waiter("instance_status_ok").wait(InstanceIds=[instance_id]) + + with open(f"{image_id}.pem", "w") as f: + f.write(private_key) + os.chmod(f"{image_id}.pem", 0o600) + + ssh_command = [ + "ssh", + "-o", + "StrictHostKeyChecking=accept-new", + "-i", + f"{image_id}.pem", + "ec2-user@" + instance["PublicIpAddress"], + "echo", + "Hello, World!", + ] + logging.info(f"Running: {' '.join(ssh_command)}") + subprocess.run(ssh_command, check=True) + + tries = 5 + console_output = ec2.get_console_output(InstanceId=instance_id, Latest=True) + output = console_output.get("Output") + while not output and tries > 0: + time.sleep(10) + logging.info( + f"Waiting for console output to become available ({tries} tries left)" + ) console_output = ec2.get_console_output(InstanceId=instance_id, Latest=True) output = console_output.get("Output") - while not output and tries > 0: - time.sleep(10) - logging.info( - f"Waiting for console output to become available ({tries} tries left)" - ) - console_output = ec2.get_console_output( - InstanceId=instance_id, Latest=True - ) - output = console_output.get("Output") - tries -= 1 - logging.info(f"Console output: {output}") - except Exception as e: - logging.error(f"Error: {e}") - raise + tries -= 1 + logging.info(f"Console output: {output}") + finally: logging.info(f"Terminating instance {instance_id}") - assert "State" in instance - assert "Name" in instance["State"] - if instance["State"]["Name"] != "terminated": - ec2.terminate_instances(InstanceIds=[instance_id]) - ec2.get_waiter("instance_terminated").wait(InstanceIds=[instance_id]) + ec2.delete_key_pair(KeyName=image_id) + ec2.terminate_instances(InstanceIds=[instance_id]) + ec2.get_waiter("instance_terminated").wait(InstanceIds=[instance_id]) def main() -> None: @@ -76,10 +112,16 @@ def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--image-id", required=True) parser.add_argument("--run-id", required=False) + parser.add_argument("--instance-type", required=False, type=str) parser.add_argument("--cancel", action="store_true", required=False) args = parser.parse_args() - smoke_test(args.image_id, args.run_id, args.cancel) + smoke_test( + image_id=args.image_id, + instance_type=args.instance_type, + run_id=args.run_id, + cancel=args.cancel, + ) if __name__ == "__main__":