Skip to content

Commit

Permalink
fix: RUNPOD_ bucket prefix & passing all job inputs to user & handle …
Browse files Browse the repository at this point in the history
…multiple images
  • Loading branch information
Justin Merrell committed Dec 22, 2022
1 parent 1b3b2eb commit d9b7c55
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 23 deletions.
6 changes: 3 additions & 3 deletions docs/serverless/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ RUNPOD_WEBHOOK_PING= # URL to ping
RUNPOD_PING_INTERVAL= # Interval in milliseconds to ping the API (Default: 10000)

# S3 Bucket
BUCKET_ENDPOINT_URL= # S3 bucket endpoint url
BUCKET_ACCESS_KEY_ID= # S3 bucket access key id
BUCKET_SECRET_ACCESS_KEY= # S3 bucket secret access key
RUNPOD_BUCKET_ENDPOINT_URL= # S3 bucket endpoint url
RUNPOD_BUCKET_ACCESS_KEY_ID= # S3 bucket access key id
RUNPOD_BUCKET_SECRET_ACCESS_KEY= # S3 bucket secret access key
```

### Additional Variables
Expand Down
13 changes: 10 additions & 3 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
# pylint: disable=unused-argument,too-few-public-methods


def setup():
''' Loads the model. '''

def validator():
'''
Optional validator function.
Lists the expected inputs of the model, and their types.
If there are any conflicts the job request is errored out.
'''
return {
'prompt': {
Expand All @@ -20,9 +19,17 @@ def validator():
}
}


def run(model_inputs):
'''
Predicts the output of the model.
Returns output path, with the seed used to generate the image.
If errors are encountered, return a dictionary with the key "error".
The error can be a string or list of strings.
'''

# Return Errors
# return {"error": "Error Message"}

return {"image": "/path/to/image.png", "seed": "1234"}
6 changes: 3 additions & 3 deletions runpod/serverless/modules/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def input_validation(self, model_inputs):

return input_errors

def run(self, model_inputs):
def run(self, job):
'''
Predicts the output of the model.
'''
input_errors = self.input_validation(model_inputs)
input_errors = self.input_validation(job['input'])
if input_errors:
return {
"error": input_errors
}

return infer.run(model_inputs)
return infer.run(job)
10 changes: 5 additions & 5 deletions runpod/serverless/modules/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,26 @@ def get(worker_id):
return None


def run(job_id, job_input):
def run(job):
'''
Run the job.
Returns list of URLs and Job Time
'''
time_job_started = time.time()

log(f"Started working on {job_id} at {time_job_started} UTC")
log(f"Started working on {job['id']} at {time_job_started} UTC")

model = inference.Model()

job_output = model.run(job_input)
job_output = model.run(job)

if "error" in job_output:
return {
"error": job_output["error"]
}

object_url = upload.upload_image(job_id, job_output["image"])
job_output["image"] = object_url
object_urls = upload.upload_image(job['id'], job_output["images"])
job_output["images"] = object_urls

job_duration = time.time() - time_job_started
job_duration_ms = int(job_duration * 1000)
Expand Down
9 changes: 6 additions & 3 deletions runpod/serverless/modules/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def log_secret(secret_name, secret, level='INFO'):
log_secret('RUNPOD_WEBHOOK_GET_JOB', os.environ.get('RUNPOD_WEBHOOK_GET_JOB', None))
log_secret('RUNPOD_WEBHOOK_POST_OUTPUT', os.environ.get('RUNPOD_WEBHOOK_POST_OUTPUT', None))

log_secret('BUCKET_ENDPOINT_URL', os.environ.get('BUCKET_ENDPOINT_URL', None))
log_secret('BUCKET_ACCESS_KEY_ID', os.environ.get('BUCKET_ACCESS_KEY_ID', None))
log_secret('BUCKET_SECRET_ACCESS_KEY', os.environ.get('BUCKET_SECRET_ACCESS_KEY', None))
log_secret('RUNPOD_BUCKET_ENDPOINT_URL', os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None))
log_secret('RUNPOD_BUCKET_ACCESS_KEY_ID', os.environ.get('RUNPOD_BUCKET_ACCESS_KEY_ID', None))
log_secret(
'RUNPOD_BUCKET_SECRET_ACCESS_KEY',
os.environ.get('RUNPOD_BUCKET_SECRET_ACCESS_KEY', None)
)
8 changes: 4 additions & 4 deletions runpod/serverless/modules/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
}
)

if os.environ.get('BUCKET_ENDPOINT_URL', None) is not None:
if os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None) is not None:
boto_client = bucket_session.client(
's3',
endpoint_url=os.environ.get('BUCKET_ENDPOINT_URL', None),
aws_access_key_id=os.environ.get('BUCKET_ACCESS_KEY_ID', None),
aws_secret_access_key=os.environ.get('BUCKET_SECRET_ACCESS_KEY', None),
endpoint_url=os.environ.get('RUNPOD_BUCKET_ENDPOINT_URL', None),
aws_access_key_id=os.environ.get('RUNPOD_BUCKET_ACCESS_KEY_ID', None),
aws_secret_access_key=os.environ.get('RUNPOD_BUCKET_SECRET_ACCESS_KEY', None),
config=boto_config
)
else:
Expand Down
2 changes: 1 addition & 1 deletion runpod/serverless/pod_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def start_worker():
job.error(worker_life.worker_id, next_job['id'], "No input provided.")
continue

job_results = job.run(next_job['id'], next_job['input'])
job_results = job.run(next_job)

if 'error' in job_results:
job.error(worker_life.worker_id, next_job['id'], job_results['error'])
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = runpod
version = 0.3.1
version = 0.4.0
description = Official Python library for RunPod API & SDK.
long_description = file: README.md
long_description_content_type = text/markdown
Expand Down

0 comments on commit d9b7c55

Please sign in to comment.