You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I'm trying to rewrite some histogram matching code in pytorch which relies on some 1D interpolations.
I've noticed that while most of the values in my result with torchinterp1d are the same, there are a couple values which are an order of magnitude off of what I expect.
Here's some code to reproduce the issue:
importnumpyasnpimporttorchfromtorchinterp1dimportInterp1dinterp1d=Interp1d()
# histogram matching with numpyrandom_state=np.random.RandomState(12345) # not all seeds have this issue, but this is one that doesbins=64target=random_state.normal(size=(128*128)) *2# some random data between about -8 and 8source=random_state.normal(size=(128*128)) *2matched=np.empty_like(target)
lo=min(target.min(), source.min())
hi=max(target.max(), source.max())
target_hist_np, bin_edges_np=np.histogram(target, bins=bins, range=[lo, hi])
source_hist_np, _=np.histogram(source, bins=bins, range=[lo, hi])
target_cdf_np=target_hist_np.cumsum()
target_cdf_np=target_cdf_np/target_cdf_np[-1]
source_cdf_np=source_hist_np.cumsum()
source_cdf_np=source_cdf_np/source_cdf_np[-1]
remapped_cdf_np=np.interp(target_cdf_np, source_cdf_np, bin_edges_np[1:])
matched_np=np.interp(target, bin_edges_np[1:], remapped_cdf_np, left=0, right=bins)
# now with pytorchtarget=torch.from_numpy(target)
source=torch.from_numpy(source)
target_hist=torch.histc(target, bins, lo, hi)
source_hist=torch.histc(source, bins, lo, hi)
assertnp.allclose(target_hist_np, target_hist.numpy())
assertnp.allclose(source_hist_np, source_hist.numpy())
target_cdf=target_hist.cumsum(0)
target_cdf=target_cdf/target_cdf[-1]
assertnp.allclose(target_cdf_np, target_cdf.numpy())
source_cdf=source_hist.cumsum(0)
source_cdf=source_cdf/source_cdf[-1]
assertnp.allclose(source_cdf_np, source_cdf.numpy())
bin_edges=torch.linspace(lo, hi, bins+1)
assertnp.allclose(bin_edges_np, bin_edges.numpy())
remapped_cdf=interp1d(source_cdf, bin_edges[1:], target_cdf).squeeze()
# ^^^ first positions of this have -100 values all of a sudden?!print(remapped_cdf_np)
print(remapped_cdf.numpy())
assertnp.allclose(remapped_cdf_np, remapped_cdf.numpy()) # failsmatched=interp1d(bin_edges[1:], remapped_cdf, target)
assertnp.allclose(matched_np, matched.numpy())
The values printed in the second array are from torchinterp1d while the top values are from np.interp for the same inputs (as evidenced by earlier asserts not triggering). Note that the order of arguments for torchinterp1d are slightly different than np.interp, but I believe they should produce the same result.
In fact, most of the values that are printed are the same. Take the last value of the array for example: 7.92322078e+00 is pretty close to 7.92346792. The same holds for almost all values in the array, except for the first 3. These are an order of magnitude lower than the rest of the values (around -140).
To be concrete, these two lines give different results for the same inputs:
Hello, I'm trying to rewrite some histogram matching code in pytorch which relies on some 1D interpolations.
I've noticed that while most of the values in my result with torchinterp1d are the same, there are a couple values which are an order of magnitude off of what I expect.
Here's some code to reproduce the issue:
The above code gives me the output:
The values printed in the second array are from torchinterp1d while the top values are from np.interp for the same inputs (as evidenced by earlier asserts not triggering). Note that the order of arguments for torchinterp1d are slightly different than np.interp, but I believe they should produce the same result.
In fact, most of the values that are printed are the same. Take the last value of the array for example: 7.92322078e+00 is pretty close to 7.92346792. The same holds for almost all values in the array, except for the first 3. These are an order of magnitude lower than the rest of the values (around -140).
To be concrete, these two lines give different results for the same inputs:
What's going on here? Is there a way to exactly reproduce numpy's results with pytorch?
The text was updated successfully, but these errors were encountered: