Skip to content
This repository has been archived by the owner on Feb 9, 2023. It is now read-only.

[WIP] Refactor decoding #6

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ dependencies:
- attrs==19.1.0
- cycler==0.10.0
- joblib==0.13.2
- https://github.com/kpu/kenlm/archive/master.zip
- kiwisolver==1.0.1
- matplotlib==3.0.3
- mock==2.0.0
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies:
- gast==0.2.2
- grpcio==1.19.0
- joblib==0.13.2
- https://github.com/kpu/kenlm/archive/master.zip
- keras==2.2.4
- keras-applications==1.0.7
- keras-preprocessing==1.0.9
Expand Down
50 changes: 50 additions & 0 deletions scripts/evaluate_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List

import numpy as np
import pandas as pd
import h5py
from tqdm import tqdm
from tabulate import tabulate

from decoding import BestPathDecoder, CTCDecoder
from text import Alphabet
from metric import get_metrics


ACTIVATION_PATH = '../scripts/evaluation-clarin-without-activations.hdf5'
ALPHABET = Alphabet(file_path='../models/pl/alphabet.txt')
LIMIT = -1 # Number of samples to use for evaluation (-1 for all)


def get_references(fname: str) -> pd.DataFrame:
with pd.HDFStore(fname, mode='r') as store:
return store['references']


def read_probabilities(fname: str, references: pd.DataFrame) -> List[np.ndarray]:
with h5py.File(fname, mode='r') as store:
output_index = 1
return [store[f'outputs/{output_index}/{sample_id}'][:]
for sample_id in tqdm(references.index)]


references = get_references(ACTIVATION_PATH)
probs = read_probabilities(ACTIVATION_PATH, references)
transcripts = references['transcript']


decoders = [
('Best-path decoding', BestPathDecoder(config={'alphabet': ALPHABET})),
('CTC decoding w/o lm', CTCDecoder(language_model=None, config={'alphabet': ALPHABET})),
]
results = []
for name, decoder in decoders:
decoded_transcripts = [transcript for transcript, _ in decoder.batch_decode(probs[:LIMIT])]
metrics = list(get_metrics(decoded_transcripts, transcripts[:LIMIT]))
avg_cer = sum(metric.wer for metric in metrics)/len(metrics)
avg_wer = sum(metric.cer for metric in metrics)/len(metrics)
results.append((name, avg_cer, avg_wer))
print(tabulate(results, headers=['name', 'cer', 'wer'], tablefmt="grid"))



4 changes: 4 additions & 0 deletions scripts/prepare_lm_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
wget http://2018.poleval.pl/task3/task3_train.txt.gz -o language_model_training_data.txt.gz
gzip -d language_model_training_data.txt.gz
rm language_model_training_data.txt.gz
cat language_model_training_data.txt | sed -E 's/[^a-zA-Ząćęłóśźż ]//g' | sed 's/\ /@/g' | sed -e 's/\(.\)/\1 /g' > language_model_training_data_preprocessed.txt
140 changes: 0 additions & 140 deletions source/ctc_decoder.py

This file was deleted.

Loading