Skip to content

Commit

Permalink
WIP: refactoring data retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
niklexical committed Aug 7, 2023
1 parent 592cd9a commit 7d81913
Showing 1 changed file with 10 additions and 140 deletions.
150 changes: 10 additions & 140 deletions survhive/tests/test_data_gen_final.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,158 +2,28 @@
import pandas as pd
import os
import torch
from pathlib import Path

# Data generation functions
# TODO : Need to change this relative path, getcwd will differ. For ex my cwd is rootdir, so it throws FileNotFound Error
# use Path(__file__).parent.parent.parent OR os.path.dirname(os.path.abspath(__file__))
path = os.getcwd()
# use Path(__file__).parent.parent.parent for PROJECT_ROOT_DIR (OR os.path.dirname(os.path.abspath(__file__)))

TEST_DIR = Path(__file__).parent

def numpy_test_data_1d(scenario="default"):
if scenario == "default":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "first_five_zero":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "last_five_zero":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "high_event_ratio":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "low_event_ratio":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "all_events":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "no_events":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)

def numpy_test_data_1d(scenario="default", dims=1):
file_path = TEST_DIR / "test_data" / f"survival_simulation_25_{scenario}.csv"
df = pd.read_csv(file_path)
linear_predictor = df.preds.to_numpy(dtype=np.float32)
time = df.time.to_numpy(dtype=np.float32) # .reshape(-1)
event = df.event.to_numpy(dtype=np.float32) # .reshape(-1)
return linear_predictor, time, event


def numpy_test_data_2d(scenario="default"):
if scenario == "default":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "first_five_zero":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "last_five_zero":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "high_event_ratio":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "low_event_ratio":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "all_events":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "no_events":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
file_path = TEST_DIR / "test_data" / f"survival_simulation_25_{scenario}.csv"
df = pd.read_csv(file_path)

pred_1d = df.preds.to_numpy(dtype=np.float32).reshape(25, 1)
linear_predictor = np.hstack((pred_1d, pred_1d))
time = df.time.to_numpy(dtype=np.float32)
event = df.event.to_numpy(dtype=np.float32)
return linear_predictor, time, event


def torch_test_data_1d(scenario="default"):
if scenario == "default":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "first_five_zero":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "last_five_zero":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "high_event_ratio":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "low_event_ratio":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "all_events":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "no_events":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
linear_predictor = df.preds.to_numpy(dtype=np.float32)
time = df.time.to_numpy(dtype=np.float32) # .reshape(-1)
event = df.event.to_numpy(dtype=np.float32) # .reshape(-1)
return (
torch.from_numpy(linear_predictor),
torch.from_numpy(time),
torch.from_numpy(event),
)


def torch_test_data_2d(scenario="default"):
if scenario == "default":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "first_five_zero":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "last_five_zero":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "high_event_ratio":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "low_event_ratio":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "all_events":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
elif scenario == "no_events":
df = pd.read_csv(
path + "/test_data/survival_simulation_25_" + scenario + ".csv"
)
pred_1d = df.preds.to_numpy(dtype=np.float32).reshape(25, 1)
linear_predictor = np.hstack((pred_1d, pred_1d))
time = df.time.to_numpy(dtype=np.float32) # .reshape(-1)
event = df.event.to_numpy(dtype=np.float32) # .reshape(-1)
return (
torch.from_numpy(linear_predictor),
torch.from_numpy(time),
torch.from_numpy(event),
)

0 comments on commit 7d81913

Please sign in to comment.