-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdct.py
176 lines (138 loc) · 5.83 KB
/
dct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#!/usr/bin/env python3
# Copyright (2021-) Shahruk Hossain <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Iterable, Tuple
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
class DCT(Layer):
"""
This layer implements the Discrete Cosine Transform. The transform is
compliant with Kaldi in that we should expect to get the same output as
Kaldi for the same input.
Only supports DCT-II at the moment because that's what we need to calculate
MFCCs.
This is adapted from PyTorch's Kaldi compliance module:
https://pytorch.org/audio/stable/_modules/torchaudio/compliance/kaldi.html#mfcc
"""
def __init__(self,
length: int,
dct_type: int = 2,
norm: str = "ortho",
name: str = None,
**kwargs):
"""
Instantiates a DCT layer with the given configuration.
"""
super(DCT, self).__init__(trainable=False, name=name, **kwargs)
self.length = length
if self.length <= 0:
raise ValueError(f"DCT length must be > 0, got {length}")
self.dctType = dct_type
if self.dctType not in [2]:
raise NotImplementedError(f"DCT-{dct_type} is not supported yet")
self.norm = norm.lower()
if self.norm not in ["ortho"]:
raise NotImplementedError(f"{norm} normalization is not supported yet")
# The DCT is acheived by multipying the input with the DCT matrix and
# then carrying out any normalization. The matrix is computed after
# receiving the shape of the input to this layer in build().
self.dct = None
# Inputs to this layers are expected to be in the shape
# (batch, timesteps, featdim)
self.featAxis = -1
def build(self, input_shape: Iterable[Union[int, None]]):
"""
Precomputes transfrom matrix that needs to be applied on the input
to achieve the desired DCT Transform.
Parameters
----------
input_shape : Iterable[Union[int, None]]
Shape of the input to this layer. Expected to have three axes,
(batch, time, feats).
Raises
------
ValueError
If input feature length (featDim) < DCT length..
"""
super(DCT, self).build(input_shape)
featDim = input_shape[self.featAxis]
if featDim < self.length:
raise ValueError("input feature length must be >= DCT length")
if self.dctType == 2:
self.computeType2Matrix(featDim)
def computeType2Matrix(self, inputLength: int):
"""
Create a DCT-II transformation matrix with shape (inputLength, self.length),
normalized depending on norm. The matrix when applied to the input tensor,
acheives the DCT-II transform. See the following link for the equation and
meaning of terms:
https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
This has been adapted from the `create_dct()` method defined in
`torchaudio.functional.functional`:
https://pytorch.org/audio/stable/_modules/torchaudio/functional/functional.html
Parameters
----------
inputLength : int
Number of samples in the input.
"""
# 0 to N.
N = inputLength
n = np.arange(N)
N = float(N)
# 0 to K; adding a dimension for matrix multiplication.
K = self.length
k = np.expand_dims(np.arange(K, dtype=np.float64), 1)
# dct.shape = (K, N) = (num_mfcc, num_mels)
dct = np.cos((np.pi / N) * (n + 0.5) * k)
if self.norm is None:
dct *= 2.0
elif self.norm == "ortho":
dct[0] *= (1.0 / np.sqrt(2.0))
dct *= np.sqrt(2.0 / N)
dct = dct.T # shape = (N, K) = (num_mels, num_mfcc)
# Kaldi expects the first cepstral to be the weighted sum of factor
# sqrt(1/num_mels); Since we multiply the input on the right by the DCT
# Matrix in this layer (i.e. input @ dct), this would be the first
# column in the dct. Note that Kaldi uses a left multiply which would be
# the first column of the kaldi's DCT Matrix.
dct[:, 0] = np.sqrt(1.0 / N)
self.dct = tf.constant(dct, dtype=self.dtype)
def compute_output_shape(self, input_shape: Iterable[Union[int, None]]) -> Tuple[Union[int, None]]:
"""
Returns the shape of the DCT output, given the shape of the input.
Parameters
----------
input_shape : Iterable[Union[int, None]]
Shape of the input to this layer. Expected to have three axes,
(batch, time, feats).
Returns
-------
Tuple[Union[int, None]]
Shape of the output of this layer.
"""
batch, time, feat = input_shape
outputShape = (batch, time, self.length)
return outputShape
def get_config(self) -> dict:
config = super(DCT, self).get_config()
config.update({
"length": self.length,
"dct_type": self.dctType,
"norm": self.norm,
})
return config
def call(self, inputs):
return tf.matmul(inputs, self.dct)