diff --git a/faslr/utilities/sample.py b/faslr/utilities/sample.py index d31d40b..9c67a3f 100644 --- a/faslr/utilities/sample.py +++ b/faslr/utilities/sample.py @@ -4,22 +4,33 @@ from chainladder import Triangle -def load_sample(sample_name: str) -> Triangle: +samples = { + 'mack97': 'mack_1997.csv', + 'us_industry_auto': 'friedland_us_industry_auto.csv', + 'uspp_incr_case': 'friedland_uspp_auto_increasing_case.csv', + 'xyz': 'friedland_xyz_auto_bi.csv', + 'us_auto_steady_state': 'friedland_us_auto_steady_state.csv' +} + + +def load_sample( + sample_name: str +) -> Triangle: path = os.path.dirname(os.path.abspath(__file__)) - if sample_name == "mack97": - df_csv = pd.read_csv( - os.path.join(path, "..", "samples", "mack_1997.csv")) - elif sample_name == "us_industry_auto": - df_csv = pd.read_csv( - os.path.join(path, "..", "samples", "friedland_us_industry_auto.csv")) - elif sample_name == "uspp_incr_case": - df_csv = pd.read_csv( - os.path.join(path, "..", "samples", "friedland_uspp_auto_increasing_case.csv")) - elif sample_name == "xyz": + + def join_path( + fname: str + ): + + joined = os.path.join(path, "..", "samples", fname) + return joined + + try: df_csv = pd.read_csv( - os.path.join(path, "..", "samples", "friedland_xyz_auto_bi.csv")) - else: + join_path(samples[sample_name]) + ) + except KeyError: raise Exception("Invalid sample name.") if sample_name != "mack97":