Skip to content

Commit

Permalink
WIP: Benchmark key vs. extracted frames
Browse files Browse the repository at this point in the history
  • Loading branch information
medengineer committed Jan 5, 2025
1 parent ab0eee7 commit 92dac52
Showing 1 changed file with 145 additions and 61 deletions.
206 changes: 145 additions & 61 deletions tests/frame-grabber.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,37 @@
import numpy as np
import cv2

import pandas as pd
pd.set_option('display.max_rows', None)

from open_ephys.control import OpenEphysHTTPServer
from open_ephys.analysis import Session

#disable data frame preview limit
import pandas as pd
pd.set_option('display.max_rows', None)
import matplotlib.pyplot as plt

import pytesseract
from PIL import Image

def convert_time_to_ms(time_str):
parts = time_str.split(':')
minutes = int(parts[0])
seconds, hundredths = map(int, parts[1].split('.'))
return (minutes * 60000) + (seconds * 1000) + (hundredths * 10)
from scipy.interpolate import interp1d

class KeyFrame():
class VideoFrame():

def __init__(self, image_path, sample_number, software_time):
def __init__(self, image_path, sample_number, software_time, clock_time=None):
self.image_path = image_path
self.sample_number = sample_number
self.software_time = software_time
self.time = self.extract_time()
if clock_time is not None:
self.clock_time = clock_time
else:
self.clock_time = self.extract_time()

@staticmethod
def convert_time_to_ms(time_str):
parts = time_str.split(':')
minutes = int(parts[0])
seconds, hundredths = map(int, parts[1].split('.'))
ms_time = (minutes * 60000) + (seconds * 1000) + (hundredths * 10)
return ms_time

def extract_time(self):
image = cv2.imread(self.image_path)
Expand All @@ -47,26 +55,113 @@ def extract_time(self):
pil_image = Image.fromarray(clean)
text = pytesseract.image_to_string(pil_image)

return convert_time_to_ms(text.replace(" ", ""))
print(f'Detected text: {text}', end=' -> ')

#fix OCR errors
text = text.replace('00:24 6 re}','00:24.63')
text = text.replace('00:27 faye}','00:27.63')
text = text.replace('00:27. Qa','00:27.94')
text = text.replace('00:34.9 fe','00:34.93')
text = text.replace('00:38. nye)','00:38.16')
text = text.replace('OO: 53.5. 4', '00:53.54')

text = text.replace(" ", "")
text = text.replace('OO','00')
text = text.replace('0G','10')
text = text.replace('g','8')
text = text.replace('e','8')
text = text.replace('B','8')
text = text.replace('@','4')
text = text.replace(',','')
text = text.replace('a','8')
text = text.replace('%','5')
text = text.replace('Z','7')
text = text.replace('z','7')
text = text.replace('8y8','00')
text = text.replace('on','07')

if text == '': text = '00:07.52' #special case?

if text.count(":") > 1: text = text[::-1].replace(":", ".", 1)[::-1]
if text.count('.') > 1: text = text.replace(".", ":", 1)
if '\u00A2' in text: text = text.replace('\u00A2', '7')
if '.' in text and ':' in text:
print(f'{text}')
return self.convert_time_to_ms(text)
else:
if '.' not in text:
#add it to position -2
text = text[:-3] + '.' + text[-3:]
print(f'{text}')
return self.convert_time_to_ms(text)
else:
print(f'\t\t {text} is not a valid time')
return "00:00.00"

class VideoRecording():

def __init__(self, directory, experiment_index, recording_index):
def __init__(self, directory):
self.directory = directory
self.experiment_index = experiment_index
self.recording_index = recording_index

self._load_timestamps()
EXTRACT_FRAMES = False
if EXTRACT_FRAMES:
self._load_images()
self._process_images()
#self._show()
if os.path.exists(os.path.join(self.directory, 'extracted_key_times.txt')):
with open(os.path.join(self.directory, 'extracted_key_times.txt'), 'r') as f:
self.key_frame_times = list(map(int, f.readlines()))

if os.path.exists(os.path.join(self.directory, 'extracted_frame_times.txt')):
with open(os.path.join(self.directory, 'extracted_frame_times.txt'), 'r') as f:
self.extracted_frame_times = list(map(int, f.readlines()))

# find any duplicates in extracted_frame_times
seen = set()
duplicates = set()
for time in self.extracted_frame_times:
if time in seen:
duplicates.add(time)
else:
seen.add(time)
print(f"Found {len(duplicates)} duplicates in extracted_frame_times")

self.key_frames = []
self.extracted_frames = []
self._load_key_frames()
self._load_images()
self._process_images()

#load line from sync_messages.txt
sync_messages = []
if os.path.exists(os.path.join(self.directory, 'sync_messages.txt')):
with open(os.path.join(self.directory, 'sync_messages.txt'), 'r') as f:
sync_messages = f.readlines()
first_recorded_software_time = sync_messages[0].split()[-1]

for frame in self.extracted_frames:
#find exact match in key_frames
for key_frame in self.key_frames:
if key_frame.clock_time == frame.clock_time:
print(f"Matched {key_frame.clock_time} with {frame.clock_time}")




def _load_timestamps(self):
def _load_key_frames(self):
timestamps_file = os.path.join(self.directory, 'frame_timestamps.csv')
self.timestamps = np.loadtxt(timestamps_file, delimiter=',')
self.sample_numbers = self.timestamps[:, 3]

for i, row in enumerate(self.timestamps):
img_path = os.path.join(self.directory, 'frames', f"frame at {int(row[0]):010d}.jpg")
sample_number = int(row[3])
software_time = row[4]
if self.key_frame_times is not None:
self.key_frames.append(VideoFrame(img_path, int(sample_number), software_time, self.key_frame_times[i]))
else:
self.key_frames.append(VideoFrame(img_path, int(sample_number), software_time))

if not self.key_frame_times:
frame_times = [key_frame.time for key_frame in self.key_frames]
with open(os.path.join(self.directory, 'extracted_key_times.txt'), 'w') as f:
f.write('\n'.join(map(str, frame_times)))

self.key_frame_times = frame_times

def _load_images(self):
images = []
Expand All @@ -91,35 +186,29 @@ def _load_images(self):
print(f"Directory {frames_dir} already exists...loading frames from disk...")
for frame_file in glob.glob(os.path.join(frames_dir, '*.jpg')):
images.append(frame_file)
print(f"Found {len(images)} frames")

self.images = images

def _process_images(self):

intensities = []

for image in self.images:
img = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
intensities.append(np.sum(img))

values = intensities
peaks = []
troughs = []
for i in range(1, len(values) - 1):
# Detect peaks
if values[i-1] < values[i] > values[i+1]:
# Ensure the last detected point was not a peak to avoid consecutive peaks
if not peaks or (peaks and peaks[-1][0] != i-1):
peaks.append((i, values[i]))
# Detect troughs
elif values[i-1] > values[i] < values[i+1]:
# Ensure the last detected point was not a trough to avoid consecutive troughs
if not troughs or (troughs and troughs[-1][0] != i-1):
troughs.append((i, values[i]))

self.intensities = intensities
self.peaks = peaks
self.troughs = troughs

if self.extracted_frame_times is None:
for image in self.images:
img = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
#intensities.append(np.sum(img))
self.extracted_frames.append(VideoFrame(image, -1, -1))
frame_times = [key_frame.time for key_frame in self.key_frames]
with open(os.path.join(self.directory, 'extracted_frame_times.txt'), 'w') as f:
f.write('\n'.join(map(str, frame_times)))

self.extracted_frame_times = frame_times

else:

for i, image in enumerate(self.images):
#img = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
#intensities.append(np.sum(img))
self.extracted_frames.append(VideoFrame(image, -1, -1, self.extracted_frame_times[i]))

def _show(self):
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -158,11 +247,9 @@ def _detect_recordings(self):
recording_directories = glob.glob(os.path.join(experiment_directory, 'recording*'))
recording_directories.sort()

for recording_index, recording_directory in enumerate(recording_directories):
for directory in recording_directories:

recordings.append(VideoRecording(recording_directory,
experiment_index,
recording_index))
recordings.append(VideoRecording(directory))

self.recordings = recordings

Expand Down Expand Up @@ -214,8 +301,6 @@ def test(gui, params):

session = Session(path)

print(path)

node = session.recordnodes[0]

recording = node.recordings[0]
Expand All @@ -226,19 +311,18 @@ def test(gui, params):

events_by_sample_number = recording.events.sort_values(by='sample_number')

#sent_events = events_by_sample_number[events_by_sample_number["stream_name"] == "Source_Sim-1100"]

received_events = events_by_sample_number[events_by_sample_number["stream_name"] == "PXIe-6341"]
received_events = events_by_sample_number[events_by_sample_number["stream_name"] == "Probe-A-AP"]

results["Received events"] = len(received_events["sample_number"])
results["Key frame events"] = len(received_events["sample_number"])

frame_grabber_nodes = detect_frame_grabbers(path)

video = frame_grabber_nodes[0].recordings[0]

results["Extracted frames"] = len(video.images)
results["fps"] = int(len(video.images) / duration_sec)

SHOW = True
SHOW = False
if SHOW:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2, 1, figsize=(30, 12), sharex=True)
Expand Down Expand Up @@ -299,8 +383,8 @@ def test(gui, params):
if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', required=True, choices={'local', 'githubactions'})
parser.add_argument('--fetch', required=True, type=int, default=1)
parser.add_argument('--mode', required=False, choices={'local', 'githubactions'})
parser.add_argument('--fetch', required=False, type=int, default=1)
parser.add_argument('--address', required=False, type=str, default='http://127.0.0.1')
parser.add_argument('--cfg_path', required=False, type=str, default=CONFIG_FILE)
parser.add_argument('--acq_time', required=False, type=int, default=2)
Expand Down

0 comments on commit 92dac52

Please sign in to comment.