Skip to content

Commit

Permalink
replace multiprocessing with pyav
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr committed Feb 20, 2024
1 parent d26f369 commit 14ea5cb
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 116 deletions.
1 change: 1 addition & 0 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
"SPACY_MODEL_NAME": "en_core_web_trf",
"PRIVATE_AI_API_KEY": "<set your api key in .env>",
"RECORD_VIDEO": True,
"VIDEO_PIXEL_FORMAT": "rgb24",
}

# each string in STOP_STRS should only contain strings
Expand Down
228 changes: 141 additions & 87 deletions openadapt/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from collections import namedtuple
from functools import partial, wraps
from typing import Any, Callable, Union
import av
import multiprocessing
import numpy as np
import os
import queue
import signal
Expand All @@ -33,6 +35,7 @@
from openadapt.extensions import synchronized_queue as sq
from openadapt.models import ActionEvent


Event = namedtuple("Event", ("timestamp", "type", "data"))

EVENT_TYPES = ("screen", "action", "window")
Expand Down Expand Up @@ -168,6 +171,9 @@ def process_events(
perf_q: sq.SynchronizedQueue,
recording_timestamp: float,
terminate_event: multiprocessing.Event,
video_container: av.container.OutputContainer,
video_stream: av.stream.Stream,
video_start_time: float,
) -> None:
"""Process events from the event queue and write them to write queues.
Expand All @@ -179,6 +185,9 @@ def process_events(
perf_q: A queue for collecting performance data.
recording_timestamp: The timestamp of the recording.
terminate_event: An event to signal the termination of the process.
video_container: The video container.
video_stream: The stream to which to write video frames.
video_start_time: The timestamp at which the video strema was started.
"""
logger.info("Starting")
Notify("Status", "Starting recording...", "OpenAdapt").send()
Expand All @@ -188,6 +197,7 @@ def process_events(
prev_window_event = None
prev_saved_screen_timestamp = 0
prev_saved_window_timestamp = 0
last_pts = 0
while not terminate_event.is_set() or not event_q.empty():
event = event_q.get()
logger.trace(f"{event=}")
Expand All @@ -199,6 +209,16 @@ def process_events(
)
if event.type == "screen":
prev_screen_event = event

last_pts = write_frame(
video_container,
video_stream,
event.data,
event.timestamp,
video_start_time,
last_pts,
)

elif event.type == "window":
prev_window_event = event
elif event.type == "action":
Expand Down Expand Up @@ -512,13 +532,12 @@ def read_screen_events(
recording_timestamp: The timestamp of the recording.
"""
logger.info("Starting")
with mss.mss() as sct:
while not terminate_event.is_set():
screenshot = utils.take_screenshot(sct)
if screenshot is None:
logger.warning("Screenshot was None")
continue
event_q.put(Event(utils.get_timestamp(), "screen", screenshot))
while not terminate_event.is_set():
screenshot = utils.take_screenshot()
if screenshot is None:
logger.warning("Screenshot was None")
continue
event_q.put(Event(utils.get_timestamp(), "screen", screenshot))
logger.info("Done")


Expand Down Expand Up @@ -818,13 +837,11 @@ def record(
recording_timestamp = recording.timestamp

video_file_name = get_video_file_name(recording_timestamp)
ffmpeg_process = start_ffmpeg_recording(video_file_name)
video_start_time = wait_for_ffmpeg_to_start(ffmpeg_process)
logger.info(f"{video_start_time=}")
if video_start_time is None:
logger.error("Failed to detect the start of the ffmpeg recording process.")
ffmpeg_process.terminate()
return
# TODO XXX replace with utils.get_monitor_dims() once fixed
width, height = utils.take_screenshot().size
video_container, video_stream, video_start_time = initialize_video_writer(
video_file_name, width, height,
)

crud.update_video_start_time(recording_timestamp, video_start_time)

Expand Down Expand Up @@ -883,6 +900,9 @@ def record(
perf_q,
recording_timestamp,
terminate_event,
video_container,
video_stream,
video_start_time,
),
)
event_processor.start()
Expand Down Expand Up @@ -972,8 +992,6 @@ def record(
term_pipe_parent_action.send(action_write_q.qsize())
term_pipe_parent_screen.send(screen_write_q.qsize())

ffmpeg_process.terminate()

logger.info("joining...")
keyboard_event_reader.join()
mouse_event_reader.join()
Expand All @@ -985,6 +1003,8 @@ def record(
window_event_writer.join()
terminate_perf_event.set()

finalize_video_writer(video_container, video_stream)

if PLOT_PERFORMANCE:
mem_plotter.join()
utils.plot_performance(recording_timestamp)
Expand All @@ -997,92 +1017,126 @@ def start() -> None:
"""Starts the recording process."""
fire.Fire(record)

import subprocess
import re
import time
###

def get_video_file_name(recording_timestamp: float):
return f"oa_recording-{recording_timestamp}.mp4"

def find_desktop_capture_index():
def initialize_video_writer(
output_path: str,
width: int,
height: int,
fps: int = 24,
codec: str = 'libx264rgb',
pix_fmt: str = config.VIDEO_PIXEL_FORMAT,
crf: int = 0,
preset: str = 'veryslow',
) -> tuple[av.container.OutputContainer, av.stream.Stream, float]:
"""
Runs ffmpeg to list avfoundation devices and parses the output to find the desktop capture index.
Initializes the video writer and returns the container, stream, and base timestamp.
Args:
output_path (str): Path to the output video file.
width (int): Width of the video.
height (int): Height of the video.
fps (int, optional): Frames per second of the video. Defaults to 24.
codec (str, optional): Codec used for encoding the video.
Defaults to 'libx264rgb'.
pix_fmt (str, optional): Pixel format of the video. Defaults to 'rgb24'.
crf (int, optional): Constant Rate Factor for encoding quality.
Defaults to 0 for lossless.
preset (str, optional): Encoding speed/quality trade-off.
Defaults to 'veryslow' for maximum compression.
Returns:
The index of the desktop capture device as a string, or None if not found.
tuple[av.container.OutputContainer, av.stream.Stream, float]: The initialized
container, stream, and base timestamp.
"""
command = ['ffmpeg', '-f', 'avfoundation', '-list_devices', 'true', '-i', '""']
process = subprocess.Popen(command, stderr=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True)
stdout, stderr = process.communicate()
logger.info("initializing video stream...")
container = av.open(output_path, mode='w')
stream = container.add_stream(codec, rate=fps)
print(f"{width=} {height=}")
stream.width = width
stream.height = height
stream.pix_fmt = pix_fmt
stream.options = {'crf': str(crf), 'preset': preset}

base_timestamp = utils.get_timestamp()

return container, stream, base_timestamp

from fractions import Fraction

def write_frame(
container: av.container.OutputContainer,
stream: av.stream.Stream,
screenshot: mss.base.ScreenShot,
timestamp: float,
base_timestamp: float,
last_pts: int, # Add this parameter to track the last PTS
pix_fmt: str = 'rgb24', # Assuming 'rgb24' is your default pixel format
) -> int: # Returns the updated last_pts
# Convert MSS ScreenShot to np.ndarray
frame = screenshot_to_np(screenshot)

# Convert the numpy array to an AVFrame
av_frame = av.VideoFrame.from_ndarray(frame, format=pix_fmt)

# Calculate the time difference in seconds
time_diff = timestamp - base_timestamp

# Calculate PTS, taking into account the fractional average rate
pts = int(time_diff * float(Fraction(stream.average_rate)))

# Ensure monotonically increasing PTS
if pts <= last_pts:
pts = last_pts + 1
av_frame.pts = pts
last_pts = pts # Update the last_pts

# Encode and write the frame
for packet in stream.encode(av_frame):
container.mux(packet)

return last_pts # Return the updated last_pts for the next call

def finalize_video_writer(
container: av.container.OutputContainer,
stream: av.stream.Stream,
) -> None:
"""
Finalizes the video writer, ensuring all buffered frames are encoded and written.
# Pattern to match the desktop capture device
pattern = re.compile(r"\[AVFoundation indev @ .*\] \[(\d+)\] Capture screen 0")
Args:
container (av.container.OutputContainer): The AV container to finalize.
stream (av.stream.Stream): The AV stream to finalize.
"""
# Flush stream
logger.info("flushing...")
for packet in stream.encode():
container.mux(packet)

match = pattern.search(stderr) # stderr is used because ffmpeg outputs device info there
if match:
return match.group(1) # Returns the index of the desktop capture device
else:
return None

def start_ffmpeg_recording(output_file):
desktop_index = find_desktop_capture_index()
assert desktop_index is not None, "Desktop capture device not found."

if sys.platform == "darwin":
capture_device = "avfoundation"
input_source = f"{desktop_index}:none" # Assuming no audio capture
frame_rate = "30" # Set to a supported frame rate
elif sys.platform == "win32":
capture_device = "gdigrab"
input_source = "desktop"
frame_rate = "" # TODO

command = [
'ffmpeg',
'-f', capture_device,
'-r', frame_rate, # Frame rate option included here directly
'-i', input_source,
'-loglevel', 'verbose',
'-copyts',
output_file
]
# Ensure all parts of the command are valid before executing
command = [arg for arg in command if arg]
logger.info(f"{command=}")
# Close the container
logger.info("closing...")
container.close()

return subprocess.Popen(command, stderr=subprocess.PIPE, universal_newlines=True)
logger.info("done")

def wait_for_ffmpeg_to_start(process, timeout=30, print_output=True):
def screenshot_to_np(screenshot: mss.base.ScreenShot) -> np.ndarray:
"""
Waits for ffmpeg to start and returns the approximate start time by monitoring its stderr output.
Converts an MSS screenshot to a NumPy array.
Args:
- process: The Popen object of the running ffmpeg process.
- timeout: Maximum time to wait for the ffmpeg start signal in seconds.
- print_output: Whether to print ffmpeg logs.
screenshot (mss.base.ScreenShot): The screenshot object from MSS.
Returns:
- The Unix timestamp of when ffmpeg likely started recording, or None if not detected.
np.ndarray: The screenshot as a NumPy array in RGB format.
"""
start_time_pattern = re.compile(r'Stream mapping:')
start_time = None
end_time = time.time() + timeout

while time.time() < end_time and not start_time:
line = process.stderr.readline()
if print_output:
print(line)
if not line:
break # End of output or timeout
# Check for the log line indicating that ffmpeg has started processing
if start_time_pattern.search(line):
start_time = time.time() # Use current time as an approximation of the start time
print(f"{line=}")
print(f"{start_time=}")
logger.info("ffmpeg has started processing.")

return start_time

def get_video_file_name(recording_timestamp: float):
return f"oa_recording-{recording_timestamp}.mp4"

# Convert the screenshot to a PIL Image first (mss provides a method for this)
img = screenshot.rgb # Get the RGB data from the screenshot
# Convert the RGB data to a NumPy array and reshape it to the correct dimensions
frame = np.frombuffer(img, dtype=np.uint8).reshape(screenshot.height, screenshot.width, 3)
return frame

if __name__ == "__main__":
fire.Fire(record)
37 changes: 9 additions & 28 deletions openadapt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,17 @@ def get_double_click_distance_pixels() -> int:
raise Exception(f"Unsupported {sys.platform=}")


def get_monitor_dims() -> tuple:
def get_monitor_dims() -> tuple[int, int]:
"""Get the dimensions of the monitor.
Returns:
tuple: The width and height of the monitor.
tuple[int, int]: The width and height of the monitor.
"""


# TODO XXX: replace with get_screenshot().size and remove get_scale_ratios?


SCT = mss.mss()
monitor = SCT.monitors[0]
monitor_width = monitor["width"]
Expand Down Expand Up @@ -578,37 +583,13 @@ def image2utf8(image: Image.Image) -> str:
return image_utf8


_start_time = None
_start_perf_counter = None


def set_start_time(value: float = None) -> float:
"""Set the start time for performance measurements.
Args:
value (float): The start time value. Defaults to the current time.
Returns:
float: The start time.
"""
global _start_time
_start_time = value or time.time()
logger.debug(f"{_start_time=}")
return _start_time


def get_timestamp(is_global: bool = False) -> float:
def get_timestamp() -> float:
"""Get the current timestamp.
Args:
is_global (bool): Flag indicating whether to use the global
start time. Defaults to False.
Returns:
float: The current timestamp.
"""
global _start_time
return _start_time + time.perf_counter()
return time.time()


# https://stackoverflow.com/a/50685454
Expand Down
Loading

0 comments on commit 14ea5cb

Please sign in to comment.