Skip to content

Commit

Permalink
Merge pull request #71 from lsst-sitcom/tickets/DM-41884
Browse files Browse the repository at this point in the history
DM-41884: Use more memory efficient nanquantile for large pixel spread.
  • Loading branch information
erykoff authored Dec 11, 2023
2 parents b5e4c2b + f80d39a commit c69f3f1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
28 changes: 18 additions & 10 deletions python/lsst/summit/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def getFilterSeeingCorrection(filterName):
raise ValueError(f"Unknown filter name: {filterName}")


def getCdf(data, scale, nBinsMax=131072):
def getCdf(data, scale, nBinsMax=300_000):
"""Return an approximate cumulative distribution function scaled to
the [0, scale] range.
Expand Down Expand Up @@ -987,9 +987,10 @@ def getQuantiles(data, nColors):
colors.
This is equivalent to using the numpy function:
np.quantile(data, np.linspace(0, 1, nColors + 1))
np.nanquantile(data, np.linspace(0, 1, nColors + 1))
but with a coarser precision, yet sufficient for our use case. This
implementation gives a significant speed-up.
implementation gives a significant speed-up. In the case of large
ranges, np.nanquantile is used because it is more memory efficient.
If all elements of ``data`` are nan then the output ``boundaries`` will
also all be ``nan`` to keep the interface consistent.
Expand All @@ -1007,15 +1008,22 @@ def getQuantiles(data, nColors):
A monotonically increasing sequence of size (nColors + 1). These are
the edges of nColors intervals.
"""
cdf, minVal, maxVal = getCdf(data, nColors)
if np.isnan(minVal): # cdf calculation has failed because all data is nan
return np.asarray([np.nan for _ in range(nColors)])
if (np.nanmax(data) - np.nanmin(data)) > 300_000:
# Use slower but memory efficient nanquantile
logger = logging.getLogger(__name__)
logger.warning("Data range is very large; using slower quantile code.")
boundaries = np.nanquantile(data, np.linspace(0, 1, nColors + 1))
else:
cdf, minVal, maxVal = getCdf(data, nColors)
if np.isnan(minVal): # cdf calculation has failed because all data is nan
return np.asarray([np.nan for _ in range(nColors)])

scale = (maxVal - minVal)/len(cdf)
scale = (maxVal - minVal)/len(cdf)

boundaries = np.asarray(
[np.argmax(cdf >= i)*scale + minVal for i in range(nColors)] + [maxVal]
)

boundaries = np.asarray(
[np.argmax(cdf >= i)*scale + minVal for i in range(nColors)] + [maxVal]
)
return boundaries


Expand Down
8 changes: 7 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,13 @@ def test_quantiles(self):
for nColors, (mean, width, decimal) in itertools.product(colorRanges, dataRanges):
data = np.random.normal(mean, width, (100, 100))
data[10, 10] = np.nan # check we're still nan-safe
edges1 = getQuantiles(data, nColors)
if np.nanmax(data) - np.nanmin(data) > 300_000:
with self.assertLogs(level="WARNING") as cm:
edges1 = getQuantiles(data, nColors)
self.assertIn("Data range is very large", cm.output[0])
else:
with self.assertNoLogs(level="WARNING") as cm:
edges1 = getQuantiles(data, nColors)
edges2 = np.nanquantile(data, np.linspace(0, 1, nColors + 1)) # must check with nanquantile
np.testing.assert_almost_equal(edges1, edges2, decimal=decimal)

Expand Down

0 comments on commit c69f3f1

Please sign in to comment.