Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems reproducing values given by np.interp() #12

Open
JCBrouwer opened this issue Mar 25, 2021 · 0 comments
Open

Problems reproducing values given by np.interp() #12

JCBrouwer opened this issue Mar 25, 2021 · 0 comments

Comments

@JCBrouwer
Copy link

JCBrouwer commented Mar 25, 2021

Hello, I'm trying to rewrite some histogram matching code in pytorch which relies on some 1D interpolations.

I've noticed that while most of the values in my result with torchinterp1d are the same, there are a couple values which are an order of magnitude off of what I expect.

Here's some code to reproduce the issue:

import numpy as np
import torch
from torchinterp1d import Interp1d

interp1d = Interp1d()

# histogram matching with numpy

random_state = np.random.RandomState(12345)  #  not all seeds have this issue, but this is one that does

bins = 64

target = random_state.normal(size=(128 * 128)) * 2   #  some random data between about -8 and 8
source = random_state.normal(size=(128 * 128)) * 2
matched = np.empty_like(target)

lo = min(target.min(), source.min())
hi = max(target.max(), source.max())

target_hist_np, bin_edges_np = np.histogram(target, bins=bins, range=[lo, hi])
source_hist_np, _ = np.histogram(source, bins=bins, range=[lo, hi])

target_cdf_np = target_hist_np.cumsum()
target_cdf_np = target_cdf_np / target_cdf_np[-1]

source_cdf_np = source_hist_np.cumsum()
source_cdf_np = source_cdf_np / source_cdf_np[-1]

remapped_cdf_np = np.interp(target_cdf_np, source_cdf_np, bin_edges_np[1:])

matched_np = np.interp(target, bin_edges_np[1:], remapped_cdf_np, left=0, right=bins)

# now with pytorch

target = torch.from_numpy(target)
source = torch.from_numpy(source)

target_hist = torch.histc(target, bins, lo, hi)
source_hist = torch.histc(source, bins, lo, hi)

assert np.allclose(target_hist_np, target_hist.numpy())
assert np.allclose(source_hist_np, source_hist.numpy())

target_cdf = target_hist.cumsum(0)
target_cdf = target_cdf / target_cdf[-1]

assert np.allclose(target_cdf_np, target_cdf.numpy())

source_cdf = source_hist.cumsum(0)
source_cdf = source_cdf / source_cdf[-1]

assert np.allclose(source_cdf_np, source_cdf.numpy())

bin_edges = torch.linspace(lo, hi, bins + 1)

assert np.allclose(bin_edges_np, bin_edges.numpy())

remapped_cdf = interp1d(source_cdf, bin_edges[1:], target_cdf).squeeze()
# ^^^ first positions of this have -100 values all of a sudden?!

print(remapped_cdf_np)
print(remapped_cdf.numpy())
assert np.allclose(remapped_cdf_np, remapped_cdf.numpy())  # fails

matched = interp1d(bin_edges[1:], remapped_cdf, target)

assert np.allclose(matched_np, matched.numpy())

The above code gives me the output:

[-8.04819874 -8.04819874 -8.04819874 -7.03412467 -6.52708763 -6.34600297
 -6.27356911 -6.10455677 -5.89329133 -5.55526664 -5.28932075 -5.00597652
 -4.81837282 -4.66795183 -4.43309052 -4.17367044 -3.93438144 -3.670879
 -3.44365304 -3.19894227 -2.97056192 -2.7420723  -2.47732906 -2.21208839
 -1.96009338 -1.69844422 -1.44815496 -1.20431557 -0.94311239 -0.68723275
 -0.44108403 -0.18912467  0.055417    0.30790917  0.5585027   0.81660576
  1.0688232   1.33458219  1.60847022  1.85890728  2.12938742  2.38900627
  2.66416974  2.93036861  3.17321839  3.41920686  3.64490881  3.92116168
  4.1785585   4.43336298  4.75240842  4.99895072  5.34133486  5.58747586
  5.77523205  5.9234885   6.12066957  6.29370601  6.36613988  6.52911607
  7.6699494   7.6699494   7.6699494   7.92346792]
[-1.37849957e+02 -1.37849957e+02 -1.37849957e+02 -7.28813667e+00
 -6.78085313e+00 -6.34605302e+00 -6.27363935e+00 -6.10459291e+00
 -5.89331551e+00 -5.55530036e+00 -5.28934614e+00 -5.00599635e+00
 -4.81837965e+00 -4.66795432e+00 -4.43309173e+00 -4.17367124e+00
 -3.93438180e+00 -3.67087919e+00 -3.44365297e+00 -3.19894198e+00
 -2.97056138e+00 -2.74207334e+00 -2.47732997e+00 -2.21208787e+00
 -1.96009280e+00 -1.69844372e+00 -1.44815450e+00 -1.20431574e+00
 -9.43111794e-01 -6.87232221e-01 -4.41083541e-01 -1.89125193e-01
  5.54165081e-02  3.07908676e-01  5.58502213e-01  8.16605249e-01
  1.06882271e+00  1.33458229e+00  1.60847020e+00  1.85890732e+00
  2.12938745e+00  2.38900625e+00  2.66416942e+00  2.93036823e+00
  3.17321798e+00  3.41920633e+00  3.64490848e+00  3.92116086e+00
  4.17855749e+00  4.43336134e+00  4.75240455e+00  4.99894268e+00
  5.34132004e+00  5.58744816e+00  5.77521847e+00  5.92348256e+00
  6.12062090e+00  6.29366587e+00  6.36607955e+00  6.52899272e+00
  6.90914693e+00  6.90914693e+00  6.90914693e+00  7.92322078e+00]
Traceback (most recent call last):
  File "histmatch.py", line 256, in <module>
    assert np.allclose(remapped_cdf_np, remapped_cdf.numpy())  # fails
AssertionError

The values printed in the second array are from torchinterp1d while the top values are from np.interp for the same inputs (as evidenced by earlier asserts not triggering). Note that the order of arguments for torchinterp1d are slightly different than np.interp, but I believe they should produce the same result.

In fact, most of the values that are printed are the same. Take the last value of the array for example: 7.92322078e+00 is pretty close to 7.92346792. The same holds for almost all values in the array, except for the first 3. These are an order of magnitude lower than the rest of the values (around -140).

To be concrete, these two lines give different results for the same inputs:

remapped_cdf_np = np.interp(x=target_cdf_np, xp=source_cdf_np, fp=bin_edges_np[1:])
remapped_cdf = interp1d(x=source_cdf, y=bin_edges[1:], xnew=target_cdf).squeeze()

What's going on here? Is there a way to exactly reproduce numpy's results with pytorch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant