forked from Netflix/vmaf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_vmaf_cross_validation.py
46 lines (35 loc) · 1.54 KB
/
run_vmaf_cross_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
__copyright__ = "Copyright 2016-2018, Netflix, Inc."
__license__ = "Apache, Version 2.0"
import matplotlib.pyplot as plt
import numpy as np
from vmaf.config import VmafConfig, DisplayConfig
from vmaf.routine import run_vmaf_cv, run_vmaf_kfold_cv
if __name__ == '__main__':
# ==== Run simple cross validation: one training and one testing dataset ====
run_vmaf_cv(
train_dataset_filepath=VmafConfig.resource_path('dataset', 'NFLX_dataset_public.py'),
test_dataset_filepath=VmafConfig.resource_path('dataset', 'VQEGHD3_dataset.py'),
param_filepath=VmafConfig.resource_path('param', 'vmaf_v3.py'),
output_model_filepath=VmafConfig.workspace_path('model', 'test_model1.pkl'),
)
# ==== Run cross validation across genres (tough test) ====
nflx_dataset_path = VmafConfig.resource_path('dataset', 'NFLX_dataset_public.py')
contentid_groups = [
[0, 5], # cartoon: BigBuckBunny, FoxBird
[1], # CG: BirdsInCage
[2, 6, 7], # complex: CrowdRun, OldTownCross, Seeking
[3, 4], # ElFuente: ElFuente1, ElFuente2
[8], # sports: Tennis
]
param_filepath = VmafConfig.resource_path('param', 'vmaf_v3.py')
aggregate_method = np.mean
# aggregate_method = ListStats.harmonic_mean
# aggregate_method = partial(ListStats.lp_norm, p=2.0)
run_vmaf_kfold_cv(
dataset_filepath=nflx_dataset_path,
contentid_groups=contentid_groups,
param_filepath=param_filepath,
aggregate_method=aggregate_method,
)
DisplayConfig.show()
print 'Done.'