Skip to content

Commit

Permalink
Multicam DLC project support (#834)
Browse files Browse the repository at this point in the history
* allow for  multicamera epochs

* change VideoFile definition

* update video_file_num iteration

* add update methods

* change VideoFile restriction in get_video_path

* modify formatting in position_dlc_project

* allow selective video selection in pose estimation

* add NotImplementedError to DLCProject insert

* allow for video file addition after proj creation

* fix add_video_files

* fix video_file_num determination in VideoFile

* modify add_video_files method

* change call to add_video_files

* modify interval_list_name call

* Fix linting errors

* Tested DLC multicam pipeline  (#841)

* Fault-permit insert and remove mutual exclusivity protections on Merge (#824)

* #719, #804, #212

* #768

* Add merge delete and populate

* Changes following PR review @edeno

* Replace delayed import of ImportedSpikeSorting

* Update CITATION.cff (#826)

* Update CITATION.cff

* Update change log

* Update ref

* Add MUA notebook and fix numbering.

* Only apply include labels filter if include_labels not empty (#827)

* dont skip unit if include_labels list is empty

* update check for np array size

* gh-actions docs fixes (#828)

* Update 'latest' in docs deploy

* Docs bugfix. Rename Link action, incorporate cspell.

* MUA as own heading

* Update README.md

* Fix citation

* Fix citation

* include all relevant restrictions on video file

* Proposed structure for user roles. (#832)

* Add roles

* Remove use of unix user group. Add note for retroactive role assign

* Add docs on roles and external tables. Reduce key length

* Fix test for update of position tools (#835)

Related to single LED halving the data bug

* fix build error in mamba and restriction for dlc

* flush stdout before converting mp4

* Fix notebook name (#840)

* remove deprecated yaml.safe_load function

* Fix test

* replace deprecated yaml.safe_load function

* only call no_transaction_make if video key not present
---------

Co-authored-by: dpeg22 <[email protected]>
Co-authored-by: CBroz1 <[email protected]>
Co-authored-by: Samuel Bray <[email protected]>
Co-authored-by: Chris Brozdowski <[email protected]>
  • Loading branch information
5 people authored Feb 21, 2024
1 parent e450392 commit eb25feb
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 62 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

- Add user roles to `database_settings.py`. #832

### Pipelines

- Position:
- Fixes to `environment-dlc.yml` restricting tensortflow #834
- Video restriction for multicamera epochs #834
- Fixes to `_convert_mp4` #834
- Replace deprecated calls to `yaml.safe_load()` #834

## [0.5.0] (February 9, 2024)

### Infrastructure
Expand Down
2 changes: 1 addition & 1 deletion environment_dlc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ dependencies:
- libgcc # dlc-only
- matplotlib
- non_local_detector
- numpy<1.24
- pip>=20.2.*
- position_tools
- pybind11 # req by mountainsort4 -> isosplit5
Expand All @@ -47,4 +46,5 @@ dependencies:
- pynwb>=2.2.0,<3
- sortingview>=0.11
- spikeinterface>=0.98.2,<0.99
- tensorflow<=2.12 # dlc-only
- .[dlc]
9 changes: 7 additions & 2 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib
import re
from functools import reduce
from typing import Dict

Expand Down Expand Up @@ -391,7 +392,7 @@ def _no_transaction_make(self, key, verbose=True):
"interval_list_name": interval_list_name,
}
).fetch1("valid_times")

cam_device_str = r"camera_device (\d+)"
is_found = False
for ind, video in enumerate(videos.values()):
if isinstance(video, pynwb.image.ImageSeries):
Expand All @@ -404,7 +405,11 @@ def _no_transaction_make(self, key, verbose=True):
interval_list_contains(valid_times, video_obj.timestamps)
> 0.9 * len(video_obj.timestamps)
):
key["video_file_num"] = ind
nwb_cam_device = video_obj.device.name
# returns whatever was captured in the first group (within the parentheses) of the regular expression -- in this case, 0
key["video_file_num"] = int(
re.match(cam_device_str, nwb_cam_device)[1]
)
camera_name = video_obj.device.camera_name
if CameraDevice & {"camera_name": camera_name}:
key["camera_name"] = video_obj.device.camera_name
Expand Down
3 changes: 2 additions & 1 deletion src/spyglass/position/v1/dlc_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def pkl(self):
def yml(self):
if self._yml is None:
with open(self.yml_path, "rb") as f:
self._yml = yaml.safe_load(f)
safe_yaml = yaml.YAML(typ="safe", pure=True)
self._yml = safe_yaml.load(f)
return self._yml

@property
Expand Down
44 changes: 25 additions & 19 deletions src/spyglass/position/v1/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pathlib
import pwd
import subprocess
import sys
from collections import abc
from contextlib import redirect_stdout
from itertools import groupby
Expand All @@ -18,6 +19,8 @@
import pandas as pd
from tqdm import tqdm as tqdm

from spyglass.common.common_behav import VideoFile
from spyglass.utils import logger
from spyglass.settings import dlc_output_dir, dlc_video_dir, raw_dir


Expand Down Expand Up @@ -418,10 +421,9 @@ def get_video_path(key):
"""
import pynwb

from ...common.common_behav import VideoFile

vf_key = {"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]}
VideoFile()._no_transaction_make(vf_key, verbose=False)
vf_key = {k: val for k, val in key.items() if k in VideoFile.heading.names}
if not VideoFile & vf_key:
VideoFile()._no_transaction_make(vf_key, verbose=False)
video_query = VideoFile & vf_key

if len(video_query) != 1:
Expand All @@ -434,9 +436,7 @@ def get_video_path(key):
with pynwb.NWBHDF5IO(path=nwb_path, mode="r") as in_out:
nwb_file = in_out.read()
nwb_video = nwb_file.objects[video_info["video_file_object_id"]]
video_filepath = VideoFile.get_abs_path(
{"nwb_file_name": key["nwb_file_name"], "epoch": key["epoch"]}
)
video_filepath = VideoFile.get_abs_path(vf_key)
video_dir = os.path.dirname(video_filepath) + "/"
video_filename = video_filepath.split(video_dir)[-1]
meters_per_pixel = nwb_video.device.meters_per_pixel
Expand Down Expand Up @@ -540,18 +540,24 @@ def _convert_mp4(
"copy",
f"{dest_path.as_posix()}",
]
try:
convert_process = subprocess.Popen(
convert_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
except subprocess.CalledProcessError as err:
raise RuntimeError(
f"command {err.cmd} return with error (code {err.returncode}): {err.output}"
) from err
out, _ = convert_process.communicate()
print(out.decode("utf-8"))
print(f"finished converting {filename}")
print(
if dest_path.exists():
logger.info(f"{dest_path} already exists, skipping conversion")
else:
try:
sys.stdout.flush()
convert_process = subprocess.Popen(
convert_command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
except subprocess.CalledProcessError as err:
raise RuntimeError(
f"command {err.cmd} return with error (code {err.returncode}): {err.output}"
) from err
out, _ = convert_process.communicate()
logger.info(out.decode("utf-8"))
logger.info(f"finished converting {filename}")
logger.info(
f"Checking that number of packets match between {orig_filename} and {dest_filename}"
)
num_packets = []
Expand Down
3 changes: 2 additions & 1 deletion src/spyglass/position/v1/position_dlc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def make(self, key):
raise OSError(f"config_path {config_path} does not exist.")
if config_path.suffix in (".yml", ".yaml"):
with open(config_path, "rb") as f:
dlc_config = yaml.safe_load(f)
safe_yaml = yaml.YAML(typ="safe", pure=True)
dlc_config = safe_yaml.load(f)
if isinstance(params["params"], dict):
dlc_config.update(params["params"])
del params["params"]
Expand Down
15 changes: 8 additions & 7 deletions src/spyglass/position/v1/position_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

from spyglass.common.common_behav import ( # noqa: F401
RawPosition,
VideoFile,
convert_epoch_interval_name_to_position_interval_name,
)
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.utils.dj_mixin import SpyglassMixin

from ...common.common_nwbfile import AnalysisNwbfile
from ...utils.dj_mixin import SpyglassMixin
from .dlc_utils import OutputLogger, infer_output_dir
from .position_dlc_model import DLCModel

Expand Down Expand Up @@ -87,10 +86,11 @@ def insert_estimation_task(
Parameters
----------
key: DataJoint key specifying a pairing of VideoRecording and Model.
task_mode (bool): Default 'trigger' computation. Or 'load' existing results.
task_mode (bool): Default 'trigger' computation.
Or 'load' existing results.
params (dict): Optional. Parameters passed to DLC's analyze_videos:
videotype, gputouse, save_as_csv, batchsize, cropping, TFGPUinference,
dynamic, robust_nframes, allow_growth, use_shelve
videotype, gputouse, save_as_csv, batchsize, cropping,
TFGPUinference, dynamic, robust_nframes, allow_growth, use_shelve
"""
from .dlc_utils import check_videofile, get_video_path

Expand Down Expand Up @@ -261,7 +261,8 @@ def make(self, key):
del key["meters_per_pixel"]
body_parts = dlc_result.df.columns.levels[0]
body_parts_df = {}
# Insert dlc pose estimation into analysis NWB file for each body part.
# Insert dlc pose estimation into analysis NWB file for
# each body part.
for body_part in bodyparts:
if body_part in body_parts:
body_parts_df[body_part] = pd.DataFrame.from_dict(
Expand Down
112 changes: 81 additions & 31 deletions src/spyglass/position/v1/position_dlc_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def insert_new_project(
groupname: str = None,
project_directory: str = dlc_project_dir,
output_path: str = dlc_video_dir,
set_permissions=False,
**kwargs,
):
"""Insert a new project into DLCProject table.
Expand All @@ -251,9 +250,6 @@ def insert_new_project(
output_path : str
target path to output converted videos
(Default is '/nimbus/deeplabcut/videos/')
set_permissions : bool
if True, will set permissions for user and group to be read+write
(Default is False)
"""
project_names_in_use = np.unique(cls.fetch("project_name"))
if project_name in project_names_in_use:
Expand Down Expand Up @@ -338,24 +334,6 @@ def insert_new_project(
"config_path": config_path,
"frames_per_video": frames_per_video,
}
# TODO: make permissions setting more flexible.
if set_permissions:
permissions = (
stat.S_IRUSR
| stat.S_IWUSR
| stat.S_IRGRP
| stat.S_IWGRP
| stat.S_IROTH
)
username = getpass.getuser()
if not groupname:
groupname = username
_set_permissions(
directory=project_directory,
mode=permissions,
username=username,
groupname=groupname,
)
cls.insert1(key, **kwargs)
cls.BodyPart.insert(
[
Expand All @@ -375,9 +353,75 @@ def insert_new_project(
config_path = config_path.as_posix()
return {"project_name": project_name, "config_path": config_path}

@classmethod
def add_video_files(
cls,
video_list,
config_path=None,
key=None,
output_path: str = os.getenv("DLC_VIDEO_PATH"),
add_new=False,
add_to_files=True,
**kwargs,
):
has_config_or_key = bool(config_path) or bool(key)

if add_new and not has_config_or_key:
raise ValueError("If add_new, must provide key or config_path")
config_path = config_path or (cls & key).fetch1("config_path")

if (
add_to_files
and not key
and len(cls & {"config_path": config_path}) != 1
):
raise ValueError("Cannot set add_to_files=True without passing key")

if all(isinstance(n, Dict) for n in video_list):
videos_to_convert = [
get_video_path(video_key) for video_key in video_list
]
videos = [
check_videofile(
video_path=video[0],
output_path=output_path,
video_filename=video[1],
)[0].as_posix()
for video in videos_to_convert
]
# If not dict, assume list of video file paths
# that may or may not need to be converted
else:
videos = []
if not all([Path(video).exists() for video in video_list]):
raise OSError("at least one file in video_list does not exist")
for video in video_list:
video_path = Path(video).parent
video_filename = video.rsplit(
video_path.as_posix(), maxsplit=1
)[-1].split("/")[-1]
videos.append(
check_videofile(
video_path=video_path,
output_path=output_path,
video_filename=video_filename,
)[0].as_posix()
)
if len(videos) < 1:
raise ValueError(f"no .mp4 videos found in{video_path}")
if add_new:
from deeplabcut import add_new_videos

add_new_videos(config=config_path, videos=videos, copy_videos=True)
if add_to_files:
# Add videos to training files
cls.add_training_files(key, **kwargs)
return videos

@classmethod
def add_training_files(cls, key, **kwargs):
"""Add training videos and labeled frames .h5 and .csv to DLCProject.File"""
"""Add training videos and labeled frames .h5
and .csv to DLCProject.File"""
config_path = (cls & {"project_name": key["project_name"]}).fetch1(
"config_path"
)
Expand All @@ -394,7 +438,8 @@ def add_training_files(cls, key, **kwargs):
)[0]
training_files.extend(
glob.glob(
f"{cfg['project_path']}/labeled-data/{video_name}/*Collected*"
f"{cfg['project_path']}/"
f"labeled-data/{video_name}/*Collected*"
)
)
for video in video_names:
Expand Down Expand Up @@ -457,16 +502,19 @@ def import_labeled_frames(
video_filenames: Union[str, List],
**kwargs,
):
"""Function to import pre-labeled frames from an existing project into a new project
"""Function to import pre-labeled frames from an existing project
into a new project
Parameters
----------
key : Dict
key to specify entry in DLCProject table to add labeled frames to
import_project_path : str
absolute path to project directory containing labeled frames to import
absolute path to project directory containing
labeled frames to import
video_filenames : str or List
filename or list of filenames of video(s) from which to import frames.
filename or list of filenames of video(s)
from which to import frames.
without file extension
"""
project_entry = (cls & key).fetch1()
Expand All @@ -476,9 +524,10 @@ def import_labeled_frames(
f"{current_project_path.as_posix()}/labeled-data"
)
if isinstance(import_project_path, PosixPath):
assert (
import_project_path.exists()
), f"import_project_path: {import_project_path.as_posix()} does not exist"
assert import_project_path.exists(), (
"import_project_path: "
f"{import_project_path.as_posix()} does not exist"
)
import_labeled_data_path = Path(
f"{import_project_path.as_posix()}/labeled-data"
)
Expand All @@ -504,7 +553,8 @@ def import_labeled_frames(
dlc_df.columns = dlc_df.columns.set_levels([team_name], level=0)
dlc_df.to_hdf(
Path(
f"{current_labeled_data_path.as_posix()}/{video_file}/CollectedData_{team_name}.h5"
f"{current_labeled_data_path.as_posix()}/"
f"{video_file}/CollectedData_{team_name}.h5"
).as_posix(),
"df_with_missing",
)
Expand Down

0 comments on commit eb25feb

Please sign in to comment.