diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..a3c59e0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] +max-complexity = 10 +max-line-length = 88 +extend-ignore = E203 +exclude = + __pycache__ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 53f2dad..dbf9fbc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: [3.7, 3.8] + python-version: ["3.9", "3.10"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -37,9 +37,9 @@ jobs: run: conda --version - name: Install depedencies and package shell: bash - run: pip install -U -e .'[tests]'; conda install match-series + run: pip install -U -e .'[dev]'; conda install match-series - name: Run tests - run: pytest --cov=pymatchseries --pyargs pymatchseries + run: pytest --cov=pymatchseries -vv - name: Generate line coverage if: ${{ matrix.os == 'ubuntu-latest' }} run: coverage report --show-missing diff --git a/.gitignore b/.gitignore index 8a02c14..c5a8318 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +Exclude tags examples/data/** examples/sample_data_HAADF diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2047a06 --- /dev/null +++ b/Makefile @@ -0,0 +1,87 @@ +.PHONY: clean clean-build clean-pyc clean-test coverage dist docs help install lint lint/flake8 lint/black +.DEFAULT_GOAL := help + +define BROWSER_PYSCRIPT +import os, webbrowser, sys + +from urllib.request import pathname2url + +webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) +endef +export BROWSER_PYSCRIPT + +define PRINT_HELP_PYSCRIPT +import re, sys + +for line in sys.stdin: + match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) + if match: + target, help = match.groups() + print("%-20s %s" % (target, help)) +endef +export PRINT_HELP_PYSCRIPT + +BROWSER := python -c "$$BROWSER_PYSCRIPT" + +help: + @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) + +check: lint test + +clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts + +clean-build: ## remove build artifacts + rm -fr build/ + rm -fr dist/ + rm -fr .eggs/ + find . -name '*.egg-info' -exec rm -fr {} + + find . -name '*.egg' -exec rm -f {} + + +clean-pyc: ## remove Python file artifacts + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + find . -name '__pycache__' -exec rm -fr {} + + +clean-test: ## remove test and coverage artifacts + rm -fr .tox/ + rm -f .coverage + rm -fr htmlcov/ + rm -fr .pytest_cache + +lint: + black . + isort . + mypy ./pymatchseries + flake8 ./pymatchseries + +test: ## run tests quickly with the default Python + pytest -vv + +coverage: ## check code coverage quickly with the default Python + coverage run --source pymatchseries -m pytest + coverage report -m + coverage html + $(BROWSER) htmlcov/index.html + +docs: ## generate Sphinx HTML documentation, including API docs + rm -f docs/pymatchseries.rst + rm -f docs/modules.rst + sphinx-apidoc -o docs/ pymatchseries + $(MAKE) -C docs clean + $(MAKE) -C docs html + $(BROWSER) docs/_build/html/index.html + +servedocs: docs ## compile the docs watching for changes + watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . + +release: dist ## package and upload a release + twine upload dist/* + +dist: clean ## builds source and wheel package + python setup.py sdist + python setup.py bdist_wheel + ls -l dist + +install: clean ## install the package to the active Python's site-packages + python setup.py install diff --git a/dev-requirements.in b/dev-requirements.in new file mode 100644 index 0000000..78cc04c --- /dev/null +++ b/dev-requirements.in @@ -0,0 +1,11 @@ +sphinx>=3.0.2 +sphinx-rtd-theme>=0.4.3 +pytest>=5.4 +pytest-cov>=2.8.1 +coverage>=5.0 +black>=19.3b0 +pre-commit>=1.16 +flake8>=3.9.2 +mypy>=0.990 +mypy-extensions>=0.4.2 +isort>=5.10.1 diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..b5dc6fb --- /dev/null +++ b/environment.yml @@ -0,0 +1,9 @@ +name: pyMatchSeries +channels: + - conda-forge +dependencies: + - hyperspy + - cupy + - match-series + - numba + - pydantic diff --git a/examples/example.ipynb b/examples/example.ipynb index 27e2bbe..273fa0e 100644 --- a/examples/example.ipynb +++ b/examples/example.ipynb @@ -1931,13 +1931,7 @@ "Registration on level 9 started\n", "--------------------------------------------------------\n", "\n", - "Initial energy: -0.692554855567\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Initial energy: -0.692554855567\n", "110 steps, tau: 8.e-03, sigma: 1.e-04 energy: -0.712478342340\n", "Descent needed 110 step(s).\n", "Created directory output/stage1/39/\n", @@ -2909,13 +2903,7 @@ "Registration on level 8 started\n", "--------------------------------------------------------\n", "\n", - "Initial energy: -0.907963916577\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Initial energy: -0.907963916577\n", "126 steps, tau: 0 , sigma: 1.e-04 energy: -0.935814278034\n", "Descent needed 126 step(s).\n", "\n", @@ -3791,13 +3779,7 @@ "Registration on level 9 started\n", "--------------------------------------------------------\n", "\n", - "Initial energy: -0.878787982158\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Initial energy: -0.878787982158\n", "140 steps, tau: 0 , sigma: 1.e-04 energy: -0.902049556726 steps, tau: 0.5000, sigma: 0.0312 energy: -0.900622429688\n", "Descent needed 140 step(s).\n", "Created directory output/stage3/12-r/\n", @@ -4917,11 +4899,369 @@ "source": [ "corrected_spectrum.save(\"data/corrected_spectrum.hspy\")" ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[2, 2, 2],\n", + " [3, 3, 4],\n", + " [1, 1, 3]],\n", + "\n", + " [[4, 4, 2],\n", + " [1, 1, 2],\n", + " [2, 3, 2]],\n", + "\n", + " [[4, 2, 4],\n", + " [1, 3, 1],\n", + " [1, 2, 2]]])" + ] + }, + "execution_count": 81, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "np.random.seed(1001)\n", + "np.random.randint(1, 5, size=(3, 3, 3))" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [], + "source": [ + "from hyperspy.signals import Signal2D" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 86, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = Signal2D(np.random.randint(1, 5, size=(3, 3, 3)))\n", + "data.axes_manager[0].name" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "imageds = np.random.randint(1, 5, size=(3, 32, 32))" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [], + "source": [ + "mso = MatchSeries(imageds)" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Directory output/ already exists\n", + "reading reference image from: input/frame_0.tiff\n", + "reading template image from: input/frame_0.tiff\n", + "Created directory output/stage1/\n", + "Using templates images ---------------------------------------\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "Created directory output/stage1/0/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 3 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: 0.106680475931\n", + " 25 steps, tau: 0 , sigma: 1.e-04 energy: 0.104452782382\n", + "Descent needed 25 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: 0.052256550569\n", + " 24 steps, tau: 0 , sigma: 1.e-04 energy: -0.073858866660\n", + "Descent needed 24 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.020301576081\n", + " 32 steps, tau: 0 , sigma: 1.e-04 energy: -0.049458095702\n", + "Descent needed 32 step(s).\n", + "Created directory output/stage1/1/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 3 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.156992000825\n", + " 18 steps, tau: 0 , sigma: 1.e-04 energy: -0.171316084630\n", + "Descent needed 18 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.072668214978\n", + " 28 steps, tau: 0 , sigma: 1.e-04 energy: -0.074586027421\n", + "Descent needed 28 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.004713604430\n", + " 30 steps, tau: 0 , sigma: 1.e-04 energy: -0.013239817251\n", + "Descent needed 30 step(s).\n", + "Created directory output/stage1/1-r/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.044467289594\n", + " 24 steps, tau: 0 , sigma: 1.e-04 energy: -0.152107954423\n", + "Descent needed 24 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.032114790938\n", + " 18 steps, tau: 0 , sigma: 1.e-04 energy: -0.068217234609\n", + "Descent needed 18 step(s).\n", + "Using templates images ---------------------------------------\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "Initial energy: 0.002786801449\n", + "106 steps, tau: 0 energy: 0.000924978319 DEnorm: 1.54667e-06\n", + "Descent needed 106 step(s).\n", + "Using templates images ---------------------------------------\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "reading reference image from: output/stage1/median.q2bz\n", + "reading template image from: input/frame_0.tiff\n", + "Created directory output/stage2/\n", + "Using templates images ---------------------------------------\n", + "input/frame_0.tiff\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "Created directory output/stage2/0/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 3 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.029571555907\n", + " 24 steps, tau: 0 , sigma: 1.e-04 energy: -0.053982331555\n", + "Descent needed 24 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.132045194814\n", + " 28 steps, tau: 0 , sigma: 1.e-04 energy: -0.221582629817\n", + "Descent needed 28 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.100072615618\n", + " 29 steps, tau: 0 , sigma: 1.e-04 energy: -0.202034018524\n", + "Descent needed 29 step(s).\n", + "Created directory output/stage2/1-r/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.509109677707\n", + " 31 steps, tau: 0 , sigma: 1.e-04 energy: -0.649321004375\n", + "Descent needed 31 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.582764924738\n", + " 59 steps, tau: 0 , sigma: 1.e-04 energy: -0.674561858825\n", + "Descent needed 59 step(s).\n", + "Created directory output/stage2/2-r/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.461072487825\n", + " 38 steps, tau: 0 , sigma: 1.e-04 energy: -0.660572884655\n", + "Descent needed 38 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.649948186128\n", + " 60 steps, tau: 0 , sigma: 1.e-04 energy: -0.740657496617\n", + "Descent needed 60 step(s).\n", + "Using templates images ---------------------------------------\n", + "input/frame_0.tiff\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "Initial energy: 0.002453847589\n", + "333 steps, tau: 0 energy: 0.001866197169 DEnorm: 2.22505e-06\n", + "Descent needed 333 step(s).\n", + "Using templates images ---------------------------------------\n", + "input/frame_0.tiff\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "reading reference image from: output/stage2/median.q2bz\n", + "reading template image from: input/frame_0.tiff\n", + "Created directory output/stage3/\n", + "Using templates images ---------------------------------------\n", + "input/frame_0.tiff\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "Created directory output/stage3/0/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 3 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.481184019033\n", + " 17 steps, tau: 0 , sigma: 1.e-04 energy: -0.507485866870\n", + "Descent needed 17 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.502146653651\n", + " 35 steps, tau: 0 , sigma: 1.e-04 energy: -0.559840057958\n", + "Descent needed 35 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.485128244802\n", + " 61 steps, tau: 0 , sigma: 1.e-04 energy: -0.568681305074\n", + "Descent needed 61 step(s).\n", + "Created directory output/stage3/1-r/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.438795789399\n", + " 31 steps, tau: 0 , sigma: 1.e-04 energy: -0.565236490287\n", + "Descent needed 31 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.461153927848\n", + " 44 steps, tau: 0 , sigma: 1.e-04 energy: -0.537615132386\n", + "Descent needed 44 step(s).\n", + "Created directory output/stage3/2-r/\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 4 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.412215077530\n", + " 48 steps, tau: 0 , sigma: 1.e-04 energy: -0.542146703669\n", + "Descent needed 48 step(s).\n", + "\n", + "--------------------------------------------------------\n", + "Registration on level 5 started\n", + "--------------------------------------------------------\n", + "\n", + "Initial energy: -0.401220723950\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 58 steps, tau: 0 , sigma: 1.e-04 energy: -0.608290274199\n", + "Descent needed 58 step(s).\n", + "Using templates images ---------------------------------------\n", + "input/frame_0.tiff\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "Initial energy: 0.001896369120\n", + "500 steps, tau: 128.0000 energy: 0.001868314424 DEnorm: 1.86872e-06\n", + "Descent needed 500 step(s).\n", + "Using templates images ---------------------------------------\n", + "input/frame_0.tiff\n", + "input/frame_1.tiff\n", + "input/frame_2.tiff\n", + "--------------------------------------------------------------\n", + "CPU time = 0m 4.757s\n", + "Wall clock time = 0m 5.201s\n", + "All output has been written to file output/log.txt\n" + ] + } + ], + "source": [ + "mso.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -4935,7 +5275,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.6" + "version": "3.9.9" } }, "nbformat": 4, diff --git a/pymatchseries/__init__.py b/pymatchseries/__init__.py index 1423068..3b5f8df 100644 --- a/pymatchseries/__init__.py +++ b/pymatchseries/__init__.py @@ -1 +1,3 @@ from pymatchseries.matchseries import MatchSeries + +__all__ = ["MatchSeries"] diff --git a/pymatchseries/config_tools.py b/pymatchseries/config_tools.py index 75a38e4..b8a8725 100644 --- a/pymatchseries/config_tools.py +++ b/pymatchseries/config_tools.py @@ -1,8 +1,7 @@ -import re import os +import re import sys - folder, _ = os.path.split(os.path.abspath(__file__)) DEFAULT_CONFIG_PATH = os.path.join(folder, "default_parameters.param") @@ -14,7 +13,7 @@ class config_dict(dict): - def __init__(self, data): + def __init__(self, data) -> None: super().__init__(data) def __getattr__(self, attr): diff --git a/pymatchseries/implementation/__init__.py b/pymatchseries/implementation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymatchseries/implementation/cuda_kernels.py b/pymatchseries/implementation/cuda_kernels.py new file mode 100644 index 0000000..09ef7cd --- /dev/null +++ b/pymatchseries/implementation/cuda_kernels.py @@ -0,0 +1,322 @@ +from typing import Tuple + +try: + import cupy as cp + from cupy import ndarray as carray + from numba import cuda, float32 + from numba.cuda import jit +except ImportError: + cp = None + carray = None + cuda = None + float32 = None + + # see https://stackoverflow.com/questions/57774497 + def jit(*args, **kwargs): + def decorator(f): + return f + return decorator + + +TPB = 16 +TPB1 = TPB + 1 + + +def interpolate_gpu( + image: carray, + coordinates: carray, +) -> carray: + """Evaluate image at non-integer coordinates with linear interpolation + + Parameters + ---------- + image: (N, M) array of float32 + The array to use for interpolation + coordinates: (R, C, 2) array of float32 + The coordinates at which to interpolate. (R, C, 0) represent the y + coordinates in the image, (R, C, 1) represent the x coordinate in the + image + + Returns + ------- + values: (R, C) array of float32 + The interpolated values for all R, C coordinates + """ + result = cp.empty(coordinates.shape[:2], dtype=cp.float32) + bpg, tpb = _get_default_grid_dims_2D(result) + _evaluate_gpu_kernel[bpg, tpb](image, coordinates, result) + return result + + +def interpolate_gradient_gpu( + image: carray, + coordinates: carray, +) -> carray: + """Evaluate image gradient at non-integer coordinates with linear interpolation + + Parameters + ---------- + image: (N, M) array of float32 + The array to use for interpolation + coordinates: (R, C, 2) array of float32 + The coordinates at which to interpolate. (R, C, 0) represent the y + coordinates in the image, (R, C, 1) represent the x coordinate in the + image + + Returns + ------- + gradient: (R, C, 2) array of float32 + The interpolated gradients at all R, C coordinates. (R, C, 0) is the + y coordinate of the gradient, (R, C, 1) is the x coordinate. + """ + result = cp.zeros(coordinates.shape, dtype=cp.float32) + bpg, tpb = _get_default_grid_dims_2D(result, tpb=(TPB, TPB)) + _evaluate_gradient_gpu_kernel[bpg, tpb](image, coordinates, result) + return result + + +def evaluate_at_quad_points_gpu( + array: carray, + node_weights: carray, +) -> carray: + """Get the value of an array interpolated at each quadrature point + + Parameters + ---------- + array: (N, M) array + The array to evaluate the quad points on + node_weights: (4, K) of float32 + The weight each of the 4 surrounding nodes on each of the K quad points + + Returns + ------- + values: ((N-1), (M-1), K) array of float32 + The value of each quadrature point in each of the cells + """ + output = cp.empty( + (array.shape[0] - 1, array.shape[1] - 1, node_weights.shape[1]), + dtype=cp.float32, + ) + bpg, tpb = _get_default_grid_dims_2D(output) + _evaluate_at_quad_points_kernel[bpg, tpb](array, node_weights, output) + return output + + +def evaluate_pd_on_quad_points_gpu( + quadrature_values: carray, + quad_weights_sqrt: carray, + node_weights: carray, +) -> Tuple[carray, carray, carray]: + """Get a sparse representation of each node contribution to each quadrature point + + This represents a matrix of size (total number of quadrature points, + total number of nodes in the original image) + + Parameters + ---------- + quadrature_values: ((N-1), (M-1), K) array of float32 + The value of each K quadrature point in each of the (N-1) x (M-1) cells + quad_weights_sqrt: (K,) array of float32 + Square root of the weight of each quadrature point + node_weights: (4, K) array of float32 + The weight each of the 4 surrounding nodes on each of the K quad points + + Returns + ------- + data: ((N-1) x (M-1) x K x 4) array of float32 + Values in sparse array + rows: ((N-1) x (M-1) x K x 4) array of int32 + Row indices in sparse matrix + cols: ((N-1) x (M-1) x K x 4) array of int32 + Column indices in sparse matrix + """ + number_of_values = 4 * quadrature_values.size + data = cp.empty(number_of_values, dtype=cp.float32) + rows = cp.empty(number_of_values, dtype=cp.int32) + cols = cp.empty(number_of_values, dtype=cp.int32) + bpg, tpb = _get_default_grid_dims_2D(quadrature_values) + _evaluate_pd_on_quad_points_kernel[bpg, tpb]( + quadrature_values, + quad_weights_sqrt, + node_weights, + data, + rows, + cols, + ) + return data, rows, cols + + +def _get_default_grid_dims_2D( + array: carray, + tpb: Tuple[int, int] = (TPB, TPB), +) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """Helper function for calculating grid dimensions for executing a CUDA kernel""" + bpg = ( + (array.shape[0] + (tpb[0] - 1)) // tpb[0], + (array.shape[1] + (tpb[1] - 1)) // tpb[1], + ) + return bpg, tpb + + +@jit +def _evaluate_gpu_kernel( + image: carray, + coordinates: carray, + result: carray, +) -> None: + """Evaluate image at non-integer coordinates with linear interpolation""" + row, column = cuda.grid(2) + + if row >= coordinates.shape[0] or column >= coordinates.shape[1]: + return + + y = coordinates[row, column, 0] + x = coordinates[row, column, 1] + _, y0, wy = _get_interpolation_parameters(y, image.shape[0]) + _, x0, wx = _get_interpolation_parameters(x, image.shape[1]) + y1 = y0 + 1 + x1 = x0 + 1 + one_minus_wx = 1 - wx + one_minus_wy = 1 - wy + w_00 = one_minus_wx * one_minus_wy + w_10 = wy * one_minus_wx + w_01 = one_minus_wy * wx + w_11 = wy * wx + + result[row, column] = ( + image[y0, x0] * w_00 + + image[y1, x0] * w_10 + + image[y0, x1] * w_01 + + image[y1, x1] * w_11 + ) + + +@jit +def _evaluate_gradient_gpu_kernel( + image: carray, + coordinates: carray, + result: carray, +) -> None: + """Evaluate image gradient at non-integer coordinates with linear interpolation""" + row, column = cuda.grid(2) + + if row >= coordinates.shape[0] or column >= coordinates.shape[1]: + return + + y = coordinates[row, column, 0] + x = coordinates[row, column, 1] + valid_y, y0, wy = _get_interpolation_parameters(y, image.shape[0]) + valid_x, x0, wx = _get_interpolation_parameters(x, image.shape[1]) + + one_minus_wx = 1.0 - wx + one_minus_wy = 1.0 - wy + y1 = y0 + 1 + x1 = x0 + 1 + + if valid_y: + result[row, column, 0] = (image[y1, x0] - image[y0, x0]) * one_minus_wx + ( + image[y1, x1] - image[y0, x1] + ) * wx + + if valid_x: + result[row, column, 1] = (image[y0, x1] - image[y0, x0]) * one_minus_wy + ( + image[y1, x1] - image[y1, x0] + ) * wy + + +@jit(device=True) +def _get_interpolation_parameters( + coordinate: float, + axis_size: int, +) -> Tuple[bool, int, float]: + """Determine if a coordinate is within bounds, and what its weight is""" + if coordinate >= 0 and coordinate < axis_size - 1: + is_valid = True + reference_gridpoint = int(coordinate) + weight = coordinate - reference_gridpoint + elif coordinate < 0: + is_valid = False + reference_gridpoint = 0 + weight = 0.0 + elif coordinate > axis_size - 1: + is_valid = False + reference_gridpoint = axis_size - 2 + weight = 1.0 + elif coordinate == axis_size - 1: + is_valid = True + reference_gridpoint = axis_size - 2 + weight = 1.0 + return is_valid, reference_gridpoint, weight + + +@jit +def _evaluate_at_quad_points_kernel( + array: carray, node_weights: carray, output: carray +) -> None: + r, c = cuda.grid(2) + tx = cuda.threadIdx.x + ty = cuda.threadIdx.y + + # all threads require the same node weights, and often same image pixels + # TODO: also put node weights into a shared array + s_array = cuda.shared.array(shape=(TPB1, TPB1), dtype=float32) + + if r >= output.shape[0] or c >= output.shape[1]: + return + + s_array[tx, ty] = array[r, c] + TPBM = TPB - 1 + at_bottom_of_block = tx == TPBM or r == output.shape[0] - 1 + at_right_of_block = ty == TPBM or c == output.shape[1] - 1 + + if at_bottom_of_block: + s_array[tx + 1, ty] = array[r + 1, c] + if at_right_of_block: + s_array[tx, ty + 1] = array[r, c + 1] + if at_right_of_block and at_bottom_of_block: + s_array[tx + 1, ty + 1] = array[r + 1, c + 1] + + cuda.syncthreads() + + for p in range(node_weights.shape[1]): + output[r, c, p] = ( + s_array[tx, ty] * node_weights[0, p] + + s_array[tx, ty + 1] * node_weights[1, p] + + s_array[tx + 1, ty] * node_weights[2, p] + + s_array[tx + 1, ty + 1] * node_weights[3, p] + ) + + +@jit +def _evaluate_pd_on_quad_points_kernel( + quadrature_values: carray, + quad_weights_sqrt: carray, + node_weights: carray, + data: carray, + rows: carray, + cols: carray, +) -> None: + i, j = cuda.grid(2) + + if i >= quadrature_values.shape[0] or j >= quadrature_values.shape[1]: + return + + # original data shape + image_shape = ( + quadrature_values.shape[0] + 1, + quadrature_values.shape[1] + 1, + ) + col_offsets = (image_shape[1] + 1, image_shape[1], 1, 0) + + abs_2D = j + i * quadrature_values.shape[1] + col_base_index = abs_2D + i + offset_2D = quadrature_values.shape[2] * abs_2D + for k in range(quadrature_values.shape[2]): + abs_3D = k + offset_2D + offset_3D = 4 * abs_3D + val = quadrature_values[i, j, k] * quad_weights_sqrt[k] + for node in range(4): + idx = node + offset_3D + data[idx] = val * node_weights[node, k] + rows[idx] = abs_3D + cols[idx] = col_base_index + col_offsets[node] diff --git a/pymatchseries/implementation/implementation.py b/pymatchseries/implementation/implementation.py new file mode 100644 index 0000000..08d3169 --- /dev/null +++ b/pymatchseries/implementation/implementation.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +from dataclasses import dataclass as classic_dataclass +from pathlib import Path +from typing import Any, Callable, Iterator, Mapping, Optional + +import dask.array as da +from hyperspy.signals import ComplexSignal2D, Signal2D +from pydantic.dataclasses import dataclass +from tqdm.auto import tqdm + +from pymatchseries.implementation.objective_functions import ( + RegistrationObjectiveFunction, +) +from pymatchseries.implementation.solvers import root_gauss_newton +from pymatchseries.utils import ( + DenseArrayType, + create_image_pyramid, + displacement_to_coordinates, + get_dispatcher, + map_coordinates, + mean, + median, + resize_image_stack, + to_device, + to_host, +) + + +class _DataclassConfig: + validate_assignment = True + + +@dataclass(config=_DataclassConfig) +class Regularization: + constant_start: float = 0.1 + factor_level: float = 1.0 + factor_stage: float = 0.1 + + +@dataclass(config=_DataclassConfig) +class ObjectiveConfig: + number_of_quadrature_points: int = 3 + cache_derivative_of_regularizer: bool = True + + +@dataclass(config=_DataclassConfig) +class SolverConfig: + max_iterations: int = 50 + stop_epsilon: float = 0.0 + start_step: float = 1.0 + show_progress: bool = True + + +@dataclass(config=_DataclassConfig) +class IOConfig: + store_each_stage: bool = True + store_each_image_comparison: bool = True + store_each_level: bool = False + + +@dataclass(config=_DataclassConfig) +class JNNRConfig: + device: str = "auto" + n_levels: int = 3 + n_stages: int = 2 + reference_update_function: str = "median" + regularization: Regularization = Regularization() + objective: ObjectiveConfig = ObjectiveConfig() + solver: SolverConfig = SolverConfig() + io: IOConfig = IOConfig() + + +@classic_dataclass +class JNNRState: + images: Signal2D + reference_image: Optional[DenseArrayType] = None + deformations: Optional[ComplexSignal2D] = None + completed_stages: int = 0 + + @classmethod + def load(cls, filepath: str = "saved_jnnr_calculation") -> JNNRState: + raise NotImplementedError() + + def save(self) -> None: + raise NotImplementedError() + + +class JNRR: + def __init__( + self, + images: Signal2D, + ) -> None: + self._validate_images(images) + self.__config = JNNRConfig() + self.__state = JNNRState(images) + + @classmethod + def load( + cls, + filepath: Path, + ) -> JNRR: + raise NotImplementedError() + + def save(self) -> None: + raise NotImplementedError() + + @property + def images(self) -> Signal2D: + return self.state.images + + @property + def state(self) -> JNNRState: + return self.__state + + def run(self) -> None: + L = self.config.regularization.constant_start + n_stages = self.config.n_stages + + for stage in range(n_stages): + displacements = [] + corrected_images = [] + + progress = tqdm(total=self.number_of_images) + + with progress: + progress.set_description( + f"Stage: {stage + 1}, Image: 0/{self.number_of_images}" + ) + # Registration - finding all displacements + for i, image in enumerate( + self._get_image_iterator( + self.images, + device=self.config.device, + ) + ): + dp = get_dispatcher(image) + if self.state.reference_image is None: + self.state.reference_image = image + displacement = dp.zeros((2, *image.shape), dtype=image.dtype) + displacements.append( + self._displacement_to_complex(displacement) + ) + corrected_images.append(image) + progress.update(n=1) + continue + + displacement = self._get_multilevel_displacement_field( + image, + self.state.reference_image, + regularization_constant=L, + configuration=self.config, + ) + displacements.append(self._displacement_to_complex(displacement)) + + corrected_image = self._apply_displacement(image, displacement) + corrected_images.append(corrected_image) + + progress.set_description( + f"Stage: {stage + 1}, Image: {i + 1}/{self.number_of_images}" + ) + progress.update(n=1) + + # Bias correction?? + + self.state.reference_image = self._aggregate_stack( + dp.stack(corrected_images), + ) + self.state.deformations = ComplexSignal2D(dp.stack(corrected_images)) + self.state.completed_stages = stage + 1 + L *= self.config.regularization.factor_stage + + @classmethod + def _displacement_to_complex(cls, displacement: DenseArrayType) -> DenseArrayType: + return displacement[1] + 1j * displacement[0] + + @property + def _aggregate_stack(self) -> Callable: + if self.config.reference_update_function == "mean": + return mean + elif self.config.reference_update_function == "median": + return median + else: + raise NotImplementedError("Unrecognized aggregation method.") + + @property + def number_of_images(self) -> int: + return self.images.axes_manager.navigation_size + + @property + def config(self) -> JNNRConfig: + return self.__config + + def configure(self, options: Mapping[str, Any]) -> None: + """Provide configuration in a dictionary using dot notation""" + for key, value in options.items(): + split_key = key.split(".") + obj = self.config + for key_part in split_key: + sub_obj = getattr(obj, key_part) + if hasattr(sub_obj, "__dataclass_fields__"): + obj = sub_obj + setattr(obj, key_part, value) + + @classmethod + def _validate_images(cls, images: Signal2D) -> None: + if not isinstance(images, Signal2D): + raise ValueError("Images must be a HyperSpy Signal2D object.") + if images.axes_manager.navigation_dimension != 1: + raise ValueError("Navigation dimension must be one dimensional.") + + @classmethod + def _get_image_iterator( + cls, + images: Signal2D, + device: Optional[str] = None, + ) -> Iterator[DenseArrayType]: + is_lazy = isinstance(images.data, da.Array) + + if device == "cpu": + transfer = to_host + elif device == "gpu": + transfer = to_device + else: + + def do_nothing(x): + return x + + transfer = do_nothing + + for image in iter(images): + if is_lazy: + image.compute() + yield transfer(image.data) + + @classmethod + def _apply_displacement( + cls, + image: DenseArrayType, + displacement: DenseArrayType, + ) -> DenseArrayType: + return map_coordinates( + image, + displacement_to_coordinates(displacement), + ) + + @classmethod + def _get_multilevel_displacement_field( + cls, + image_deformed: DenseArrayType, + image_reference: DenseArrayType, + regularization_constant: float, + configuration: JNNRConfig, + ) -> DenseArrayType: + """Get the displacement field between two images by progressively scaling""" + n_levels = configuration.n_levels + + im_def_pyramid = create_image_pyramid( + image_deformed, + n_levels, + downscale_factor=2.0, + ) + im_ref_pyramid = create_image_pyramid( + image_reference, + n_levels, + downscale_factor=2.0, + ) + + displacement = None + + progress = tqdm(total=n_levels, leave=False) + + with progress: + progress.set_description(f"Level: 0/{n_levels}") + for i, (im_def, im_ref) in enumerate(zip(im_def_pyramid, im_ref_pyramid)): + if displacement is not None: + displacement = resize_image_stack( + image_stack=displacement, + new_size=im_def.shape, + ) + + displacement = cls._get_displacement_field( + image_deformed=im_def, + image_reference=im_ref, + regularization_constant=regularization_constant, + configuration=configuration, + displacement_start=displacement, + ) + + regularization_constant *= configuration.regularization.factor_level + + progress.set_description( + f"Level: {i + 1}/{n_levels}, Size: {im_def.shape}" + ) + progress.update(n=1) + + return displacement + + @classmethod + def _get_displacement_field( + cls, + image_deformed: DenseArrayType, + image_reference: DenseArrayType, + regularization_constant: float, + configuration: JNNRConfig, + displacement_start: Optional[DenseArrayType] = None, + ) -> DenseArrayType: + """Compare two images and get the optimized displacement field""" + image_shape = image_deformed.shape + objective_configuration = configuration.objective + n_quad_points = objective_configuration.number_of_quadrature_points + cache_d_reg = objective_configuration.cache_derivative_of_regularizer + objective = RegistrationObjectiveFunction( + image_deformed, + image_reference, + regularization_constant, + number_of_quadrature_points=n_quad_points, + cache_derivative_of_regularizer=cache_d_reg, + ) + + dp = objective.dispatcher + + if displacement_start is not None: + displacement_vector = displacement_start.ravel() + else: + displacement_vector = dp.zeros(2 * objective.number_of_nodes) + + solver_configuration = configuration.solver + return root_gauss_newton( + F=objective.evaluate_residual, + x0=displacement_vector, + DF=objective.evaluate_residual_gradient, + max_iterations=solver_configuration.max_iterations, + stop_epsilon=solver_configuration.stop_epsilon, + start_step=solver_configuration.start_step, + show_progress=solver_configuration.show_progress, + ).reshape(2, *image_shape) diff --git a/pymatchseries/implementation/interpolation.py b/pymatchseries/implementation/interpolation.py new file mode 100644 index 0000000..6d6e2bf --- /dev/null +++ b/pymatchseries/implementation/interpolation.py @@ -0,0 +1,203 @@ +from typing import Tuple + +import numpy as np +from numba import njit, prange + +from pymatchseries.utils import DenseArrayType, cp, get_dispatcher + + +class BilinearInterpolation2D: + """Class to perform interpolation of an image on an arbitrary non-integer + coordinate grid + + Parameters + ---------- + image: (N, M) array of float32 + The array to use for interpolation + """ + + def __init__( + self, + image: DenseArrayType, + ) -> None: + self.image = image + dispatcher = get_dispatcher(image) + if dispatcher == np: + self.evaluate_function = interpolate_cpu + self.evaluate_gradient_function = interpolate_gradient_cpu + elif dispatcher == cp: + from pymatchseries.implementation.cuda_kernels import ( + interpolate_gpu, + interpolate_gradient_gpu, + ) + + self.evaluate_function = interpolate_gpu + self.evaluate_gradient_function = interpolate_gradient_gpu + else: + raise ValueError("Unexpected object type for image") + + def evaluate(self, coordinates: DenseArrayType) -> DenseArrayType: + """Evaluate image at non-integer coordinates with linear interpolation + + Parameters + ---------- + coordinates: (R, C, 2) array of float32 + The coordinates at which to interpolate. (R, C, 0) represent the y + coordinates in the image, (R, C, 1) represent the x coordinate in + the image. + + Returns + ------- + values: (R, C) array of float32 + The interpolated values for all R, C coordinates + + Notes + ----- + R and C are actually arbitrary and serve to define how the problem is + parallelized. In the CPU implementation, a new thread is launched for + each R. In the GPU implementation, the grid is defined based on R and + C. + """ + return self.evaluate_function(self.image, coordinates) + + def evaluate_gradient(self, coordinates: DenseArrayType) -> DenseArrayType: + """Evaluate image gradient at non-integer coordinates with linear interpolation + + Parameters + ---------- + coordinates: (R, C, 2) array of float32 + The coordinates at which to interpolate. (R, C, 0) represent the y + coordinates in the image, (R, C, 1) represent the x coordinate in + the image. R and C are actually arbitrary and serve to define how + the problem is parallelized. + + Returns + ------- + gradient: (R, C, 2) array of float32 + The interpolated gradients at all R, C coordinates. (R, C, 0) is + the y coordinate of the gradient, (R, C, 1) is the x coordinate. + + Notes + ----- + R and C are actually arbitrary and serve to define how the problem is + parallelized. In the CPU implementation, a new thread is launched for + each R. In the GPU implementation, the grid is defined based on R and + C. + """ + return self.evaluate_gradient_function(self.image, coordinates) + + +@njit(parallel=True) +def interpolate_cpu( + image: np.ndarray, + coordinates: np.ndarray, +) -> np.ndarray: + """Evaluate image at non-integer coordinates with linear interpolation + + Parameters + ---------- + image: (N, M) array of float32 + The array to use for interpolation + coordinates: (R, C, 2) array of float32 + The coordinates at which to interpolate. (R, C, 0) represent the y + coordinates in the image, (R, C, 1) represent the x coordinate in the + image + + Returns + ------- + values: (R, C) array of float32 + The interpolated values for all R, C coordinates + """ + result = np.empty(coordinates.shape[:2], dtype=np.float32) + rows = coordinates.shape[0] + columns = coordinates.shape[1] + for row in prange(rows): + temp_weights = np.empty((2, 2), np.float32) + for column in range(columns): + y = coordinates[row, column, 0] + x = coordinates[row, column, 1] + _, y0, wy = _get_interpolation_parameters(y, image.shape[0]) + _, x0, wx = _get_interpolation_parameters(x, image.shape[1]) + sample = image[y0 : y0 + 2, x0 : x0 + 2] + one_minus_wx = 1 - wx + one_minus_wy = 1 - wy + temp_weights[0, 0] = one_minus_wx * one_minus_wy + temp_weights[0, 1] = one_minus_wy * wx + temp_weights[1, 0] = wy * one_minus_wx + temp_weights[1, 1] = wy * wx + result[row, column] = np.sum(sample * temp_weights) + return result + + +@njit(parallel=True) +def interpolate_gradient_cpu( + image: np.ndarray, + coordinates: np.ndarray, +) -> np.ndarray: + """Evaluate image gradient at non-integer coordinates with linear interpolation + + Parameters + ---------- + image: (N, M) array of float32 + The array to use for interpolation + coordinates: (R, C, 2) array of float32 + The coordinates at which to interpolate. (R, C, 0) represent the y + coordinates in the image, (R, C, 1) represent the x coordinate in the + image. + + Returns + ------- + gradient: (R, C, 2) array of float32 + The interpolated gradients at all R, C coordinates. (R, C, 0) is the + y coordinate of the gradient, (R, C, 1) is the x coordinate. + """ + result = np.zeros(coordinates.shape, dtype=np.float32) + rows = coordinates.shape[0] + columns = coordinates.shape[1] + for row in prange(rows): + for column in range(columns): + y = coordinates[row, column, 0] + x = coordinates[row, column, 1] + valid_y, y0, wy = _get_interpolation_parameters(y, image.shape[0]) + valid_x, x0, wx = _get_interpolation_parameters(x, image.shape[1]) + + one_minus_wx = 1.0 - wx + one_minus_wy = 1.0 - wy + y1 = y0 + 1 + x1 = x0 + 1 + + if valid_y: + result[row, column, 0] = ( + image[y1, x0] - image[y0, x0] + ) * one_minus_wx + (image[y1, x1] - image[y0, x1]) * wx + + if valid_x: + result[row, column, 1] = ( + image[y0, x1] - image[y0, x0] + ) * one_minus_wy + (image[y1, x1] - image[y1, x0]) * wy + + return result + + +@njit +def _get_interpolation_parameters( + coordinate: float, + axis_size: int, +) -> Tuple[bool, int, float]: + if coordinate >= 0 and coordinate < axis_size - 1: + is_valid = True + reference_gridpoint = int(coordinate) + weight = coordinate - reference_gridpoint + elif coordinate < 0: + is_valid = False + reference_gridpoint = 0 + weight = 0.0 + elif coordinate > axis_size - 1: + is_valid = False + reference_gridpoint = axis_size - 2 + weight = 1.0 + elif coordinate == axis_size - 1: + is_valid = True + reference_gridpoint = axis_size - 2 + weight = 1.0 + return is_valid, reference_gridpoint, weight diff --git a/pymatchseries/implementation/objective_functions.py b/pymatchseries/implementation/objective_functions.py new file mode 100644 index 0000000..6febb74 --- /dev/null +++ b/pymatchseries/implementation/objective_functions.py @@ -0,0 +1,322 @@ +from functools import cached_property +from math import prod, sqrt +from types import ModuleType +from typing import Dict, Optional, Tuple + +from pymatchseries.utils import ( + DenseArrayType, + OneValueCache, + SparseMatrixType, + get_dispatcher, + get_sparse_module, +) + +from .interpolation import BilinearInterpolation2D +from .quadrature import Quadrature2D + + +class RegistrationObjectiveFunction: + # class level cache for the regularizer derivative + _DERIVATIVE_OF_REGULARIZER_CACHE: Dict[ + Tuple[int, int, int, float, ModuleType], SparseMatrixType + ] = {} + + def __init__( + self, + image_deformed: DenseArrayType, + image_reference: DenseArrayType, + regularization_constant: float, + number_of_quadrature_points: int = 3, + cache_derivative_of_regularizer: bool = False, + ) -> None: + self.dispatcher = get_dispatcher(image_deformed) + self.grid_shape = image_deformed.shape + self.sparse = get_sparse_module(self.dispatcher) + self.quadrature = Quadrature2D( + grid_shape=(self.grid_shape[0], self.grid_shape[1]), + number_of_points=number_of_quadrature_points, + dispatcher=self.dispatcher, + ) + + self.image_deformed_interpolated = BilinearInterpolation2D(image_deformed) + self.image_reference = image_reference + + self.identity = self.dispatcher.mgrid[ + 0 : self.grid_shape[0], + 0 : self.grid_shape[1], + ].astype(self.dispatcher.float32) + + self.regularization_constant = float(regularization_constant) + self.regularization_constant_sqrt = sqrt(regularization_constant) + + # since this value is used in multiple functions we cache it + self.positions_at_quad_points_cache = OneValueCache() + + self.derivative_of_regularizer = self._get_derivative_of_regularizer() + + if cache_derivative_of_regularizer: + self.cache_derivative_of_regularizer() + + def evaluate_residual( + self, + displacement_vector: DenseArrayType, + ) -> DenseArrayType: + """Evaluate the error on the image corrected with the provided + displacement with respect to the reference image + + Parameters + ---------- + displacement_vector + Array of length (2 * N * M), representing the y and x displacements + in each pixel. + + Returns + ------- + error + Array of length (5 * (N-1) * (M-1) * K) + """ + positions_at_quad_points = self._quantize_displacement_vector( + displacement_vector + ) + + dp = self.dispatcher + R = self.cell_grid_shape[0] + C = self.cell_grid_shape[1] * self.cell_grid_shape[2] + positions_at_quad_points = positions_at_quad_points.reshape(R, C, 2) + + corrected_image = self.image_deformed_interpolated.evaluate( + positions_at_quad_points + ).reshape(-1, self.number_of_quadrature_points) + + ground_truth = self.quadrature.evaluate(self.image_reference).reshape( + -1, self.number_of_quadrature_points + ) + + residual_data = dp.multiply( + self.quadrature.quadrature_point_weights_sqrt, + corrected_image - ground_truth, + ) + residual_regularization = self.derivative_of_regularizer.dot( + displacement_vector + ) + + return dp.concatenate( + ( + residual_data.ravel(), + residual_regularization, + ) + ) + + def evaluate_residual_gradient( + self, + displacement_vector: DenseArrayType, + ) -> SparseMatrixType: + """Evaluate the error on the corrected image with respect to the image_reference + + Parameters + ---------- + displacement_vector + Array of length (2 * N * M), representing the y and x displacements + in each pixel. + + Returns + ------- + error_gradient + Sparse matrix of shape (5 * (N-1) * (M-1) * K, 2 * N * M) + """ + positions_at_quad_points = self._quantize_displacement_vector( + displacement_vector + ) + + dp = self.dispatcher + R = self.cell_grid_shape[0] + C = self.cell_grid_shape[1] * self.cell_grid_shape[2] + df = ( + self.image_deformed_interpolated.evaluate_gradient( + positions_at_quad_points.reshape(R, C, 2) + ) + / self.grid_scaling + ) + dfdy = df[..., 0].reshape(self.cell_grid_shape).astype(dp.float32) + dfdx = df[..., 1].reshape(self.cell_grid_shape).astype(dp.float32) + + data_y, rows_y, cols_y = self.quadrature.evaluate_partial_derivatives( + dfdy, + self.quadrature.basis_f_at_points, + ) + data_x, rows_x, cols_x = self.quadrature.evaluate_partial_derivatives( + dfdx, + self.quadrature.basis_f_at_points, + ) + + gradient_data = self.sparse.csr_matrix( + ( + dp.concatenate((data_y, data_x)), + ( + dp.concatenate((rows_y, rows_x)), + dp.concatenate((cols_y, cols_x + self.number_of_nodes)), + ), + ), + shape=( + self.quadrature.total_number_of_quadrature_points, + 2 * self.number_of_nodes, + ), + ) + + return self.sparse.vstack( + [gradient_data, self.derivative_of_regularizer] + ).tocsr() + + def evaluate_energy( + self, + displacement_vector: DenseArrayType, + ) -> float: + dp = self.dispatcher + return dp.sum(self.evaluate_residual(displacement_vector) ** 2) + + def evaluate_energy_gradient( + self, + displacement_vector: DenseArrayType, + ) -> DenseArrayType: + residual = self.evaluate_residual(displacement_vector) + residual_gradient = self.evaluate_residual_gradient(displacement_vector) + return 2 * residual_gradient.T * residual.ravel() + + def _quantize_displacement_vector( + self, + displacement_vector: DenseArrayType, + ) -> DenseArrayType: + """Convert displacement field vector to quadrature point evaluations + + Parameters + ---------- + displacement_vector + Array of length (2 * N * M), representing the y and x displacements + in each pixel. + + Returns + ------- + positions_at_quad_points + Array of shape ((N-1) * (M-1) * K, 2), representing the position + field evaluated at all quadrature points. [:, 0] is the y component + [:, 1] is the x component. + """ + dp = self.dispatcher + + array_bytes = displacement_vector.tobytes() + if array_bytes in self.positions_at_quad_points_cache: + return self.positions_at_quad_points_cache[array_bytes] + + displacement_y, displacement_x = displacement_vector.reshape( + (2, *self.grid_shape) + ) + pixel_row, pixel_column = self.identity + new_position_x = displacement_x / self.grid_scaling + pixel_column + new_position_y = displacement_y / self.grid_scaling + pixel_row + n_rows = self.quadrature.total_number_of_quadrature_points + positions_at_quad_points = dp.empty((n_rows, 2), dtype=dp.float32) + positions_at_quad_points[:, 0] = self.quadrature.evaluate( + new_position_y + ).ravel() + positions_at_quad_points[:, 1] = self.quadrature.evaluate( + new_position_x + ).ravel() + + self.positions_at_quad_points_cache[array_bytes] = positions_at_quad_points + + return positions_at_quad_points + + @property + def grid_scaling(self) -> float: + return self.quadrature.grid_scaling + + @property + def cell_grid_shape(self) -> Tuple[int, int, int]: + return self.quadrature.cell_grid_shape + + @property + def number_of_quadrature_points(self) -> int: + return self.quadrature.number_of_quadrature_points + + @cached_property + def number_of_nodes(self) -> int: + return prod(self.grid_shape) + + def _get_derivative_of_regularizer(self) -> SparseMatrixType: + """Derivative of regularizer is constant matrix + + Has shape (4 * (N-1) * (M-1) * K, 2 * N * M), with the number of + rows equaling 4x the total number of quadrature points and the + columns 2x the total number of nodes (pixels). + """ + cached = self._get_cached_derivative_of_regularizer() + if cached is not None: + return cached + + # TODO: since the value going in is a constant, the number of unique + # values is limited and there may be shortcuts to calculate this. + dp = self.dispatcher + sparse = self.sparse + # regularization constant for each quadrature point in the grid of cells + quadrature_values = dp.full( + self.quadrature.cell_grid_shape, + fill_value=self.regularization_constant_sqrt, + dtype=dp.float32, + ) + + # reg = regularizer + ( + data_reg_x, + rows_reg_x, + cols_reg_x, + ) = self.quadrature.evaluate_partial_derivatives( + quadrature_values, + node_weights=self.quadrature.basis_dfx_at_points, + ) + ( + data_reg_y, + rows_reg_y, + cols_reg_y, + ) = self.quadrature.evaluate_partial_derivatives( + quadrature_values, + node_weights=self.quadrature.basis_dfy_at_points, + ) + + # combine the data into a single matrix + n_quad_points = self.quadrature.total_number_of_quadrature_points + block_shape = (2 * n_quad_points, self.number_of_nodes) + mat_reg = sparse.csr_matrix( + ( + dp.concatenate((data_reg_x, data_reg_y)), + ( + dp.concatenate((rows_reg_x, rows_reg_y + n_quad_points)), + dp.concatenate((cols_reg_x, cols_reg_y)), + ), + ), + shape=block_shape, + dtype=dp.float32, + ) + + mat_zero = sparse.csr_matrix(block_shape, dtype=dp.float32) + return sparse.vstack( + [ + sparse.hstack([mat_zero, mat_reg]), + sparse.hstack([mat_reg, mat_zero]), + ] + ).tocsr() + + @cached_property + def _cache_key(self) -> Tuple[int, int, int, float, ModuleType]: + return (*self.cell_grid_shape, self.regularization_constant, self.dispatcher) + + def cache_derivative_of_regularizer(self) -> None: + if self._cache_key not in self._DERIVATIVE_OF_REGULARIZER_CACHE: + self._DERIVATIVE_OF_REGULARIZER_CACHE[ + self._cache_key + ] = self.derivative_of_regularizer + + def _get_cached_derivative_of_regularizer(self) -> Optional[SparseMatrixType]: + return self._DERIVATIVE_OF_REGULARIZER_CACHE.get(self._cache_key) + + def clear_cache(self) -> None: + self._DERIVATIVE_OF_REGULARIZER_CACHE.clear() diff --git a/pymatchseries/implementation/quadrature.py b/pymatchseries/implementation/quadrature.py new file mode 100644 index 0000000..0391414 --- /dev/null +++ b/pymatchseries/implementation/quadrature.py @@ -0,0 +1,493 @@ +from functools import cached_property +from math import prod, sqrt +from types import ModuleType +from typing import Tuple + +import numpy as np +from numba import njit, prange + +from pymatchseries.utils import DenseArrayType, get_grid_scaling_factor + + +class Quadrature2D: + def __init__( + self, + grid_shape: Tuple[int, int], + number_of_points: int = 3, + dispatcher: ModuleType = np, + ) -> None: + """ + 2D quadrature point representation to approximate the integral or + gradient of a function F(x, y) + """ + self._number_of_points = number_of_points**2 + if number_of_points == 2: + self._points = self._get_gauss_quad_points_2(dispatcher) + self._weight = self._get_gauss_quad_weights_2(dispatcher) + elif number_of_points == 3: + self._points = self._get_gauss_quad_points_3(dispatcher) + self._weight = self._get_gauss_quad_weights_3(dispatcher) + else: + raise NotImplementedError( + f"Quadrature with {number_of_points} points not implemented", + ) + self.grid_shape = grid_shape + self.dispatcher = dispatcher + # Factor to rescale integration domain so that longest dimension is 1 + self.grid_scaling: float = get_grid_scaling_factor(grid_shape) + if self.dispatcher == np: + self.evaluate_function = evaluate_at_quad_points_cpu + self.evaluate_pd_function = evaluate_pd_on_quad_points_cpu + else: + from pymatchseries.implementation.cuda_kernels import ( + evaluate_at_quad_points_gpu, + evaluate_pd_on_quad_points_gpu, + ) + + self.evaluate_function = evaluate_at_quad_points_gpu + self.evaluate_pd_function = evaluate_pd_on_quad_points_gpu + + def evaluate(self, array: DenseArrayType) -> DenseArrayType: + """Get the value of an array interpolated at each quadrature point + + Parameters + ---------- + array: (N, M) array of float32 + The array to evaluate the quad points on + + Returns + ------- + values: ((N-1), (M-1), K) array of float32 + The interpolated value at each K quadrature point in each cell in + the (N-1) x (M-1) grid + """ + return self.evaluate_function(array, self.node_weights) + + def evaluate_partial_derivatives( + self, + quadrature_values: DenseArrayType, + node_weights: DenseArrayType, + ) -> Tuple[DenseArrayType, DenseArrayType, DenseArrayType]: + """Get a sparse representation of the partial derivative at each quadrature points + + This represents a matrix of size (total number of quadrature points, + total number of nodes in the original image). The number of nodes is + equal to the number of basis functions. + + Parameters + ---------- + quadrature_values: ((N-1), (M-1), K) array of float32 + The value of each K quadrature point in each of the (N-1) x (M-1) cells + node_weights: (4, K) array of float32 + The weight each of the 4 surrounding nodes on each of the K quad points + in a cell + + Returns + ------- + data: ((N-1) x (M-1) x K x 4) array of float32 + Values in sparse array + rows: ((N-1) x (M-1) x K x 4) array of int32 + Row indices in sparse matrix + cols: ((N-1) x (M-1) x K x 4) array of int32 + Column indices in sparse matrix + """ + return self.evaluate_pd_function( + quadrature_values, self.quadrature_point_weights_sqrt, node_weights + ) + + @cached_property + def quadrature_points(self) -> DenseArrayType: + """Quadrature point x, y coordinates. Array shape is (K, 2).""" + return self._points + + @cached_property + def quadrature_point_weights(self) -> DenseArrayType: + """Quadrature point weights. Array shape is (K,).""" + return self._weight * (self.grid_scaling**2) + + @cached_property + def quadrature_point_weights_sqrt(self) -> DenseArrayType: + """Root of quadrature point weights. Array shape is (K,).""" + return self.dispatcher.sqrt(self.quadrature_point_weights) + + @property + def quadrature_points_x_coordinate(self) -> DenseArrayType: + """x coordinate of quadrature points. Array shape is (K,).""" + return self.quadrature_points[:, 0] + + @property + def quadrature_points_y_coordinate(self) -> DenseArrayType: + """y coordinate of quadrature points. Array shape is (K,).""" + return self.quadrature_points[:, 1] + + @property + def number_of_quadrature_points(self) -> int: + """Number of quadrature points in a cell""" + return self._number_of_points + + @cached_property + def cell_grid_shape(self) -> Tuple[int, int, int]: + """((N-1), (M-1), K)""" + return ( + self.grid_shape[0] - 1, + self.grid_shape[1] - 1, + self.number_of_quadrature_points, + ) + + @cached_property + def total_number_of_quadrature_points(self) -> int: + """Number of quadrature points over all cells""" + return prod(self.cell_grid_shape) + + @cached_property + def node_weights(self) -> DenseArrayType: + """The weights w_i that each surrounding node contributes to evaluating + the function f at the quadrature points at x, y, i.e.: + `f(x, y) = w_0 * f_00 + w_1 * f_01 + w_2 * f_10 + w_4 * f_11.` + + Returns + ------- + weights + Weights of each node for each quadrature point. Array of shape + (4, K) of float32, where K is the number of quadrature points in + a cell. + + Notes + ----- + We assume the following order of nodes: + * 00 = top left + * 01 = top right + * 10 = bottom left + * 11 = bottom right + """ + qx = self.quadrature_points_x_coordinate + qy = self.quadrature_points_y_coordinate + wx1 = 1 - qx + wx2 = qx + wy1 = 1 - qy + wy2 = qy + return self.dispatcher.vstack( + [ + wy1 * wx1, + wy1 * wx2, + wy2 * wx1, + wy2 * wx2, + ] + ) + + @cached_property + def basis_f_at_points(self) -> DenseArrayType: + """The value of the 4 basis functions around a cell at each of the + quadrature points in the cell. + + Returns + ------- + values + Value of 4 basis functions. Array of shape (4, K) of float32, + where K is the number of quadrature points in a cell. + + Notes + ----- + We assume the following order of nodes: + * values[0] = top left + * values[1] = top right + * values[2] = bottom left + * values[3] = bottom right + """ + qx = self.quadrature_points_x_coordinate + qy = self.quadrature_points_y_coordinate + one_minus_qx = 1 - qx + one_minus_qy = 1 - qy + return self.dispatcher.vstack( + [ + qx * qy, + one_minus_qx * qy, + qx * one_minus_qy, + one_minus_qx * one_minus_qy, + ] + ) + + @cached_property + def dx_node_weights(self) -> DenseArrayType: + """The weights to evaluate d/dx * f(x, y) at the quadrature points + + Returns + ------- + weights + Weights of each node for each quadrature point. Array of shape + (4, K) of float32, where K is the number of quadrature points in + a cell. + + Notes + ----- + * see also node weights + * since cells are rescaled (see quadrature point weights), gradients + are larger (scaled by same factor) + """ + qy = self.quadrature_points_y_coordinate + one_minus_qy = 1 - qy + return ( + self.dispatcher.vstack([-one_minus_qy, one_minus_qy, -qy, qy]) + / self.grid_scaling + ) + + @cached_property + def dy_node_weights(self) -> DenseArrayType: + """The weights to evaluate d/dy * f(x, y) at the quadrature points + + Returns + ------- + weights + Weights of each node for each quadrature point. Array of shape + (4, K) of float32, where K is the number of quadrature points in + a cell. + + Notes + ----- + * see also node weights + * since cells are rescaled (see quadrature point weights), gradients + are larger (scaled by same factor) + """ + qx = self.quadrature_points_x_coordinate + one_minus_qx = 1 - qx + return ( + self.dispatcher.vstack([-one_minus_qx, -qx, one_minus_qx, qx]) + / self.grid_scaling + ) + + @cached_property + def basis_dfx_at_points(self) -> DenseArrayType: + """d/dx of the 4 basis functions around a cell at each of the + quadrature points in the cell. + + Returns + ------- + values + d/dx of 4 basis functions. Array of shape (4, K) of float32, + where K is the number of quadrature points in a cell. + + Notes + ----- + We assume the following order of nodes: + * values[0] = top left + * values[1] = top right + * values[2] = bottom left + * values[3] = bottom right + """ + qy = self.quadrature_points_y_coordinate + one_minus_qy = 1 - qy + return ( + self.dispatcher.vstack( + [ + qy, + -qy, + one_minus_qy, + -one_minus_qy, + ] + ) + / self.grid_scaling + ) + + @cached_property + def basis_dfy_at_points(self) -> DenseArrayType: + """d/dy of the 4 basis functions around a cell at each of the + quadrature points in the cell. + + Returns + ------- + values + d/dy of 4 basis functions. Array of shape (4, K) of float32, + where K is the number of quadrature points in a cell. + + Notes + ----- + We assume the following order of nodes: + * values[0] = top left + * values[1] = top right + * values[2] = bottom left + * values[3] = bottom right + """ + qx = self.quadrature_points_x_coordinate + one_minus_qx = 1 - qx + return ( + self.dispatcher.vstack( + [ + qx, + one_minus_qx, + -qx, + -one_minus_qx, + ] + ) + / self.grid_scaling + ) + + @classmethod + def _get_gauss_quad_points_2( + cls, + dispatcher: ModuleType = np, + ) -> DenseArrayType: + """ + Get the x, y coordinates of the Gaussian quadrature points with 4 points + """ + p = 1 / sqrt(3) / 2 + quads = dispatcher.array( + [[-p, -p], [p, -p], [-p, p], [p, p]], + dtype=dispatcher.float32, + ) + quads += 0.5 + return quads + + @classmethod + def _get_gauss_quad_weights_2( + cls, + dispatcher: ModuleType = np, + ) -> DenseArrayType: + """ + Get the weights for the Gaussian quadrature points with 4 points + """ + return dispatcher.ones(4, dtype=dispatcher.float32) / 4 + + @classmethod + def _get_gauss_quad_points_3( + cls, + dispatcher: ModuleType = np, + ) -> DenseArrayType: + """ + Get the x, y coordinates of the Gaussian quadrature points with 9 points + """ + p = sqrt(3 / 5) / 2 + quads = dispatcher.array( + [ + [-p, -p], + [0, -p], + [p, -p], + [-p, 0], + [0, 0], + [p, 0], + [-p, p], + [0, p], + [p, p], + ], + dtype=dispatcher.float32, + ) + quads += 0.5 + return quads + + @classmethod + def _get_gauss_quad_weights_3( + cls, + dispatcher: ModuleType = np, + ) -> DenseArrayType: + """ + Get the weights for the Gaussian quadrature points with 9 points + + References + ---------- + http://users.metu.edu.tr/csert/me582/ME582%20Ch%2003.pdf + """ + return ( + dispatcher.array( + [25, 40, 25, 40, 64, 40, 25, 40, 25], + dtype=dispatcher.float32, + ) + / 81 + / 4 + ) + + +@njit(parallel=True) +def evaluate_at_quad_points_cpu( + array: np.ndarray, + node_weights: np.ndarray, +) -> np.ndarray: + """Get the value of an array interpolated at each quadrature point + + Parameters + ---------- + array: (N, M) array of float32 + The array to evaluate the quad points on + node_weights: (4, K) of float32 + The weight each of the 4 surrounding nodes on each of the K quad points + + Returns + ------- + values: ((N-1), (M-1), K) array of float32 + The value of each quadrature point in each of the cells + """ + output = np.empty( + (array.shape[0] - 1, array.shape[1] - 1, node_weights.shape[1]), + dtype=np.float32, + ) + for r in prange(array.shape[0] - 1): + for c in range(array.shape[1] - 1): + for p in range(node_weights.shape[1]): + output[r, c, p] = ( + array[r, c] * node_weights[0, p] + + array[r, c + 1] * node_weights[1, p] + + array[r + 1, c] * node_weights[2, p] + + array[r + 1, c + 1] * node_weights[3, p] + ) + return output + + +@njit(parallel=True) +def evaluate_pd_on_quad_points_cpu( + quadrature_values: np.ndarray, + quad_weights_sqrt: np.ndarray, + node_weights: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get a sparse representation of each node contribution to each quadrature point + + This represents a matrix of size (total number of quadrature points, + total number of nodes in the original image) + + Parameters + ---------- + quadrature_values: ((N-1), (M-1), K) array of float32 + The value of each K quadrature point in each of the (N-1) x (M-1) cells + quad_weights_sqrt: (K,) array of float32 + Square root of the weight of each quadrature point + node_weights: (4, K) array of float32 + The weight each of the 4 surrounding nodes on each of the K quad points + + Returns + ------- + data: ((N-1) x (M-1) x K x 4) array of float32 + Values in sparse array + rows: ((N-1) x (M-1) x K x 4) array of int32 + Row indices in sparse matrix + cols: ((N-1) x (M-1) x K x 4) array of int32 + Column indices in sparse matrix + """ + # original data shape + image_shape = ( + quadrature_values.shape[0] + 1, + quadrature_values.shape[1] + 1, + ) + + number_of_values = 4 * quadrature_values.size + data = np.empty(number_of_values, dtype=np.float32) + rows = np.empty(number_of_values, dtype=np.int32) + cols = np.empty(number_of_values, dtype=np.int32) + col_offsets = np.array( + [image_shape[1] + 1, image_shape[1], 1, 0], + dtype=np.int32, + ) + for i in prange(quadrature_values.shape[0]): + for j in prange(quadrature_values.shape[1]): + # index in flattened 2D array + abs_2D = j + i * quadrature_values.shape[1] + col_base_index = abs_2D + i + offset_2D = quadrature_values.shape[2] * abs_2D + for k in range(quadrature_values.shape[2]): + # index in flattened 3D array + abs_3D = k + offset_2D + offset_3D = 4 * abs_3D + val = quadrature_values[i, j, k] * quad_weights_sqrt[k] + for node in range(4): + idx = node + offset_3D + data[idx] = val * node_weights[node, k] + rows[idx] = abs_3D + cols[idx] = col_base_index + col_offsets[node] + + return data, rows, cols diff --git a/pymatchseries/implementation/solvers.py b/pymatchseries/implementation/solvers.py new file mode 100644 index 0000000..e7a511b --- /dev/null +++ b/pymatchseries/implementation/solvers.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import logging +import warnings +from typing import Callable + +from tqdm.auto import tqdm + +from pymatchseries.utils import ArrayType, DenseArrayType, Matrix, get_dispatcher + +logger = logging.getLogger(__name__) + + +def root_gauss_newton( + F: Callable[[DenseArrayType], DenseArrayType], + x0: DenseArrayType, + DF: Callable[[DenseArrayType], ArrayType], + max_iterations: int = 50, + stop_epsilon: float = 0.0, + start_step: float = 1.0, + show_progress: bool = False, +) -> DenseArrayType: + """ + Implementation of Gauss-Newton iterative solver for sparse systems + + Parameters + ---------- + F + Function to find the root of + x0 + Initial guess + DF + Function that returns the Jacobian of F + max_iterations + Maximum number of iterations + stop_epsilon + Relative error change in a step at which to stop iteration + start_step + Initial step size + + Returns + ------- + x + Root of F + """ + x = x0.copy() + f = F(x) + total_square_error = f.dot(f) + logger.info("Initial error {:#.6g}".format(total_square_error)) + step = start_step + + # use tqdm to show a progress bar + if show_progress: + iterations = tqdm(range(max_iterations), leave=False) + else: + iterations = range(max_iterations) + + dp = get_dispatcher(x0) + matrix_type = None + + for i in iterations: + matDF = DF(x) + if matrix_type is None: + matrix_type = Matrix.get_matrix_type(matDF) + dx = matrix_type(matDF).solve_lstsq(f) + + if not dp.all(dp.isfinite(dx)): + raise RuntimeError("Least squares solving failed.") + + x -= dx + f = F(x) + updated_total_square_error = f.dot(f) + + if updated_total_square_error >= total_square_error: + # If the target functional did not decrease with the update, try to + # find a smaller step so that it does. + x += dx + dx *= -1 + step = _get_stepsize(F=F, x=x, dx=dx, start_step=min(2 * step, 1)) + x += step * dx + f = F(x) + updated_total_square_error = f.dot(f) + else: + step = 1 + + error_difference = total_square_error - updated_total_square_error + + if show_progress: + iterations.set_description( + "step_size={:#.2g}, error={:#.5g}, difference={:.1e}".format( + step, + updated_total_square_error, + error_difference, + ) + ) + + if error_difference <= stop_epsilon * updated_total_square_error or dp.isclose( + updated_total_square_error, 0 + ): + # convergence is reached + if show_progress: + iterations.container.close() + break + + total_square_error = updated_total_square_error + + else: + warnings.warn( + "Reached the maximum number of iterations without reaching stop criterion" + ) + + return x + + +def _get_stepsize( + F: Callable[[DenseArrayType], DenseArrayType], + x: DenseArrayType, + dx: DenseArrayType, + start_step: float = 1.0, + min_step: float = 2**-30, +) -> float: + """ + Get maximum iteration step width to ensure convergence via geometric search + + Parameters + ---------- + F + Function to find root of + x + Vector of length N that indicates the current best estimate solution + dx + Vector of length N that indicates the delta vector to x + start_step + Initial guess for the step + min_step + Smallest step to accept + + Returns + ------- + step + Largest step that ensures a decrease in energy + """ + + def error_function(v): + evaluated = F(v) + return evaluated.dot(evaluated) + + step = start_step + + error = error_function(x) + updated_error = error_function(x + step * dx) + + while (updated_error >= error) and (step >= min_step): + step *= 0.5 + updated_error = error_function(x + step * dx) + + return step diff --git a/pymatchseries/io_utils.py b/pymatchseries/io_utils.py index 2b0acef..6712e56 100644 --- a/pymatchseries/io_utils.py +++ b/pymatchseries/io_utils.py @@ -1,10 +1,11 @@ -from PIL import Image -from pathlib import Path +import bz2 import concurrent.futures as cf import logging import os -import bz2 +from pathlib import Path + import numpy as np +from PIL import Image def _save_frame_to_file(i, data, folder, prefix, digits, data_format="tiff"): @@ -21,7 +22,7 @@ def _save_frame_to_file(i, data, folder, prefix, digits, data_format="tiff"): pass # in case of a weird datatype if not (frm.dtype == np.uint16 or frm.dtype == np.uint8): - frm = (frm - frm.min()) / (frm.max() - frm.min()) * (2 ** 16 - 1) + frm = (frm - frm.min()) / (frm.max() - frm.min()) * (2**16 - 1) frm = np.uint16(frm) img = Image.fromarray(frm) img.save(fp) diff --git a/pymatchseries/matchseries.py b/pymatchseries/matchseries.py index e5715a2..76212da 100644 --- a/pymatchseries/matchseries.py +++ b/pymatchseries/matchseries.py @@ -2,22 +2,23 @@ Module that includes tools for converting experimental data into the file structure required for match-series """ -from subprocess import Popen, PIPE, STDOUT +import json import logging -from pathlib import Path import os -import numpy as np -import hyperspy.api as hs -from tabulate import tabulate -import warnings -from scipy import ndimage +import shutil import uuid +import warnings +from pathlib import Path +from subprocess import PIPE, STDOUT, Popen + import dask.array as da +import h5py +import hyperspy.api as hs +import numpy as np from dask import delayed from dask.diagnostics import ProgressBar -import shutil -import json -import h5py +from scipy import ndimage +from tabulate import tabulate from pymatchseries import config_tools as ctools from pymatchseries import io_utils as ioutls @@ -285,8 +286,13 @@ def __prepare_calculation(self): self.__update_metadata_file() def __run_match_series(self): - """Run match series using the config file and print all output""" - # from https://github.com/takluyver/rt2-workshop-jupyter/blob/e7fde6565e28adf31a0f9003094db70c3766bd6d/Subprocess%20output.ipynb + """Run match series using the config file and print all output + + See also + -------- + * https://github.com/takluyver/rt2-workshop-jupyter/blob/ + e7fde6565e28adf31a0f9003094db70c3766bd6d/Subprocess%20output.ipynb + """ cmd = ["matchSeries", f"{self.config_file_path}"] p = Popen(cmd, stdout=PIPE, stderr=STDOUT, cwd=self.path) while True: diff --git a/pymatchseries/utils.py b/pymatchseries/utils.py new file mode 100644 index 0000000..053210e --- /dev/null +++ b/pymatchseries/utils.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +from types import ModuleType +from typing import TYPE_CHECKING, Callable, Dict, Iterator, Mapping, Tuple, Union + +import dask.array as da +import numpy as np +import scipy +import scipy.ndimage as ndimage +import scipy.sparse as sparse + +try: + import cupy as cp + import cupyx.scipy.ndimage as cndimage + import cupyx.scipy.sparse as csparse + import cupyx.scipy.sparse.linalg as clinalg + + CUPY_IS_INSTALLED = True +except ImportError: + cp = None + cndimage = None + csparse = None + clinalg = None + CUPY_IS_INSTALLED = False + + +if TYPE_CHECKING: + ArrayType = Union[sparse.spmatrix, np.ndarray, csparse.spmatrix, cp.ndarray] + DenseArrayType = Union[np.ndarray, cp.ndarray] + SparseMatrixType = Union[sparse.spmatrix, csparse.spmatrix] +else: + ArrayType = None + DenseArrayType = None + SparseMatrixType = None + + +def mean(images: DenseArrayType) -> DenseArrayType: + """Calculate the mean of an image stack, stack dimension is axis 0""" + dp = get_dispatcher(images) + return dp.mean(images, axis=0) + + +def median(images: DenseArrayType) -> DenseArrayType: + """Calculate the median of an image stack, stack dimension is axis 0""" + dp = get_dispatcher(images) + return dp.median(images, axis=0) + + +def to_host(array: DenseArrayType) -> np.ndarray: + if CUPY_IS_INSTALLED and isinstance(array, cp.ndarray): + return array.get() + elif isinstance(array, np.ndarray): + return array + else: + raise ValueError(f"Array type is {type(array)}, must be array.") + + +def to_device(array: DenseArrayType) -> cp.ndarray: + if CUPY_IS_INSTALLED and isinstance(array, cp.ndarray): + return array + elif isinstance(array, np.ndarray): + return cp.asarray(array) + else: + raise ValueError(f"Array type is {type(array)}, must be array.") + + +def get_array_type( + array: Union[np.ndarray, cp.ndarray, da.Array], +) -> Tuple[ModuleType, bool]: + """Returns the underlying dispatcher and whether an array is lazy""" + is_lazy = False + if isinstance(array, da.Array): + first_chunk_slice = tuple(slice(chunk_dim[0]) for chunk_dim in array.chunks) + array = da.compute(array[first_chunk_slice]) + is_lazy = True + return get_dispatcher(array), is_lazy + + +def displacement_to_coordinates( + displacement: DenseArrayType, +) -> DenseArrayType: + dp = get_dispatcher(displacement) + grid_shape = (displacement.shape[1], displacement.shape[2]) + scaling_factor = get_grid_scaling_factor(grid_shape) + identity = dp.mgrid[0 : grid_shape[0], 0 : grid_shape[1]].astype(displacement.dtype) + return displacement / scaling_factor + identity + + +def get_grid_scaling_factor(grid_shape: Tuple[int, int]) -> float: + return 1 / (max(grid_shape) - 1) + + +def map_coordinates( + image: DenseArrayType, + displacement: DenseArrayType, + **kwargs, +) -> DenseArrayType: + """Deform""" + dp = get_dispatcher(image) + ndi = get_ndimage_module(dp) + return ndi.map_coordinates( + image, + displacement, + order=kwargs.pop("order", 1), + **kwargs, + ) + + +def create_image_pyramid( + image: DenseArrayType, + n_levels: int, + downscale_factor: float = 2.0, + **kwargs, +) -> Iterator[DenseArrayType]: + """Create an iterator of an image resized by a constant factor""" + smallest_dimension = min(image.shape) + dp = get_dispatcher(image) + ndi = get_ndimage_module(dp) + sf = 1 / downscale_factor + if smallest_dimension * sf ** (n_levels - 1) < 2: + raise ValueError("The image size is reduced too much.") + for level in reversed(range(n_levels)): + yield ndi.zoom( + image, + sf**level, + order=kwargs.pop("order", 1), + **kwargs, + ) + + +def resize_image_stack( + image_stack: DenseArrayType, + new_size: Tuple[int, int], + **kwargs, +) -> DenseArrayType: + """Resize image stack to a new size. It is assumed the stack axis is at index 0""" + dp = get_dispatcher(image_stack) + ndi = get_ndimage_module(dp) + original_shape = image_stack.shape + new_shape = (image_stack.shape[0], *new_size) + zoom = tuple(new / original for new, original in zip(new_shape, original_shape)) + output = dp.empty(new_shape, image_stack.dtype) + ndi.zoom( + image_stack, + zoom, + output=output, + order=kwargs.pop("order", 1), + **kwargs, + ) + return output + + +class OneValueCache(dict): + # adapted from https://stackoverflow.com/questions/2437617 + def __init__(self): + dict.__init__(self) + + def __setitem__(self, key, value): + if key not in self: + self.clear() + dict.__setitem__(self, key, value) + + +def get_dispatcher(array: DenseArrayType) -> ModuleType: + """Returns the correct dispatcher to work with an array""" + if CUPY_IS_INSTALLED and isinstance(array, cp.ndarray): + return cp + elif isinstance(array, np.ndarray): + return np + else: + raise ValueError(f"Array type is {type(array)}, must be array.") + + +def get_sparse_module(dispatcher: ModuleType) -> ModuleType: + if dispatcher == cp: + return csparse + elif dispatcher == np: + return sparse + else: + raise ValueError("Array must be numpy or cupy array") + + +def get_ndimage_module(dispatcher: ModuleType) -> ModuleType: + if dispatcher == cp: + return cndimage + elif dispatcher == np: + return ndimage + else: + raise ValueError("Array must be numpy or cupy array") + + +class Matrix: + + _to_sparse_methods_cpu: Mapping[str, Callable] = { + "coo": sparse.coo_matrix, + "csc": sparse.csc_matrix, + "csr": sparse.csr_matrix, + } + + _to_sparse_methods_gpu: Dict[str, Callable] = {} + if CUPY_IS_INSTALLED: + _to_sparse_methods_gpu["coo"] = csparse.coo_matrix + _to_sparse_methods_gpu["csc"] = csparse.csc_matrix + _to_sparse_methods_gpu["csr"] = csparse.csr_matrix + + def __init__(self, matrix: ArrayType) -> None: + raise NotImplementedError("The array module could not be determined") + + @property + def module(self) -> ModuleType: + raise NotImplementedError("The array module could not be determined") + + @classmethod + def new(cls, matrix: ArrayType) -> Matrix: + matrix_type = cls.get_matrix_type(matrix) + return matrix_type(matrix) + + @classmethod + def get_matrix_type(cls, matrix: ArrayType) -> type[Matrix]: + if CUPY_IS_INSTALLED and isinstance(matrix, cp.ndarray): + return CupyMatrix + elif isinstance(matrix, np.ndarray): + return NumpyMatrix + elif CUPY_IS_INSTALLED and isinstance(matrix, csparse.spmatrix): + return SparseCupyMatrix + elif isinstance(matrix, sparse.spmatrix): + return SparseNumpyMatrix + else: + raise ValueError(f"Array type is {type(matrix)}, must be array.") + + def to_host(self) -> Matrix: + return self + + def to_device(self) -> Matrix: + return self + + def to_sparse(self, sparse_type: str) -> Matrix: + return self + + def to_dense(self) -> Matrix: + return self + + def solve(self, b: ArrayType) -> ArrayType: + raise NotImplementedError("No solving method is implemented") + + def solve_lstsq(self, b: ArrayType) -> ArrayType: + raise NotImplementedError("No least squares method is implemented") + + +class NumpyMatrix(Matrix): + def __init__(self, matrix: np.ndarray) -> None: + self.data = matrix + + @property + def module(self) -> ModuleType: + return np + + def to_device(self) -> CupyMatrix: + if not CUPY_IS_INSTALLED: + raise RuntimeError("Cupy must be installed.") + return CupyMatrix(cp.array(self.data)) + + def to_sparse(self, sparse_type: str = "coo") -> SparseNumpyMatrix: + sparse_array = self._to_sparse_methods_cpu[sparse_type](self.data) + return SparseNumpyMatrix(sparse_array) + + def solve(self, b: np.ndarray) -> np.ndarray: + return scipy.linalg.solve(self.data, b) + + def solve_lstsq(self, b: np.ndarray) -> np.ndarray: + """Solve linear least squares directly for predictable performance""" + M = self.data + M_t = M.T + MtM = M_t.dot(M) + Mtb = M_t.dot(b) + # MtM will always be a positive definite matrix + return scipy.linalg.solve(MtM, Mtb, assume_a="pos") + + +class CupyMatrix(Matrix): + def __init__(self, matrix: cp.ndarray) -> None: + self.data = matrix + + @property + def module(self) -> ModuleType: + raise cp + + def to_host(self) -> NumpyMatrix: + return NumpyMatrix(cp.asnumpy(self.data)) + + def to_sparse(self, sparse_type: str = "coo") -> SparseCupyMatrix: + sparse_array = self._to_sparse_methods_gpu[sparse_type](self.data) + return SparseCupyMatrix(sparse_array) + + def solve(self, b: np.ndarray) -> np.ndarray: + return cp.linalg.solve(self.data, b) + + def solve_lstsq(self, b: np.ndarray) -> np.ndarray: + """Solve linear least squares directly for predictable performance""" + M = self.data + M_t = M.T + MtM = M_t.dot(M) + Mtb = M_t.dot(b) + return cp.linalg.solve(MtM, Mtb) + + +class SparseNumpyMatrix(Matrix): + def __init__(self, matrix: sparse.spmatrix) -> None: + self.data = matrix + + def to_dense(self) -> NumpyMatrix: + return NumpyMatrix(self.data.toarray()) + + def convert_to_csc(self) -> None: + self.data = sparse.csc_matrix(self.data) + + def convert_to_csr(self) -> None: + self.data = sparse.csr_matrix(self.data) + + def convert_to_coo(self) -> None: + self.data = sparse.coo_matrix(self.data) + + def convert_to(self, sparse_type: str) -> None: + self.data = self._to_sparse_methods_cpu[sparse_type](self.data) + + def to_device(self) -> SparseCupyMatrix: + if not CUPY_IS_INSTALLED: + raise RuntimeError("Cupy must be installed.") + matrix = self.data + data = cp.array(matrix.data) + + if sparse.isspmatrix_coo(matrix): + row = cp.array(matrix.row) + col = cp.array(matrix.col) + return SparseCupyMatrix(csparse.coo_matrix((data, (row, col)))) + + elif sparse.isspmatrix_csr(matrix): + indices = cp.array(matrix.indices) + indptr = cp.array(matrix.indptr) + return SparseCupyMatrix(csparse.csr_matrix((data, indices, indptr))) + + elif sparse.isspmatrix_csc(matrix): + indices = cp.array(matrix.indices) + indptr = cp.array(matrix.indptr) + return SparseCupyMatrix(csparse.csc_matrix((data, indices, indptr))) + + else: + raise RuntimeError("Unrecognized sparse matrix format.") + + def solve(self, b: np.ndarray) -> np.ndarray: + return sparse.linalg.spsolve(self.data, b) + + def solve_lstsq(self, b: np.ndarray) -> np.ndarray: + """Solve linear least squares directly for predictable performance""" + M = self.data + M_t = M.T + MtM = M_t.dot(M) + Mtb = M_t.dot(b) + return sparse.linalg.spsolve(MtM, Mtb) + + +class SparseCupyMatrix(Matrix): + def __init__(self, matrix: csparse.spmatrix) -> None: + self.data = matrix + + def to_dense(self) -> CupyMatrix: + return CupyMatrix(self.data.toarray()) + + def convert_to_csc(self) -> None: + self.data = csparse.csc_matrix(self.data) + + def convert_to_csr(self) -> None: + self.data = csparse.csr_matrix(self.data) + + def convert_to_coo(self) -> None: + self.data = csparse.coo_matrix(self.data) + + def convert_to(self, sparse_type: str) -> None: + self.data = self._to_sparse_methods_gpu[sparse_type](self.data) + + def to_host(self) -> SparseNumpyMatrix: + matrix = self.data + data = cp.asnumpy(matrix.data) + + if csparse.isspmatrix_coo(matrix): + row = cp.asnumpy(matrix.row) + col = cp.asnumpy(matrix.col) + return SparseNumpyMatrix(sparse.coo_matrix((data, (row, col)))) + + elif csparse.isspmatrix_csr(matrix): + indices = cp.asnumpy(matrix.indices) + indptr = cp.asnumpy(matrix.indptr) + return SparseNumpyMatrix(sparse.csr_matrix((data, indices, indptr))) + + elif csparse.isspmatrix_csc(matrix): + indices = cp.asnumpy(matrix.indices) + indptr = cp.asnumpy(matrix.indptr) + return SparseNumpyMatrix(sparse.csc_matrix((data, indices, indptr))) + + else: + raise RuntimeError("Unrecognized sparse matrix format.") + + def solve(self, b: cp.ndarray) -> cp.ndarray: + return clinalg.spsolve(self.data, b) + + def solve_lstsq(self, b: cp.ndarray) -> cp.ndarray: + """Solve linear least squares directly for predictable performance""" + M = self.data + M_t = M.T + MtM = M_t.dot(M) + Mtb = M_t.dot(b) + return clinalg.spsolve(MtM, Mtb) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..37471f8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,43 @@ +# Tool configuration + +[tool.black] +line-length = 88 +target-version = ['py38', 'py39'] +include = '\.pyi?$' +exclude = "venv" + +[tool.isort] +profile = "black" +skip_gitignore = true + +[tool.mypy] +show_error_codes = true +warn_unused_ignores = true +warn_redundant_casts = true +# disallow_untyped_defs = true +ignore_missing_imports = true +exclude = ["venv", "docs", "setup.py", "tests"] +warn_unused_configs = true +# check_untyped_defs = true + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra --import-mode=importlib" +testpaths = [ + "tests", +] +pythonpath = [ + ".", "pymatchseries", +] +xfail_strict = true + +[coverage.run] +source = "pymatchseries" +include = "*/pymatchseries/*" +omit = [ + "*/pymatchseries/__init__.py", +] +relative_files = true + +[coverage.report] +precision=2 diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..403e015 --- /dev/null +++ b/requirements.in @@ -0,0 +1,3 @@ +hyperspy>=1.7 +tabulate>=0.9.0 +pydantic>=1.10.4 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 91294cc..0000000 --- a/setup.cfg +++ /dev/null @@ -1,15 +0,0 @@ -[tool:pytest] -addopts = -ra -testpaths = pymatchseries/tests/ -xfail_strict = True - -[coverage:run] -source = pymatchseries -include = */pymatchseries/* -omit = - */pymatchseries/__init__.py - -relative_files = True - -[coverage:report] -precision=2 diff --git a/setup.py b/setup.py index 29e6eee..bd3d33e 100644 --- a/setup.py +++ b/setup.py @@ -1,45 +1,47 @@ -from setuptools import setup, find_packages -from itertools import chain +from setuptools import find_packages, setup with open("README.md") as f: readme = f.read() +with open("requirements.in") as f: + requirements = f.read().splitlines() + +with open("dev-requirements.in") as f: + dev_requirements = f.read().splitlines() # Projects with optional features for building the documentation and running # tests. From setuptools: # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras-optional-features-with-their-own-dependencies extra_feature_requirements = { - "doc": ["sphinx >= 3.0.2", "sphinx-rtd-theme >= 0.4.3"], - "tests": ["pytest >= 5.4", "pytest-cov >= 2.8.1", "coverage >= 5.0"], + "dev": dev_requirements, + "gpu": ["cupy"], } -extra_feature_requirements["dev"] = ["black >= 19.3b0", "pre-commit >= 1.16"] + list( - chain(*list(extra_feature_requirements.values())) -) setup( name="pyMatchSeries", - version="0.1.0", - description=("A python wrapper for the non-rigid-registration " - "code match-series"), - url='https://github.com/din14970/pyMatchSeries', - author='Niels Cautaerts', - author_email='nielscautaerts@hotmail.com', - license='GPL-3.0', + version="0.3.0", + description=( + "A python implementation of joint-non-rigid-registration " + "and wrapper of match-series." + ), + url="https://github.com/din14970/pyMatchSeries", + author="Niels Cautaerts", + author_email="nielscautaerts@hotmail.com", + license="GPL-3.0", long_description=readme, long_description_content_type="text/markdown", - classifiers=['Topic :: Scientific/Engineering :: Physics', - 'Intended Audience :: Science/Research', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8'], - keywords='TEM', + classifiers=[ + "Topic :: Scientific/Engineering :: Physics", + "Intended Audience :: Science/Research", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + ], + keywords="TEM", extras_require=extra_feature_requirements, packages=find_packages(exclude=["*tests*", "*examples*"]), - package_data={'': ['pymatchseries/default_parameters.param']}, + package_data={"": ["pymatchseries/default_parameters.param"]}, + entry_points={"hyperspy.extensions": "pyMatchSeries = pyMatchSeries"}, include_package_data=True, - install_requires=[ - 'hyperspy>=1.6.1', - 'Pillow', - 'tabulate', - ], + install_requires=requirements, ) diff --git a/tests/data/expected_derivative_regularizer.npz b/tests/data/expected_derivative_regularizer.npz new file mode 100644 index 0000000..8b20cee Binary files /dev/null and b/tests/data/expected_derivative_regularizer.npz differ diff --git a/tests/data/expected_energy_gradient.npy b/tests/data/expected_energy_gradient.npy new file mode 100644 index 0000000..06a9fc9 Binary files /dev/null and b/tests/data/expected_energy_gradient.npy differ diff --git a/tests/data/expected_eval_residual.npy b/tests/data/expected_eval_residual.npy new file mode 100644 index 0000000..98ca992 Binary files /dev/null and b/tests/data/expected_eval_residual.npy differ diff --git a/tests/data/expected_eval_residual_grad.npz b/tests/data/expected_eval_residual_grad.npz new file mode 100644 index 0000000..17878aa Binary files /dev/null and b/tests/data/expected_eval_residual_grad.npz differ diff --git a/tests/implementation/test_cuda_kernels.py b/tests/implementation/test_cuda_kernels.py new file mode 100644 index 0000000..04144c2 --- /dev/null +++ b/tests/implementation/test_cuda_kernels.py @@ -0,0 +1,121 @@ +import numpy as np +import pytest + +from pymatchseries.implementation.cuda_kernels import ( + evaluate_at_quad_points_gpu, + evaluate_pd_on_quad_points_gpu, + interpolate_gpu, + interpolate_gradient_gpu, +) +from pymatchseries.implementation.interpolation import ( + interpolate_cpu, + interpolate_gradient_cpu, +) +from pymatchseries.implementation.quadrature import ( + evaluate_at_quad_points_cpu, + evaluate_pd_on_quad_points_cpu, +) +from pymatchseries.utils import CUPY_IS_INSTALLED, cp + +RTOL = 1e-6 + + +@pytest.mark.skipif( + not CUPY_IS_INSTALLED, + reason="cupy not installed, gpu probably not installed", +) +def test_interpolate_gpu() -> None: + image = np.random.rand(400, 500).astype(np.float32) + coordinates = np.mgrid[0 : image.shape[0], 0 : image.shape[1]] + coordinates = np.moveaxis(coordinates, 0, -1) + coordinates = coordinates.astype(np.float32) + jitter = np.random.rand(*coordinates.shape) - 0.5 + coordinates_2 = coordinates + jitter + + image_gpu = cp.asarray(image) + coordinates_gpu = cp.asarray(coordinates) + coordinates_2_gpu = cp.asarray(coordinates_2) + + result_cpu_1 = interpolate_cpu(image, coordinates) + result_gpu_1 = interpolate_gpu(image_gpu, coordinates_gpu) + + result_cpu_2 = interpolate_cpu(image, coordinates_2) + result_gpu_2 = interpolate_gpu(image_gpu, coordinates_2_gpu) + + np.testing.assert_allclose(cp.asnumpy(result_gpu_1), result_cpu_1, rtol=RTOL) + np.testing.assert_allclose(cp.asnumpy(result_gpu_2), result_cpu_2, rtol=RTOL) + + +@pytest.mark.skipif( + not CUPY_IS_INSTALLED, + reason="cupy not installed, gpu probably not installed", +) +def test_interpolate_gradient_gpu() -> None: + image = np.random.rand(400, 500).astype(np.float32) + coordinates = np.mgrid[0 : image.shape[0], 0 : image.shape[1]] + coordinates = np.moveaxis(coordinates, 0, -1) + coordinates = coordinates.astype(np.float32) + jitter = np.random.rand(*coordinates.shape) - 0.5 + coordinates_2 = coordinates + jitter + + image_gpu = cp.asarray(image) + coordinates_gpu = cp.asarray(coordinates) + coordinates_2_gpu = cp.asarray(coordinates_2) + + result_cpu_1 = interpolate_gradient_cpu(image, coordinates) + result_gpu_1 = interpolate_gradient_gpu(image_gpu, coordinates_gpu) + + result_cpu_2 = interpolate_gradient_cpu(image, coordinates_2) + result_gpu_2 = interpolate_gradient_gpu(image_gpu, coordinates_2_gpu) + + np.testing.assert_allclose(cp.asnumpy(result_gpu_1), result_cpu_1, rtol=RTOL) + np.testing.assert_allclose(cp.asnumpy(result_gpu_2), result_cpu_2, rtol=RTOL) + + +@pytest.mark.skipif( + not CUPY_IS_INSTALLED, + reason="cupy not installed, gpu probably not installed", +) +def test_evaluate_at_quad_points_gpu() -> None: + image = np.random.rand(400, 500).astype(np.float32) + node_weights = np.random.rand(4, 11).astype(np.float32) + + image_gpu = cp.asarray(image) + node_weights_gpu = cp.asarray(node_weights) + + result_cpu = evaluate_at_quad_points_cpu(image, node_weights) + result_gpu = evaluate_at_quad_points_gpu(image_gpu, node_weights_gpu) + + np.testing.assert_allclose(cp.asnumpy(result_gpu), result_cpu, rtol=RTOL) + + +@pytest.mark.skipif( + not CUPY_IS_INSTALLED, + reason="cupy not installed, gpu probably not installed", +) +def test_evaluate_pd_on_quad_points_gpu() -> None: + K = 11 + N = 400 + M = 500 + quadrature_values = np.random.rand(N - 1, M - 1, K).astype(np.float32) + quad_weights_sqrt = np.random.rand(K).astype(np.float32) + node_weights = np.random.rand(4, K).astype(np.float32) + + quadrature_values_gpu = cp.asarray(quadrature_values) + quad_weights_sqrt_gpu = cp.asarray(quad_weights_sqrt) + node_weights_gpu = cp.asarray(node_weights) + + result_cpu = evaluate_pd_on_quad_points_cpu( + quadrature_values, + quad_weights_sqrt, + node_weights, + ) + result_gpu = evaluate_pd_on_quad_points_gpu( + quadrature_values_gpu, + quad_weights_sqrt_gpu, + node_weights_gpu, + ) + + np.testing.assert_allclose(cp.asnumpy(result_gpu[0]), result_cpu[0], rtol=RTOL) + np.testing.assert_allclose(cp.asnumpy(result_gpu[1]), result_cpu[1], rtol=RTOL) + np.testing.assert_allclose(cp.asnumpy(result_gpu[2]), result_cpu[2], rtol=RTOL) diff --git a/tests/implementation/test_implementation.py b/tests/implementation/test_implementation.py new file mode 100644 index 0000000..602816a --- /dev/null +++ b/tests/implementation/test_implementation.py @@ -0,0 +1,26 @@ +import numpy as np + +from pymatchseries.implementation.implementation import JNRR, JNNRConfig + + +def test_multilevel() -> None: + shape = (128, 132) + im1 = np.zeros(shape, dtype=np.float32) + im2 = np.zeros(shape, dtype=np.float32) + + im1[2 * 10 : 2 * 30, 2 * 15 : 2 * 45] = 1 + im2[2 * 25 : 2 * 45, 2 * 25 : 2 * 55] = 1 + + configuration = JNNRConfig() + configuration.solver.show_progress = False + configuration.n_levels = 5 + L = 0.1 + + displacement = JNRR._get_multilevel_displacement_field( + image_deformed=im1, + image_reference=im2, + regularization_constant=L, + configuration=configuration, + ) + + assert displacement.shape == (2, *shape) diff --git a/tests/implementation/test_interpolation.py b/tests/implementation/test_interpolation.py new file mode 100644 index 0000000..d1382ca --- /dev/null +++ b/tests/implementation/test_interpolation.py @@ -0,0 +1,132 @@ +import numpy as np + +from pymatchseries.implementation.interpolation import ( + interpolate_cpu, + interpolate_gradient_cpu, +) + +RTOL = 1e-6 + + +def _get_image_and_coords(): + np.random.seed(42) + image = np.array( + [ + [2, 5, 7, 8, 1], + [3, 1, 2, 0, 0], + [1, 1, 0, 4, 6], + [1, 2, 1, 0, 5], + ] + ).astype(np.float32) + coordinates = np.mgrid[0 : image.shape[0], 0 : image.shape[1]] + coordinates = np.moveaxis(coordinates, 0, -1) + coordinates = coordinates.astype(np.float32) + # random array to be added to coordinates + jitter = np.array( + [ + [ + [-0.12545988, 0.45071431], + [0.23199394, 0.09865848], + [-0.34398136, -0.34400548], + [-0.44191639, 0.36617615], + [0.10111501, 0.20807258], + ], + [ + [-0.47941551, 0.46990985], + [0.33244264, -0.28766089], + [-0.31817503, -0.31659549], + [-0.19575776, 0.02475643], + [-0.06805498, -0.20877086], + ], + [ + [0.11185289, -0.36050614], + [-0.20785535, -0.13363816], + [-0.04393002, 0.28517596], + [-0.30032622, 0.01423444], + [0.09241457, -0.45354959], + ], + [ + [0.10754485, -0.32947588], + [-0.43494841, 0.44888554], + [0.46563203, 0.30839735], + [-0.19538623, -0.40232789], + [0.18423303, -0.05984751], + ], + ] + ) + coordinates_2 = coordinates + jitter + return ( + image, + coordinates, + coordinates_2, + ) + + +def test_interpolate_cpu(): + image, coordinates, coordinates_2 = _get_image_and_coords() + result = interpolate_cpu(image, coordinates) + + result_2 = interpolate_cpu(image, coordinates_2) + expected = np.array( + [ + [3.3521428, 4.2464533, 6.3119893, 5.436767, 0.898885], + [2.707175, 1.3840603, 3.173547, 1.5321381, 0.16751026], + [1.0, 1.0555547, 1.1533972, 2.818614, 4.8747425], + [1.0, 1.116166, 0.69160265, 0.79082614, 4.7007623], + ] + ) + + np.testing.assert_allclose(image, result, rtol=RTOL) + np.testing.assert_allclose(expected, result_2, rtol=RTOL) + + +def test_interpolate_gradient_cpu(): + image, coordinates, coordinates_2 = _get_image_and_coords() + result = interpolate_gradient_cpu(image, coordinates) + verify = np.array( + [ + [[1.0, 3.0], [-4.0, 2.0], [-5.0, 1.0], [-8.0, -7.0], [-1.0, -7.0]], + [[-2.0, -2.0], [0.0, 1.0], [-2.0, -2.0], [4.0, 0.0], [6.0, 0.0]], + [[0.0, 0.0], [1.0, -1.0], [1.0, 4.0], [-4.0, 2.0], [-1.0, 2.0]], + [[0.0, 1.0], [1.0, -1.0], [1.0, -1.0], [-4.0, 5.0], [-1.0, 5.0]], + ], + dtype=np.float32, + ) + + result_2 = interpolate_gradient_cpu(image, coordinates_2) + verify_2 = np.array( + [ + [ + [0.0, 3.0], + [-4.0986586, 1.7680061], + [0.0, 2.0], + [0.0, -7.0], + [-1.0, 0.0], + ], + [ + [-1.3495493, 0.39707753], + [-0.5753218, -1.3351147], + [-4.6834044, 1.3181751], + [-7.826705, -1.3703043], + [-2.461396, -0.47638488], + ], + [ + [0.0, 0.0], + [-0.26727632, -0.41571072], + [-0.28894424, 3.73642], + [4.028469, 1.3993475], + [-2.3606489, 2.2772436], + ], + [ + [0.0, 0.0], + [1.0, -1.0], + [0.0, -1.0], + [-1.9883605, -0.02306885], + [0.0, 5.0], + ], + ], + dtype=np.float32, + ) + + np.testing.assert_allclose(verify, result, rtol=RTOL) + np.testing.assert_allclose(verify_2, result_2, rtol=RTOL) diff --git a/tests/implementation/test_objective_functions.py b/tests/implementation/test_objective_functions.py new file mode 100644 index 0000000..5ecfd5f --- /dev/null +++ b/tests/implementation/test_objective_functions.py @@ -0,0 +1,251 @@ +from pathlib import Path + +import numpy as np +import pytest +from scipy import sparse + +from pymatchseries.implementation.objective_functions import ( + RegistrationObjectiveFunction, +) +from pymatchseries.utils import CUPY_IS_INSTALLED, cp + +this_file = Path(__file__) +data_folder = this_file.parent.parent / "data" + + +image_deformed = np.array( + [ + [9, 8, 0, 5, 5, 5], + [4, 7, 9, 5, 2, 3], + [9, 4, 5, 7, 1, 8], + [8, 1, 2, 7, 2, 6], + [8, 2, 9, 7, 0, 3], + ] +).astype(np.float32) + +image_reference = np.array( + [ + [3, 4, 2, 6, 9, 0], + [5, 6, 2, 4, 7, 1], + [0, 5, 1, 2, 7, 2], + [0, 7, 2, 0, 0, 8], + [6, 4, 0, 9, 7, 8], + ] +).astype(np.float32) +regularization_constant = 10 +number_of_quadrature_points = 3 + +regobj_cpu = RegistrationObjectiveFunction( + image_deformed=image_deformed, + image_reference=image_reference, + regularization_constant=regularization_constant, + number_of_quadrature_points=number_of_quadrature_points, +) + +params = [regobj_cpu] + +if CUPY_IS_INSTALLED: + image_deformed_gpu = cp.asarray(image_deformed) + image_reference_gpu = cp.asarray(image_reference) + regobj_gpu = RegistrationObjectiveFunction( + image_deformed=image_deformed_gpu, + image_reference=image_reference_gpu, + regularization_constant=regularization_constant, + number_of_quadrature_points=number_of_quadrature_points, + ) + params.append(regobj_gpu) + + +DATA = { + "expected_derivative": sparse.load_npz( + data_folder / "expected_derivative_regularizer.npz" + ), + "expected_eval_residual": np.load(data_folder / "expected_eval_residual.npy"), + "expected_eval_residual_grad": sparse.load_npz( + data_folder / "expected_eval_residual_grad.npz" + ), + "expected_energy_gradient": np.load(data_folder / "expected_energy_gradient.npy"), +} + +RTOL = 1e-6 +ATOL = 1e-6 + + +@pytest.mark.parametrize("of", params) +class TestRegistrationObjectiveFunction: + def test_derivative_of_regularizer( + self, + of: RegistrationObjectiveFunction, + ) -> None: + dreg = of.derivative_of_regularizer + expected_shape = ( + 4 * of.quadrature.total_number_of_quadrature_points, + 2 * of.number_of_nodes, + ) + assert dreg.shape == expected_shape + + if of.dispatcher != np: + dreg = dreg.get() + + np.testing.assert_allclose( + DATA["expected_derivative"].data, dreg.data, rtol=RTOL + ) + np.testing.assert_array_equal(DATA["expected_derivative"].indices, dreg.indices) + np.testing.assert_array_equal(DATA["expected_derivative"].indptr, dreg.indptr) + + def test_evaluate_residual( + self, + of: RegistrationObjectiveFunction, + ) -> None: + dp = of.dispatcher + v = self.displacement_vector + if dp != np: + v = cp.asarray(v) + error = of.evaluate_residual(v) + expected_shape = (5 * of.quadrature.total_number_of_quadrature_points,) + assert error.shape == expected_shape + + if dp != np: + error = error.get() + + np.testing.assert_allclose( + DATA["expected_eval_residual"], error, rtol=RTOL, atol=ATOL + ) + + def test_evaluate_residual_gradient( + self, + of: RegistrationObjectiveFunction, + ) -> None: + dp = of.dispatcher + v = self.displacement_vector + if dp != np: + v = cp.asarray(v) + error_grad = of.evaluate_residual_gradient(v) + expected_shape = ( + 5 * of.quadrature.total_number_of_quadrature_points, + 2 * of.number_of_nodes, + ) + assert error_grad.shape == expected_shape + + if dp != np: + error_grad = error_grad.get() + + np.testing.assert_allclose( + DATA["expected_eval_residual_grad"].data, + error_grad.data, + rtol=RTOL, + atol=ATOL, + ) + np.testing.assert_array_equal( + DATA["expected_eval_residual_grad"].indices, error_grad.indices + ) + np.testing.assert_array_equal( + DATA["expected_eval_residual_grad"].indptr, error_grad.indptr + ) + + def test_evaluate_energy( + self, + of: RegistrationObjectiveFunction, + ) -> None: + dp = of.dispatcher + v = self.displacement_vector + if dp != np: + v = cp.asarray(v) + error = of.evaluate_energy(v) + + if dp != np: + error = error.get() + + expected = 97.65085868419155 + + assert abs(error - expected) / expected < RTOL + + def test_evaluate_energy_gradient( + self, + of: RegistrationObjectiveFunction, + ) -> None: + dp = of.dispatcher + v = self.displacement_vector + if dp != np: + v = cp.asarray(v) + error_grad = of.evaluate_energy_gradient(v) + expected_shape = (2 * of.number_of_nodes,) + assert error_grad.shape == expected_shape + + if dp != np: + error_grad = error_grad.get() + + np.testing.assert_allclose( + DATA["expected_energy_gradient"], + error_grad, + rtol=RTOL, + atol=ATOL, + ) + + # Some data + + # random displacement vector + displacement_vector = np.array( + [ + 0.2, + 0.4, + 0.7, + 0.5, + 0.7, + 0.7, + 0.3, + 0.1, + 0.1, + 0.9, + 0.7, + 0.8, + 0.7, + 0.3, + 0.7, + 0.1, + 0.5, + 0.8, + 0.2, + 0.9, + 0.3, + 0.9, + 0.5, + 0.3, + 0.2, + 0.6, + 0.6, + 0.1, + 0.6, + 0.5, + 0.2, + 0.1, + 0.7, + 0.0, + 0.6, + 0.4, + 0.4, + 0.3, + 0.8, + 0.0, + 0.2, + 0.7, + 0.6, + 0.6, + 0.4, + 0.9, + 0.8, + 0.8, + 0.5, + 0.1, + 0.1, + 0.4, + 0.0, + 0.6, + 0.2, + 0.4, + 0.3, + 0.6, + 0.9, + 0.1, + ] + ).astype(np.float32) diff --git a/tests/implementation/test_quadrature.py b/tests/implementation/test_quadrature.py new file mode 100644 index 0000000..c5212fc --- /dev/null +++ b/tests/implementation/test_quadrature.py @@ -0,0 +1,1868 @@ +from math import prod + +import numpy as np +import pytest + +from pymatchseries.implementation.quadrature import ( + Quadrature2D, + evaluate_at_quad_points_cpu, + evaluate_pd_on_quad_points_cpu, +) +from pymatchseries.utils import CUPY_IS_INSTALLED, cp + +RTOL = 1e-6 + + +if CUPY_IS_INSTALLED: + params = [ + Quadrature2D((7, 8), 2, np), + Quadrature2D((4, 5), 3, np), + Quadrature2D((7, 8), 2, cp), + Quadrature2D((4, 5), 3, cp), + ] +else: + params = [ + Quadrature2D((7, 8), 2, np), + Quadrature2D((4, 5), 3, np), + ] + + +@pytest.mark.parametrize("quad", params) +class TestQuadrature: + """Simply tests whether we get the right array shapes from all functions""" + + def test_evaluate(self, quad: Quadrature2D) -> None: + dp = quad.dispatcher + image = dp.random.rand(*quad.grid_shape).astype(dp.float32) + result = quad.evaluate(image) + expected_shape = ( + image.shape[0] - 1, + image.shape[1] - 1, + quad.number_of_quadrature_points, + ) + assert result.shape == expected_shape + + def test_evaluate_partial_derivatives(self, quad: Quadrature2D) -> None: + dp = quad.dispatcher + shape = quad.cell_grid_shape + quad_values = dp.random.rand(*shape).astype(dp.float32) + node_weights = dp.random.rand(4, quad.number_of_quadrature_points).astype( + dp.float32 + ) + result = quad.evaluate_partial_derivatives(quad_values, node_weights) + expected_shape = (quad.total_number_of_quadrature_points * 4,) + assert len(result) == 3 + assert result[0].shape == expected_shape + assert result[1].shape == expected_shape + assert result[2].shape == expected_shape + + def test_quadrature_points(self, quad: Quadrature2D) -> None: + points = quad.quadrature_points + assert points.shape == (quad.number_of_quadrature_points, 2) + + def test_quadrature_points_weights(self, quad: Quadrature2D) -> None: + weights = quad.quadrature_point_weights + assert weights.shape == (quad.number_of_quadrature_points,) + + def test_quadrature_points_weights_sqrt(self, quad: Quadrature2D) -> None: + weights = quad.quadrature_point_weights_sqrt + assert weights.shape == (quad.number_of_quadrature_points,) + + def test_quadrature_points_x_coordinate(self, quad: Quadrature2D) -> None: + points = quad.quadrature_points_x_coordinate + assert points.shape == (quad.number_of_quadrature_points,) + + def test_quadrature_points_y_coordinate(self, quad: Quadrature2D) -> None: + points = quad.quadrature_points_y_coordinate + assert points.shape == (quad.number_of_quadrature_points,) + + def test_node_weights(self, quad: Quadrature2D) -> None: + node_weights = quad.node_weights + assert node_weights.shape == (4, quad.number_of_quadrature_points) + + def test_basis_f_at_points(self, quad: Quadrature2D) -> None: + basis_f = quad.basis_f_at_points + assert basis_f.shape == (4, quad.number_of_quadrature_points) + + def test_dx_node_weights(self, quad: Quadrature2D) -> None: + values = quad.dx_node_weights + assert values.shape == (4, quad.number_of_quadrature_points) + + def test_dy_node_weights(self, quad: Quadrature2D) -> None: + values = quad.dy_node_weights + assert values.shape == (4, quad.number_of_quadrature_points) + + def test_basis_dfx_at_points(self, quad: Quadrature2D) -> None: + values = quad.basis_dfx_at_points + assert values.shape == (4, quad.number_of_quadrature_points) + + def test_basis_dfy_at_points(self, quad: Quadrature2D) -> None: + values = quad.basis_dfy_at_points + assert values.shape == (4, quad.number_of_quadrature_points) + + +def test_evaluate_at_quad_points_cpu() -> None: + image = np.array( + [ + [5, 9, 5, 3, 3], + [2, 1, 8, 0, 0], + [5, 7, 7, 3, 0], + [5, 7, 7, 3, 0], + [2, 1, 8, 0, 0], + [5, 2, 9, 2, 8], + ] + ).astype(np.float32) + node_weights = ( + np.array( + [ + [4, 8, 4, 3, 7, 6, 7, 9, 9, 6, 3], + [7, 3, 8, 9, 9, 3, 9, 3, 1, 5, 7], + [7, 2, 6, 6, 1, 4, 2, 8, 7, 5, 8], + [4, 4, 6, 9, 4, 9, 0, 3, 0, 6, 4], + ] + ).astype(np.float32) + / 20 + ) + + result = evaluate_at_quad_points_cpu(image, node_weights) + expected = np.array( + [ + 5.049999, + 3.75, + 5.5000005, + 5.8499994, + 6.0999994, + 3.7, + 5.9999995, + 4.55, + 3.4, + 4.55, + 4.8999996, + 5.5, + 6.05, + 6.5000005, + 7.5, + 7.0499997, + 7.25, + 5.4999995, + 6.3999996, + 4.6499996, + 6.6, + 5.1, + 4.85, + 3.25, + 4.6000004, + 4.5, + 3.5, + 3.5500002, + 3.8999999, + 5.9, + 5.2, + 4.25, + 5.0, + 1.65, + 1.6500001, + 1.8000001, + 1.8, + 2.3999999, + 1.35, + 2.3999999, + 1.8, + 1.4999999, + 1.6500001, + 1.5, + 3.9, + 2.85, + 4.4, + 5.3999996, + 2.8, + 4.8999996, + 1.65, + 4.1, + 2.7, + 4.2, + 4.05, + 6.85, + 3.6999998, + 7.6000004, + 9.0, + 5.7, + 6.05, + 4.6499996, + 5.5, + 3.3000002, + 6.1500006, + 7.15, + 4.65, + 4.5, + 4.6000004, + 4.65, + 3.75, + 5.15, + 3.5, + 6.8499994, + 6.05, + 5.05, + 4.6, + 1.05, + 0.3, + 0.90000004, + 0.90000004, + 0.15, + 0.6, + 0.3, + 1.2, + 1.05, + 0.75, + 1.2, + 6.6, + 4.9500003, + 7.4000006, + 8.549999, + 6.5499997, + 6.7, + 5.3999996, + 6.3500004, + 4.35, + 6.6000004, + 6.6, + 7.7000003, + 5.95, + 8.400001, + 9.45, + 7.35, + 7.7, + 6.2999997, + 8.05, + 5.95, + 7.700001, + 7.7000003, + 5.4999995, + 4.55, + 5.6, + 5.85, + 4.75, + 5.3, + 4.5, + 6.8499994, + 5.75, + 5.5000005, + 5.4999995, + 1.65, + 1.5, + 1.5, + 1.35, + 1.1999999, + 1.5, + 1.3499999, + 2.55, + 2.3999999, + 1.6500001, + 1.6500001, + 4.35, + 3.4500003, + 4.7000003, + 4.95, + 5.1999993, + 3.4000003, + 5.0999994, + 4.2500005, + 3.3, + 4.05, + 4.2, + 5.7999997, + 5.5499997, + 6.9, + 8.1, + 7.25, + 6.95, + 5.7, + 5.8, + 3.8499997, + 6.5000005, + 5.5, + 5.25, + 4.05, + 5.0, + 4.8, + 4.2, + 4.15, + 4.6, + 6.8, + 6.1, + 4.8500004, + 5.3, + 0.6, + 1.2, + 0.6, + 0.45000002, + 1.05, + 0.90000004, + 1.05, + 1.3499999, + 1.3499999, + 0.90000004, + 0.45000002, + 2.9, + 1.85, + 2.9, + 3.15, + 1.8, + 2.65, + 1.65, + 3.35, + 2.7, + 2.6999998, + 3.0500002, + 5.5, + 3.6000001, + 6.7, + 8.4, + 5.85, + 5.95, + 4.1499996, + 3.8000002, + 1.55, + 5.5, + 5.55, + 5.15, + 4.5, + 4.9, + 4.8, + 3.65, + 5.1000004, + 3.7, + 7.5, + 6.75, + 5.25, + 5.2000003, + 2.3, + 1.8000001, + 3.0, + 4.2, + 1.7, + 4.0, + 0.2, + 2.0, + 0.7, + 2.9, + 2.4, + ] + ).astype(np.float32) + + expected_shape = (image.shape[0] - 1, image.shape[1] - 1, node_weights.shape[1]) + assert result.shape == expected_shape + np.testing.assert_allclose(result.ravel(), expected, rtol=RTOL) + + +def test_evaluate_pd_at_quad_points_cpu() -> None: + # (N - 1) = 4 + # (M - 1) = 5 + # K = 6 + quadrature_values = np.array( + [ + [ + [9, 5, 1, 8, 6, 1], + [5, 1, 7, 2, 1, 1], + [2, 9, 1, 2, 6, 0], + [5, 1, 0, 4, 6, 8], + [5, 4, 1, 7, 4, 9], + ], + [ + [6, 2, 1, 1, 8, 4], + [1, 3, 0, 4, 1, 2], + [5, 1, 0, 0, 7, 1], + [6, 5, 1, 2, 8, 0], + [5, 4, 5, 3, 9, 2], + ], + [ + [0, 9, 3, 2, 6, 4], + [2, 2, 6, 8, 8, 5], + [6, 7, 1, 2, 2, 6], + [4, 3, 3, 5, 9, 7], + [3, 7, 0, 1, 8, 5], + ], + [ + [2, 5, 2, 2, 9, 8], + [7, 0, 2, 8, 4, 6], + [2, 6, 9, 3, 8, 8], + [1, 8, 8, 5, 5, 9], + [5, 9, 8, 7, 2, 6], + ], + ] + ).astype(np.float32) + quad_weights_sqrt = np.array([0, 4, 2, 1, 3, 8]).astype(np.float32) + node_weights = np.array( + [[5, 1, 3, 4, 7, 4], [2, 7, 4, 0, 9, 5], [3, 1, 3, 8, 2, 3], [8, 0, 0, 8, 6, 5]] + ).astype(np.float32) + + data, rows, cols = evaluate_pd_on_quad_points_cpu( + quadrature_values, quad_weights_sqrt, node_weights + ) + + expected_shape = (4 * prod(quadrature_values.shape),) + assert data.shape == expected_shape + assert rows.shape == expected_shape + assert cols.shape == expected_shape + + data_expected = np.array( + [ + 0.0, + 0.0, + 0.0, + 0.0, + 20.0, + 140.0, + 20.0, + 0.0, + 6.0, + 8.0, + 6.0, + 0.0, + 32.0, + 0.0, + 64.0, + 64.0, + 126.0, + 162.0, + 36.0, + 108.0, + 32.0, + 40.0, + 24.0, + 40.0, + 0.0, + 0.0, + 0.0, + 0.0, + 4.0, + 28.0, + 4.0, + 0.0, + 42.0, + 56.0, + 42.0, + 0.0, + 8.0, + 0.0, + 16.0, + 16.0, + 21.0, + 27.0, + 6.0, + 18.0, + 32.0, + 40.0, + 24.0, + 40.0, + 0.0, + 0.0, + 0.0, + 0.0, + 36.0, + 252.0, + 36.0, + 0.0, + 6.0, + 8.0, + 6.0, + 0.0, + 8.0, + 0.0, + 16.0, + 16.0, + 126.0, + 162.0, + 36.0, + 108.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 4.0, + 28.0, + 4.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 16.0, + 0.0, + 32.0, + 32.0, + 126.0, + 162.0, + 36.0, + 108.0, + 256.0, + 320.0, + 192.0, + 320.0, + 0.0, + 0.0, + 0.0, + 0.0, + 16.0, + 112.0, + 16.0, + 0.0, + 6.0, + 8.0, + 6.0, + 0.0, + 28.0, + 0.0, + 56.0, + 56.0, + 84.0, + 108.0, + 24.0, + 72.0, + 288.0, + 360.0, + 216.0, + 360.0, + 0.0, + 0.0, + 0.0, + 0.0, + 8.0, + 56.0, + 8.0, + 0.0, + 6.0, + 8.0, + 6.0, + 0.0, + 4.0, + 0.0, + 8.0, + 8.0, + 168.0, + 216.0, + 48.0, + 144.0, + 128.0, + 160.0, + 96.0, + 160.0, + 0.0, + 0.0, + 0.0, + 0.0, + 12.0, + 84.0, + 12.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 16.0, + 0.0, + 32.0, + 32.0, + 21.0, + 27.0, + 6.0, + 18.0, + 64.0, + 80.0, + 48.0, + 80.0, + 0.0, + 0.0, + 0.0, + 0.0, + 4.0, + 28.0, + 4.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 147.0, + 189.0, + 42.0, + 126.0, + 32.0, + 40.0, + 24.0, + 40.0, + 0.0, + 0.0, + 0.0, + 0.0, + 20.0, + 140.0, + 20.0, + 0.0, + 6.0, + 8.0, + 6.0, + 0.0, + 8.0, + 0.0, + 16.0, + 16.0, + 168.0, + 216.0, + 48.0, + 144.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 16.0, + 112.0, + 16.0, + 0.0, + 30.0, + 40.0, + 30.0, + 0.0, + 12.0, + 0.0, + 24.0, + 24.0, + 189.0, + 243.0, + 54.0, + 162.0, + 64.0, + 80.0, + 48.0, + 80.0, + 0.0, + 0.0, + 0.0, + 0.0, + 36.0, + 252.0, + 36.0, + 0.0, + 18.0, + 24.0, + 18.0, + 0.0, + 8.0, + 0.0, + 16.0, + 16.0, + 126.0, + 162.0, + 36.0, + 108.0, + 128.0, + 160.0, + 96.0, + 160.0, + 0.0, + 0.0, + 0.0, + 0.0, + 8.0, + 56.0, + 8.0, + 0.0, + 36.0, + 48.0, + 36.0, + 0.0, + 32.0, + 0.0, + 64.0, + 64.0, + 168.0, + 216.0, + 48.0, + 144.0, + 160.0, + 200.0, + 120.0, + 200.0, + 0.0, + 0.0, + 0.0, + 0.0, + 28.0, + 196.0, + 28.0, + 0.0, + 6.0, + 8.0, + 6.0, + 0.0, + 8.0, + 0.0, + 16.0, + 16.0, + 42.0, + 54.0, + 12.0, + 36.0, + 192.0, + 240.0, + 144.0, + 240.0, + 0.0, + 0.0, + 0.0, + 0.0, + 12.0, + 84.0, + 12.0, + 0.0, + 18.0, + 24.0, + 18.0, + 0.0, + 20.0, + 0.0, + 40.0, + 40.0, + 189.0, + 243.0, + 54.0, + 162.0, + 224.0, + 280.0, + 168.0, + 280.0, + 0.0, + 0.0, + 0.0, + 0.0, + 28.0, + 196.0, + 28.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 4.0, + 0.0, + 8.0, + 8.0, + 168.0, + 216.0, + 48.0, + 144.0, + 160.0, + 200.0, + 120.0, + 200.0, + 0.0, + 0.0, + 0.0, + 0.0, + 20.0, + 140.0, + 20.0, + 0.0, + 12.0, + 16.0, + 12.0, + 0.0, + 8.0, + 0.0, + 16.0, + 16.0, + 189.0, + 243.0, + 54.0, + 162.0, + 256.0, + 320.0, + 192.0, + 320.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 12.0, + 16.0, + 12.0, + 0.0, + 32.0, + 0.0, + 64.0, + 64.0, + 84.0, + 108.0, + 24.0, + 72.0, + 192.0, + 240.0, + 144.0, + 240.0, + 0.0, + 0.0, + 0.0, + 0.0, + 24.0, + 168.0, + 24.0, + 0.0, + 54.0, + 72.0, + 54.0, + 0.0, + 12.0, + 0.0, + 24.0, + 24.0, + 168.0, + 216.0, + 48.0, + 144.0, + 256.0, + 320.0, + 192.0, + 320.0, + 0.0, + 0.0, + 0.0, + 0.0, + 32.0, + 224.0, + 32.0, + 0.0, + 48.0, + 64.0, + 48.0, + 0.0, + 20.0, + 0.0, + 40.0, + 40.0, + 105.0, + 135.0, + 30.0, + 90.0, + 288.0, + 360.0, + 216.0, + 360.0, + 0.0, + 0.0, + 0.0, + 0.0, + 36.0, + 252.0, + 36.0, + 0.0, + 48.0, + 64.0, + 48.0, + 0.0, + 28.0, + 0.0, + 56.0, + 56.0, + 42.0, + 54.0, + 12.0, + 36.0, + 192.0, + 240.0, + 144.0, + 240.0, + ], + dtype=np.float32, + ) + + rows_expected = np.array( + [ + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 3.0, + 3.0, + 3.0, + 3.0, + 4.0, + 4.0, + 4.0, + 4.0, + 5.0, + 5.0, + 5.0, + 5.0, + 6.0, + 6.0, + 6.0, + 6.0, + 7.0, + 7.0, + 7.0, + 7.0, + 8.0, + 8.0, + 8.0, + 8.0, + 9.0, + 9.0, + 9.0, + 9.0, + 10.0, + 10.0, + 10.0, + 10.0, + 11.0, + 11.0, + 11.0, + 11.0, + 12.0, + 12.0, + 12.0, + 12.0, + 13.0, + 13.0, + 13.0, + 13.0, + 14.0, + 14.0, + 14.0, + 14.0, + 15.0, + 15.0, + 15.0, + 15.0, + 16.0, + 16.0, + 16.0, + 16.0, + 17.0, + 17.0, + 17.0, + 17.0, + 18.0, + 18.0, + 18.0, + 18.0, + 19.0, + 19.0, + 19.0, + 19.0, + 20.0, + 20.0, + 20.0, + 20.0, + 21.0, + 21.0, + 21.0, + 21.0, + 22.0, + 22.0, + 22.0, + 22.0, + 23.0, + 23.0, + 23.0, + 23.0, + 24.0, + 24.0, + 24.0, + 24.0, + 25.0, + 25.0, + 25.0, + 25.0, + 26.0, + 26.0, + 26.0, + 26.0, + 27.0, + 27.0, + 27.0, + 27.0, + 28.0, + 28.0, + 28.0, + 28.0, + 29.0, + 29.0, + 29.0, + 29.0, + 30.0, + 30.0, + 30.0, + 30.0, + 31.0, + 31.0, + 31.0, + 31.0, + 32.0, + 32.0, + 32.0, + 32.0, + 33.0, + 33.0, + 33.0, + 33.0, + 34.0, + 34.0, + 34.0, + 34.0, + 35.0, + 35.0, + 35.0, + 35.0, + 36.0, + 36.0, + 36.0, + 36.0, + 37.0, + 37.0, + 37.0, + 37.0, + 38.0, + 38.0, + 38.0, + 38.0, + 39.0, + 39.0, + 39.0, + 39.0, + 40.0, + 40.0, + 40.0, + 40.0, + 41.0, + 41.0, + 41.0, + 41.0, + 42.0, + 42.0, + 42.0, + 42.0, + 43.0, + 43.0, + 43.0, + 43.0, + 44.0, + 44.0, + 44.0, + 44.0, + 45.0, + 45.0, + 45.0, + 45.0, + 46.0, + 46.0, + 46.0, + 46.0, + 47.0, + 47.0, + 47.0, + 47.0, + 48.0, + 48.0, + 48.0, + 48.0, + 49.0, + 49.0, + 49.0, + 49.0, + 50.0, + 50.0, + 50.0, + 50.0, + 51.0, + 51.0, + 51.0, + 51.0, + 52.0, + 52.0, + 52.0, + 52.0, + 53.0, + 53.0, + 53.0, + 53.0, + 54.0, + 54.0, + 54.0, + 54.0, + 55.0, + 55.0, + 55.0, + 55.0, + 56.0, + 56.0, + 56.0, + 56.0, + 57.0, + 57.0, + 57.0, + 57.0, + 58.0, + 58.0, + 58.0, + 58.0, + 59.0, + 59.0, + 59.0, + 59.0, + 60.0, + 60.0, + 60.0, + 60.0, + 61.0, + 61.0, + 61.0, + 61.0, + 62.0, + 62.0, + 62.0, + 62.0, + 63.0, + 63.0, + 63.0, + 63.0, + 64.0, + 64.0, + 64.0, + 64.0, + 65.0, + 65.0, + 65.0, + 65.0, + 66.0, + 66.0, + 66.0, + 66.0, + 67.0, + 67.0, + 67.0, + 67.0, + 68.0, + 68.0, + 68.0, + 68.0, + 69.0, + 69.0, + 69.0, + 69.0, + 70.0, + 70.0, + 70.0, + 70.0, + 71.0, + 71.0, + 71.0, + 71.0, + 72.0, + 72.0, + 72.0, + 72.0, + 73.0, + 73.0, + 73.0, + 73.0, + 74.0, + 74.0, + 74.0, + 74.0, + 75.0, + 75.0, + 75.0, + 75.0, + 76.0, + 76.0, + 76.0, + 76.0, + 77.0, + 77.0, + 77.0, + 77.0, + 78.0, + 78.0, + 78.0, + 78.0, + 79.0, + 79.0, + 79.0, + 79.0, + 80.0, + 80.0, + 80.0, + 80.0, + 81.0, + 81.0, + 81.0, + 81.0, + 82.0, + 82.0, + 82.0, + 82.0, + 83.0, + 83.0, + 83.0, + 83.0, + 84.0, + 84.0, + 84.0, + 84.0, + 85.0, + 85.0, + 85.0, + 85.0, + 86.0, + 86.0, + 86.0, + 86.0, + 87.0, + 87.0, + 87.0, + 87.0, + 88.0, + 88.0, + 88.0, + 88.0, + 89.0, + 89.0, + 89.0, + 89.0, + 90.0, + 90.0, + 90.0, + 90.0, + 91.0, + 91.0, + 91.0, + 91.0, + 92.0, + 92.0, + 92.0, + 92.0, + 93.0, + 93.0, + 93.0, + 93.0, + 94.0, + 94.0, + 94.0, + 94.0, + 95.0, + 95.0, + 95.0, + 95.0, + 96.0, + 96.0, + 96.0, + 96.0, + 97.0, + 97.0, + 97.0, + 97.0, + 98.0, + 98.0, + 98.0, + 98.0, + 99.0, + 99.0, + 99.0, + 99.0, + 100.0, + 100.0, + 100.0, + 100.0, + 101.0, + 101.0, + 101.0, + 101.0, + 102.0, + 102.0, + 102.0, + 102.0, + 103.0, + 103.0, + 103.0, + 103.0, + 104.0, + 104.0, + 104.0, + 104.0, + 105.0, + 105.0, + 105.0, + 105.0, + 106.0, + 106.0, + 106.0, + 106.0, + 107.0, + 107.0, + 107.0, + 107.0, + 108.0, + 108.0, + 108.0, + 108.0, + 109.0, + 109.0, + 109.0, + 109.0, + 110.0, + 110.0, + 110.0, + 110.0, + 111.0, + 111.0, + 111.0, + 111.0, + 112.0, + 112.0, + 112.0, + 112.0, + 113.0, + 113.0, + 113.0, + 113.0, + 114.0, + 114.0, + 114.0, + 114.0, + 115.0, + 115.0, + 115.0, + 115.0, + 116.0, + 116.0, + 116.0, + 116.0, + 117.0, + 117.0, + 117.0, + 117.0, + 118.0, + 118.0, + 118.0, + 118.0, + 119.0, + 119.0, + 119.0, + 119.0, + ], + dtype=np.float32, + ) + + cols_expected = np.array( + [ + 7.0, + 6.0, + 1.0, + 0.0, + 7.0, + 6.0, + 1.0, + 0.0, + 7.0, + 6.0, + 1.0, + 0.0, + 7.0, + 6.0, + 1.0, + 0.0, + 7.0, + 6.0, + 1.0, + 0.0, + 7.0, + 6.0, + 1.0, + 0.0, + 8.0, + 7.0, + 2.0, + 1.0, + 8.0, + 7.0, + 2.0, + 1.0, + 8.0, + 7.0, + 2.0, + 1.0, + 8.0, + 7.0, + 2.0, + 1.0, + 8.0, + 7.0, + 2.0, + 1.0, + 8.0, + 7.0, + 2.0, + 1.0, + 9.0, + 8.0, + 3.0, + 2.0, + 9.0, + 8.0, + 3.0, + 2.0, + 9.0, + 8.0, + 3.0, + 2.0, + 9.0, + 8.0, + 3.0, + 2.0, + 9.0, + 8.0, + 3.0, + 2.0, + 9.0, + 8.0, + 3.0, + 2.0, + 10.0, + 9.0, + 4.0, + 3.0, + 10.0, + 9.0, + 4.0, + 3.0, + 10.0, + 9.0, + 4.0, + 3.0, + 10.0, + 9.0, + 4.0, + 3.0, + 10.0, + 9.0, + 4.0, + 3.0, + 10.0, + 9.0, + 4.0, + 3.0, + 11.0, + 10.0, + 5.0, + 4.0, + 11.0, + 10.0, + 5.0, + 4.0, + 11.0, + 10.0, + 5.0, + 4.0, + 11.0, + 10.0, + 5.0, + 4.0, + 11.0, + 10.0, + 5.0, + 4.0, + 11.0, + 10.0, + 5.0, + 4.0, + 13.0, + 12.0, + 7.0, + 6.0, + 13.0, + 12.0, + 7.0, + 6.0, + 13.0, + 12.0, + 7.0, + 6.0, + 13.0, + 12.0, + 7.0, + 6.0, + 13.0, + 12.0, + 7.0, + 6.0, + 13.0, + 12.0, + 7.0, + 6.0, + 14.0, + 13.0, + 8.0, + 7.0, + 14.0, + 13.0, + 8.0, + 7.0, + 14.0, + 13.0, + 8.0, + 7.0, + 14.0, + 13.0, + 8.0, + 7.0, + 14.0, + 13.0, + 8.0, + 7.0, + 14.0, + 13.0, + 8.0, + 7.0, + 15.0, + 14.0, + 9.0, + 8.0, + 15.0, + 14.0, + 9.0, + 8.0, + 15.0, + 14.0, + 9.0, + 8.0, + 15.0, + 14.0, + 9.0, + 8.0, + 15.0, + 14.0, + 9.0, + 8.0, + 15.0, + 14.0, + 9.0, + 8.0, + 16.0, + 15.0, + 10.0, + 9.0, + 16.0, + 15.0, + 10.0, + 9.0, + 16.0, + 15.0, + 10.0, + 9.0, + 16.0, + 15.0, + 10.0, + 9.0, + 16.0, + 15.0, + 10.0, + 9.0, + 16.0, + 15.0, + 10.0, + 9.0, + 17.0, + 16.0, + 11.0, + 10.0, + 17.0, + 16.0, + 11.0, + 10.0, + 17.0, + 16.0, + 11.0, + 10.0, + 17.0, + 16.0, + 11.0, + 10.0, + 17.0, + 16.0, + 11.0, + 10.0, + 17.0, + 16.0, + 11.0, + 10.0, + 19.0, + 18.0, + 13.0, + 12.0, + 19.0, + 18.0, + 13.0, + 12.0, + 19.0, + 18.0, + 13.0, + 12.0, + 19.0, + 18.0, + 13.0, + 12.0, + 19.0, + 18.0, + 13.0, + 12.0, + 19.0, + 18.0, + 13.0, + 12.0, + 20.0, + 19.0, + 14.0, + 13.0, + 20.0, + 19.0, + 14.0, + 13.0, + 20.0, + 19.0, + 14.0, + 13.0, + 20.0, + 19.0, + 14.0, + 13.0, + 20.0, + 19.0, + 14.0, + 13.0, + 20.0, + 19.0, + 14.0, + 13.0, + 21.0, + 20.0, + 15.0, + 14.0, + 21.0, + 20.0, + 15.0, + 14.0, + 21.0, + 20.0, + 15.0, + 14.0, + 21.0, + 20.0, + 15.0, + 14.0, + 21.0, + 20.0, + 15.0, + 14.0, + 21.0, + 20.0, + 15.0, + 14.0, + 22.0, + 21.0, + 16.0, + 15.0, + 22.0, + 21.0, + 16.0, + 15.0, + 22.0, + 21.0, + 16.0, + 15.0, + 22.0, + 21.0, + 16.0, + 15.0, + 22.0, + 21.0, + 16.0, + 15.0, + 22.0, + 21.0, + 16.0, + 15.0, + 23.0, + 22.0, + 17.0, + 16.0, + 23.0, + 22.0, + 17.0, + 16.0, + 23.0, + 22.0, + 17.0, + 16.0, + 23.0, + 22.0, + 17.0, + 16.0, + 23.0, + 22.0, + 17.0, + 16.0, + 23.0, + 22.0, + 17.0, + 16.0, + 25.0, + 24.0, + 19.0, + 18.0, + 25.0, + 24.0, + 19.0, + 18.0, + 25.0, + 24.0, + 19.0, + 18.0, + 25.0, + 24.0, + 19.0, + 18.0, + 25.0, + 24.0, + 19.0, + 18.0, + 25.0, + 24.0, + 19.0, + 18.0, + 26.0, + 25.0, + 20.0, + 19.0, + 26.0, + 25.0, + 20.0, + 19.0, + 26.0, + 25.0, + 20.0, + 19.0, + 26.0, + 25.0, + 20.0, + 19.0, + 26.0, + 25.0, + 20.0, + 19.0, + 26.0, + 25.0, + 20.0, + 19.0, + 27.0, + 26.0, + 21.0, + 20.0, + 27.0, + 26.0, + 21.0, + 20.0, + 27.0, + 26.0, + 21.0, + 20.0, + 27.0, + 26.0, + 21.0, + 20.0, + 27.0, + 26.0, + 21.0, + 20.0, + 27.0, + 26.0, + 21.0, + 20.0, + 28.0, + 27.0, + 22.0, + 21.0, + 28.0, + 27.0, + 22.0, + 21.0, + 28.0, + 27.0, + 22.0, + 21.0, + 28.0, + 27.0, + 22.0, + 21.0, + 28.0, + 27.0, + 22.0, + 21.0, + 28.0, + 27.0, + 22.0, + 21.0, + 29.0, + 28.0, + 23.0, + 22.0, + 29.0, + 28.0, + 23.0, + 22.0, + 29.0, + 28.0, + 23.0, + 22.0, + 29.0, + 28.0, + 23.0, + 22.0, + 29.0, + 28.0, + 23.0, + 22.0, + 29.0, + 28.0, + 23.0, + 22.0, + ], + dtype=np.float32, + ) + + np.testing.assert_allclose(data, data_expected, rtol=RTOL) + np.testing.assert_allclose(cols, cols_expected, rtol=RTOL) + np.testing.assert_allclose(rows, rows_expected, rtol=RTOL) diff --git a/tests/implementation/test_solvers.py b/tests/implementation/test_solvers.py new file mode 100644 index 0000000..916ed9e --- /dev/null +++ b/tests/implementation/test_solvers.py @@ -0,0 +1,68 @@ +from types import ModuleType +from typing import Tuple + +import numpy as np +import pytest +from skimage.transform import pyramid_gaussian + +from pymatchseries.implementation.objective_functions import ( + RegistrationObjectiveFunction, +) +from pymatchseries.implementation.solvers import root_gauss_newton +from pymatchseries.utils import CUPY_IS_INSTALLED, DenseArrayType, cp + +if CUPY_IS_INSTALLED: + params = [np, cp] +else: + params = [np] + + +def _setup(dp: ModuleType = np) -> Tuple[RegistrationObjectiveFunction, DenseArrayType]: + im1 = np.zeros((128, 128), dtype=np.float32) + im2 = np.zeros((128, 128), dtype=np.float32) + + im1[2 * 10 : 2 * 30, 2 * 15 : 2 * 45] = 1 + im2[2 * 25 : 2 * 45, 2 * 25 : 2 * 55] = 1 + + num_levels = 5 + + # Create an image hierarchy for both of our images + pyramid_tem = tuple( + pyramid_gaussian(im1, max_layer=num_levels - 1, downscale=2, channel_axis=None) + ) + pyramid_ref = tuple( + pyramid_gaussian(im2, max_layer=num_levels - 1, downscale=2, channel_axis=None) + ) + + # Regularization parameter + L = 0.1 + + image_def = pyramid_tem[-1] + image_ref = pyramid_ref[-1] + + if dp != np: + image_def = cp.asarray(image_def) + image_ref = cp.asarray(image_ref) + + objective = RegistrationObjectiveFunction( + image_ref, + image_def, + L, + ) + + disp = dp.zeros_like(objective.identity) + return objective, disp + + +@pytest.mark.parametrize("dp", params) +def test_root_gauss_newton(dp: ModuleType) -> None: + objective, disp = _setup(dp) + disp_new = root_gauss_newton( + objective.evaluate_residual, + disp.ravel(), + objective.evaluate_residual_gradient, + ).reshape(disp.shape) + + assert objective.evaluate_energy(disp.ravel()) > objective.evaluate_energy( + disp_new.ravel() + ) diff --git a/pymatchseries/tests/test_config_tools.py b/tests/test_config_tools.py similarity index 99% rename from pymatchseries/tests/test_config_tools.py rename to tests/test_config_tools.py index 486ce11..de0e23d 100644 --- a/pymatchseries/tests/test_config_tools.py +++ b/tests/test_config_tools.py @@ -1,7 +1,9 @@ -from pymatchseries import config_tools as ctools import os + import pytest +from pymatchseries import config_tools as ctools + def test_load_config(): ctools.load_config() diff --git a/pymatchseries/tests/test_matchseries.py b/tests/test_matchseries.py similarity index 98% rename from pymatchseries/tests/test_matchseries.py rename to tests/test_matchseries.py index 3031a3d..8001b05 100644 --- a/pymatchseries/tests/test_matchseries.py +++ b/tests/test_matchseries.py @@ -1,12 +1,13 @@ -from pymatchseries import matchseries as ms -import os import gc +import os import shutil -import pytest -from hyperspy.signals import Signal2D, EDSTEMSpectrum + import dask.array as da import numpy as np +import pytest +from hyperspy.signals import EDSTEMSpectrum, Signal2D +from pymatchseries import matchseries as ms # numpy dataset np.random.seed(1001) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e69de29