From ba100eaef49010706a28835e51a905e4254bb6f0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 21 Jun 2024 04:57:52 +0000 Subject: [PATCH 1/9] build(deps): bump ruff from 0.1.14 to 0.4.10 Bumps [ruff](https://github.com/astral-sh/ruff) from 0.1.14 to 0.4.10. - [Release notes](https://github.com/astral-sh/ruff/releases) - [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md) - [Commits](https://github.com/astral-sh/ruff/compare/v0.1.14...v0.4.10) --- updated-dependencies: - dependency-name: ruff dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index db3705201..eb8eca6b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ training = [ "codecarbon>=2.0.0,<3.0.0", ] quality = [ - "ruff==0.1.14", + "ruff==0.4.10", "mypy==1.9.0", "types-tqdm", "pre-commit>=3.0.0,<4.0.0", @@ -83,7 +83,7 @@ dev = [ "pytest-pretty>=1.0.0,<2.0.0", "onnx>=1.13.0,<2.0.0", # style - "ruff==0.1.14", + "ruff==0.4.10", "mypy==1.9.0", "types-tqdm", "pre-commit>=3.0.0,<4.0.0", From 314395a2665e213b786e4a0635b3633acf2cce13 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sun, 23 Jun 2024 19:06:38 +0200 Subject: [PATCH 2/9] style(pre-commit): bump ruff to 0.4.10 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 16626bfb4..26001cd0f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: args: ['--branch', 'main'] - id: trailing-whitespace - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.1.14' + rev: 'v0.4.10' hooks: - id: ruff args: From ad436825ae2a0255e582e1471b32b67814238af2 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sun, 23 Jun 2024 19:11:57 +0200 Subject: [PATCH 3/9] style(pyproject): update ruff config --- pyproject.toml | 67 +++++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eb8eca6b0..36b6a383e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,38 +113,45 @@ exclude = ["demo*", "docs*", "notebooks*", "scripts*", "tests*"] source = ["holocron"] [tool.ruff] +line-length = 120 +target-version = "py39" +preview = true + +[tool.ruff.lint] select = [ + "F", # pyflakes "E", # pycodestyle errors "W", # pycodestyle warnings - "D101", "D103", # pydocstyle missing docstring in public function/class - "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", # pydocstyle - "F", # pyflakes "I", # isort - "C4", # flake8-comprehensions - "B", # flake8-bugbear - "CPY001", # flake8-copyright - "ISC", # flake8-implicit-str-concat - "PYI", # flake8-pyi - "NPY", # numpy - "PERF", # perflint - "RUF", # ruff specific - "PTH", # flake8-use-pathlib - "S", # flake8-bandit "N", # pep8-naming - "T10", # flake8-debugger - "T20", # flake8-print - "PT", # flake8-pytest-style - "LOG", # flake8-logging - "SIM", # flake8-simplify + "D101", "D103", # pydocstyle missing docstring in public function/class + "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", # pydocstyle "YTT", # flake8-2020 "ANN", # flake8-annotations "ASYNC", # flake8-async + "S", # flake8-bandit "BLE", # flake8-blind-except + "B", # flake8-bugbear "A", # flake8-builtins + "COM", # flake8-commas + "CPY", # flake8-copyright + "C4", # flake8-comprehensions + "T10", # flake8-debugger + "ISC", # flake8-implicit-str-concat "ICN", # flake8-import-conventions + "LOG", # flake8-logging "PIE", # flake8-pie + "T20", # flake8-print + "PYI", # flake8-pyi + "PT", # flake8-pytest-style + "RET", # flake8-return + "SIM", # flake8-simplify "ARG", # flake8-unused-arguments + "PTH", # flake8-use-pathlib + "PERF", # perflint + "NPY", # numpy "FURB", # refurb + "RUF", # ruff specific ] ignore = [ "E501", # line too long, handled by black @@ -161,19 +168,20 @@ ignore = [ "ANN003", # missing type annotations on **kwargs "PT011", # pytest.raises must have a match pattern "N812", # Lowercase imported as non-lowercase + "COM812", # trailing comma missing (handled by format) "ISC001", # implicit string concatenation (handled by format) "ANN401", # Dynamically typed expressions (typing.Any) are disallowed ] exclude = [".git"] -line-length = 120 -target-version = "py39" -preview = true -[tool.ruff.format] -quote-style = "double" -indent-style = "space" +[tool.ruff.lint.flake8-quotes] +docstring-quotes = "double" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.isort] +known-first-party = ["holocron", "app"] +known-third-party = ["wandb"] + +[tool.ruff.lint.per-file-ignores] "**/__init__.py" = ["I001", "F401", "CPY001"] "scripts/**.py" = ["D", "T201", "ANN"] ".github/**.py" = ["D", "T201", "S602", "ANN"] @@ -191,12 +199,9 @@ indent-style = "space" "holocron/nn/functional.py" = ["N802", "ARG"] "holocron/trainer/**.py" = ["ARG"] -[tool.ruff.flake8-quotes] -docstring-quotes = "double" - -[tool.ruff.isort] -known-first-party = ["holocron", "app"] -known-third-party = ["wandb"] +[tool.ruff.format] +quote-style = "double" +indent-style = "space" [tool.mypy] python_version = "3.9" From 74fead6616b2bc39d8ec9626e9d74f11b6cca8c4 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sun, 23 Jun 2024 19:12:34 +0200 Subject: [PATCH 4/9] style(ruff): fix lint --- .github/collect_env.py | 9 ++++----- docs/source/conf.py | 2 +- holocron/trainer/core.py | 3 +-- holocron/transforms/interpolation.py | 15 +++++++-------- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/.github/collect_env.py b/.github/collect_env.py index 93cb40b22..d55d84f3d 100644 --- a/.github/collect_env.py +++ b/.github/collect_env.py @@ -168,14 +168,13 @@ def get_nvidia_smi(): def get_platform(): if sys.platform.startswith("linux"): return "linux" - elif sys.platform.startswith("win32"): + if sys.platform.startswith("win32"): return "win32" - elif sys.platform.startswith("cygwin"): + if sys.platform.startswith("cygwin"): return "cygwin" - elif sys.platform.startswith("darwin"): + if sys.platform.startswith("darwin"): return "darwin" - else: - return sys.platform + return sys.platform def get_mac_version(run_lambda): diff --git a/docs/source/conf.py b/docs/source/conf.py index ee6c1927b..80f7dbd2b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -220,7 +220,7 @@ def inject_checkpoint_metadata(app, what, name, obj, options, lines): table.extend(( ("url", f"`link <{meta.url}>`__"), ("sha256", meta.sha256[:16]), - ("size", f"{meta.size / 1024 ** 2:.1f}MB"), + ("size", f"{meta.size / 1024**2:.1f}MB"), ("num_params", f"{meta.num_params / 1000000.0:.1f}M"), )) # Wrap the text diff --git a/holocron/trainer/core.py b/holocron/trainer/core.py index 0a4f7c1ae..5f300608a 100644 --- a/holocron/trainer/core.py +++ b/holocron/trainer/core.py @@ -362,8 +362,7 @@ def find_lr( if torch.isnan(batch_loss) or torch.isinf(batch_loss): if batch_idx == 0: raise ValueError("loss value is NaN or inf.") - else: - break + break self.loss_recorder.append(batch_loss.item()) # Stop after the number of iterations if batch_idx + 1 == num_it: diff --git a/holocron/transforms/interpolation.py b/holocron/transforms/interpolation.py index 9d6c740ea..3c5c39a78 100644 --- a/holocron/transforms/interpolation.py +++ b/holocron/transforms/interpolation.py @@ -87,14 +87,13 @@ def get_params(self, image: Union[Image.Image, torch.Tensor]) -> Tuple[int, int] def forward(self, image: Union[Image.Image, torch.Tensor]) -> Union[Image.Image, torch.Tensor]: if self.mode == ResizeMethod.SQUISH: return super().forward(image) - else: - h, w = self.get_params(image) - img = resize(image, (h, w), self.interpolation) - # get the padding - h_pad, w_pad = self.size[0] - h, self.size[1] - w - _padding = w_pad // 2, h_pad // 2, w_pad - w_pad // 2, h_pad - h_pad // 2 - # Fill the rest up to target_size - return pad(img, _padding, padding_mode=self.pad_mode) + h, w = self.get_params(image) + img = resize(image, (h, w), self.interpolation) + # get the padding + h_pad, w_pad = self.size[0] - h, self.size[1] - w + _padding = w_pad // 2, h_pad // 2, w_pad - w_pad // 2, h_pad - h_pad // 2 + # Fill the rest up to target_size + return pad(img, _padding, padding_mode=self.pad_mode) class RandomZoomOut(nn.Module): From d5764b1edfc497ac61ab17ef7757b69067578854 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sun, 23 Jun 2024 19:31:26 +0200 Subject: [PATCH 5/9] docs(makefile): update ruff command --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 81bffa1de..70f3cca53 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ quality: # this target runs checks on all files and potentially modifies some of them style: ruff format . - ruff --fix . + ruff check --fix . # Run tests for the library test: From a8c5cc5abaf1b955d8560431b9c9441e2bfb2301 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sun, 23 Jun 2024 19:31:45 +0200 Subject: [PATCH 6/9] ci(style): switched to uv installer --- .github/workflows/style.yml | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 65e09bc90..ec980050b 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -21,7 +21,8 @@ jobs: architecture: x64 - name: Run ruff run: | - pip install ruff==0.1.14 + python -m pip install --upgrade uv + uv pip install --system -e '.[quality]' ruff --version ruff check --diff . @@ -44,10 +45,8 @@ jobs: key: ${{ runner.os }}-python-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-quality - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e '.[quality]' - - name: Run mypy - run: | + python -m pip install --upgrade uv + uv pip install --system -e '.[quality]' mypy --version mypy @@ -65,7 +64,8 @@ jobs: architecture: x64 - name: Run ruff run: | - pip install ruff==0.1.14 + python -m pip install --upgrade uv + uv pip install --system -e '.[quality]' ruff --version ruff format --check --diff . @@ -83,7 +83,8 @@ jobs: architecture: x64 - name: Run ruff run: | - pip install pre-commit + python -m pip install --upgrade uv + uv pip install --system -e '.[quality]' git checkout -b temp pre-commit install pre-commit --version From 1769d39dbc651eb946c5a6a4fdb08dfd1ad5f24f Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sun, 23 Jun 2024 19:33:51 +0200 Subject: [PATCH 7/9] ci(labeler): update label rules --- .github/labeler.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index ae95e19b1..6d69bf672 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -60,4 +60,10 @@ 'topic: build': - changed-files: - - any-glob-to-any-file: setup.py + - any-glob-to-any-file: + - setup.py + - pyproject.toml + +'topic: style': +- changed-files: + - any-glob-to-any-file: .pre-commit-config.yaml From edb1963308e3d39e8c2c9850bc3af41dd6950785 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sun, 23 Jun 2024 19:36:08 +0200 Subject: [PATCH 8/9] style(ruff): fix lint --- .github/verify_labels.py | 4 +--- holocron/models/classification/sknet.py | 4 +--- holocron/models/classification/tridentnet.py | 4 +--- holocron/models/detection/yolo.py | 4 +--- holocron/models/detection/yolov2.py | 4 +--- holocron/models/detection/yolov4.py | 3 +-- holocron/models/segmentation/unet.py | 6 ++---- holocron/models/segmentation/unet3p.py | 3 +-- holocron/nn/functional.py | 8 ++------ holocron/nn/modules/activation.py | 6 ++---- holocron/nn/modules/attention.py | 3 +-- holocron/nn/modules/conv.py | 4 +--- holocron/nn/modules/downsample.py | 9 +++------ holocron/nn/modules/lambda_layer.py | 3 +-- references/clean_checkpoint.py | 4 +--- 15 files changed, 20 insertions(+), 49 deletions(-) diff --git a/.github/verify_labels.py b/.github/verify_labels.py index 08ba8aa4a..34ef6e630 100644 --- a/.github/verify_labels.py +++ b/.github/verify_labels.py @@ -82,9 +82,7 @@ def parse_args(): ) parser.add_argument("pr", type=int, help="PR number") - args = parser.parse_args() - - return args + return parser.parse_args() if __name__ == "__main__": diff --git a/holocron/models/classification/sknet.py b/holocron/models/classification/sknet.py index 183d954de..2eaf4137c 100644 --- a/holocron/models/classification/sknet.py +++ b/holocron/models/classification/sknet.py @@ -110,9 +110,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: b, m, c = paths.shape[:3] z = self.sa(paths.sum(dim=1)).view(b, m, c, 1, 1) attention_factors = torch.softmax(z, dim=1) - out = (attention_factors * paths).sum(dim=1) - - return out + return (attention_factors * paths).sum(dim=1) class SKBottleneck(_ResBlock): diff --git a/holocron/models/classification/tridentnet.py b/holocron/models/classification/tridentnet.py index c10df0db2..25f354a14 100644 --- a/holocron/models/classification/tridentnet.py +++ b/holocron/models/classification/tridentnet.py @@ -41,7 +41,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dilations = [1] * self.num_branches if self.dilation[0] == 1 else [1 + idx for idx in range(self.num_branches)] # Use shared weight to apply the convolution - out = torch.cat( + return torch.cat( [ F.conv2d( _x, @@ -57,8 +57,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: 1, ) - return out - class Tridentneck(_ResBlock): expansion: int = 4 diff --git a/holocron/models/detection/yolo.py b/holocron/models/detection/yolo.py index d1c82f6b3..e8a9e936c 100644 --- a/holocron/models/detection/yolo.py +++ b/holocron/models/detection/yolo.py @@ -339,9 +339,7 @@ def _format_outputs(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: def _forward(self, x: Tensor) -> Tensor: out = self.backbone(x) out = self.block4(out) - out = self.classifier(out) - - return out + return self.classifier(out) def forward( self, x: Tensor, target: Optional[List[Dict[str, Tensor]]] = None diff --git a/holocron/models/detection/yolov2.py b/holocron/models/detection/yolov2.py index 594033d00..6ee2f6600 100644 --- a/holocron/models/detection/yolov2.py +++ b/holocron/models/detection/yolov2.py @@ -208,9 +208,7 @@ def _forward(self, x: Tensor) -> Tensor: out = torch.cat((passthrough, out), 1) out = self.block6(out) - out = self.head(out) - - return out + return self.head(out) def forward( self, x: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], target: Optional[List[Dict[str, Tensor]]] = None diff --git a/holocron/models/detection/yolov4.py b/holocron/models/detection/yolov4.py index 955cae3f5..11a5838f6 100644 --- a/holocron/models/detection/yolov4.py +++ b/holocron/models/detection/yolov4.py @@ -628,7 +628,7 @@ def forward( y3 = self.yolo3(o3, target) if not self.training: - detections = [ + return [ { "boxes": torch.cat((det1["boxes"], det2["boxes"], det3["boxes"]), dim=0), "scores": torch.cat((det1["scores"], det2["scores"], det3["scores"]), dim=0), @@ -636,7 +636,6 @@ def forward( } for det1, det2, det3 in zip(y1, y2, y3) ] - return detections return {k: y1[k] + y2[k] + y3[k] for k in y1} diff --git a/holocron/models/segmentation/unet.py b/holocron/models/segmentation/unet.py index 89c36790f..2df1cf9d5 100644 --- a/holocron/models/segmentation/unet.py +++ b/holocron/models/segmentation/unet.py @@ -223,8 +223,7 @@ def forward(self, x: Tensor) -> Tensor: x = decoder(xs.pop(), x) # Classifier - x = self.classifier(x) - return x + return self.classifier(x) class UBlock(nn.Module): @@ -368,8 +367,7 @@ def forward(self, x: Tensor) -> Tensor: x = self.upsample(x) # Classifier - x = self.classifier(x) - return x + return self.classifier(x) def _unet(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> UNet: diff --git a/holocron/models/segmentation/unet3p.py b/holocron/models/segmentation/unet3p.py index dc1d90d3a..f616f19cd 100644 --- a/holocron/models/segmentation/unet3p.py +++ b/holocron/models/segmentation/unet3p.py @@ -155,8 +155,7 @@ def forward(self, x: Tensor) -> Tensor: xs[idx] = self.decoder[idx](xs[:idx], xs[idx], xs[idx + 1 :]) # Classifier - x = self.classifier(xs[0]) - return x + return self.classifier(xs[0]) def _unet(arch: str, pretrained: bool, progress: bool, **kwargs: Any) -> nn.Module: diff --git a/holocron/nn/functional.py b/holocron/nn/functional.py index 3f0245ab0..41cfb0ee9 100644 --- a/holocron/nn/functional.py +++ b/holocron/nn/functional.py @@ -133,9 +133,7 @@ def concat_downsample2d(x: Tensor, scale_factor: int) -> Tensor: # N * C * H * W --> N * C * (H/scale_factor) * scale_factor * (W/scale_factor) * scale_factor x = x.view(b, c, h // scale_factor, scale_factor, w // scale_factor, scale_factor) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() - x = x.view(b, int(c * scale_factor**2), h // scale_factor, w // scale_factor) - - return x + return x.view(b, int(c * scale_factor**2), h // scale_factor, w // scale_factor) def z_pool(x: Tensor, dim: int) -> Tensor: @@ -364,9 +362,7 @@ def _xcorr2d( h = floor((h + (2 * padding[0]) - (dilation[0] * (weight.shape[-2] - 1)) - 1) / stride[0] + 1) w = floor((w + (2 * padding[1]) - (dilation[1] * (weight.shape[-1] - 1)) - 1) / stride[1] + 1) - x = x.view(-1, weight.shape[0], h, w) - - return x + return x.view(-1, weight.shape[0], h, w) def _convNd(x: Tensor, weight: Tensor) -> Tensor: diff --git a/holocron/nn/modules/activation.py b/holocron/nn/modules/activation.py index 539f1683f..5bbd6ba0c 100644 --- a/holocron/nn/modules/activation.py +++ b/holocron/nn/modules/activation.py @@ -22,8 +22,7 @@ def __init__(self, inplace: bool = False) -> None: self.inplace = inplace def extra_repr(self) -> str: - inplace_str = "inplace=True" if self.inplace else "" - return inplace_str + return "inplace=True" if self.inplace else "" class HardMish(_Activation): @@ -80,5 +79,4 @@ def __init__(self, in_channels: int, kernel_size: int = 3) -> None: def forward(self, x: Tensor) -> Tensor: out = self.conv(x) out = self.bn(out) - x = torch.max(x, out) - return x + return torch.max(x, out) diff --git a/holocron/nn/modules/attention.py b/holocron/nn/modules/attention.py index e01fb6637..1475c34c7 100644 --- a/holocron/nn/modules/attention.py +++ b/holocron/nn/modules/attention.py @@ -74,5 +74,4 @@ def forward(self, x: Tensor) -> Tensor: x_h = cast(Tensor, self.h_branch(x)) x_w = cast(Tensor, self.w_branch(x)) - out = (x_c + x_h + x_w) / 3 - return out + return (x_c + x_h + x_w) / 3 diff --git a/holocron/nn/modules/conv.py b/holocron/nn/modules/conv.py index fb588a2b1..192b80188 100644 --- a/holocron/nn/modules/conv.py +++ b/holocron/nn/modules/conv.py @@ -496,6 +496,4 @@ def forward(self, x: Tensor) -> Tensor: # Multiply-Add operation # --> (N, C, H // s, W // s) - out = (kernel * x_unfolded).sum(dim=3).view(*x.shape[:2], *kernel.shape[-2:]) - - return out + return (kernel * x_unfolded).sum(dim=3).view(*x.shape[:2], *kernel.shape[-2:]) diff --git a/holocron/nn/modules/downsample.py b/holocron/nn/modules/downsample.py index a7f766a5e..db1e2eedf 100644 --- a/holocron/nn/modules/downsample.py +++ b/holocron/nn/modules/downsample.py @@ -74,8 +74,7 @@ def forward(self, x: Tensor) -> Tensor: return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) def extra_repr(self) -> str: - inplace_str = "flatten=True" if self.flatten else "" - return inplace_str + return "flatten=True" if self.flatten else "" class GlobalMaxPool2d(nn.Module): @@ -97,13 +96,11 @@ def forward(self, x: Tensor) -> Tensor: return x.view(x.size(0), x.size(1), -1).max(-1).values.view(x.size(0), x.size(1), 1, 1) def extra_repr(self) -> str: - inplace_str = "flatten=True" if self.flatten else "" - return inplace_str + return "flatten=True" if self.flatten else "" def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int: - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding + return ((stride - 1) + dilation * (kernel_size - 1)) // 2 class BlurPool2d(nn.Module): diff --git a/holocron/nn/modules/lambda_layer.py b/holocron/nn/modules/lambda_layer.py index a0dd4e02d..edc4717b7 100644 --- a/holocron/nn/modules/lambda_layer.py +++ b/holocron/nn/modules/lambda_layer.py @@ -105,5 +105,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Y = Yc + Yp # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W - out = Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w) - return out + return Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w) diff --git a/references/clean_checkpoint.py b/references/clean_checkpoint.py index 6400172c6..9d3e30120 100644 --- a/references/clean_checkpoint.py +++ b/references/clean_checkpoint.py @@ -27,9 +27,7 @@ def parse_args(): parser.add_argument("checkpoint", type=str, help="path to the training checkpoint") parser.add_argument("outfile", type=str, help="model") - args = parser.parse_args() - - return args + return parser.parse_args() if __name__ == "__main__": From 81928708ebaaed9f5d58d35575cee9e09917ac83 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sun, 23 Jun 2024 19:38:06 +0200 Subject: [PATCH 9/9] style(api): fix async --- api/app/routes/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/routes/classification.py b/api/app/routes/classification.py index a484e9af6..7b8a0c3fb 100644 --- a/api/app/routes/classification.py +++ b/api/app/routes/classification.py @@ -13,7 +13,7 @@ @router.post("/", status_code=status.HTTP_200_OK, summary="Perform image classification") -async def classify(file: UploadFile = File(...)) -> ClsCandidate: +def classify(file: UploadFile = File(...)) -> ClsCandidate: """Runs holocron vision model to analyze the input image""" probs = classify_image(decode_image(file.file.read()))