From 6b4ff10b3b6976cab15ba7df29e23a7339fc9179 Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Tue, 29 Oct 2024 15:33:52 -0700 Subject: [PATCH] `DLCPosVideo`: Use provided epoch, multithread matplotlib (#1168) * Use provided epoch * Save video to temp dir * Remove open-cv support * WIP: Multithread, RAM hungry * Limit number of workers * Save file images in batches * Reduce RAM cost, remove cv2 dep * Update changelog * Get debug arg from params * Revert merge error #870, #975 * Adjust for final frame. Resume from existing * Resume from fail * except IndexError for final frame * Delay delete temp files until complete * Explicit error messages * Return video object for debugging --- .gitignore | 1 + CHANGELOG.md | 6 +- src/spyglass/common/common_behav.py | 2 +- src/spyglass/position/v1/dlc_utils_makevid.py | 690 +++++++++--------- .../position/v1/position_dlc_selection.py | 30 +- 5 files changed, 351 insertions(+), 378 deletions(-) diff --git a/.gitignore b/.gitignore index 0cbd43c74..032ec4f7c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ spyglass.code-workspace mountainsort4_output/ .idea/ mysql_config +memray* # Notebooks *.ipynb diff --git a/CHANGELOG.md b/CHANGELOG.md index 9530fa9de..3abcc5861 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,7 +59,11 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Fix video directory bug in `DLCPoseEstimationSelection` #1103 - Restore #973, allow DLC without position tracking #1100 - Minor fix to `DLCCentroid` make function order #1112, #1148 - - Pass output path as string to `cv2.VideoWriter` #1150 + - Video creator tools: + - Pass output path as string to `cv2.VideoWriter` #1150 + - Set `DLCPosVideo` default processor to `matplotlib`, remove support + for `open-cv` #1168 + - `VideoMaker` class to process frames in multithreaded batches #1168 - Spike Sorting diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 22afaf8c7..bab4ba075 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -541,7 +541,7 @@ def _no_transaction_make(self, key): # Skip populating if no pos interval list names if len(pos_intervals) == 0: - logger.error(f"NO POS INTERVALS FOR {key}; {no_pop_msg}") + logger.error(f"NO POS INTERVALS FOR {key};\n{no_pop_msg}") self.insert1(null_key, **insert_opts) return diff --git a/src/spyglass/position/v1/dlc_utils_makevid.py b/src/spyglass/position/v1/dlc_utils_makevid.py index 994c027f0..51c209134 100644 --- a/src/spyglass/position/v1/dlc_utils_makevid.py +++ b/src/spyglass/position/v1/dlc_utils_makevid.py @@ -1,16 +1,21 @@ # Convenience functions + # some DLC-utils copied from datajoint element-interface utils.py +import shutil +import subprocess +from concurrent.futures import ProcessPoolExecutor, as_completed +from os import system as os_system from pathlib import Path +from typing import Tuple -import cv2 import matplotlib.pyplot as plt import numpy as np import pandas as pd -from tqdm import tqdm as tqdm +from tqdm import tqdm +from spyglass.settings import temp_dir from spyglass.utils import logger from spyglass.utils.position import convert_to_pixels as _to_px -from spyglass.utils.position import fill_nan RGB_PINK = (234, 82, 111) RGB_YELLOW = (253, 231, 76) @@ -37,328 +42,145 @@ def __init__( position_time, video_frame_inds=None, likelihoods=None, - processor="opencv", # opencv, opencv-trodes, matplotlib - video_time=None, + processor="matplotlib", frames=None, percent_frames=1, output_video_filename="output.mp4", cm_to_pixels=1.0, disable_progressbar=False, crop=None, - arrow_radius=15, - circle_radius=8, + batch_size=500, + max_workers=25, + max_jobs_in_queue=250, + debug=False, + key_hash=None, + *args, + **kwargs, ): + """Create a video from a set of position data. + + Uses batch size as frame count for processing steps. All in temp_dir. + 1. Extract frames from original video to 'orig_XXXX.png' + 2. Multithread pool frames to matplotlib 'plot_XXXX.png' + 3. Stitch frames into partial video 'partial_XXXX.mp4' + 4. Concatenate partial videos into final video output + + """ + if processor != "matplotlib": + raise ValueError( + "open-cv processors are no longer supported. \n" + + "Use matplotlib or submit a feature request via GitHub." + ) + + # key_hash supports resume from previous run + self.temp_dir = Path(temp_dir) / f"dlc_vid_{key_hash}" + self.temp_dir.mkdir(parents=True, exist_ok=True) + logger.debug(f"Temporary directory: {self.temp_dir}") + + if not Path(video_filename).exists(): + raise FileNotFoundError(f"Video not found: {video_filename}") + self.video_filename = video_filename self.video_frame_inds = video_frame_inds - self.position_mean = position_mean - self.orientation_mean = orientation_mean + self.position_mean = position_mean["DLC"] + self.orientation_mean = orientation_mean["DLC"] self.centroids = centroids self.likelihoods = likelihoods self.position_time = position_time - self.processor = processor - self.video_time = video_time - self.frames = frames self.percent_frames = percent_frames + self.frames = frames self.output_video_filename = output_video_filename self.cm_to_pixels = cm_to_pixels - self.disable_progressbar = disable_progressbar self.crop = crop - self.arrow_radius = arrow_radius - self.circle_radius = circle_radius + self.window_ind = np.arange(501) - 501 // 2 + self.debug = debug - if not Path(self.video_filename).exists(): - raise FileNotFoundError(f"Video not found: {self.video_filename}") + self.dropped_frames = set() - if frames is None: - self.n_frames = ( - int(self.orientation_mean.shape[0]) - if processor == "opencv-trodes" - else int(len(video_frame_inds) * percent_frames) - ) - self.frames = np.arange(0, self.n_frames) - else: - self.n_frames = len(frames) - - self.tqdm_kwargs = { - "iterable": ( - range(self.n_frames - 1) - if self.processor == "opencv-trodes" - else self.frames - ), - "desc": "frames", - "disable": self.disable_progressbar, - } + self.batch_size = batch_size + self.max_workers = max_workers + self.max_jobs_in_queue = max_jobs_in_queue - # init for cv - self.video, self.frame_size = None, None - self.frame_rate, self.out = None, None - self.source_map = { - "DLC": RGB_BLUE, - "Trodes": RGB_ORANGE, - "Common": RGB_PINK, - } + self.ffmpeg_log_args = ["-hide_banner", "-loglevel", "error"] + self.ffmpeg_fmt_args = ["-c:v", "libx264", "-pix_fmt", "yuv420p"] - # intit for matplotlib - self.image, self.title, self.progress_bar = None, None, None - self.crop_offset_x = crop[0] if crop else 0 - self.crop_offset_y = crop[2] if crop else 0 - self.centroid_plot_objs, self.centroid_position_dot = None, None - self.orientation_line = None - self.likelihood_objs = None - self.window_ind = np.arange(501) - 501 // 2 + _ = self._set_frame_info() + _ = self._set_plot_bases() - self.make_video() + logger.info( + f"Making video: {self.output_video_filename} " + + f"in batches of {self.batch_size}" + ) + self.process_frames() + plt.close(self.fig) + logger.info(f"Finished video: {self.output_video_filename}") + logger.debug(f"Dropped frames: {self.dropped_frames}") + + shutil.rmtree(self.temp_dir) # Clean up temp directory - def make_video(self): - """Make video based on processor chosen at init.""" - if self.processor == "opencv": - self.make_video_opencv() - elif self.processor == "opencv-trodes": - self.make_trodes_video() - elif self.processor == "matplotlib": - self.make_video_matplotlib() + def _set_frame_info(self): + """Set the frame information for the video.""" + logger.debug("Setting frame information") - def _init_video(self): - logger.info(f"Making video: {self.output_video_filename}") - self.video = cv2.VideoCapture(str(self.video_filename)) + width, height, self.frame_rate = self._get_input_stats() self.frame_size = ( - (int(self.video.get(3)), int(self.video.get(4))) + (width, height) if not self.crop else ( self.crop[1] - self.crop[0], self.crop[3] - self.crop[2], ) ) - self.frame_rate = self.video.get(5) - - def _init_cv_video(self): - _ = self._init_video() - self.out = cv2.VideoWriter( - filename=str(self.output_video_filename), - fourcc=cv2.VideoWriter_fourcc(*"mp4v"), - fps=self.frame_rate, - frameSize=self.frame_size, - isColor=True, - ) - frames_log = ( - f"\tFrames start: {self.frames[0]}\n" if np.any(self.frames) else "" - ) - inds_log = ( - f"\tVideo frame inds: {self.video_frame_inds[0]}\n" - if np.any(self.video_frame_inds) - else "" - ) - logger.info( - f"\n{frames_log}{inds_log}\tcv2 ind start: {int(self.video.get(1))}" - ) - - def _close_cv_video(self): - self.video.release() - self.out.release() - try: - cv2.destroyAllWindows() - except cv2.error: # if cv is already closed or does not have func - pass - logger.info(f"Finished video: {self.output_video_filename}") - - def _get_frame(self, frame, init_only=False, crop_order=(0, 1, 2, 3)): - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - if init_only or not self.crop: - return frame - x1, x2, y1, y2 = self.crop_order - return frame[ - self.crop[x1] : self.crop[x2], self.crop[y1] : self.crop[y2] - ].copy() - - def _video_set_by_ind(self, time_ind): - if time_ind == 0: - self.video.set(1, time_ind + 1) - elif int(self.video.get(1)) != time_ind - 1: - self.video.set(1, time_ind - 1) - - def _all_num(self, *args): - return all(np.all(~np.isnan(data)) for data in args) - - def _make_arrow( - self, - position, - orientation, - color, - img, - thickness=4, - line_type=8, - tipLength=0.25, - shift=cv2.CV_8U, - ): - if not self._all_num(position, orientation): - return - arrow_tip = ( - int(position[0] + self.arrow_radius * np.cos(orientation)), - int(position[1] + self.arrow_radius * np.sin(orientation)), - ) - cv2.arrowedLine( - img=img, - pt1=tuple(position.astype(int)), - pt2=arrow_tip, - color=color, - thickness=thickness, - line_type=line_type, - tipLength=tipLength, - shift=shift, - ) - - def _make_circle( - self, - data, - color, - img, - radius=None, - thickness=-1, - shift=cv2.CV_8U, - **kwargs, - ): - if not self._all_num(data): - return - cv2.circle( - img=img, - center=tuple(data.astype(int)), - radius=radius or self.circle_radius, - color=color, - thickness=thickness, - shift=shift, + self.ratio = ( + (self.crop[3] - self.crop[2]) / (self.crop[1] - self.crop[0]) + if self.crop + else self.frame_size[1] / self.frame_size[0] ) + self.fps = int(np.round(self.frame_rate)) - def make_video_opencv(self): - """Make video using opencv.""" - _ = self._init_cv_video() - - if self.video_time: - self.position_mean = { - key: fill_nan( - self.position_mean[key]["position"], - self.video_time, - self.position_time, - ) - for key in self.position_mean.keys() - } - self.orientation_mean = { - key: fill_nan( - self.position_mean[key]["orientation"], - self.video_time, - self.position_time, - ) - for key in self.position_mean.keys() - } - - for time_ind in tqdm(**self.tqdm_kwargs): - _ = self._video_set_by_ind(time_ind) - - is_grabbed, frame = self.video.read() - - if not is_grabbed: - break - - frame = self._get_frame(frame) - - cv2.putText( - img=frame, - text=f"time_ind: {int(time_ind)} video frame: {int(self.video.get(1))}", - org=(10, 10), - fontFace=cv2.FONT_HERSHEY_SIMPLEX, - fontScale=0.5, - color=RGB_YELLOW, - thickness=1, - ) - - if time_ind < self.video_frame_inds[0] - 1: - self.out.write(self._get_frame(frame, init_only=True)) - continue - - pos_ind = time_ind - self.video_frame_inds[0] - - for key in self.position_mean: - position = _to_px( - data=self.position_mean[key][pos_ind], - cm_to_pixels=self.cm_to_pixels, - ) - orientation = self.orientation_mean[key][pos_ind] - cv_kwargs = { - "img": frame, - "color": self.source_map[key], - } - self._make_arrow(position, orientation, **cv_kwargs) - self._make_circle(data=position, **cv_kwargs) - - self._get_frame(frame, init_only=True) - self.out.write(frame) - self._close_cv_video() - return - - def make_trodes_video(self): - """Make video using opencv with trodes data.""" - _ = self._init_cv_video() - - if np.any(self.video_time): - centroids = { - color: fill_nan( - variable=data, - video_time=self.video_time, - variable_time=self.position_time, - ) - for color, data in self.centroids.items() - } - position_mean = fill_nan( - self.position_mean, self.video_time, self.position_time - ) - orientation_mean = fill_nan( - self.orientation_mean, self.video_time, self.position_time - ) - - for time_ind in tqdm(**self.tqdm_kwargs): - is_grabbed, frame = self.video.read() - if not is_grabbed: - break - - frame = self._get_frame(frame) - - red_centroid = centroids["red"][time_ind] - green_centroid = centroids["green"][time_ind] - position = position_mean[time_ind] - position = _to_px(data=position, cm_to_pixels=self.cm_to_pixels) - orientation = orientation_mean[time_ind] - - self._make_circle(data=red_centroid, img=frame, color=RGB_YELLOW) - self._make_circle(data=green_centroid, img=frame, color=RGB_PINK) - self._make_arrow( - position=position, - orientation=orientation, - color=RGB_WHITE, - img=frame, + if self.frames is None: + self.n_frames = int( + len(self.video_frame_inds) * self.percent_frames ) - self._make_circle(data=position, img=frame, color=RGB_WHITE) - self._get_frame(frame, init_only=True) - self.out.write(frame) - - self._close_cv_video() - - def make_video_matplotlib(self): - """Make video using matplotlib.""" - import matplotlib.animation as animation - - self.position_mean = self.position_mean["DLC"] - self.orientation_mean = self.orientation_mean["DLC"] - - _ = self._init_video() + self.frames = np.arange(0, self.n_frames) + else: + self.n_frames = len(self.frames) + self.pad_len = len(str(self.n_frames)) + + def _get_input_stats(self, video_filename=None) -> Tuple[int, int]: + """Get the width and height of the video.""" + logger.debug("Getting video dimensions") + + video_filename = video_filename or self.video_filename + ret = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v", + "-show_entries", + "stream=width,height,r_frame_rate", + "-of", + "csv=p=0:s=x", + video_filename, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if ret.returncode != 0: + raise ValueError(f"Error getting video dimensions: {ret.stderr}") - video_slowdown = 1 - fps = int(np.round(self.frame_rate / video_slowdown)) - Writer = animation.writers["ffmpeg"] - writer = Writer(fps=fps, bitrate=-1) + stats = ret.stdout.strip().split("x") + width, height = tuple(map(int, stats[:-1])) + frame_rate = eval(stats[-1]) - ret, frame = self.video.read() - frame = self._get_frame(frame, crop_order=(2, 3, 0, 1)) + return width, height, frame_rate - frame_ind = 0 + def _set_plot_bases(self): + """Create the figure and axes for the video.""" + logger.debug("Setting plot bases") plt.style.use("dark_background") fig, axes = plt.subplots( 2, @@ -371,9 +193,6 @@ def make_video_matplotlib(self): axes[0].tick_params(colors="white", which="both") axes[0].spines["bottom"].set_color("white") axes[0].spines["left"].set_color("white") - self.image = axes[0].imshow(frame, animated=True) - - logger.info(f"frame after init plot: {self.video.get(1)}") self.centroid_plot_objs = { bodypart: axes[0].scatter( @@ -383,7 +202,7 @@ def make_video_matplotlib(self): zorder=102, color=color, label=f"{bodypart} position", - animated=True, + # animated=True, alpha=0.6, ) for color, bodypart in zip(COLOR_SWATCH, self.centroids.keys()) @@ -395,7 +214,6 @@ def make_video_matplotlib(self): zorder=102, color="#b045f3", label="centroid position", - animated=True, alpha=0.6, ) (self.orientation_line,) = axes[0].plot( @@ -403,23 +221,18 @@ def make_video_matplotlib(self): [], color="cyan", linewidth=1, - animated=True, label="Orientation", ) axes[0].set_xlabel("") axes[0].set_ylabel("") - ratio = ( - (self.crop[3] - self.crop[2]) / (self.crop[1] - self.crop[0]) - if self.crop - else self.frame_size[1] / self.frame_size[0] - ) - x_left, x_right = axes[0].get_xlim() y_low, y_high = axes[0].get_ylim() - axes[0].set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio) + axes[0].set_aspect( + abs((x_right - x_left) / (y_low - y_high)) * self.ratio + ) axes[0].spines["top"].set_color("black") axes[0].spines["right"].set_color("black") @@ -429,7 +242,7 @@ def make_video_matplotlib(self): axes[0].legend(loc="lower right", fontsize=4) self.title = axes[0].set_title( - f"time = {time_delta:3.4f}s\n frame = {frame_ind}", + f"time = {time_delta:3.4f}s\n frame = {0}", fontsize=8, ) axes[0].axis("off") @@ -441,7 +254,6 @@ def make_video_matplotlib(self): [], color=color, linewidth=1, - animated=True, clip_on=False, label=bodypart, )[0] @@ -463,21 +275,8 @@ def make_video_matplotlib(self): axes[1].spines["right"].set_color("black") axes[1].legend(loc="upper right", fontsize=4) - self.progress_bar = tqdm(leave=True, position=0) - self.progress_bar.reset(total=self.n_frames) - - movie = animation.FuncAnimation( - fig, - self._update_plot, - frames=self.frames, - interval=1000 / fps, - blit=True, - ) - movie.save(self.output_video_filename, writer=writer, dpi=400) - self.video.release() - plt.style.use("default") - logger.info("finished making video with matplotlib") - return + self.fig = fig + self.axes = axes def _get_centroid_data(self, pos_ind): def centroid_to_px(*idx): @@ -504,64 +303,231 @@ def orient_list(c): c0, c1 = self._get_centroid_data(pos_ind) self.orientation_line.set_data(orient_list(c0), orient_list(c1)) - def _update_plot(self, time_ind, *args): - _ = self._video_set_by_ind(time_ind) - - ret, frame = self.video.read() - if ret: - frame = self._get_frame(frame, crop_order=(2, 3, 0, 1)) - self.image.set_array(frame) + def _generate_single_frame(self, frame_ind): + """Generate a single frame and save it as an image.""" + padded = self._pad(frame_ind) + frame_file = self.temp_dir / f"orig_{padded}.png" + if not frame_file.exists(): + self.dropped_frames.add(frame_ind) + print(f"\rFrame not found: {frame_file}", end="") + return + frame = plt.imread(frame_file) + _ = self.axes[0].imshow(frame) - pos_ind = np.where(self.video_frame_inds == time_ind)[0] + pos_ind = np.where(self.video_frame_inds == frame_ind)[0] if len(pos_ind) == 0: self.centroid_position_dot.set_offsets((np.NaN, np.NaN)) for bodypart in self.centroid_plot_objs.keys(): self.centroid_plot_objs[bodypart].set_offsets((np.NaN, np.NaN)) self.orientation_line.set_data((np.NaN, np.NaN)) - self.title.set_text(f"time = {0:3.4f}s\n frame = {time_ind}") - self.progress_bar.update() - return - - pos_ind = pos_ind[0] - likelihood_inds = pos_ind + self.window_ind - # initial implementation did not cover case of both neg and over < 0 - neg_inds = np.where(likelihood_inds < 0)[0] - likelihood_inds[neg_inds] = 0 if len(neg_inds) > 0 else -1 + self.title.set_text(f"time = {0:3.4f}s\n frame = {frame_ind}") + else: + pos_ind = pos_ind[0] + likelihood_inds = pos_ind + self.window_ind + neg_inds = np.where(likelihood_inds < 0)[0] + likelihood_inds[neg_inds] = 0 if len(neg_inds) > 0 else -1 + + dlc_centroid_data = self._get_centroid_data(pos_ind) + + for bodypart in self.centroid_plot_objs: + self.centroid_plot_objs[bodypart].set_offsets( + _to_px( + data=self.centroids[bodypart][pos_ind], + cm_to_pixels=self.cm_to_pixels, + ) + ) + self.centroid_position_dot.set_offsets(dlc_centroid_data) + _ = self._set_orient_line(frame, pos_ind) - dlc_centroid_data = self._get_centroid_data(pos_ind) + time_delta = pd.Timedelta( + pd.to_datetime(self.position_time[pos_ind] * 1e9, unit="ns") + - pd.to_datetime(self.position_time[0] * 1e9, unit="ns") + ).total_seconds() - for bodypart in self.centroid_plot_objs: - self.centroid_plot_objs[bodypart].set_offsets( - _to_px( - data=self.centroids[bodypart][pos_ind], - cm_to_pixels=self.cm_to_pixels, - ) + self.title.set_text( + f"time = {time_delta:3.4f}s\n frame = {frame_ind}" ) - self.centroid_position_dot.set_offsets(dlc_centroid_data) - _ = self._set_orient_line(frame, pos_ind) + if self.likelihoods: + for bodypart in self.likelihood_objs.keys(): + self.likelihood_objs[bodypart].set_data( + self.window_ind / self.frame_rate, + np.asarray(self.likelihoods[bodypart][likelihood_inds]), + ) - time_delta = pd.Timedelta( - pd.to_datetime(self.position_time[pos_ind] * 1e9, unit="ns") - - pd.to_datetime(self.position_time[0] * 1e9, unit="ns") - ).total_seconds() + # Zero-padded filename based on the dynamic padding length + frame_path = self.temp_dir / f"plot_{padded}.png" + self.fig.savefig(frame_path, dpi=400) + plt.cla() # clear the current axes - self.title.set_text(f"time = {time_delta:3.4f}s\n frame = {time_ind}") - for bodypart in self.likelihood_objs.keys(): - self.likelihood_objs[bodypart].set_data( - self.window_ind / self.frame_rate, - np.asarray(self.likelihoods[bodypart][likelihood_inds]), + return frame_ind + + def process_frames(self): + """Process video frames in batches and generate matplotlib frames.""" + + progress_bar = tqdm(leave=True, position=0, disable=self.debug) + progress_bar.reset(total=self.n_frames) + + for start_frame in range(0, self.n_frames, self.batch_size): + if start_frame >= self.n_frames: # Skip if no frames left + break + end_frame = min(start_frame + self.batch_size, self.n_frames) - 1 + logger.debug(f"Processing frames: {start_frame} - {end_frame}") + + output_partial_video = ( + self.temp_dir / f"partial_{self._pad(start_frame)}.mp4" ) - self.progress_bar.update() + if output_partial_video.exists(): + logger.debug(f"Skipping existing video: {output_partial_video}") + progress_bar.update(end_frame - start_frame) + continue - return ( - self.image, - self.centroid_position_dot, - self.orientation_line, - self.title, - ) + self.ffmpeg_extract(start_frame, end_frame) + self.plot_frames(start_frame, end_frame, progress_bar) + self.ffmpeg_stitch_partial(start_frame, str(output_partial_video)) + + for frame_file in self.temp_dir.glob("*.png"): + frame_file.unlink() # Delete orig and plot frames + + progress_bar.close() + + logger.info("Concatenating partial videos") + self.concat_partial_videos() + + def _debug_print(self, msg=None, end=""): + """Print a self-overwiting message if debug is enabled.""" + if self.debug: + print(f"\r{msg}", end=end) + + def plot_frames(self, start_frame, end_frame, progress_bar=None): + logger.debug(f"Plotting frames: {start_frame} - {end_frame}") + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + jobs = {} # dict of jobs + + frames_left = end_frame - start_frame + frames_iter = iter(range(start_frame, end_frame)) + + while frames_left: + while len(jobs) < self.max_jobs_in_queue: + try: + this_frame = next(frames_iter) + self._debug_print(f"Submit: {this_frame}") + job = executor.submit( + self._generate_single_frame, this_frame + ) + jobs[job] = this_frame + except StopIteration: + break # No more frames to submit + + for job in as_completed(jobs): + frames_left -= 1 + try: + ret = job.result() + except IndexError: + ret = "IndexError" + self._debug_print(f"Finish: {ret}") + progress_bar.update() + del jobs[job] + self._debug_print(end="\n") + + def ffmpeg_extract(self, start_frame, end_frame): + """Use ffmpeg to extract a batch of frames.""" + logger.debug(f"Extracting frames: {start_frame} - {end_frame}") + output_pattern = str(self.temp_dir / f"orig_%0{self.pad_len}d.png") + + # Use ffmpeg to extract frames + ffmpeg_cmd = [ + "ffmpeg", + "-y", # overwrite + "-i", + self.video_filename, + "-vf", + f"select=between(n\\,{start_frame}\\,{end_frame})", + "-vsync", + "vfr", + "-start_number", + str(start_frame), + "-n", # no overwrite + output_pattern, + *self.ffmpeg_log_args, + ] + ret = subprocess.run(ffmpeg_cmd, stderr=subprocess.PIPE) + + extracted = len(list(self.temp_dir.glob("orig_*.png"))) + logger.debug(f"Extracted frames: {start_frame}, len: {extracted}") + if extracted < self.batch_size - 1: + logger.warning( + f"Could not extract frames: {extracted} / {self.batch_size-1}" + ) + one_err = "\n".join(str(ret.stderr).split("\\")[-3:-1]) + logger.debug(f"\nExtract Error: {one_err}") + + def _pad(self, frame_ind): + return f"{frame_ind:0{self.pad_len}d}" + + def ffmpeg_stitch_partial(self, start_frame, output_partial_video): + """Stitch a partial movie from processed frames.""" + logger.debug(f"Stitch part vid : {start_frame}") + frame_pattern = str(self.temp_dir / f"plot_%0{self.pad_len}d.png") + + ffmpeg_cmd = [ + "ffmpeg", + "-y", # overwrite + "-r", + str(self.fps), + "-start_number", + str(start_frame), + "-i", + frame_pattern, + *self.ffmpeg_fmt_args, + output_partial_video, + *self.ffmpeg_log_args, + ] + try: + ret = subprocess.run( + ffmpeg_cmd, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + check=True, + text=True, + ) + except subprocess.CalledProcessError as e: + logger.error(f"Error stitching partial video: {e.stderr}") + + def concat_partial_videos(self): + """Concatenate all the partial videos into one final video.""" + partial_vids = sorted(self.temp_dir.glob("partial_*.mp4")) + logger.debug(f"Concat part vids: {len(partial_vids)}") + concat_list_path = self.temp_dir / "concat_list.txt" + with open(concat_list_path, "w") as f: + for partial_video in partial_vids: + f.write(f"file '{partial_video}'\n") + + ffmpeg_cmd = [ + "ffmpeg", + "-y", # overwrite + "-f", + "concat", + "-safe", + "0", + "-i", + str(concat_list_path), + *self.ffmpeg_fmt_args, + str(self.output_video_filename), + *self.ffmpeg_log_args, + ] + try: + ret = subprocess.run( + ffmpeg_cmd, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + logger.error(f"Error stitching partial video: {e.stderr}") def make_video(**kwargs): """Passthrough for VideoMaker class for backwards compatibility.""" - VideoMaker(**kwargs) + return VideoMaker(**kwargs) diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index e0bd0359e..581140797 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -342,18 +342,8 @@ def make(self, key): M_TO_CM = 100 params = (DLCPosVideoParams & key).fetch1("params") + epoch = key["epoch"] - interval_name = convert_epoch_interval_name_to_position_interval_name( - { - "nwb_file_name": key["nwb_file_name"], - "epoch": key["epoch"], - }, - populate_missing=False, - ) - epoch = ( - int(interval_name.replace("pos ", "").replace(" valid times", "")) - + 1 - ) pose_est_key = { "nwb_file_name": key["nwb_file_name"], "epoch": epoch, @@ -424,7 +414,12 @@ def make(self, key): ) frames = params.get("frames", None) - make_video( + if limit := params.get("limit", None): # new int param for debugging + output_video_filename = Path(".") / f"TEST_VID_{limit}.mp4" + video_frame_inds = video_frame_inds[:limit] + pos_info_df = pos_info_df.head(limit) + + video_maker = make_video( video_filename=video_filename, video_frame_inds=video_frame_inds, position_mean={ @@ -434,12 +429,19 @@ def make(self, key): centroids=centroids, likelihoods=likelihoods, position_time=np.asarray(pos_info_df.index), - processor=params.get("processor", "opencv"), + processor=params.get("processor", "matplotlib"), frames=np.arange(frames[0], frames[1]) if frames else None, percent_frames=params.get("percent_frames", None), output_video_filename=output_video_filename, cm_to_pixels=meters_per_pixel * M_TO_CM, crop=pose_estimation_params.get("cropping"), + key_hash=dj.hash.key_hash(key), + debug=params.get("debug", False), **params.get("video_params", {}), ) - self.insert1(key) + + if limit: # don't insert if we're just debugging + return video_maker + + if output_video_filename.exists(): + self.insert1(key)