diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bfe55b55..bbcaac88d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,9 +65,10 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Minor fix to `DLCCentroid` make function order #1112, #1148 - Video creator tools: - Pass output path as string to `cv2.VideoWriter` #1150 - - Set `DLCPosVideo` default processor to `matplotlib`, remove support - for `open-cv` #1168 - - `VideoMaker` class to process frames in multithreaded batches #1168 + - Set `DLCPosVideo` default processor to `matplotlib`, remove support for + `open-cv` #1168 + - `VideoMaker` class to process frames in multithreaded batches #1168, #1174 + - `TrodesPosVideo` updates for `matplotlib` processor #1174 - Spike Sorting diff --git a/dj_local_conf_example.json b/dj_local_conf_example.json index b9b5e725e..bde731751 100644 --- a/dj_local_conf_example.json +++ b/dj_local_conf_example.json @@ -53,4 +53,4 @@ }, "kachery_zone": "franklab.default" } -} \ No newline at end of file +} diff --git a/notebooks/40_Extracting_Clusterless_Waveform_Features.ipynb b/notebooks/40_Extracting_Clusterless_Waveform_Features.ipynb index 07b3130a5..51ebbf36d 100644 --- a/notebooks/40_Extracting_Clusterless_Waveform_Features.ipynb +++ b/notebooks/40_Extracting_Clusterless_Waveform_Features.ipynb @@ -404,7 +404,7 @@ "for sorting_id in sorting_ids:\n", " try:\n", " sgs.CurationV1.insert_curation(sorting_id=sorting_id)\n", - " except KeyError as e:\n", + " except KeyError:\n", " pass\n", "\n", "SpikeSortingOutput.insert(\n", diff --git a/notebooks/py_scripts/40_Extracting_Clusterless_Waveform_Features.py b/notebooks/py_scripts/40_Extracting_Clusterless_Waveform_Features.py index ad17a7c6f..5449183a8 100644 --- a/notebooks/py_scripts/40_Extracting_Clusterless_Waveform_Features.py +++ b/notebooks/py_scripts/40_Extracting_Clusterless_Waveform_Features.py @@ -176,7 +176,7 @@ for sorting_id in sorting_ids: try: sgs.CurationV1.insert_curation(sorting_id=sorting_id) - except KeyError as e: + except KeyError: pass SpikeSortingOutput.insert( diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index 7ea82fa70..592e02964 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -434,7 +434,9 @@ def find_mp4( .rsplit(video_filepath.parent.as_posix(), maxsplit=1)[-1] .split("/")[-1] ) - return _convert_mp4(video_file, video_path, output_path, videotype="mp4") + return _convert_mp4( + video_file, video_path, output_path, videotype="mp4", count_frames=True + ) def _convert_mp4( diff --git a/src/spyglass/position/v1/dlc_utils_makevid.py b/src/spyglass/position/v1/dlc_utils_makevid.py index 51c209134..2763a7898 100644 --- a/src/spyglass/position/v1/dlc_utils_makevid.py +++ b/src/spyglass/position/v1/dlc_utils_makevid.py @@ -3,17 +3,17 @@ # some DLC-utils copied from datajoint element-interface utils.py import shutil import subprocess -from concurrent.futures import ProcessPoolExecutor, as_completed -from os import system as os_system +from concurrent.futures import ProcessPoolExecutor, TimeoutError, as_completed from pathlib import Path from typing import Tuple +import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd from tqdm import tqdm -from spyglass.settings import temp_dir +from spyglass.settings import temp_dir, test_mode from spyglass.utils import logger from spyglass.utils.position import convert_to_pixels as _to_px @@ -49,9 +49,9 @@ def __init__( cm_to_pixels=1.0, disable_progressbar=False, crop=None, - batch_size=500, - max_workers=25, - max_jobs_in_queue=250, + batch_size=512, + max_workers=256, + max_jobs_in_queue=128, debug=False, key_hash=None, *args, @@ -80,10 +80,16 @@ def __init__( if not Path(video_filename).exists(): raise FileNotFoundError(f"Video not found: {video_filename}") + try: + position_mean = position_mean["DLC"] + orientation_mean = orientation_mean["DLC"] + except IndexError: + pass # trodes data provides bare arrays + self.video_filename = video_filename self.video_frame_inds = video_frame_inds - self.position_mean = position_mean["DLC"] - self.orientation_mean = orientation_mean["DLC"] + self.position_mean = position_mean + self.orientation_mean = orientation_mean self.centroids = centroids self.likelihoods = likelihoods self.position_time = position_time @@ -94,16 +100,21 @@ def __init__( self.crop = crop self.window_ind = np.arange(501) - 501 // 2 self.debug = debug + self.start_time = pd.to_datetime(position_time[0] * 1e9, unit="ns") self.dropped_frames = set() self.batch_size = batch_size self.max_workers = max_workers self.max_jobs_in_queue = max_jobs_in_queue + self.timeout = 30 if test_mode else 300 self.ffmpeg_log_args = ["-hide_banner", "-loglevel", "error"] self.ffmpeg_fmt_args = ["-c:v", "libx264", "-pix_fmt", "yuv420p"] + prev_backend = matplotlib.get_backend() + matplotlib.use("Agg") # Use non-interactive backend + _ = self._set_frame_info() _ = self._set_plot_bases() @@ -116,15 +127,41 @@ def __init__( logger.info(f"Finished video: {self.output_video_filename}") logger.debug(f"Dropped frames: {self.dropped_frames}") - shutil.rmtree(self.temp_dir) # Clean up temp directory + if not debug: + shutil.rmtree(self.temp_dir) # Clean up temp directory + + matplotlib.use(prev_backend) # Reset to previous backend def _set_frame_info(self): """Set the frame information for the video.""" logger.debug("Setting frame information") - width, height, self.frame_rate = self._get_input_stats() + ret = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v", + "-show_entries", + "stream=width,height,r_frame_rate,nb_frames", + "-of", + "csv=p=0:s=x", + str(self.video_filename), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if ret.returncode != 0: + raise ValueError(f"Error getting video dimensions: {ret.stderr}") + + stats = ret.stdout.strip().split("x") + self.width, self.height = tuple(map(int, stats[:2])) + self.frame_rate = eval(stats[2]) + self.frame_size = ( - (width, height) + (self.width, self.height) if not self.crop else ( self.crop[1] - self.crop[0], @@ -138,45 +175,24 @@ def _set_frame_info(self): ) self.fps = int(np.round(self.frame_rate)) - if self.frames is None: + if self.frames is None and self.video_frame_inds is not None: self.n_frames = int( len(self.video_frame_inds) * self.percent_frames ) self.frames = np.arange(0, self.n_frames) - else: + elif self.frames is not None: self.n_frames = len(self.frames) - self.pad_len = len(str(self.n_frames)) + else: + self.n_frames = int(stats[3]) - def _get_input_stats(self, video_filename=None) -> Tuple[int, int]: - """Get the width and height of the video.""" - logger.debug("Getting video dimensions") + if self.debug: # If debugging, limit frames to available data + self.n_frames = min(len(self.position_mean), self.n_frames) - video_filename = video_filename or self.video_filename - ret = subprocess.run( - [ - "ffprobe", - "-v", - "error", - "-select_streams", - "v", - "-show_entries", - "stream=width,height,r_frame_rate", - "-of", - "csv=p=0:s=x", - video_filename, - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - if ret.returncode != 0: - raise ValueError(f"Error getting video dimensions: {ret.stderr}") - - stats = ret.stdout.strip().split("x") - width, height = tuple(map(int, stats[:-1])) - frame_rate = eval(stats[-1]) + self.pad_len = len(str(self.n_frames)) - return width, height, frame_rate + def _set_input_stats(self, video_filename=None) -> Tuple[int, int]: + """Get the width and height of the video.""" + logger.debug("Getting video stats with ffprobe") def _set_plot_bases(self): """Create the figure and axes for the video.""" @@ -202,7 +218,6 @@ def _set_plot_bases(self): zorder=102, color=color, label=f"{bodypart} position", - # animated=True, alpha=0.6, ) for color, bodypart in zip(COLOR_SWATCH, self.centroids.keys()) @@ -240,6 +255,7 @@ def _set_plot_bases(self): self.position_time[0] - self.position_time[-1] ).total_seconds() + # TODO: Update legend location based on centroid position axes[0].legend(loc="lower right", fontsize=4) self.title = axes[0].set_title( f"time = {time_delta:3.4f}s\n frame = {0}", @@ -305,12 +321,18 @@ def orient_list(c): def _generate_single_frame(self, frame_ind): """Generate a single frame and save it as an image.""" + # Zero-padded filename based on the dynamic padding length padded = self._pad(frame_ind) + frame_out_path = self.temp_dir / f"plot_{padded}.png" + if frame_out_path.exists() and not self.debug: + return frame_ind # Skip if frame already exists + frame_file = self.temp_dir / f"orig_{padded}.png" - if not frame_file.exists(): + if not frame_file.exists(): # Skip if input frame not found self.dropped_frames.add(frame_ind) - print(f"\rFrame not found: {frame_file}", end="") + self._debug_print(f"Frame not found: {frame_file}", end="") return + frame = plt.imread(frame_file) _ = self.axes[0].imshow(frame) @@ -322,42 +344,42 @@ def _generate_single_frame(self, frame_ind): self.centroid_plot_objs[bodypart].set_offsets((np.NaN, np.NaN)) self.orientation_line.set_data((np.NaN, np.NaN)) self.title.set_text(f"time = {0:3.4f}s\n frame = {frame_ind}") - else: - pos_ind = pos_ind[0] - likelihood_inds = pos_ind + self.window_ind - neg_inds = np.where(likelihood_inds < 0)[0] - likelihood_inds[neg_inds] = 0 if len(neg_inds) > 0 else -1 - - dlc_centroid_data = self._get_centroid_data(pos_ind) - - for bodypart in self.centroid_plot_objs: - self.centroid_plot_objs[bodypart].set_offsets( - _to_px( - data=self.centroids[bodypart][pos_ind], - cm_to_pixels=self.cm_to_pixels, - ) - ) - self.centroid_position_dot.set_offsets(dlc_centroid_data) - _ = self._set_orient_line(frame, pos_ind) - time_delta = pd.Timedelta( - pd.to_datetime(self.position_time[pos_ind] * 1e9, unit="ns") - - pd.to_datetime(self.position_time[0] * 1e9, unit="ns") - ).total_seconds() + self.fig.savefig(frame_out_path, dpi=400) + plt.cla() # clear the current axes + return frame_ind + + pos_ind = pos_ind[0] + likelihood_inds = pos_ind + self.window_ind + neg_inds = np.where(likelihood_inds < 0)[0] + likelihood_inds[neg_inds] = 0 if len(neg_inds) > 0 else -1 - self.title.set_text( - f"time = {time_delta:3.4f}s\n frame = {frame_ind}" + dlc_centroid_data = self._get_centroid_data(pos_ind) + + for bodypart in self.centroid_plot_objs: + self.centroid_plot_objs[bodypart].set_offsets( + _to_px( + data=self.centroids[bodypart][pos_ind], + cm_to_pixels=self.cm_to_pixels, + ) ) - if self.likelihoods: - for bodypart in self.likelihood_objs.keys(): - self.likelihood_objs[bodypart].set_data( - self.window_ind / self.frame_rate, - np.asarray(self.likelihoods[bodypart][likelihood_inds]), - ) + self.centroid_position_dot.set_offsets(dlc_centroid_data) + _ = self._set_orient_line(frame, pos_ind) - # Zero-padded filename based on the dynamic padding length - frame_path = self.temp_dir / f"plot_{padded}.png" - self.fig.savefig(frame_path, dpi=400) + time_delta = pd.Timedelta( + pd.to_datetime(self.position_time[pos_ind] * 1e9, unit="ns") + - self.start_time + ).total_seconds() + + self.title.set_text(f"time = {time_delta:3.4f}s\n frame = {frame_ind}") + if self.likelihoods: + for bodypart in self.likelihood_objs.keys(): + self.likelihood_objs[bodypart].set_data( + self.window_ind / self.frame_rate, + np.asarray(self.likelihoods[bodypart][likelihood_inds]), + ) + + self.fig.savefig(frame_out_path, dpi=400) plt.cla() # clear the current axes return frame_ind @@ -394,7 +416,7 @@ def process_frames(self): logger.info("Concatenating partial videos") self.concat_partial_videos() - def _debug_print(self, msg=None, end=""): + def _debug_print(self, msg=" ", end=""): """Print a self-overwiting message if debug is enabled.""" if self.debug: print(f"\r{msg}", end=end) @@ -411,7 +433,7 @@ def plot_frames(self, start_frame, end_frame, progress_bar=None): while len(jobs) < self.max_jobs_in_queue: try: this_frame = next(frames_iter) - self._debug_print(f"Submit: {this_frame}") + self._debug_print(f"Submit: {self._pad(this_frame)}") job = executor.submit( self._generate_single_frame, this_frame ) @@ -422,23 +444,28 @@ def plot_frames(self, start_frame, end_frame, progress_bar=None): for job in as_completed(jobs): frames_left -= 1 try: - ret = job.result() - except IndexError: - ret = "IndexError" - self._debug_print(f"Finish: {ret}") + ret = job.result(timeout=self.timeout) + except (IndexError, TimeoutError) as e: + ret = type(e).__name__ + self._debug_print(f"Finish: {self._pad(ret)}") progress_bar.update() del jobs[job] - self._debug_print(end="\n") + self._debug_print(msg="", end="\n") def ffmpeg_extract(self, start_frame, end_frame): """Use ffmpeg to extract a batch of frames.""" logger.debug(f"Extracting frames: {start_frame} - {end_frame}") + last_frame = self.temp_dir / f"orig_{self._pad(end_frame)}.png" + if last_frame.exists(): # assumes all frames previously extracted + logger.debug(f"Skipping existing frames: {last_frame}") + return + output_pattern = str(self.temp_dir / f"orig_%0{self.pad_len}d.png") # Use ffmpeg to extract frames ffmpeg_cmd = [ "ffmpeg", - "-y", # overwrite + "-n", # no overwrite "-i", self.video_filename, "-vf", @@ -447,7 +474,6 @@ def ffmpeg_extract(self, start_frame, end_frame): "vfr", "-start_number", str(start_frame), - "-n", # no overwrite output_pattern, *self.ffmpeg_log_args, ] @@ -462,7 +488,12 @@ def ffmpeg_extract(self, start_frame, end_frame): one_err = "\n".join(str(ret.stderr).split("\\")[-3:-1]) logger.debug(f"\nExtract Error: {one_err}") - def _pad(self, frame_ind): + def _pad(self, frame_ind=None): + """Pad a frame index with leading zeros.""" + if frame_ind is None: + return "?" * self.pad_len + elif not isinstance(frame_ind, int): + return frame_ind return f"{frame_ind:0{self.pad_len}d}" def ffmpeg_stitch_partial(self, start_frame, output_partial_video): @@ -493,6 +524,7 @@ def ffmpeg_stitch_partial(self, start_frame, output_partial_video): ) except subprocess.CalledProcessError as e: logger.error(f"Error stitching partial video: {e.stderr}") + logger.debug(f"stderr: {ret.stderr}") def concat_partial_videos(self): """Concatenate all the partial videos into one final video.""" @@ -526,6 +558,7 @@ def concat_partial_videos(self): ) except subprocess.CalledProcessError as e: logger.error(f"Error stitching partial video: {e.stderr}") + logger.debug(f"stderr: {ret.stderr}") def make_video(**kwargs): diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index 581140797..627a55cf7 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -8,9 +8,6 @@ import pynwb from datajoint.utils import to_camel_case -from spyglass.common.common_behav import ( - convert_epoch_interval_name_to_position_interval_name, -) from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.position.v1.dlc_utils_makevid import make_video from spyglass.position.v1.position_dlc_centroid import DLCCentroid @@ -436,7 +433,7 @@ def make(self, key): cm_to_pixels=meters_per_pixel * M_TO_CM, crop=pose_estimation_params.get("cropping"), key_hash=dj.hash.key_hash(key), - debug=params.get("debug", False), + debug=params.get("debug", True), # REVERT TO FALSE **params.get("video_params", {}), ) diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 72adada46..031c0e6bb 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -1,5 +1,6 @@ import copy import os +from pathlib import Path import datajoint as dj import numpy as np @@ -13,6 +14,7 @@ from spyglass.position.v1.dlc_utils_makevid import make_video from spyglass.settings import test_mode from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.position import fill_nan schema = dj.schema("position_v1_trodes_position") @@ -270,7 +272,7 @@ def make(self, key): - Raw position data from the RawPosition table - Position data from the TrodesPosV1 table - Video data from the VideoFile table - Generates a video using opencv and the VideoMaker class. + Generates a video using VideoMaker class. """ M_TO_CM = 100 @@ -303,10 +305,31 @@ def make(self, key): {"nwb_file_name": key["nwb_file_name"], "epoch": epoch} ) + # Check if video exists if not video_path: self.insert1(dict(**key, has_video=False)) return + # Check timepoints overlap + if not set(video_time).intersection(set(pos_df.index)): + raise ValueError( + "No overlapping time points between video and position data" + ) + + params_pk = "trodes_pos_params_name" + params = (TrodesPosParams() & {params_pk: key[params_pk]}).fetch1( + "params" + ) + + # Check if upsampled + if params["is_upsampled"]: + logger.error( + "Upsampled position data not supported for video creation\n" + + "Please submit a feature request via GitHub if needed." + ) + self.insert1(dict(**key, has_video=False)) # Null insert + return + video_path = find_mp4( video_path=os.path.dirname(video_path) + "/", video_filename=video_filename, @@ -315,31 +338,70 @@ def make(self, key): output_video_filename = ( key["nwb_file_name"].replace(".nwb", "") + f"_{epoch:02d}_" - + f'{key["trodes_pos_params_name"]}.mp4' + + f"{key[params_pk]}.mp4" ) adj_df = _fix_col_names(raw_df) # adjust 'xloc1' to 'xloc' - if test_mode: + limit = params.get("limit", None) + if limit or test_mode: + params["debug"] = True + output_video_filename = Path(".") / f"TEST_VID_{limit}.mp4" # pytest video data has mismatched shapes in some cases - min_len = min(len(adj_df), len(pos_df), len(video_time)) - adj_df = adj_df[:min_len] - pos_df = pos_df[:min_len] + min_len = limit or min(len(adj_df), len(pos_df), len(video_time)) + adj_df = adj_df.head(min_len) + pos_df = pos_df.head(min_len) video_time = video_time[:min_len] - make_video( - processor="opencv-trodes", + centroids = { + "red": np.asarray(adj_df[["xloc", "yloc"]]), + "green": np.asarray(adj_df[["xloc2", "yloc2"]]), + } + position_mean = np.asarray(pos_df[["position_x", "position_y"]]) + orientation_mean = np.asarray(pos_df[["orientation"]]) + position_time = np.asarray(pos_df.index) + + ind_col = ( + pos_df["video_frame_ind"] + if "video_frame_ind" in pos_df.columns + else pos_df.index + ) + video_frame_inds = ind_col.astype(int).to_numpy() + + centroids = { + color: fill_nan( + variable=data, + video_time=video_time, + variable_time=position_time, + ) + for color, data in centroids.items() + } + position_mean = fill_nan( + variable=position_mean, + video_time=video_time, + variable_time=position_time, + ) + orientation_mean = fill_nan( + variable=orientation_mean, + video_time=video_time, + variable_time=position_time, + ) + + vid_maker = make_video( video_filename=video_path, - centroids={ - "red": np.asarray(adj_df[["xloc", "yloc"]]), - "green": np.asarray(adj_df[["xloc2", "yloc2"]]), - }, - position_mean=np.asarray(pos_df[["position_x", "position_y"]]), - orientation_mean=np.asarray(pos_df[["orientation"]]), + video_frame_inds=video_frame_inds, + centroids=centroids, video_time=video_time, - position_time=np.asarray(pos_df.index), + position_mean=position_mean, + orientation_mean=orientation_mean, + position_time=position_time, output_video_filename=output_video_filename, cm_to_pixels=meters_per_pixel * M_TO_CM, - disable_progressbar=False, + key_hash=dj.hash.key_hash(key), + **params, ) + + if limit: + return vid_maker + self.insert1(dict(**key, has_video=True)) diff --git a/tests/position/conftest.py b/tests/position/conftest.py index 8f9e90795..c6c58d199 100644 --- a/tests/position/conftest.py +++ b/tests/position/conftest.py @@ -30,7 +30,6 @@ def dlc_video_params(sgp): "params": { "percent_frames": 0.05, "incl_likelihood": True, - "processor": "opencv", }, }, skip_duplicates=True,