Skip to content

Commit

Permalink
Run futures in blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
domfournier committed Jan 16, 2025
1 parent 8a8c968 commit 3e1ec4a
Showing 1 changed file with 170 additions and 75 deletions.
245 changes: 170 additions & 75 deletions simpeg/dask/objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import numpy as np
from dask.distributed import Client
from ..data_misfit import L2DataMisfit
from simpeg.meta.dask_sim import _validate_type_or_future_of_type, _reduce

from simpeg.meta.dask_sim import _reduce
from simpeg.utils import validate_list_of_types
from operator import add


Expand Down Expand Up @@ -45,6 +45,83 @@ def _get_jtj_diag(objfct, model):
return jtj.flatten()


def _validate_type_or_future_of_type(
property_name,
objects,
obj_type,
client,
workers=None,
return_workers=False,
):
# try:
# # validate as a list of things that need to be sent.
workers = [(worker.worker_address,) for worker in client.cluster.workers.values()]
objects = validate_list_of_types(
property_name, objects, obj_type, ensure_unique=True
)
workload = [[]]
count = 0
for obj in objects:
if count == len(workers):
count = 0
workload.append([])
obj.simulation.simulations[0].worker = workers[count]
future = client.scatter([obj], workers=workers[count])[0]

if hasattr(obj, "name"):
future.name = obj.name

workload[-1].append(future)
count += 1

# objects[0].simulation.simulations[0].worker = workers[0]
# if workers is None:
# objects = client.scatter(objects)
# else:
# tmp = []
# for obj, worker in zip(objects, workers):
# tmp.append(client.scatter([obj], workers=worker)[0])
# objects = tmp
# except TypeError:
# pass
# ensure list of futures
# objects = validate_list_of_types(
# property_name,
# objects,
# Future,
# )
# Figure out where everything lives

# who = client.who_has(workload)
# # if workers is None:
# # workers = []
# for ii, worker in enumerate(who.values()):
# if worker != workers[ii % len(workers)]:
# warnings.warn(
# f"{property_name} {i} is not on the expected worker.", stacklevel=2
# )
# # obj = client.submit(_set_worker, obj, worker)

# Ensure this runs on the expected worker
futures = []
for work in workload:

for obj, worker in zip(work, workers):
futures.append(
client.submit(
lambda v: not isinstance(v, obj_type), obj, workers=worker
)
)
is_not_obj = np.array(client.gather(futures))
if np.any(is_not_obj):
raise TypeError(f"{property_name} futures must be an instance of {obj_type}")

if return_workers:
return workload, workers
else:
return workload


class DaskComboMisfits(ComboObjectiveFunction):
"""
A composite objective function for distributed computing.
Expand Down Expand Up @@ -131,15 +208,23 @@ def deriv(self, m, f=None):
# f = self.fields(m)

derivs = []
for multiplier, objfct, worker in zip(
self.multipliers, self._futures, self._workers
):
if multiplier == 0.0: # don't evaluate the fct
continue
count = 0
for futures in self._futures:
for objfct, worker in zip(futures, self._workers):
if self.multipliers[count] == 0.0: # don't evaluate the fct
continue

derivs.append(
client.submit(
_deriv,
objfct,
self.multipliers[count],
m_future,
workers=worker,
)
)
count += 1

derivs.append(
client.submit(_deriv, objfct, multiplier, m_future, workers=worker)
)
derivs = _reduce(client, add, derivs)
return derivs

Expand All @@ -162,23 +247,24 @@ def deriv2(self, m, v=None, f=None):
# f = self.fields(m)

derivs = []
for multiplier, objfct, worker in zip(
self.multipliers, self._futures, self._workers
):
if multiplier == 0.0: # don't evaluate the fct
continue
count = 0
for futures in self._futures:
for objfct, worker in zip(futures, self._workers):
if self.multipliers[count] == 0.0: # don't evaluate the fct
continue

derivs.append(
client.submit(
_deriv2,
objfct,
multiplier,
m_future,
v_future,
# field,
workers=worker,
derivs.append(
client.submit(
_deriv2,
objfct,
self.multipliers[count],
m_future,
v_future,
# field,
workers=worker,
)
)
)
count += 1

derivs = _reduce(client, add, derivs)

Expand All @@ -193,38 +279,43 @@ def get_dpred(self, m, f=None):
client = self.client
m_future = self._m_as_future
dpred = []
for objfct, worker, field in zip(self._futures, self._workers, f):
dpred.append(
client.submit(
_calc_dpred,
objfct,
m_future,
field,
workers=worker,
for futures, fields in zip(self._futures, f):
for objfct, worker, field in zip(futures, self._workers, fields):
dpred.append(
client.submit(
_calc_dpred,
objfct,
m_future,
field,
workers=worker,
)
)
)
return client.gather(dpred)

def getJtJdiag(self, m, f=None):
self.model = m
m_future = self._m_as_future
if getattr(self, "_jtjdiag", None) is None:

jtj_diag = []
jtj_diag = 0.0
client = self.client
# if f is None:
# f = self.fields(m)
for objfct, worker in zip(self._futures, self._workers):
jtj_diag.append(
client.submit(
_get_jtj_diag,
objfct,
m_future,
# field,
workers=worker,
for futures in self._futures:
work = []
for objfct, worker in zip(futures, self._workers):
work.append(
client.submit(
_get_jtj_diag,
objfct,
m_future,
# field,
workers=worker,
)
)
)
self._jtjdiag = _reduce(client, add, jtj_diag)
jtj_diag += _reduce(client, add, work)

self._jtjdiag = jtj_diag

return self._jtjdiag

Expand All @@ -236,15 +327,17 @@ def fields(self, m):
return self._stashed_fields
# The above should pass the model to all the internal simulations.
f = []
for objfct, worker in zip(self._futures, self._workers):
f.append(
client.submit(
_calc_fields,
objfct,
m_future,
workers=worker,
for futures in self._futures:
f.append([])
for objfct, worker in zip(futures, self._workers):
f[-1].append(
client.submit(
_calc_fields,
objfct,
m_future,
workers=worker,
)
)
)
self._stashed_fields = f
return f

Expand All @@ -268,17 +361,18 @@ def model(self, value):
client = self.client
[self._m_as_future] = client.scatter([value], broadcast=True)

futures = []
for objfct, worker in zip(self._futures, self._workers):
futures.append(
client.submit(
_store_model,
objfct,
self._m_as_future,
workers=worker,
stores = []
for futures in self._futures:
for objfct, worker in zip(futures, self._workers):
stores.append(
client.submit(
_store_model,
objfct,
self._m_as_future,
workers=worker,
)
)
)
self.client.gather(futures) # blocking call to ensure all models were stored
self.client.gather(stores) # blocking call to ensure all models were stored
self._model = value

@property
Expand All @@ -297,9 +391,9 @@ def objfcts(self, objfcts):
workers=self.workers,
return_workers=True,
)
for objfct, future in zip(objfcts, futures):
if hasattr(objfct, "name"):
future.name = objfct.name
# for objfct, future in zip(objfcts, futures):
# if hasattr(objfct, "name"):
# future.name = objfct.name

self._objfcts = objfcts
self._futures = futures
Expand All @@ -315,14 +409,15 @@ def residuals(self, m, f=None):
client = self.client
m_future = self._m_as_future
residuals = []
for objfct, worker, field in zip(self._futures, self._workers, f):
residuals.append(
client.submit(
_calc_residual,
objfct,
m_future,
field,
workers=worker,
for futures, fields in zip(self._futures, f):
for objfct, worker, field in zip(futures, self._workers, fields):
residuals.append(
client.submit(
_calc_residual,
objfct,
m_future,
field,
workers=worker,
)
)
)
return client.gather(residuals)

0 comments on commit 3e1ec4a

Please sign in to comment.