From 04c03bdf959bdd8eb5ff170f724d7712b41883d0 Mon Sep 17 00:00:00 2001 From: Joanne Bogart Date: Fri, 16 Feb 2024 15:39:26 -0800 Subject: [PATCH 1/4] allow parallelization for star flux computation --- skycatalogs/catalog_creator.py | 118 ++++++++++++++++++++++++++++----- 1 file changed, 100 insertions(+), 18 deletions(-) diff --git a/skycatalogs/catalog_creator.py b/skycatalogs/catalog_creator.py index 3cf36788..4e55a40c 100644 --- a/skycatalogs/catalog_creator.py +++ b/skycatalogs/catalog_creator.py @@ -185,6 +185,41 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, instrument_needed, return out_dict +def _do_star_flux_chunk(send_conn, star_collection, instrument_needed, + l_bnd, u_bnd): + ''' + end_conn output connection + star_collection information from main file + instrument_needed List of which calculations should be done + l_bnd, u_bnd demarcates slice to process + + returns + dict with keys id, lsst_flux_u, ... lsst_flux_y + ''' + out_dict = {} + + o_list = star_collection[l_bnd: u_bnd] + out_dict['id'] = [o.get_native_attribute('id') for o in o_list] + if 'lsst' in instrument_needed: + all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] + all_fluxes_transpose = zip(*all_fluxes) + for i, band in enumerate(LSST_BANDS): + v = all_fluxes_transpose.__next__() + out_dict[f'lsst_flux_{band}'] = v + + if 'roman' in instrument_needed: + all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list] + all_fluxes_transpose = zip(*all_fluxes) + for i, band in enumerate(ROMAN_BANDS): + v = all_fluxes_transpose.__next__() + out_dict[f'roman_flux_{band}'] = v + + if send_conn: + send_conn.send(out_dict) + else: + return out_dict + + class CatalogCreator: def __init__(self, parts, area_partition=None, skycatalog_root=None, catalog_dir='.', galaxy_truth=None, @@ -990,6 +1025,9 @@ def _create_pointsource_flux_pixel(self, pixel): # For schema use self._ps_flux_schema # output_template should be derived from value for flux_file_template # in main catalog config. Cheat for now + + global _star_collection + output_filename = f'pointsource_flux_{pixel}.parquet' output_path = os.path.join(self._output_dir, output_filename) @@ -1000,36 +1038,80 @@ def _create_pointsource_flux_pixel(self, pixel): self._logger.info(f'Skipping regeneration of {output_path}') return + # NOTE: For now there is only one collection in the object list + # because stars are in a single row group object_list = self._cat.get_object_type_by_hp(pixel, 'star') - last_row_ix = len(object_list) - 1 - writer = None + obj_coll = object_list.get_collections()[0] + _star_collection = obj_coll - # Write out as a single rowgroup as was done for main catalog l_bnd = 0 - u_bnd = last_row_ix + 1 + u_bnd = len(_star_collection) + n_parallel = self._flux_parallel + + if n_parallel == 1: + n_per = u_bnd - l_bnd + else: + n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel) + fields_needed = self._ps_flux_schema.names + instrument_needed = ['lsst'] # for now + + writer = None + rg_written = 0 + + lb = l_bnd + u = min(l_bnd + n_per, u_bnd) + readers = [] + + if n_parallel == 1: + out_dict = _do_star_flux_chunk(None, _star_collection, + instrument_needed, lb, u) + else: + # Expect to be able to do about 1500/minute/process + out_dict = {} + for field in fields_needed: + out_dict[field] = [] + + tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion + self._logger.info(f'Using timeout value {tm} for {n_per} sources') + p_list = [] + for i in range(n_parallel): + conn_rd, conn_wrt = Pipe(duplex=False) + readers.append(conn_rd) + + # For debugging call directly + proc = Process(target=_do_star_flux_chunk, + name=f'proc_{i}', + args=(conn_wrt, _star_collection, + instrument_needed, lb, u)) + proc.start() + p_list.append(proc) + lb = u + u = min(lb + n_per, u_bnd) + + self._logger.debug('Processes started') + for i in range(n_parallel): + ready = readers[i].poll(tm) + if not ready: + self._logger.error(f'Process {i} timed out after {tm} sec') + sys.exit(1) + dat = readers[i].recv() + for field in fields_needed: + out_dict[field] += dat[field] + for p in p_list: + p.join() - o_list = object_list[l_bnd: u_bnd] - self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}') - out_dict = {} - out_dict['id'] = [o.get_native_attribute('id') for o in o_list] - all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] - all_fluxes_transpose = zip(*all_fluxes) - for i, band in enumerate(LSST_BANDS): - self._logger.debug(f'Band {band} is number {i}') - v = all_fluxes_transpose.__next__() - out_dict[f'lsst_flux_{band}'] = v - if i == 1: - self._logger.debug(f'Len of flux column: {len(v)}') - self._logger.debug(f'Type of flux column: {type(v)}') out_df = pd.DataFrame.from_dict(out_dict) out_table = pa.Table.from_pandas(out_df, schema=self._ps_flux_schema) + if not writer: writer = pq.ParquetWriter(output_path, self._ps_flux_schema) writer.write_table(out_table) + rg_written += 1 + writer.close() - # self._logger.debug(f'#row groups written to flux file: {rg_written}') + self._logger.debug(f'# row groups written to flux file: {rg_written}') if self._provenance == 'yaml': self.write_provenance_file(output_path) From 98889c4c4c5708b085d1c06b2003b59d5a9ec7a3 Mon Sep 17 00:00:00 2001 From: Joanne Bogart Date: Tue, 14 May 2024 19:29:39 -0700 Subject: [PATCH 2/4] address reviewer comments --- skycatalogs/catalog_creator.py | 38 ++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/skycatalogs/catalog_creator.py b/skycatalogs/catalog_creator.py index 4e55a40c..194a23c7 100644 --- a/skycatalogs/catalog_creator.py +++ b/skycatalogs/catalog_creator.py @@ -168,16 +168,16 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, instrument_needed, if 'lsst' in instrument_needed: all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] all_fluxes_transpose = zip(*all_fluxes) - for i, band in enumerate(LSST_BANDS): - v = all_fluxes_transpose.__next__() - out_dict[f'lsst_flux_{band}'] = v + colnames = [f'lsst_flux_{band}' for band in LSST_BANDS] + flux_dict = dict(zip(colnames, all_fluxes_transpose)) + out_dict.update(flux_dict) if 'roman' in instrument_needed: all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list] all_fluxes_transpose = zip(*all_fluxes) - for i, band in enumerate(ROMAN_BANDS): - v = all_fluxes_transpose.__next__() - out_dict[f'roman_flux_{band}'] = v + colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS] + flux_dict = dict(zip(colnames, all_fluxes_transpose)) + out_dict.update(flux_dict) if send_conn: send_conn.send(out_dict) @@ -188,10 +188,13 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, instrument_needed, def _do_star_flux_chunk(send_conn, star_collection, instrument_needed, l_bnd, u_bnd): ''' - end_conn output connection - star_collection information from main file - instrument_needed List of which calculations should be done - l_bnd, u_bnd demarcates slice to process + send_conn output connection, used to send results to + parent process + star_collection ObjectCollection. Information from main skyCatalogs + star file + instrument_needed List of which calculations should be done. Currently + supported instrument names are 'lsst' and 'roman' + l_bnd, u_bnd demarcates slice to process returns dict with keys id, lsst_flux_u, ... lsst_flux_y @@ -203,16 +206,16 @@ def _do_star_flux_chunk(send_conn, star_collection, instrument_needed, if 'lsst' in instrument_needed: all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] all_fluxes_transpose = zip(*all_fluxes) - for i, band in enumerate(LSST_BANDS): - v = all_fluxes_transpose.__next__() - out_dict[f'lsst_flux_{band}'] = v + colnames = [f'lsst_flux_{band}' for band in LSST_BANDS] + flux_dict = dict(zip(colnames, all_fluxes_transpose)) + out_dict.update(flux_dict) if 'roman' in instrument_needed: all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list] all_fluxes_transpose = zip(*all_fluxes) - for i, band in enumerate(ROMAN_BANDS): - v = all_fluxes_transpose.__next__() - out_dict[f'roman_flux_{band}'] = v + colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS] + flux_dict = dict(zip(colnames, all_fluxes_transpose)) + out_dict.update(flux_dict) if send_conn: send_conn.send(out_dict) @@ -1041,8 +1044,7 @@ def _create_pointsource_flux_pixel(self, pixel): # NOTE: For now there is only one collection in the object list # because stars are in a single row group object_list = self._cat.get_object_type_by_hp(pixel, 'star') - obj_coll = object_list.get_collections()[0] - _star_collection = obj_coll + _star_collection = object_list.get_collections()[0] l_bnd = 0 u_bnd = len(_star_collection) From 864634c65bf14b0cf2cb87d891680a11451607a4 Mon Sep 17 00:00:00 2001 From: Joanne Bogart Date: Wed, 15 May 2024 23:26:56 -0700 Subject: [PATCH 3/4] omit unnecessary global declarations --- skycatalogs/catalog_creator.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/skycatalogs/catalog_creator.py b/skycatalogs/catalog_creator.py index 194a23c7..eb1d1dc4 100644 --- a/skycatalogs/catalog_creator.py +++ b/skycatalogs/catalog_creator.py @@ -773,8 +773,6 @@ def _create_galaxy_flux_pixel(self, pixel): self._sed_gen.generate_pixel(pixel) writer = None - global _galaxy_collection - global _instrument_needed _instrument_needed = [] for field in self._gal_flux_needed: if 'lsst' in field and 'lsst' not in _instrument_needed: @@ -1028,9 +1026,6 @@ def _create_pointsource_flux_pixel(self, pixel): # For schema use self._ps_flux_schema # output_template should be derived from value for flux_file_template # in main catalog config. Cheat for now - - global _star_collection - output_filename = f'pointsource_flux_{pixel}.parquet' output_path = os.path.join(self._output_dir, output_filename) From d0f7962d8ab83d76d256a59ab593e2130a029db1 Mon Sep 17 00:00:00 2001 From: Joanne Bogart Date: Wed, 15 May 2024 23:54:03 -0700 Subject: [PATCH 4/4] bug fix for Config.get_config_value --- skycatalogs/utils/config_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/skycatalogs/utils/config_utils.py b/skycatalogs/utils/config_utils.py index 00450b89..23a01228 100644 --- a/skycatalogs/utils/config_utils.py +++ b/skycatalogs/utils/config_utils.py @@ -150,7 +150,13 @@ def get_config_value(self, key_path, silent=False): d = d[i] if not isinstance(d, dict): raise ValueError(f'intermediate {d} is not a dict') - return d[path_items[-1]] + + if path_items[-1] in d: + return d[path_items[-1]] + else: + if silent: + return None + raise ValueError(f'Item {i} not found') def add_key(self, k, v): '''