Skip to content

Commit

Permalink
Fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
euxhenh committed Jul 29, 2023
1 parent fa8377b commit 7f4b55e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/grinch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self, cfg: Config, /):
self._reporter = reporter

@property
def logs_path(self) -> Path:
def logs_path(self) -> Path | None:
return self.cfg.logs_path

@contextmanager
Expand All @@ -131,12 +131,12 @@ def interactive(self, save_path: str | Path | None = None, **kwargs):
plt.ioff()

if all_not_None(self.logs_path, save_path):
self.logs_path.mkdir(parents=True, exist_ok=True)
self.logs_path.mkdir(parents=True, exist_ok=True) # type: ignore
# Set good defaults
kwargs.setdefault('dpi', 300)
kwargs.setdefault('bbox_inches', 'tight')
kwargs.setdefault('transparent', True)
plt.savefig(self.logs_path / save_path, **kwargs)
plt.savefig(self.logs_path / save_path, **kwargs) # type: ignore

plt.clf()
plt.show()
Expand Down
4 changes: 2 additions & 2 deletions src/grinch/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import matplotlib.pyplot as plt
import numpy.typing as npt
import numpy as np
import seaborn as sns
from scipy.stats import norm

Expand All @@ -11,7 +11,7 @@


def plot1d(
rvs: npt.ArrayLike,
rvs: np.ndarray,
dist: str,
*,
title: str | None = None,
Expand Down
24 changes: 19 additions & 5 deletions src/grinch/utils/stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from dataclasses import dataclass
from functools import wraps
from typing import Any, Dict, Hashable, List, Optional, Tuple, overload
from typing import Any, Dict, Hashable, List, Tuple, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -101,11 +101,25 @@ def _var(x, axis=None, ddof=0, mean=None):
return var


@overload
def mean_var(
x: npt.ArrayLike,
axis: Optional[int] = None,
ddof: int = 0
) -> Tuple[int | np.ndarray, int | np.ndarray]:
axis: None = None,
ddof: int = 0,
) -> Tuple[float, float]:
...


@overload
def mean_var(
x: npt.ArrayLike,
axis: int,
ddof: int = 0,
) -> Tuple[np.ndarray, np.ndarray]:
...


def mean_var(x, axis=None, ddof=0):
"""Returns both mean and variance.
Parameters
Expand Down Expand Up @@ -135,7 +149,7 @@ def mean_var(
def ttest(
a: npt.ArrayLike,
b: npt.ArrayLike,
axis: Optional[int] = 0
axis: int | None = 0
) -> Tuple[np.ndarray, np.ndarray]:
"""Performs a Welch's t-test (unequal sample sizes, unequal vars).
Extends scipy's ttest_ind to support sparse matrices.
Expand Down

0 comments on commit 7f4b55e

Please sign in to comment.