-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
allow parallelization for star flux computation #85
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,16 +168,54 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, instrument_needed, | |
if 'lsst' in instrument_needed: | ||
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] | ||
all_fluxes_transpose = zip(*all_fluxes) | ||
for i, band in enumerate(LSST_BANDS): | ||
v = all_fluxes_transpose.__next__() | ||
out_dict[f'lsst_flux_{band}'] = v | ||
colnames = [f'lsst_flux_{band}' for band in LSST_BANDS] | ||
flux_dict = dict(zip(colnames, all_fluxes_transpose)) | ||
out_dict.update(flux_dict) | ||
|
||
if 'roman' in instrument_needed: | ||
all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list] | ||
all_fluxes_transpose = zip(*all_fluxes) | ||
for i, band in enumerate(ROMAN_BANDS): | ||
v = all_fluxes_transpose.__next__() | ||
out_dict[f'roman_flux_{band}'] = v | ||
colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS] | ||
flux_dict = dict(zip(colnames, all_fluxes_transpose)) | ||
out_dict.update(flux_dict) | ||
|
||
if send_conn: | ||
send_conn.send(out_dict) | ||
else: | ||
return out_dict | ||
|
||
|
||
def _do_star_flux_chunk(send_conn, star_collection, instrument_needed, | ||
l_bnd, u_bnd): | ||
''' | ||
send_conn output connection, used to send results to | ||
parent process | ||
star_collection ObjectCollection. Information from main skyCatalogs | ||
star file | ||
instrument_needed List of which calculations should be done. Currently | ||
supported instrument names are 'lsst' and 'roman' | ||
l_bnd, u_bnd demarcates slice to process | ||
|
||
returns | ||
dict with keys id, lsst_flux_u, ... lsst_flux_y | ||
''' | ||
out_dict = {} | ||
|
||
o_list = star_collection[l_bnd: u_bnd] | ||
out_dict['id'] = [o.get_native_attribute('id') for o in o_list] | ||
if 'lsst' in instrument_needed: | ||
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] | ||
all_fluxes_transpose = zip(*all_fluxes) | ||
colnames = [f'lsst_flux_{band}' for band in LSST_BANDS] | ||
flux_dict = dict(zip(colnames, all_fluxes_transpose)) | ||
out_dict.update(flux_dict) | ||
|
||
if 'roman' in instrument_needed: | ||
all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list] | ||
all_fluxes_transpose = zip(*all_fluxes) | ||
colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS] | ||
flux_dict = dict(zip(colnames, all_fluxes_transpose)) | ||
out_dict.update(flux_dict) | ||
|
||
if send_conn: | ||
send_conn.send(out_dict) | ||
|
@@ -735,8 +773,6 @@ def _create_galaxy_flux_pixel(self, pixel): | |
self._sed_gen.generate_pixel(pixel) | ||
|
||
writer = None | ||
global _galaxy_collection | ||
global _instrument_needed | ||
_instrument_needed = [] | ||
for field in self._gal_flux_needed: | ||
if 'lsst' in field and 'lsst' not in _instrument_needed: | ||
|
@@ -1000,36 +1036,79 @@ def _create_pointsource_flux_pixel(self, pixel): | |
self._logger.info(f'Skipping regeneration of {output_path}') | ||
return | ||
|
||
# NOTE: For now there is only one collection in the object list | ||
# because stars are in a single row group | ||
object_list = self._cat.get_object_type_by_hp(pixel, 'star') | ||
last_row_ix = len(object_list) - 1 | ||
writer = None | ||
_star_collection = object_list.get_collections()[0] | ||
|
||
# Write out as a single rowgroup as was done for main catalog | ||
l_bnd = 0 | ||
u_bnd = last_row_ix + 1 | ||
u_bnd = len(_star_collection) | ||
n_parallel = self._flux_parallel | ||
|
||
if n_parallel == 1: | ||
n_per = u_bnd - l_bnd | ||
else: | ||
n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel) | ||
fields_needed = self._ps_flux_schema.names | ||
instrument_needed = ['lsst'] # for now | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it should ultimately comes from the way the outer script was called but currently the assumption is that lsst is always included and including roman is an option. I could see making that assumption explicit in the outer script and flowing only from there, but I think not in this PR. |
||
|
||
writer = None | ||
rg_written = 0 | ||
|
||
lb = l_bnd | ||
u = min(l_bnd + n_per, u_bnd) | ||
readers = [] | ||
|
||
if n_parallel == 1: | ||
out_dict = _do_star_flux_chunk(None, _star_collection, | ||
instrument_needed, lb, u) | ||
else: | ||
# Expect to be able to do about 1500/minute/process | ||
out_dict = {} | ||
for field in fields_needed: | ||
out_dict[field] = [] | ||
|
||
tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion | ||
self._logger.info(f'Using timeout value {tm} for {n_per} sources') | ||
p_list = [] | ||
for i in range(n_parallel): | ||
conn_rd, conn_wrt = Pipe(duplex=False) | ||
readers.append(conn_rd) | ||
|
||
# For debugging call directly | ||
proc = Process(target=_do_star_flux_chunk, | ||
name=f'proc_{i}', | ||
args=(conn_wrt, _star_collection, | ||
instrument_needed, lb, u)) | ||
proc.start() | ||
p_list.append(proc) | ||
lb = u | ||
u = min(lb + n_per, u_bnd) | ||
|
||
self._logger.debug('Processes started') | ||
for i in range(n_parallel): | ||
ready = readers[i].poll(tm) | ||
if not ready: | ||
self._logger.error(f'Process {i} timed out after {tm} sec') | ||
sys.exit(1) | ||
dat = readers[i].recv() | ||
for field in fields_needed: | ||
out_dict[field] += dat[field] | ||
for p in p_list: | ||
p.join() | ||
|
||
o_list = object_list[l_bnd: u_bnd] | ||
self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}') | ||
out_dict = {} | ||
out_dict['id'] = [o.get_native_attribute('id') for o in o_list] | ||
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] | ||
all_fluxes_transpose = zip(*all_fluxes) | ||
for i, band in enumerate(LSST_BANDS): | ||
self._logger.debug(f'Band {band} is number {i}') | ||
v = all_fluxes_transpose.__next__() | ||
out_dict[f'lsst_flux_{band}'] = v | ||
if i == 1: | ||
self._logger.debug(f'Len of flux column: {len(v)}') | ||
self._logger.debug(f'Type of flux column: {type(v)}') | ||
out_df = pd.DataFrame.from_dict(out_dict) | ||
out_table = pa.Table.from_pandas(out_df, | ||
schema=self._ps_flux_schema) | ||
|
||
if not writer: | ||
writer = pq.ParquetWriter(output_path, self._ps_flux_schema) | ||
writer.write_table(out_table) | ||
|
||
rg_written += 1 | ||
|
||
writer.close() | ||
# self._logger.debug(f'#row groups written to flux file: {rg_written}') | ||
self._logger.debug(f'# row groups written to flux file: {rg_written}') | ||
if self._provenance == 'yaml': | ||
self.write_provenance_file(output_path) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be a bit cleaner to simplify this interface by doing the slicing in the calling code. So the new interface would effectively become:
and in the calling code, it would be used like this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, see comment above about global declaration.