Skip to content

Commit

Permalink
Fixed issues with refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
MitchellAV committed Feb 8, 2025
1 parent f537a63 commit bb52155
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 30 deletions.
131 changes: 107 additions & 24 deletions workers/src/pvinsight-validation-runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import (
Any,
Callable,
Sequence,
Tuple,
TypeVar,
TypedDict,
Expand All @@ -27,7 +28,11 @@
)
import pandas as pd
import os
from importlib import import_module
from collections import ChainMap
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import json
import requests
import tarfile
Expand All @@ -42,7 +47,6 @@
RUNNER_ERROR_PREFIX,
RunnerException,
SubmissionException,
SubmissionFunctionArgs,
create_blank_error_report,
create_docker_image_for_submission,
dask_multiprocess,
Expand All @@ -53,12 +57,11 @@
pull_from_s3,
request_to_API_w_credentials,
submission_task,
timeout,
timing,
is_local,
)

from metric_operations import performance_metrics_map, metric_operations_map

P = ParamSpec("P")

FAILED = "failed"
Expand Down Expand Up @@ -117,7 +120,7 @@ def push_to_s3(
s3.upload_file(local_file_path, S3_BUCKET_NAME, s3_file_path)


def convert_compressed_file_path_to_directory(compressed_file_path: str):
def convert_compressed_file_path_to_directory(compressed_file_path):
path_components = compressed_file_path.split("/")
path_components[-1] = path_components[-1].split(".")[0]
path_components = "/".join(path_components)
Expand Down Expand Up @@ -256,7 +259,7 @@ def remove_unallowed_starting_characters(file_name: str) -> str | None:


def get_module_file_name(module_dir: str):
for _, _, files in os.walk(module_dir, topdown=True):
for root, _, files in os.walk(module_dir, topdown=True):
for name in files:
if name.endswith(".py"):
return name.split("/")[-1]
Expand All @@ -270,11 +273,36 @@ def get_module_name(module_dir: str):
return get_module_file_name(module_dir)[:-3]


def generate_histogram(
dataframe, x_axis, title, color_code=None, number_bins=30
):
"""
Generate a histogram for a distribution. Option to color code the
histogram by the color_code column parameter.
"""
sns.displot(
dataframe, x=x_axis, hue=color_code, multiple="stack", bins=number_bins
)
plt.title(title)
plt.tight_layout()
return plt


def generate_scatter_plot(dataframe, x_axis, y_axis, title):
"""
Generate a scatterplot between an x- and a y-variable.
"""
sns.scatterplot(data=dataframe, x=x_axis, y=y_axis)
plt.title(title)
plt.tight_layout()
return plt


@timing(verbose=True, logger=logger)
def run_user_submission(
fn: Callable[P, pd.Series[float]],
*args: P.args,
**kwargs: P.kwargs,
fn: Callable[P, pd.Series],
*args,
**kwargs,
):
return fn(*args, **kwargs)

Expand Down Expand Up @@ -368,7 +396,7 @@ def run( # noqa: C901

logger.info(f"Creating docker image for submission...")

_, image_tag = create_docker_image_for_submission(
image, image_tag = create_docker_image_for_submission(
docker_dir,
image_tag,
python_version,
Expand All @@ -384,6 +412,8 @@ def run( # noqa: C901
# os.path.join(new_dir, submission_file_name),
# )

# Generate list for us to store all of our results for the module
results_list = list()
# Load in data set that we're going to analyze.

# Make GET requests to the Django API to get the system metadata
Expand Down Expand Up @@ -456,6 +486,26 @@ def run( # noqa: C901
performance_metrics: list[str] = config_data["performance_metrics"]
logger.info(f"performance_metrics: {performance_metrics}")

# Get the name of the function we want to import associated with this
# test
# # Import designated module via importlib
# module = import_module(module_name)
# try:
# submission_function: Callable = getattr(module, function_name)
# function_parameters = list(
# inspect.signature(submission_function).parameters
# )
# except AttributeError:
# logger.error(
# f"function {function_name} not found in module {module_name}"
# )
# logger.info(f"update submission status to {FAILED}")
# update_submission_status(submission_id, FAILED)
# error_code = 6
# raise RunnerException(
# *get_error_by_code(error_code, runner_error_codes, logger)
# )

total_number_of_files = len(file_metadata_df)
logger.info(f"total_number_of_files: {total_number_of_files}")

Expand Down Expand Up @@ -581,6 +631,17 @@ class SubmissionFunctionInfo(TypedDict):

metrics_dict: dict[str, str | float] = {}

def m_mean(df: pd.DataFrame, column: str):
return df[column].mean()

def m_median(df: pd.DataFrame, column: str):
return df[column].median()

metric_operations_mapping = {
"mean": m_mean,
"median": m_median,
}

# perfomance_metrics_mapping = [
# "mean_absolute_error",
# "absolute_error",
Expand Down Expand Up @@ -637,7 +698,7 @@ class SubmissionFunctionInfo(TypedDict):
operations = metrics_operations[key]

for operation in operations:
if operation not in metric_operations_map:
if operation not in metric_operations_mapping:
# TODO: add error code
logger.error(
f"operation {operation} not found in metric_operations_mapping"
Expand All @@ -646,7 +707,7 @@ class SubmissionFunctionInfo(TypedDict):
*get_error_by_code(500, runner_error_codes, logger)
)

operation_function = metric_operations_map[operation]
operation_function = metric_operations_mapping[operation]

metric_result = operation_function(results_df, key)

Expand Down Expand Up @@ -881,16 +942,16 @@ def prepare_function_args_for_parallel_processing(

logger.info(f"submission_args: {submission_args}")

function_args = SubmissionFunctionArgs(
submission_id=submission_id,
image_tag=image_tag,
memory_limit=memory_limit,
submission_file_name=submission_file_name,
submission_function_name=submission_function_name,
submission_args=submission_args,
volume_data_dir=volume_data_dir,
volume_results_dir=volume_results_dir,
logger=logger,
function_args = (
submission_id,
image_tag,
memory_limit,
submission_file_name,
submission_function_name,
submission_args,
volume_data_dir,
volume_results_dir,
logger,
)

function_args_list = append_to_list(function_args, function_args_list)
Expand Down Expand Up @@ -940,7 +1001,7 @@ def run_submission(


def loop_over_files_and_generate_results(
func_arguments_list: list[SubmissionFunctionArgs],
func_arguments_list: list[Tuple],
) -> int:

# func_arguments_list = prepare_function_args_for_parallel_processing(
Expand Down Expand Up @@ -975,7 +1036,7 @@ def loop_over_files_and_generate_results(
logger=logger,
)

is_errors_list = [error for error, _ in test_errors]
is_errors_list = [error for error, error_code in test_errors]
number_of_errors += sum(is_errors_list)

if number_of_errors == NUM_FILES_TO_TEST:
Expand Down Expand Up @@ -1016,7 +1077,7 @@ def loop_over_files_and_generate_results(
raise RunnerException(
*get_error_by_code(500, runner_error_codes, logger)
)
is_errors_list = [error for error, _ in rest_errors]
is_errors_list = [error for error, error_code in rest_errors]

number_of_errors += sum(is_errors_list)

Expand Down Expand Up @@ -1201,6 +1262,28 @@ def generate_performance_metrics_for_submission(
# Loop through the rest of the performance metrics and calculate them
# (this predominantly applies to error metrics)

def p_absolute_error(output: pd.Series, ground_truth: pd.Series):
difference = output - ground_truth
absolute_difference = np.abs(difference)
return absolute_difference

def p_mean_absolute_error(output: pd.Series, ground_truth: pd.Series):
output.index = ground_truth.index
difference = output - ground_truth
absolute_difference = np.abs(difference)
mean_absolute_error = np.mean(absolute_difference)
return mean_absolute_error

def p_error(output: pd.Series, ground_truth: pd.Series):
difference = output - ground_truth
return difference

performance_metrics_map = {
"absolute_error": p_absolute_error,
"mean_absolute_error": p_mean_absolute_error,
"error": p_error,
}

for metric in performance_metrics:

if metric == "runtime":
Expand Down
7 changes: 5 additions & 2 deletions workers/src/submission_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def extract_analysis_data( # noqa: C901
f"Ground truth data file {analysis_file} not found for analysis {analysis_id}",
)

logger.info(f"files for analysis: {files_for_analysis}")
logger.info(f"analytical files: {analytical_files}")

if not all(file in analytical_files for file in files_for_analysis):
raise FileNotFoundError(
10, f"Analytical data files not found for analysis {analysis_id}"
Expand Down Expand Up @@ -395,8 +398,8 @@ def load_analysis(
)

shutil.copy(
os.path.join("/root/worker/src", "meteric_operations.py"),
os.path.join(current_evaluation_dir, "meteric_operations.py"),
os.path.join("/root/worker/src", "metric_operations.py"),
os.path.join(current_evaluation_dir, "metric_operations.py"),
)

# Copy the error codes file into the current evaluation directory
Expand Down
8 changes: 4 additions & 4 deletions workers/src/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def handle_exceeded_resources(

def dask_multiprocess(
func: Callable[P, T],
function_args_list: list[SubmissionFunctionArgs],
function_args_list: list[tuple[U, ...]],
n_workers: int | None = None,
threads_per_worker: int | None = None,
memory_per_run: float | int | None = None,
Expand Down Expand Up @@ -385,7 +385,7 @@ def dask_multiprocess(

lazy_results: list[Delayed] = []
for args in function_args_list:
submission_fn_args = args.to_tuple()
submission_fn_args = args
logger_if_able(f"args: {submission_fn_args}", logger, "INFO")

lazy_result = cast(
Expand Down Expand Up @@ -1028,7 +1028,7 @@ class ErrorReport(TypedDict):
error_code: str
error_type: str
error_message: str
error_rate: str
error_rate: int
file_errors: dict[str, Any]


Expand All @@ -1041,7 +1041,7 @@ def create_blank_error_report(
"error_code": "",
"error_type": "",
"error_message": "",
"error_rate": "",
"error_rate": 0,
"file_errors": {"errors": []},
}

Expand Down

0 comments on commit bb52155

Please sign in to comment.