Skip to content

Commit

Permalink
[FIX] each audio file have all segments in parquet file
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Jan 7, 2025
1 parent 298312e commit 1f8db62
Showing 1 changed file with 20 additions and 55 deletions.
75 changes: 20 additions & 55 deletions src/parse_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from utils import remove_extension


def setup_logging():
logging.basicConfig(
filename="audio_processing.log",
Expand All @@ -21,7 +20,6 @@ def setup_logging():
datefmt="%Y-%m-%d %H:%M:%S",
)


@retry(wait=wait_exponential(multiplier=1, min=4, max=120))
def do_connection(connection_string):
"""Establish a connection to the filesystem with retries."""
Expand All @@ -34,7 +32,6 @@ def do_connection(connection_string):
logging.info("Retrying connection...")
raise


def walk_audio(filesystem, input_path):
"""Walk through the filesystem and yield audio files."""
walker = filesystem.walk(
Expand All @@ -45,18 +42,17 @@ def walk_audio(filesystem, input_path):
for f in flist:
yield fs.path.combine(path, f.name)


def parse_folders(filesystem, apath, rpath):
"""
Parse audio and result folders, matching audio files with their corresponding result files.
""" # noqa: E501
"""
audio_files = get_audio_files(filesystem, apath)
audio_no_extension = [remove_extension(audio_file) for audio_file in audio_files]

result_files = [
f
for f in glob.glob(rpath + "/**/*", recursive=True)
if os.path.isfile(f) # noqa: PTH113, PTH207
if os.path.isfile(f)
]
matched_files = match_audio_and_results(
audio_files, audio_no_extension, result_files
Expand All @@ -65,20 +61,18 @@ def parse_folders(filesystem, apath, rpath):
logging.info(f"Found {len(matched_files)} audio files with valid result file.")
return matched_files


def get_audio_files(filesystem, apath):
"""Get all audio files from the specified path."""
if not filesystem:
audio_files = [
f
for f in glob.glob(apath + "/**/*", recursive=True)
if os.path.isfile(f) # noqa: PTH113, PTH207
if os.path.isfile(f)
]
return [f for f in audio_files if f.endswith((".WAV", ".wav", ".mp3"))]
else:
return [audiofile for audiofile in walk_audio(filesystem, apath)]


def match_audio_and_results(audio_files, audio_no_extension, result_files):
"""Match audio files with their corresponding result files."""
matched_files = []
Expand All @@ -89,51 +83,23 @@ def match_audio_and_results(audio_files, audio_no_extension, result_files):
matched_files.append({"audio": audio_files[audio_idx], "result": result})
return matched_files


def parse_files(file_list, max_segments=10, threshold=0.6):
"""Parse the file list and make a list of segments."""
species_segments = group_segments_by_species(file_list, threshold)

for species in species_segments:
np.random.shuffle(species_segments[species])
species_segments[species] = species_segments[species][:max_segments]

segments = organize_segments_by_audio_file(species_segments)
logging.info(
f"Found {sum(len(v) for v in segments.values())} segments in {len(segments)} audio files." # noqa: E501
)

return [(audio, segments[audio]) for audio in segments]


def group_segments_by_species(file_list, threshold):
"""Group segments by species."""
species_segments = {}
segments = []
for files in file_list:
segments = find_segments(files["audio"], files["result"], threshold)
for segment in segments:
species_segments.setdefault(segment["species"], []).append(segment)
return species_segments


def organize_segments_by_audio_file(species_segments):
"""Organize segments by audio file."""
segments = {}
for species in species_segments:
for segment in species_segments[species]:
segments.setdefault(segment["audio"], []).append(segment)
segments.extend(find_segments(files["audio"], files["result"], threshold))
logging.info(f"Found {len(segments)} segments in total.")
return segments


def find_segments(audio_file, result_file, confidence_threshold):
"""Find segments in the result file that meet the confidence threshold."""
segments = []
try:
with open(result_file) as rf: # noqa: PTH123
with open(result_file) as rf:
lines = [line.strip() for line in rf.readlines()]

for i, line in enumerate(lines):
if i > 0:
if i > 0: # Skip header
data = line.split("\t")
start, end, species, confidence = (
float(data[3]),
Expand All @@ -158,7 +124,6 @@ def find_segments(audio_file, result_file, confidence_threshold):

return segments


if __name__ == "__main__":
setup_logging()

Expand All @@ -170,26 +135,26 @@ def find_segments(audio_file, result_file, confidence_threshold):
)
args = parser.parse_args()

with open(args.config) as config_file: # noqa: PTH123
config = yaml.load(config_file, Loader=yaml.FullLoader) # noqa: S506
with open(args.config) as config_file:
config = yaml.load(config_file, Loader=yaml.FullLoader)

myfs = do_connection(config["CONNECTION_STRING"])
parsed_folders = parse_folders(myfs, config["INPUT_PATH"], config["OUTPUT_PATH_BIRDNET"])
parsed_files = parse_files(
parsed_folders, config["NUM_SEGMENTS"], config["THRESHOLD"]
parsed_segments = parse_files(
parsed_folders, max_segments=config["NUM_SEGMENTS"], threshold=config["THRESHOLD"]
)

entries = [item[1][0] for item in parsed_files]
entries.sort(key=lambda e: e["audio"])

# Create a Parquet table
table = pa.table(
{
"audio": (entry["audio"] for entry in entries),
"start": (entry["start"] for entry in entries),
"end": (entry["end"] for entry in entries),
"species": (entry["species"] for entry in entries),
"confidence": (entry["confidence"] for entry in entries),
"audio": [segment["audio"] for segment in parsed_segments],
"start": [segment["start"] for segment in parsed_segments],
"end": [segment["end"] for segment in parsed_segments],
"species": [segment["species"] for segment in parsed_segments],
"confidence": [segment["confidence"] for segment in parsed_segments],
}
)

# Write the table to a Parquet file
pq.write_table(table, "sample.parquet")
logging.info("Parquet file written successfully with all segments!")

0 comments on commit 1f8db62

Please sign in to comment.