From 379e2897248a9c57ab75b57b8acbd8164f22c4a3 Mon Sep 17 00:00:00 2001
From: Logan Walker <loganaw@umich.edu>
Date: Thu, 18 Jul 2024 13:49:40 -0400
Subject: [PATCH] Create sndif_utils

---
 src/pySISF/sndif_utils | 78 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 78 insertions(+)
 create mode 100644 src/pySISF/sndif_utils

diff --git a/src/pySISF/sndif_utils b/src/pySISF/sndif_utils
new file mode 100644
index 0000000..225edb2
--- /dev/null
+++ b/src/pySISF/sndif_utils
@@ -0,0 +1,78 @@
+#   ---------------------------------------------------------------------------------
+#   Copyright (c) University of Michigan 2020-2024. All rights reserved.
+#   Licensed under the MIT License. See LICENSE in project root for information.
+#   ---------------------------------------------------------------------------------
+
+import zipfile
+import zstd
+import tqdm
+import concurrent
+
+import numpy as np
+from numba import njit
+
+def load_from_zip(
+    file_name,
+    stack_size=2000,
+    stack_select=None,
+    thread_count=1,
+    chunk_batch=1,
+    correction_image=None,
+    shift=None,
+):
+    zf = zipfile.ZipFile(file_name, mode="r")
+
+    file_list = list(zf.namelist())
+    file_list.sort(key=lambda x: int(x.split("_")[1]))
+
+    if stack_select is not None:
+        file_list = file_list[stack_select]
+
+    out = bytearray()
+    n = 0
+
+    with concurrent.futures.ThreadPoolExecutor(max_workers=thread_count) as executor:
+        for r in tqdm.tqdm(
+            executor.map(lambda x: zstd.ZSTD_uncompress(zf.read(x)), file_list),
+            total=len(file_list),
+        ):
+            out += r
+            n += 1
+    zf.close()
+
+    while n < stack_size:
+        n += 1
+        out += b'0' * (2 * 2304 * 2304)
+
+    outnp = np.frombuffer(out, dtype=np.uint16)
+    outnp = outnp.reshape(stack_size, 2304, 2304)
+
+    if correction_image is not None:
+        for i in range(stack_size):
+            np.multiply(
+                outnp[i, ...], correction_image, out=outnp[i, ...], casting="unsafe"
+            )
+
+    outnp = np.moveaxis(outnp, 0, -1)
+
+    return outnp
+
+@njit
+def downsample(in_array, out_array):
+    for i, j in zip(in_array.shape, out_array.shape):
+        if i // 2 != j:
+            raise ValueError(f"Invalid casting ({i}/2) != ({j})")
+
+    for i in range(out_array.shape[0]):
+        for j in range(out_array.shape[1]):
+            for k in range(out_array.shape[2]):
+                total: float = 0.0
+                n: int = 0
+
+                for ii in range(2):
+                    for jj in range(2):
+                        for kk in range(2):
+                            total += in_array[(i * 2) + ii, (j * 2) + jj, (k * 2) + kk]
+                            n += 1
+
+                out_array[i, j, k] = int(total / n)