Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow parallelization for star flux computation #85

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 105 additions & 26 deletions skycatalogs/catalog_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,54 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, instrument_needed,
if 'lsst' in instrument_needed:
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(LSST_BANDS):
v = all_fluxes_transpose.__next__()
out_dict[f'lsst_flux_{band}'] = v
colnames = [f'lsst_flux_{band}' for band in LSST_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

if 'roman' in instrument_needed:
all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(ROMAN_BANDS):
v = all_fluxes_transpose.__next__()
out_dict[f'roman_flux_{band}'] = v
colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

if send_conn:
send_conn.send(out_dict)
else:
return out_dict


def _do_star_flux_chunk(send_conn, star_collection, instrument_needed,
l_bnd, u_bnd):
Comment on lines +188 to +189
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a bit cleaner to simplify this interface by doing the slicing in the calling code. So the new interface would effectively become:

def _do_star_flux_chunk(send_conn, o_list, instrument_needed):

and in the calling code, it would be used like this:

_do_star_flux_chunk(send_conn, star_collection[l_bnd: u_bnd], instrument_needed)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see comment above about global declaration.

'''
send_conn output connection, used to send results to
parent process
star_collection ObjectCollection. Information from main skyCatalogs
star file
instrument_needed List of which calculations should be done. Currently
supported instrument names are 'lsst' and 'roman'
l_bnd, u_bnd demarcates slice to process

returns
dict with keys id, lsst_flux_u, ... lsst_flux_y
'''
out_dict = {}

o_list = star_collection[l_bnd: u_bnd]
out_dict['id'] = [o.get_native_attribute('id') for o in o_list]
if 'lsst' in instrument_needed:
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
colnames = [f'lsst_flux_{band}' for band in LSST_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

if 'roman' in instrument_needed:
all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

if send_conn:
send_conn.send(out_dict)
Expand Down Expand Up @@ -735,8 +773,6 @@ def _create_galaxy_flux_pixel(self, pixel):
self._sed_gen.generate_pixel(pixel)

writer = None
global _galaxy_collection
global _instrument_needed
_instrument_needed = []
for field in self._gal_flux_needed:
if 'lsst' in field and 'lsst' not in _instrument_needed:
Expand Down Expand Up @@ -1000,36 +1036,79 @@ def _create_pointsource_flux_pixel(self, pixel):
self._logger.info(f'Skipping regeneration of {output_path}')
return

# NOTE: For now there is only one collection in the object list
# because stars are in a single row group
object_list = self._cat.get_object_type_by_hp(pixel, 'star')
last_row_ix = len(object_list) - 1
writer = None
_star_collection = object_list.get_collections()[0]

# Write out as a single rowgroup as was done for main catalog
l_bnd = 0
u_bnd = last_row_ix + 1
u_bnd = len(_star_collection)
n_parallel = self._flux_parallel

if n_parallel == 1:
n_per = u_bnd - l_bnd
else:
n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel)
fields_needed = self._ps_flux_schema.names
instrument_needed = ['lsst'] # for now
Copy link
Collaborator

@jchiang87 jchiang87 May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like instrument_needed should be set as a instance-level attribute and set in the .__init__(...), instead of hard-wired here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should ultimately comes from the way the outer script was called but currently the assumption is that lsst is always included and including roman is an option. I could see making that assumption explicit in the outer script and flowing only from there, but I think not in this PR.


writer = None
rg_written = 0

lb = l_bnd
u = min(l_bnd + n_per, u_bnd)
readers = []

if n_parallel == 1:
out_dict = _do_star_flux_chunk(None, _star_collection,
instrument_needed, lb, u)
else:
# Expect to be able to do about 1500/minute/process
out_dict = {}
for field in fields_needed:
out_dict[field] = []

tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion
self._logger.info(f'Using timeout value {tm} for {n_per} sources')
p_list = []
for i in range(n_parallel):
conn_rd, conn_wrt = Pipe(duplex=False)
readers.append(conn_rd)

# For debugging call directly
proc = Process(target=_do_star_flux_chunk,
name=f'proc_{i}',
args=(conn_wrt, _star_collection,
instrument_needed, lb, u))
proc.start()
p_list.append(proc)
lb = u
u = min(lb + n_per, u_bnd)

self._logger.debug('Processes started')
for i in range(n_parallel):
ready = readers[i].poll(tm)
if not ready:
self._logger.error(f'Process {i} timed out after {tm} sec')
sys.exit(1)
dat = readers[i].recv()
for field in fields_needed:
out_dict[field] += dat[field]
for p in p_list:
p.join()

o_list = object_list[l_bnd: u_bnd]
self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}')
out_dict = {}
out_dict['id'] = [o.get_native_attribute('id') for o in o_list]
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(LSST_BANDS):
self._logger.debug(f'Band {band} is number {i}')
v = all_fluxes_transpose.__next__()
out_dict[f'lsst_flux_{band}'] = v
if i == 1:
self._logger.debug(f'Len of flux column: {len(v)}')
self._logger.debug(f'Type of flux column: {type(v)}')
out_df = pd.DataFrame.from_dict(out_dict)
out_table = pa.Table.from_pandas(out_df,
schema=self._ps_flux_schema)

if not writer:
writer = pq.ParquetWriter(output_path, self._ps_flux_schema)
writer.write_table(out_table)

rg_written += 1

writer.close()
# self._logger.debug(f'#row groups written to flux file: {rg_written}')
self._logger.debug(f'# row groups written to flux file: {rg_written}')
if self._provenance == 'yaml':
self.write_provenance_file(output_path)

Expand Down
8 changes: 7 additions & 1 deletion skycatalogs/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def get_config_value(self, key_path, silent=False):
d = d[i]
if not isinstance(d, dict):
raise ValueError(f'intermediate {d} is not a dict')
return d[path_items[-1]]

if path_items[-1] in d:
return d[path_items[-1]]
else:
if silent:
return None
raise ValueError(f'Item {i} not found')

def add_key(self, k, v):
'''
Expand Down