Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
Browse files Browse the repository at this point in the history
…poch_reject
  • Loading branch information
withmywoessner committed Jan 10, 2024
2 parents 3ece314 + 16c17b4 commit db0cf11
Show file tree
Hide file tree
Showing 16 changed files with 351 additions and 247 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12336.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow :meth:`mne.io.Raw.interpolate_bads` and :meth:`mne.Epochs.interpolate_bads` to work on ``ecog`` and ``seeg`` data; for ``seeg`` data a spline is fit to neighboring electrode contacts on the same shaft, by `Alex Rockhill`_
1 change: 1 addition & 0 deletions doc/changes/devel/12343.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up raw FIF reading when using small buffer sizes by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/devel/12345.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix clicking on an axis of :func:`mne.viz.plot_evoked_topo` when multiple vertical lines ``vlines`` are used, by `Mathieu Scheltienne`_.
29 changes: 17 additions & 12 deletions mne/_fiff/open.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

from ..utils import _file_like, logger, verbose, warn
from .constants import FIFF
from .tag import Tag, _call_dict_names, _matrix_info, read_tag, read_tag_info
from .tag import (
Tag,
_call_dict_names,
_matrix_info,
_read_tag_header,
read_tag,
)
from .tree import dir_tree_find, make_dir_tree


Expand Down Expand Up @@ -139,7 +145,7 @@ def _fiff_open(fname, fid, preload):
with fid as fid_old:
fid = BytesIO(fid_old.read())

tag = read_tag_info(fid)
tag = _read_tag_header(fid, 0)

# Check that this looks like a fif file
prefix = f"file {repr(fname)} does not"
Expand All @@ -152,7 +158,7 @@ def _fiff_open(fname, fid, preload):
if tag.size != 20:
raise ValueError(f"{prefix} start with a file id tag")

tag = read_tag(fid)
tag = read_tag(fid, tag.next_pos)

if tag.kind != FIFF.FIFF_DIR_POINTER:
raise ValueError(f"{prefix} have a directory pointer")
Expand All @@ -176,16 +182,15 @@ def _fiff_open(fname, fid, preload):
directory = dir_tag.data
read_slow = False
if read_slow:
fid.seek(0, 0)
pos = 0
fid.seek(pos, 0)
directory = list()
while tag.next >= 0:
pos = fid.tell()
tag = read_tag_info(fid)
while pos is not None:
tag = _read_tag_header(fid, pos)
if tag is None:
break # HACK : to fix file ending with empty tag...
else:
tag.pos = pos
directory.append(tag)
pos = tag.next_pos
directory.append(tag)

tree, _ = make_dir_tree(fid, directory)

Expand Down Expand Up @@ -309,7 +314,7 @@ def _show_tree(
for k, kn, size, pos, type_ in zip(kinds[:-1], kinds[1:], sizes, poss, types):
if not tag_found and k != tag_id:
continue
tag = Tag(k, size, 0, pos)
tag = Tag(kind=k, type=type_, size=size, next=FIFF.FIFFV_NEXT_NONE, pos=pos)
if read_limit is None or size <= read_limit:
try:
tag = read_tag(fid, pos)
Expand Down Expand Up @@ -348,7 +353,7 @@ def _show_tree(
)
else:
postpend += " ... type=" + str(type(tag.data))
postpend = ">" * 20 + "BAD" if not good else postpend
postpend = ">" * 20 + f"BAD @{pos}" if not good else postpend
matrix_info = _matrix_info(tag)
if matrix_info is not None:
_, type_, _, _ = matrix_info
Expand Down
97 changes: 38 additions & 59 deletions mne/_fiff/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import html
import re
import struct
from dataclasses import dataclass
from functools import partial
from typing import Any

import numpy as np
from scipy.sparse import csc_matrix, csr_matrix
Expand All @@ -28,40 +30,16 @@
# HELPERS


@dataclass
class Tag:
"""Tag in FIF tree structure.
"""Tag in FIF tree structure."""

Parameters
----------
kind : int
Kind of Tag.
type_ : int
Type of Tag.
size : int
Size in bytes.
int : next
Position of next Tag.
pos : int
Position of Tag is the original file.
"""

def __init__(self, kind, type_, size, next, pos=None):
self.kind = int(kind)
self.type = int(type_)
self.size = int(size)
self.next = int(next)
self.pos = pos if pos is not None else next
self.pos = int(self.pos)
self.data = None

def __repr__(self): # noqa: D105
attrs = list()
for attr in ("kind", "type", "size", "next", "pos", "data"):
try:
attrs.append(f"{attr} {getattr(self, attr)}")
except AttributeError:
pass
return "<Tag | " + " - ".join(attrs) + ">"
kind: int
type: int
size: int
next: int
pos: int
data: Any = None

def __eq__(self, tag): # noqa: D105
return int(
Expand All @@ -73,17 +51,15 @@ def __eq__(self, tag): # noqa: D105
and self.data == tag.data
)


def read_tag_info(fid):
"""Read Tag info (or header)."""
tag = _read_tag_header(fid)
if tag is None:
return None
if tag.next == 0:
fid.seek(tag.size, 1)
elif tag.next > 0:
fid.seek(tag.next, 0)
return tag
@property
def next_pos(self):
"""The next tag position."""
if self.next == FIFF.FIFFV_NEXT_SEQ: # 0
return self.pos + 16 + self.size
elif self.next > 0:
return self.next
else: # self.next should be -1 if we get here
return None # safest to return None so that things like fid.seek die


def _frombuffer_rows(fid, tag_size, dtype=None, shape=None, rlims=None):
Expand Down Expand Up @@ -157,16 +133,18 @@ def _loc_to_eeg_loc(loc):
# by the function names.


def _read_tag_header(fid):
def _read_tag_header(fid, pos):
"""Read only the header of a Tag."""
s = fid.read(4 * 4)
fid.seek(pos, 0)
s = fid.read(16)
if len(s) != 16:
where = fid.tell() - len(s)
extra = f" in file {fid.name}" if hasattr(fid, "name") else ""
warn(f"Invalid tag with only {len(s)}/16 bytes at position {where}{extra}")
return None
# struct.unpack faster than np.frombuffer, saves ~10% of time some places
return Tag(*struct.unpack(">iIii", s))
kind, type_, size, next_ = struct.unpack(">iIii", s)
return Tag(kind, type_, size, next_, pos)


def _read_matrix(fid, tag, shape, rlims):
Expand All @@ -178,10 +156,10 @@ def _read_matrix(fid, tag, shape, rlims):

matrix_coding, matrix_type, bit, dtype = _matrix_info(tag)

pos = tag.pos + 16
fid.seek(pos + tag.size - 4, 0)
if matrix_coding == "dense":
# Find dimensions and return to the beginning of tag data
pos = fid.tell()
fid.seek(tag.size - 4, 1)
ndim = int(np.frombuffer(fid.read(4), dtype=">i4").item())
fid.seek(-(ndim + 1) * 4, 1)
dims = np.frombuffer(fid.read(4 * ndim), dtype=">i4")[::-1]
Expand All @@ -205,8 +183,6 @@ def _read_matrix(fid, tag, shape, rlims):
data.shape = dims
else:
# Find dimensions and return to the beginning of tag data
pos = fid.tell()
fid.seek(tag.size - 4, 1)
ndim = int(np.frombuffer(fid.read(4), dtype=">i4").item())
fid.seek(-(ndim + 2) * 4, 1)
dims = np.frombuffer(fid.read(4 * (ndim + 1)), dtype=">i4")
Expand Down Expand Up @@ -388,7 +364,16 @@ def _read_old_pack(fid, tag, shape, rlims):

def _read_dir_entry_struct(fid, tag, shape, rlims):
"""Read dir entry struct tag."""
return [_read_tag_header(fid) for _ in range(tag.size // 16 - 1)]
pos = tag.pos + 16
entries = list()
for offset in range(1, tag.size // 16):
ent = _read_tag_header(fid, pos + offset * 16)
# The position of the real tag on disk is stored in the "next" entry within the
# directory, so we need to overwrite ent.pos. For safety let's also overwrite
# ent.next to point nowhere
ent.pos, ent.next = ent.next, FIFF.FIFFV_NEXT_NONE
entries.append(ent)
return entries


def _read_julian(fid, tag, shape, rlims):
Expand Down Expand Up @@ -439,7 +424,7 @@ def _read_julian(fid, tag, shape, rlims):
_call_dict_names[key] = dtype


def read_tag(fid, pos=None, shape=None, rlims=None):
def read_tag(fid, pos, shape=None, rlims=None):
"""Read a Tag from a file at a given position.
Parameters
Expand All @@ -462,9 +447,7 @@ def read_tag(fid, pos=None, shape=None, rlims=None):
tag : Tag
The Tag read.
"""
if pos is not None:
fid.seek(pos, 0)
tag = _read_tag_header(fid)
tag = _read_tag_header(fid, pos)
if tag is None:
return tag
if tag.size > 0:
Expand All @@ -477,10 +460,6 @@ def read_tag(fid, pos=None, shape=None, rlims=None):
except KeyError:
raise Exception(f"Unimplemented tag data type {tag.type}") from None
tag.data = fun(fid, tag, shape, rlims)
if tag.next != FIFF.FIFFV_NEXT_SEQ:
# f.seek(tag.next,0)
fid.seek(tag.next, 1) # XXX : fix? pb when tag.next < 0

return tag


Expand Down
47 changes: 1 addition & 46 deletions mne/_fiff/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np

from ..utils import logger, verbose
from .constants import FIFF
from .tag import Tag, read_tag
from .write import _write, end_block, start_block, write_id
from .tag import read_tag


def dir_tree_find(tree, kind):
Expand Down Expand Up @@ -108,46 +106,3 @@ def make_dir_tree(fid, directory, start=0, indent=0, verbose=None):
logger.debug(" " * indent + "end } %d" % block)
last = this
return tree, last


###############################################################################
# Writing


def copy_tree(fidin, in_id, nodes, fidout):
"""Copy directory subtrees from fidin to fidout."""
if len(nodes) <= 0:
return

if not isinstance(nodes, list):
nodes = [nodes]

for node in nodes:
start_block(fidout, node["block"])
if node["id"] is not None:
if in_id is not None:
write_id(fidout, FIFF.FIFF_PARENT_FILE_ID, in_id)

write_id(fidout, FIFF.FIFF_BLOCK_ID, in_id)
write_id(fidout, FIFF.FIFF_PARENT_BLOCK_ID, node["id"])

if node["directory"] is not None:
for d in node["directory"]:
# Do not copy these tags
if (
d.kind == FIFF.FIFF_BLOCK_ID
or d.kind == FIFF.FIFF_PARENT_BLOCK_ID
or d.kind == FIFF.FIFF_PARENT_FILE_ID
):
continue

# Read and write tags, pass data through transparently
fidin.seek(d.pos, 0)
tag = Tag(*np.fromfile(fidin, (">i4,>I4,>i4,>i4"), 1)[0])
tag.data = np.fromfile(fidin, ">B", tag.size)
_write(fidout, tag.data, tag.kind, 1, tag.type, ">B")

for child in node["children"]:
copy_tree(fidin, in_id, child, fidout)

end_block(fidout, node["block"])
38 changes: 27 additions & 11 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,9 +870,12 @@ def interpolate_bads(
.. versionadded:: 0.9.0
"""
from .interpolation import (
_interpolate_bads_ecog,
_interpolate_bads_eeg,
_interpolate_bads_meeg,
_interpolate_bads_nan,
_interpolate_bads_nirs,
_interpolate_bads_seeg,
)

_check_preload(self, "interpolation")
Expand All @@ -894,35 +897,48 @@ def interpolate_bads(
"eeg": ("spline", "MNE", "nan"),
"meg": ("MNE", "nan"),
"fnirs": ("nearest", "nan"),
"ecog": ("spline", "nan"),
"seeg": ("spline", "nan"),
}
for key in method:
_check_option("method[key]", key, ("meg", "eeg", "fnirs"))
_check_option("method[key]", key, tuple(valids))
_check_option(f"method['{key}']", method[key], valids[key])
logger.info("Setting channel interpolation method to %s.", method)
idx = _picks_to_idx(self.info, list(method), exclude=(), allow_empty=True)
if idx.size == 0 or len(pick_info(self.info, idx)["bads"]) == 0:
warn("No bad channels to interpolate. Doing nothing...")
return self
for ch_type in method.copy():
idx = _picks_to_idx(self.info, ch_type, exclude=(), allow_empty=True)
if len(pick_info(self.info, idx)["bads"]) == 0:
method.pop(ch_type)
logger.info("Interpolating bad channels.")
origin = _check_origin(origin, self.info)
needs_origin = [key != "seeg" and val != "nan" for key, val in method.items()]
if any(needs_origin):
origin = _check_origin(origin, self.info)
for ch_type, interp in method.items():
if interp == "nan":
_interpolate_bads_nan(self, ch_type, exclude=exclude)
if method.get("eeg", "") == "spline":
_interpolate_bads_eeg(self, origin=origin, exclude=exclude)
eeg_mne = False
elif "eeg" not in method:
eeg_mne = False
else:
eeg_mne = True
if "meg" in method or eeg_mne:
meg_mne = method.get("meg", "") == "MNE"
eeg_mne = method.get("eeg", "") == "MNE"
if meg_mne or eeg_mne:
_interpolate_bads_meeg(
self,
mode=mode,
origin=origin,
meg=meg_mne,
eeg=eeg_mne,
origin=origin,
exclude=exclude,
method=method,
)
if "fnirs" in method:
_interpolate_bads_nirs(self, exclude=exclude, method=method["fnirs"])
if method.get("fnirs", "") == "nearest":
_interpolate_bads_nirs(self, exclude=exclude)
if method.get("ecog", "") == "spline":
_interpolate_bads_ecog(self, origin=origin, exclude=exclude)
if method.get("seeg", "") == "spline":
_interpolate_bads_seeg(self, exclude=exclude)

if reset_bads is True:
if "nan" in method.values():
Expand Down
Loading

0 comments on commit db0cf11

Please sign in to comment.