From d9b7c5557a78f854a38aa8bba7436d2753d368b0 Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Thu, 22 Dec 2022 13:39:58 -0500 Subject: [PATCH] fix: RUNPOD_ bucket prefix & passing all job inputs to user & handle multiple images --- docs/serverless/worker.md | 6 +++--- infer.py | 13 ++++++++++--- runpod/serverless/modules/inference.py | 6 +++--- runpod/serverless/modules/job.py | 10 +++++----- runpod/serverless/modules/logging.py | 9 ++++++--- runpod/serverless/modules/upload.py | 8 ++++---- runpod/serverless/pod_worker.py | 2 +- setup.cfg | 2 +- 8 files changed, 33 insertions(+), 23 deletions(-) diff --git a/docs/serverless/worker.md b/docs/serverless/worker.md index ed244190..c640e555 100644 --- a/docs/serverless/worker.md +++ b/docs/serverless/worker.md @@ -16,9 +16,9 @@ RUNPOD_WEBHOOK_PING= # URL to ping RUNPOD_PING_INTERVAL= # Interval in milliseconds to ping the API (Default: 10000) # S3 Bucket -BUCKET_ENDPOINT_URL= # S3 bucket endpoint url -BUCKET_ACCESS_KEY_ID= # S3 bucket access key id -BUCKET_SECRET_ACCESS_KEY= # S3 bucket secret access key +RUNPOD_BUCKET_ENDPOINT_URL= # S3 bucket endpoint url +RUNPOD_BUCKET_ACCESS_KEY_ID= # S3 bucket access key id +RUNPOD_BUCKET_SECRET_ACCESS_KEY= # S3 bucket secret access key ``` ### Additional Variables diff --git a/infer.py b/infer.py index f9d392c7..681c77e8 100644 --- a/infer.py +++ b/infer.py @@ -6,12 +6,11 @@ # pylint: disable=unused-argument,too-few-public-methods -def setup(): - ''' Loads the model. ''' - def validator(): ''' + Optional validator function. Lists the expected inputs of the model, and their types. + If there are any conflicts the job request is errored out. ''' return { 'prompt': { @@ -20,9 +19,17 @@ def validator(): } } + def run(model_inputs): ''' Predicts the output of the model. Returns output path, with the seed used to generate the image. + + If errors are encountered, return a dictionary with the key "error". + The error can be a string or list of strings. ''' + + # Return Errors + # return {"error": "Error Message"} + return {"image": "/path/to/image.png", "seed": "1234"} diff --git a/runpod/serverless/modules/inference.py b/runpod/serverless/modules/inference.py index a5d54c41..8ded98e1 100644 --- a/runpod/serverless/modules/inference.py +++ b/runpod/serverless/modules/inference.py @@ -48,14 +48,14 @@ def input_validation(self, model_inputs): return input_errors - def run(self, model_inputs): + def run(self, job): ''' Predicts the output of the model. ''' - input_errors = self.input_validation(model_inputs) + input_errors = self.input_validation(job['input']) if input_errors: return { "error": input_errors } - return infer.run(model_inputs) + return infer.run(job) diff --git a/runpod/serverless/modules/job.py b/runpod/serverless/modules/job.py index 0eb9c613..15e053b6 100644 --- a/runpod/serverless/modules/job.py +++ b/runpod/serverless/modules/job.py @@ -51,26 +51,26 @@ def get(worker_id): return None -def run(job_id, job_input): +def run(job): ''' Run the job. Returns list of URLs and Job Time ''' time_job_started = time.time() - log(f"Started working on {job_id} at {time_job_started} UTC") + log(f"Started working on {job['id']} at {time_job_started} UTC") model = inference.Model() - job_output = model.run(job_input) + job_output = model.run(job) if "error" in job_output: return { "error": job_output["error"] } - object_url = upload.upload_image(job_id, job_output["image"]) - job_output["image"] = object_url + object_urls = upload.upload_image(job['id'], job_output["images"]) + job_output["images"] = object_urls job_duration = time.time() - time_job_started job_duration_ms = int(job_duration * 1000) diff --git a/runpod/serverless/modules/logging.py b/runpod/serverless/modules/logging.py index 1bd8320a..7a37d39b 100644 --- a/runpod/serverless/modules/logging.py +++ b/runpod/serverless/modules/logging.py @@ -36,6 +36,9 @@ def log_secret(secret_name, secret, level='INFO'): log_secret('RUNPOD_WEBHOOK_GET_JOB', os.environ.get('RUNPOD_WEBHOOK_GET_JOB', None)) log_secret('RUNPOD_WEBHOOK_POST_OUTPUT', os.environ.get('RUNPOD_WEBHOOK_POST_OUTPUT', None)) -log_secret('BUCKET_ENDPOINT_URL', os.environ.get('BUCKET_ENDPOINT_URL', None)) -log_secret('BUCKET_ACCESS_KEY_ID', os.environ.get('BUCKET_ACCESS_KEY_ID', None)) -log_secret('BUCKET_SECRET_ACCESS_KEY', os.environ.get('BUCKET_SECRET_ACCESS_KEY', None)) +log_secret('RUNPOD_BUCKET_ENDPOINT_URL', os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None)) +log_secret('RUNPOD_BUCKET_ACCESS_KEY_ID', os.environ.get('RUNPOD_BUCKET_ACCESS_KEY_ID', None)) +log_secret( + 'RUNPOD_BUCKET_SECRET_ACCESS_KEY', + os.environ.get('RUNPOD_BUCKET_SECRET_ACCESS_KEY', None) +) diff --git a/runpod/serverless/modules/upload.py b/runpod/serverless/modules/upload.py index 91c0526f..9144afc5 100644 --- a/runpod/serverless/modules/upload.py +++ b/runpod/serverless/modules/upload.py @@ -21,12 +21,12 @@ } ) -if os.environ.get('BUCKET_ENDPOINT_URL', None) is not None: +if os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None) is not None: boto_client = bucket_session.client( 's3', - endpoint_url=os.environ.get('BUCKET_ENDPOINT_URL', None), - aws_access_key_id=os.environ.get('BUCKET_ACCESS_KEY_ID', None), - aws_secret_access_key=os.environ.get('BUCKET_SECRET_ACCESS_KEY', None), + endpoint_url=os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None), + aws_access_key_id=os.environ.get('RUNPOD_BUCKET_ACCESS_KEY_ID', None), + aws_secret_access_key=os.environ.get('RUNPOD_BUCKET_SECRET_ACCESS_KEY', None), config=boto_config ) else: diff --git a/runpod/serverless/pod_worker.py b/runpod/serverless/pod_worker.py index aad1550a..f7f218ec 100644 --- a/runpod/serverless/pod_worker.py +++ b/runpod/serverless/pod_worker.py @@ -30,7 +30,7 @@ def start_worker(): job.error(worker_life.worker_id, next_job['id'], "No input provided.") continue - job_results = job.run(next_job['id'], next_job['input']) + job_results = job.run(next_job) if 'error' in job_results: job.error(worker_life.worker_id, next_job['id'], job_results['error']) diff --git a/setup.cfg b/setup.cfg index 00ea96d0..91920a33 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = runpod -version = 0.3.1 +version = 0.4.0 description = Official Python library for RunPod API & SDK. long_description = file: README.md long_description_content_type = text/markdown