diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..4d018ca --- /dev/null +++ b/.gitattributes @@ -0,0 +1,5 @@ +## Change the following files in GitHub language stats +### '*.filesuffix linguist-vendored': ignored +### '*.filesuffix linguist-detectable': detected as the language +### '*.filesuffix linguist-language=Python': forced to be Python +*.ipynb linguist-vendored \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..f869776 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,149 @@ +name: build + +on: + push: + branches: [ "main" ] + # branches: [ "dev" ] + # branches: [ "main", "dev" ] + # pull_request: + # branches: [ "main" ] + # branches: [ "dev" ] + # branches: [ "main", "dev" ] + workflow_dispatch: + inputs: + name: + description: 'description' + required: false + default: '' + +permissions: + contents: read + +jobs: + + build: + + name: ${{ matrix.platform }}, py${{ matrix.python-version }}, ${{ matrix.install-level }} + runs-on: ${{ matrix.platform }} + strategy: + fail-fast: false + matrix: + platform: [ + # ubuntu-latest, + ubuntu-22.04, + ubuntu-20.04, + # # windows-latest, + windows-2022, + windows-2019, + # # macos-latest, + macos-12, + macos-11.0, + # macos-10.15, + ] + python-version: [ + # "3.9", + "3.10", + "3.11", + "3.12", + ] + install-level: [ + system, + user, + ] + + steps: + + - name: Set up conda + uses: conda-incubator/setup-miniconda@v2 + with: + miniconda-version: latest + activate-environment: sparse_convolution + auto-activate-base: true + auto-update-conda: false + remove-profiles: true + architecture: x64 + clean-patched-environment-file: true + run-post: true + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Check out repository code + uses: actions/checkout@v3 + + - name: Prepare PowerShell + shell: pwsh + run: | + conda info + conda list + + - name: Check specs of the machine -- Linux + if: startsWith(matrix.platform, 'ubuntu') + run: | + ## check cpu, memory, disk, etc. + ## print the command inputs to the workflow + echo "CPU info (lscpu):" + lscpu + echo "Memory info (free -h):" + free -h + echo "Disk info (df -h):" + df -h + echo "Network info (ip a):" + ip a + echo "OS info (uname -a):" + uname -a + - name: Check specs of the machine -- Windows + if: startsWith(matrix.platform, 'windows') + run: | + ## check cpu, memory, disk, etc. + ## just do a generic check on system info + ## print the command inputs to the workflow + echo "System info (systeminfo):" + systeminfo + - name: Check specs of the machine -- MacOS + if: startsWith(matrix.platform, 'macos') + run: | + ## check cpu, memory, disk, etc. + ## print the command inputs to the workflow + echo "CPU info (sysctl -n machdep.cpu.brand_string):" + sysctl -n machdep.cpu.brand_string + echo "Memory info (sysctl -n hw.memsize):" + sysctl -n hw.memsize + echo "Disk info (df -h):" + df -h + echo "OS info (uname -a):" + uname -a + + + - name: Install package with pip dependencies -- system-level + if: matrix.install-level == 'system' + run: | + ## install dependencies with optional extras + pip install -v -e . + - name: Install package with pip dependencies -- user-level + if: matrix.install-level == 'user' + run: | + pip install -v -e . --user + + + - name: Check installed packages + run: | + pip list + ## Below, check which versions of torch and torchvision are installed; and whether CUDA is available + python -c "import torch, torchvision; print(f'Using versions: torch=={torch.__version__}, torchvision=={torchvision.__version__}'); print('torch.cuda.is_available() = ', torch.cuda.is_available())" + + - name: Run pytest and generate coverage report + run: | + # pip install tox tox-gh-actions + pip install pytest pytest-cov + python -m pytest --capture=tee-sys --cov=sparse_convolution --cov-report=xml:coverage.xml --color=yes + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 ## this is a public action recognized by GitHub Actions + with: + token: ${{ secrets.CODECOV_TOKEN }} ## this is a secret variable + file: coverage.xml ## this is the default name of the coverage report file + fail_ci_if_error: false + verbose: true diff --git a/.github/workflows/pypi_release.yml b/.github/workflows/pypi_release.yml new file mode 100644 index 0000000..fcdd4f3 --- /dev/null +++ b/.github/workflows/pypi_release.yml @@ -0,0 +1,101 @@ +name: Publish Python 🐍 distribution 📦 to PyPI + +on: + workflow_dispatch: + inputs: + name: + description: 'description' + required: false + default: '' +jobs: + build: + name: Build distribution 📦 + runs-on: ubuntu-20.04 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install pypa/build + run: >- + python3 -m pip install build --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v3 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: >- + Publish Python 🐍 distribution 📦 to PyPI + # if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes + needs: + - build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/project/bnpm # Replace with your PyPI project name + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + + steps: + - name: Download all the dists + uses: actions/download-artifact@v3 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + github-release: + name: >- + Sign the Python 🐍 distribution 📦 with Sigstore + and upload them to GitHub Release + needs: + - publish-to-pypi + runs-on: ubuntu-latest + + permissions: + contents: write # IMPORTANT: mandatory for making GitHub Releases + id-token: write # IMPORTANT: mandatory for sigstore + + steps: + - uses: actions/checkout@v4 + - name: Download all the dists + uses: actions/download-artifact@v3 + with: + name: python-package-distributions + path: dist/ + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v1.2.3 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + - name: List files in workspace + run: >- + ls ${{ github.workspace }} + - name: List files in upper directory + run: >- + ls ${{ github.workspace }}/.. + - name: List files in dist + run: >- + ls ${{ github.workspace }}/dist + - name: Store the version number + run: | + version=$(grep '__version__' ${{ github.workspace }}/bnpm/__init__.py | awk -F = '{print $2}' | tr -d ' ' | tr -d \"\') + echo "VERSION=$version" >> $GITHUB_ENV + - name: Check version number + run: echo "Version is $VERSION" + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release create "v$VERSION" + --title "v$VERSION" + --notes "Release for version $VERSION. Also available on PyPI: https://pypi.org/project/bnpm/$VERSION/" + --repo ${{ github.repository }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6a4507f..6769e21 100644 --- a/.gitignore +++ b/.gitignore @@ -1,54 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class -# Windows image file caches -Thumbs.db -ehthumbs.db +# C extensions +*.so -# Folder config file -Desktop.ini +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST -# Recycle Bin used on file shares -$RECYCLE.BIN/ +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec -# Compiled source -*.mexa64 -*.mexw64 -*.asv +# Installer logs +pip-log.txt +pip-delete-this-directory.txt -# Windows shortcuts -*.lnk +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ -# python and jupyter -*.npy -jupyter/.ipynb_checkpoints/ -jupyter/__pycache__/ -.ipynb_checkpoints/ -__pycache__/ -dist/ -sparse_convolution.egg-info/ -build/ -*.ipynb +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject -# ========================= -# Operating System Files -# ========================= +# mkdocs documentation +/site -# OSX -# ========================= +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json -.DS_Store -.AppleDouble -.LSOverride +# Pyre type checker +.pyre/ -# Thumbnails -._* +# pytype static type analyzer +.pytype/ -# Files that might appear on external disk -.Spotlight-V100 -.Trashes +# Cython debug symbols +cython_debug/ -# Directories potentially created on remote AFP share -.AppleDB -.AppleDesktop -Network Trash Folder -Temporary Items -.apdisk \ No newline at end of file +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..636dfbf --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 RichieHakim + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..540b720 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include requirements.txt \ No newline at end of file diff --git a/README.md b/README.md index 132d388..0c90074 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # sparse_convolution -Sparse convolution in python. +Sparse convolution in python. \ Uses Toeplitz convolutional matrix multiplication to perform sparse convolution. \ This allows for extremely fast convolution when: - The kernel is small (<= 30x30) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..0f94f37 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[metadata] +description_file = README.md \ No newline at end of file diff --git a/setup.py b/setup.py index 2fefaf5..0c7f80b 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,54 @@ -"""A setuptools based setup module. -See: -https://packaging.python.org/en/latest/distributing.html -https://github.com/pypa/sampleproject -""" +## setup.py file for roicat +from pathlib import Path -import setuptools -from codecs import open -from os import path +from distutils.core import setup -here = path.abspath(path.dirname(__file__)) +## Get the parent directory of this file +dir_parent = Path(__file__).parent -# Get the long description from the README file -with open(path.join(here, 'README.md'), encoding='utf-8') as f: - long_description = f.read() +## Get requirements from requirements.txt +def read_requirements(): + with open(str(dir_parent / "requirements.txt"), "r") as req: + content = req.read() ## read the file + requirements = content.split("\n") ## make a list of requirements split by (\n) which is the new line character -requirements_f = open('requirements.txt', 'r') -dependencies = list(requirements_f.readlines()) + ## Filter out any empty strings from the list + requirements = [req for req in requirements if req] + ## Filter out any lines starting with # + requirements = [req for req in requirements if not req.startswith("#")] + ## Remove any commas, quotation marks, and spaces from each requirement + requirements = [req.replace(",", "").replace("\"", "").replace("\'", "").strip() for req in requirements] -setuptools.setup( + return requirements +deps_all = read_requirements() + + +## Dependencies: latest versions of requirements +### remove everything starting and after the first =,>,<,! sign +deps_names = [req.split('=')[0].split('>')[0].split('<')[0].split('!')[0] for req in deps_all] +deps_all_dict = dict(zip(deps_names, deps_all)) +deps_all_latest = dict(zip(deps_names, deps_names)) + + +## Get README.md +with open(str(dir_parent / "README.md"), "r") as f: + readme = f.read() + +## Get version number +with open(str(dir_parent / "sparse_convolution" / "__init__.py"), "r") as f: + for line in f: + if line.startswith("__version__"): + version = line.split("=")[1].strip().replace("\"", "").replace("\'", "") + break + + + +setup( name='sparse_convolution', - version='0.1.0', + version=version, description='Sparse convolution in python using Toeplitz convolution matrix multiplication.', - long_description=long_description, + long_description=open('README.md').read(), long_description_content_type='text/markdown', # The project's main homepage. @@ -33,7 +59,7 @@ author_email='richhakim@gmail.com', # Choose your license - # license='MIT', + license='MIT', # Supported platforms platforms=['Any'], @@ -67,7 +93,7 @@ # You can just specify the packages manually here if your project is # simple. Or you can use find_packages(). - packages=setuptools.find_packages(), + packages=['sparse_convolution'], # Alternatively, if you want to distribute just a my_module.py, uncomment # this: @@ -77,7 +103,16 @@ # your project is installed. For an analysis of "install_requires" vs pip's # requirements files see: # https://packaging.python.org/en/latest/requirements.html - install_requires=dependencies, + install_requires=list(deps_all_dict.values()), include_package_data=True, -) \ No newline at end of file +) + + + + + + + + + diff --git a/sparse_convolution/__init__.py b/sparse_convolution/__init__.py index ac3027c..36e7d02 100644 --- a/sparse_convolution/__init__.py +++ b/sparse_convolution/__init__.py @@ -1,2 +1,3 @@ from sparse_convolution.sparse_convolution import Toeplitz_convolution2d -from sparse_convolution.tests import test_toeplitz_convolution2d, benchmark_toeplitz_convolution2d \ No newline at end of file + +__version__ = '0.1.1' \ No newline at end of file diff --git a/sparse_convolution/sparse_convolution.py b/sparse_convolution/sparse_convolution.py index 8463028..e2b385b 100644 --- a/sparse_convolution/sparse_convolution.py +++ b/sparse_convolution/sparse_convolution.py @@ -148,7 +148,7 @@ def __call__( p_r = self.x_shape[1]+1 if p_r==0 else p_r if batching: - idx_crop = np.zeros((self.so), dtype=np.bool8) + idx_crop = np.zeros((self.so), dtype=np.bool_) idx_crop[p_t:p_b, p_l:p_r] = True idx_crop = idx_crop.reshape(-1) out = out_v[idx_crop,:].T diff --git a/sparse_convolution/tests.py b/tests/benchmarks.py similarity index 65% rename from sparse_convolution/tests.py rename to tests/benchmarks.py index c221192..cd8663b 100644 --- a/sparse_convolution/tests.py +++ b/tests/benchmarks.py @@ -4,91 +4,7 @@ import numpy as np import scipy.signal -from .sparse_convolution import Toeplitz_convolution2d - -def test_toeplitz_convolution2d(): - """ - Test toeplitz_convolution2d - Tests for modes, shapes, values, and for sparse matrices against - scipy.signal.convolve2d. - - RH 2022 - """ - ## test toepltiz convolution - - print(f'testing with batching=False') - - stt = shapes_to_try = np.meshgrid(np.arange(1, 7), np.arange(1, 7), np.arange(1, 7), np.arange(1, 7)) - stt = [s.reshape(-1) for s in stt] - - for mode in ['full', 'same', 'valid']: - for ii in range(len(stt[0])): - x = np.random.rand(stt[0][ii], stt[1][ii]) - k = np.random.rand(stt[2][ii], stt[3][ii]) - # print(stt[0][ii], stt[1][ii], stt[2][ii], stt[3][ii]) - - try: - t = Toeplitz_convolution2d(x_shape=x.shape, k=k, mode=mode, dtype=None) - out_t2d = t(x, batching=False, mode=mode) - out_t2d_s = t(scipy.sparse.csr_matrix(x), batching=False, mode=mode) - out_sp = scipy.signal.convolve2d(x, k, mode=mode) - except Exception as e: - if mode == 'valid' and (stt[0][ii] < stt[2][ii] or stt[1][ii] < stt[3][ii]): - if 'x must be larger than k' in str(e): - continue - print(f'A) test failed with shapes: x: {x.shape}, k: {k.shape} and mode: {mode} and Exception: {e} {traceback.format_exc()}') - success = False - break - try: - if np.allclose(out_t2d, out_t2d_s.A) and np.allclose(out_t2d, out_sp) and np.allclose(out_sp, out_t2d_s.A): - success = True - continue - except Exception as e: - print(f'B) test failed with shapes: x: {x.shape}, k: {k.shape} and mode: {mode} and Exception: {e} {traceback.format_exc()}') - success = False - break - - else: - print(f'C) test failed with batching==False, shapes: x: {x.shape}, k: {k.shape} and mode: {mode}') - success = False - break - - print(f'testing with batching=True') - - for mode in ['full', 'same', 'valid']: - for ii in range(len(stt[0])): - x = np.stack([np.random.rand(stt[0][ii], stt[1][ii]).reshape(-1) for jj in range(3)], axis=0) - k = np.random.rand(stt[2][ii], stt[3][ii]) - # print(stt[0][ii], stt[1][ii], stt[2][ii], stt[3][ii]) - - try: - t = Toeplitz_convolution2d(x_shape=(stt[0][ii], stt[1][ii]), k=k, mode=mode, dtype=None) - out_t2d = t(x, batching=True, mode=mode).todense() - out_t2d_s = t(scipy.sparse.csr_matrix(x), batching=True, mode=mode).todense() - out_sp = np.stack([scipy.signal.convolve2d(x_i.reshape(stt[0][ii], stt[1][ii]), k, mode=mode) for x_i in x], axis=0) - except Exception as e: - if mode == 'valid' and (stt[0][ii] < stt[2][ii] or stt[1][ii] < stt[3][ii]): - if 'x must be larger than k' in str(e): - continue - else: - print(f'A) test failed with shapes: x: {x.shape}, k: {k.shape} and mode: {mode} and Exception: {e} {traceback.format_exc()}') - success = False - break - try: - if np.allclose(out_t2d, out_t2d_s) and np.allclose(out_t2d, out_sp) and np.allclose(out_sp, out_t2d_s): - success = True - continue - except Exception as e: - print(f'B) test failed with shapes: x: {x.shape}, k: {k.shape} and mode: {mode} and Exception: {e} {traceback.format_exc()}') - success = False - break - - else: - print(f'C) test failed with batching==False, shapes: x: {x.shape}, k: {k.shape} and mode: {mode}') - success = False - break - print(f'success with all shapes and modes') if success else None - return success +from sparse_convolution import Toeplitz_convolution2d def benchmark_toeplitz_convolution2d(): """ diff --git a/tests/test_unit.py b/tests/test_unit.py new file mode 100644 index 0000000..015e17d --- /dev/null +++ b/tests/test_unit.py @@ -0,0 +1,98 @@ +import traceback +import time + +import numpy as np +import scipy.signal + +from sparse_convolution import Toeplitz_convolution2d + +def test_toeplitz_convolution2d(): + """ + Test toeplitz_convolution2d + Tests for modes, shapes, values, and for sparse matrices against + scipy.signal.convolve2d. + + RH 2022 + """ + ## test toepltiz convolution + + print(f'testing with batching=False') + + stt = shapes_to_try = np.meshgrid(np.arange(1, 7), np.arange(1, 7), np.arange(1, 7), np.arange(1, 7)) + stt = [s.reshape(-1) for s in stt] + + for mode in ['full', 'same', 'valid']: + for ii in range(len(stt[0])): + x = np.random.rand(stt[0][ii], stt[1][ii]) + k = np.random.rand(stt[2][ii], stt[3][ii]) + # print(stt[0][ii], stt[1][ii], stt[2][ii], stt[3][ii]) + + try: + t = Toeplitz_convolution2d(x_shape=x.shape, k=k, mode=mode, dtype=None) + out_t2d = t(x, batching=False, mode=mode) + out_t2d_s = t(scipy.sparse.csr_matrix(x), batching=False, mode=mode) + out_sp = scipy.signal.convolve2d(x, k, mode=mode) + except Exception as e: + if mode == 'valid' and (stt[0][ii] < stt[2][ii] or stt[1][ii] < stt[3][ii]): + if 'x must be larger than k' in str(e): + continue + print(f'A) test failed with shapes: x: {x.shape}, k: {k.shape} and mode: {mode} and Exception: {e} {traceback.format_exc()}') + success = False + break + try: + if np.allclose(out_t2d, out_t2d_s.A) and np.allclose(out_t2d, out_sp) and np.allclose(out_sp, out_t2d_s.A): + success = True + continue + except Exception as e: + print(f'B) test failed with shapes: x: {x.shape}, k: {k.shape} and mode: {mode} and Exception: {e} {traceback.format_exc()}') + success = False + break + + else: + print(f'C) test failed with batching==False, shapes: x: {x.shape}, k: {k.shape} and mode: {mode}') + success = False + break + + print(f'testing with batching=True') + + for mode in ['full', 'same', 'valid']: + for ii in range(len(stt[0])): + x = np.stack([np.random.rand(stt[0][ii], stt[1][ii]).reshape(-1) for jj in range(3)], axis=0) + k = np.random.rand(stt[2][ii], stt[3][ii]) + # print(stt[0][ii], stt[1][ii], stt[2][ii], stt[3][ii]) + + try: + t = Toeplitz_convolution2d(x_shape=(stt[0][ii], stt[1][ii]), k=k, mode=mode, dtype=None) + out_sp = np.stack([scipy.signal.convolve2d(x_i.reshape(stt[0][ii], stt[1][ii]), k, mode=mode) for x_i in x], axis=0) + out_t2d = t(x, batching=True, mode=mode).reshape(3, out_sp.shape[1], out_sp.shape[2]) + out_t2d_s = t(scipy.sparse.csr_matrix(x), batching=True, mode=mode).toarray().reshape(3, out_sp.shape[1], out_sp.shape[2]) + except Exception as e: + if mode == 'valid' and (stt[0][ii] < stt[2][ii] or stt[1][ii] < stt[3][ii]): + if 'x must be larger than k' in str(e): + continue + else: + print(f'A) test failed with shapes: x: {x.shape}, k: {k.shape} and mode: {mode} and Exception: {e} {traceback.format_exc()}') + success = False + break + try: + if np.allclose(out_t2d, out_t2d_s) and np.allclose(out_t2d, out_sp) and np.allclose(out_sp, out_t2d_s): + success = True + continue + except Exception as e: + print(f'B) test failed with shapes: x: {x.shape}, k: {k.shape} and mode: {mode} and Exception: {e} {traceback.format_exc()}') + success = False + break + + else: + print(f'C) test failed with batching==False, shapes: x: {x.shape}, k: {k.shape} and mode: {mode}') + print(f"Failure analysis: \n") + print(f"Shapes: x: {x.shape}, k: {k.shape}, out_t2d: {out_t2d.shape}, out_t2d_s: {out_t2d_s.shape}, out_sp: {out_sp.shape}") + print(f"out_t2d: {out_t2d}") + print(f"out_t2d_s: {out_t2d_s}") + print(f"out_sp: {out_sp}") + + success = False + break + print(f'success with all shapes and modes') if success else None + assert success, 'test failed' + # return success \ No newline at end of file