-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
47 lines (36 loc) · 1.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#! /home/alergn/virtualenvs/torchgpu/bin/python3.9
import click
import sys
import yaml
@click.command()
@click.option('--config_file', type=str)
def training(config_file):
# Open config file
with open(config_file) as file:
config = yaml.load(file, Loader=yaml.FullLoader)
# Import script (this changes based on which config file arg is passed by the user)
module_name = config['training_script']
__import__(module_name)
# Make imported module variables available
imported_module = sys.modules[module_name]
# Create new training object
train_obj = imported_module.Training(config_file)
# *** EXPERIMENT CHOICE ***
# TRAINING NOMAD
if config['experiment_name'] == 'Training':
train_obj.training_loop()
# PERFORMANCE EVALUATION
# NMR AUDIO QUALITY
elif config['experiment_name'] == 'quality_nmr':
train_obj.eval_audio_quality(config['nomad_model_path'])
# NMR RANKING VALIDATION SET CONDITIONS
elif config['experiment_name'] == 'valid_rank':
train_obj.eval_degr_level(config['nomad_model_path'])
# NMR RANKING DEGRADATION INTENSITY
elif config['experiment_name'] == 'intensity':
train_obj.eval_degradation_intensity(config['nomad_model_path'])
# FULL REFERENCE AUDIO QUALITY
elif config['experiment_name'] == 'quality_fr':
train_obj.eval_full_reference(config['nomad_model_path'])
if __name__ == '__main__':
training()