diff --git a/lingvo/core/attention.py b/lingvo/core/attention.py index b51eb7e85..74575fe9c 100644 --- a/lingvo/core/attention.py +++ b/lingvo/core/attention.py @@ -93,7 +93,6 @@ def SafeCumprod(x, *args, **kwargs): ) -# pyformat: disable def MonotonicAttentionProb(p_choose_i, previous_attention, mode): """Compute monotonic attention distribution from choosing probabilities. @@ -115,7 +114,6 @@ def MonotonicAttentionProb(p_choose_i, previous_attention, mode): 0] for all n in [0, ... batch_size - 1]. mode: How to compute the attention distribution. Must be one of `recursive`, `parallel`, or `hard`. - * recursive: uses tf.scan to recursively compute the distribution. This is slowest but is exact, general, and does not suffer from numerical instabilities. @@ -136,17 +134,18 @@ def MonotonicAttentionProb(p_choose_i, previous_attention, mode): Raises: ValueError: mode is not one of 'recursive', 'parallel', 'hard'. """ - # pyformat: enable # Force things to be tensors p_choose_i = tf.convert_to_tensor(p_choose_i, name='p_choose_i') previous_attention = tf.convert_to_tensor( - previous_attention, name='previous_attention') + previous_attention, name='previous_attention' + ) if mode == 'recursive': batch_size = py_utils.GetShape(p_choose_i)[0] tf.logging.info(batch_size) # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_i[-2]] shifted_1mp_choose_i = tf.concat( - [tf.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1) + [tf.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1 + ) # Compute attention distribution recursively as # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i] # attention[i] = p_choose_i[i]*q[i] @@ -158,19 +157,26 @@ def MonotonicAttentionProb(p_choose_i, previous_attention, mode): # Loop variables yz[0] and yz[1] [ tf.transpose(shifted_1mp_choose_i), - tf.transpose(previous_attention) + tf.transpose(previous_attention), ], # Initial value of x is just zeros - tf.zeros((batch_size,)))) + tf.zeros((batch_size,)), + ) + ) elif mode == 'parallel': # SafeCumprod computes cumprod in logspace with numeric checks cumprod_1mp_choose_i = SafeCumprod(1 - p_choose_i, axis=1, exclusive=True) # Compute recurrence relation solution - attention = p_choose_i * cumprod_1mp_choose_i * py_utils.CumSum( - previous_attention / - # Clip cumprod_1mp to avoid divide-by-zero - tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.), - axis=1) + attention = ( + p_choose_i + * cumprod_1mp_choose_i + * py_utils.CumSum( + previous_attention / + # Clip cumprod_1mp to avoid divide-by-zero + tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.0), + axis=1, + ) + ) elif mode == 'hard': # Remove any probabilities before the index chosen last time step p_choose_i *= tf.cumsum(previous_attention, axis=1) @@ -180,7 +186,8 @@ def MonotonicAttentionProb(p_choose_i, previous_attention, mode): # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0] # Product of above: [0, 0, 0, 1, 0, 0, 0, 0] attention = p_choose_i * tf.math.cumprod( - 1 - p_choose_i, axis=1, exclusive=True) + 1 - p_choose_i, axis=1, exclusive=True + ) else: raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.") return attention @@ -192,19 +199,30 @@ class BaseAttentionLayer(quant_utils.QuantizableLayer): @classmethod def Params(cls): p = super().Params() - p.Define('atten_dropout_prob', 0.0, - 'Probability at which we apply dropout to the attention weights.') p.Define( - 'atten_dropout_deterministic', False, + 'atten_dropout_prob', + 0.0, + 'Probability at which we apply dropout to the attention weights.', + ) + p.Define( + 'atten_dropout_deterministic', + False, 'Whether to dropout in a fully deterministic way, which is more ' - 'suitable for TPU.') - p.Define('packed_input', False, - 'If True, each training example may pack multiple sequences.') + 'suitable for TPU.', + ) + p.Define( + 'packed_input', + False, + 'If True, each training example may pack multiple sequences.', + ) p.qdomain.Define('softmax', None, 'QDomain for the internal softmax.') p.qdomain.Define( - 'fullyconnected', None, 'Fully connected layers are fed ' - 'into activation functions which have known input ranges') + 'fullyconnected', + None, + 'Fully connected layers are fed ' + 'into activation functions which have known input ranges', + ) return p @@ -220,12 +238,14 @@ def _CreateLayerVariables(self): super()._CreateLayerVariables() self.TrackQActs('logits', domain='fullyconnected') - def InitForSourcePacked(self, - theta, - source_vecs, - source_contexts, - source_padding, - source_segment_id=None): + def InitForSourcePacked( + self, + theta, + source_vecs, + source_contexts, + source_padding, + source_segment_id=None, + ): """Initialize attention for the given source vectors. Must set `_source_init_done` to True in the function. @@ -250,16 +270,19 @@ def InitForSourcePacked(self, inspected or modified by its callers. """ self._source_init_done = True - self._packed_src = self.PackSource(theta, source_vecs, source_contexts, - source_padding, source_segment_id) + self._packed_src = self.PackSource( + theta, source_vecs, source_contexts, source_padding, source_segment_id + ) return self._packed_src - def PackSource(self, - theta, - source_vecs, - source_contexts, - source_padding, - source_segment_id=None): + def PackSource( + self, + theta, + source_vecs, + source_contexts, + source_padding, + source_segment_id=None, + ): """Packs source vectors. Does not change attention state. @@ -285,13 +308,15 @@ def PackSource(self, """ raise NotImplementedError('Abstract method.') - def ComputeContextVectorWithSource(self, - theta, - packed_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + def ComputeContextVectorWithSource( + self, + theta, + packed_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): """Computes the context vector given the current query output. Args: @@ -315,12 +340,14 @@ def ComputeContextVectorWithSource(self, """ raise NotImplementedError('Abstract method.') - def ComputeContextVector(self, - theta, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + def ComputeContextVector( + self, + theta, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): """Computes the context vector given the current query output. Unlike `ComputeContextVectorWithSource` which explicitly asks for the packed @@ -344,10 +371,14 @@ def ComputeContextVector(self, dimensions [target_batch, ...] """ assert self._source_init_done - return self.ComputeContextVectorWithSource(theta, self._packed_src, - query_vec, attention_state, - per_step_source_padding, - query_segment_id) + return self.ComputeContextVectorWithSource( + theta, + self._packed_src, + query_vec, + attention_state, + per_step_source_padding, + query_segment_id, + ) def GetInitializationSourceState(self): """Gets the attention initialization state. @@ -398,8 +429,9 @@ def _PaddedSoftmax(self, logits, padding): logits = tf.abs(logits) assert logits.dtype.is_floating assert hasattr(logits.dtype, 'max') - very_negative_logits = ( - tf.constant(-0.7 * logits.dtype.max, dtype=logits.dtype)) + very_negative_logits = tf.constant( + -0.7 * logits.dtype.max, dtype=logits.dtype + ) if self.do_eval: very_negative_logits = self.QAct('logits', very_negative_logits) padded_logits = py_utils.ApplyPadding(padding, logits, very_negative_logits) @@ -407,8 +439,9 @@ def _PaddedSoftmax(self, logits, padding): # incompatible concats. return fns.qsoftmax(padded_logits, qdomain='softmax') - def _UpdatePaddingWithPackedInputMask(self, padding, source_segment_ids, - query_segment_ids): + def _UpdatePaddingWithPackedInputMask( + self, padding, source_segment_ids, query_segment_ids + ): """Creates an attention mask based on source and query segment ids. This creates a mask that removes invalid attention, where the query vector @@ -426,11 +459,13 @@ def _UpdatePaddingWithPackedInputMask(self, padding, source_segment_ids, # Generating packed input mask for attention padding. source_segment_ids = tf.expand_dims(source_segment_ids, 1) query_segment_ids = tf.reshape( - query_segment_ids, - [1, -1, py_utils.GetShape(source_segment_ids)[2]]) + query_segment_ids, [1, -1, py_utils.GetShape(source_segment_ids)[2]] + ) padding = tf.where_v2( - tf.equal(source_segment_ids, query_segment_ids), padding, - tf.ones([], padding.dtype)) + tf.equal(source_segment_ids, query_segment_ids), + padding, + tf.ones([], padding.dtype), + ) return padding @@ -455,8 +490,10 @@ def Params(cls): # Fill in reasonable default for params init p.params_init = py_utils.WeightInit.GaussianSqrtDim() p.Define( - 'same_batch_size', False, - 'True iff the source and target sequence has the same batch size.') + 'same_batch_size', + False, + 'True iff the source and target sequence has the same batch size.', + ) return p def __init__(self, params): @@ -481,7 +518,8 @@ def AttenProbs(inputs): # tf.reshape(v, [1, 1, 1, hidden_dim]), 3) logits = py_utils.Matmul( tf.reshape(summed, [-1, p.hidden_dim]), - tf.reshape(inputs.v, [p.hidden_dim, 1])) + tf.reshape(inputs.v, [p.hidden_dim, 1]), + ) logits = tf.reshape(logits, tf.shape(summed)[:3]) # Take out the padding states. # _source_padding is of shape [source_length, source_batch]. @@ -491,7 +529,8 @@ def AttenProbs(inputs): source_padding = tf.expand_dims(inputs.source_padding, 1) per_step_source_padding = tf.reshape( tf.transpose(inputs.per_step_source_padding), - [-1, multiplier, source_batch]) + [-1, multiplier, source_batch], + ) if source_padding.dtype != tf.bool: source_padding = source_padding > 0 if per_step_source_padding.dtype != tf.bool: @@ -500,19 +539,29 @@ def AttenProbs(inputs): if p.packed_input: source_padding = self._UpdatePaddingWithPackedInputMask( - source_padding, inputs.source_segment_id, inputs.query_segment_id) + source_padding, inputs.source_segment_id, inputs.query_segment_id + ) # Reshape logits to a matrix of shape [target_batch, source_length] and # takes the softmax to compute the probabilities. logits = tf.transpose(tf.reshape(logits, [-1, target_batch])) source_padding = tf.transpose( - tf.reshape(source_padding, [-1, target_batch])) + tf.reshape(source_padding, [-1, target_batch]) + ) probs = self._PaddedSoftmax(logits, source_padding) return probs # Adds the atten function into the graph's library. - def Atten(v, w, source_padding, source_segment_id, concated_source_vecs, - concated_source_contexts, query_vec, query_segment_id, - per_step_source_padding): + def Atten( + v, + w, + source_padding, + source_segment_id, + concated_source_vecs, + concated_source_contexts, + query_vec, + query_segment_id, + per_step_source_padding, + ): """Computes the attention context vector. Args: @@ -525,6 +574,7 @@ def Atten(v, w, source_padding, source_segment_id, concated_source_vecs, query_vec: [target_batch, query_dim] query_segment_id: [target_batch] per_step_source_padding: [target_batch, source_length] + Note: concated_source_vecs are the vectors that are used to compute the attention score between the query_vec and each concated_source_vec. The concated_source_contexts are the vectors that compose the result. The @@ -546,7 +596,8 @@ def Atten(v, w, source_padding, source_segment_id, concated_source_vecs, # query_vec is reshaped to # [1, target_batch/source_batch, source_batch, hidden_dims]. query_vec_reshaped = tf.reshape( - query_vec_transformed, [1, multiplier, source_batch, p.hidden_dim]) + query_vec_transformed, [1, multiplier, source_batch, p.hidden_dim] + ) # probs is of shape [target_batch, source_length] probs = py_utils.CallDefun( AttenProbs, @@ -557,7 +608,9 @@ def Atten(v, w, source_padding, source_segment_id, concated_source_vecs, v=v, per_step_source_padding=per_step_source_padding, source_segment_id=source_segment_id, - query_segment_id=query_segment_id)) + query_segment_id=query_segment_id, + ), + ) probs.set_shape(per_step_source_padding.shape) # Apply dropout to weights if applicable. @@ -583,10 +636,17 @@ def Atten(v, w, source_padding, source_segment_id, concated_source_vecs, return tf.reshape(summed, [target_batch, -1]), probs # The source batch size equals to the target batch size. - def AttenSameBatchSize(v, w, source_padding, source_segment_id, - concated_source_vecs, concated_source_contexts, - query_vec, query_segment_id, - per_step_source_padding): + def AttenSameBatchSize( + v, + w, + source_padding, + source_segment_id, + concated_source_vecs, + concated_source_contexts, + query_vec, + query_segment_id, + per_step_source_padding, + ): """Computes the attention context vector. Args: @@ -609,6 +669,7 @@ def AttenSameBatchSize(v, w, source_padding, source_segment_id, # [b, hidden_dim] query_vec = py_utils.Matmul(query_vec, w) + # [sl, b] def AttenProbs(inputs): """Calculates atten probs with padding.""" @@ -616,13 +677,13 @@ def AttenProbs(inputs): summed = tf.tanh(inputs.x + inputs.y) # [-1, hidden_dim] * [hidden_dim, 1] = [-1, 1] res = py_utils.Matmul( - tf.reshape(summed, [-1, p.hidden_dim]), tf.expand_dims(inputs.v, 1)) + tf.reshape(summed, [-1, p.hidden_dim]), tf.expand_dims(inputs.v, 1) + ) # Reshape res to [sl, b] logits = tf.reshape(res, tf.shape(summed)[:2]) # Take out the padding states. _source_padding is of shape [sl, b]. source_padding = inputs.source_padding - per_step_source_padding = tf.transpose( - inputs.per_step_source_padding) + per_step_source_padding = tf.transpose(inputs.per_step_source_padding) if source_padding.dtype != tf.bool: source_padding = source_padding > 0 if per_step_source_padding.dtype != tf.bool: @@ -631,8 +692,10 @@ def AttenProbs(inputs): if p.packed_input: source_padding = self._UpdatePaddingWithPackedInputMask( - tf.expand_dims(source_padding, 1), inputs.source_segment_id, - inputs.query_segment_id) + tf.expand_dims(source_padding, 1), + inputs.source_segment_id, + inputs.query_segment_id, + ) source_padding = tf.squeeze(source_padding, 1) # [b, sl] source_padding = tf.transpose(source_padding) @@ -650,7 +713,9 @@ def AttenProbs(inputs): v=v, per_step_source_padding=per_step_source_padding, source_segment_id=source_segment_id, - query_segment_id=query_segment_id)) + query_segment_id=query_segment_id, + ), + ) probs.set_shape(per_step_source_padding.shape) # contexts[i, :] is a weighted (probs[i, :]) average of @@ -674,7 +739,8 @@ def EncodeSource(theta, vecs, ctxs): ctxs = py_utils.HasShape(ctxs, [time, batch, -1]) transformed_vecs = tf.matmul(vecs, theta.source_var) transformed_vecs = tf.identity( - transformed_vecs, name='source_vecs_projected') + transformed_vecs, name='source_vecs_projected' + ) transposed_ctxs = tf.transpose(ctxs, [1, 0, 2]) transposed_ctxs = tf.identity(transposed_ctxs, name='source_ctx') return transformed_vecs, transposed_ctxs @@ -688,21 +754,24 @@ def _CreateLayerVariables(self): shape=[p.source_dim, p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['AdditiveAttention_vars']) + collections=['AdditiveAttention_vars'], + ) self.CreateVariable('source_var', pc) pc = py_utils.WeightParams( shape=[p.query_dim, p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['AdditiveAttention_vars']) + collections=['AdditiveAttention_vars'], + ) self.CreateVariable('query_var', pc) pc = py_utils.WeightParams( shape=[p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['AdditiveAttention_vars']) + collections=['AdditiveAttention_vars'], + ) self.CreateVariable('hidden_var', pc) def AddGlobalVN(self, theta): @@ -712,12 +781,14 @@ def AddGlobalVN(self, theta): theta.query_var = self.AddVN(theta.query_var) return theta - def PackSource(self, - theta, - source_vecs, - source_contexts, - source_padding, - source_segment_id=None): + def PackSource( + self, + theta, + source_vecs, + source_contexts, + source_padding, + source_segment_id=None, + ): """Packs source vectors. Does not change attention state. @@ -737,7 +808,8 @@ def PackSource(self, if source_segment_id is None: source_segment_id = tf.zeros_like(source_padding) concated_source_vecs, concated_source_contexts = self._encode_source( - theta, source_vecs, source_contexts) + theta, source_vecs, source_contexts + ) return py_utils.NestedMap( # [time, batch_size, hidden_dim]. source_vecs=concated_source_vecs, @@ -749,7 +821,8 @@ def PackSource(self, # [time, batch_size]. source_padding=source_padding, # [time, batch_size]. - source_segment_id=source_segment_id) + source_segment_id=source_segment_id, + ) def ZeroAttentionState(self, source_length, decoder_batch_size): p = self.params @@ -758,13 +831,15 @@ def ZeroAttentionState(self, source_length, decoder_batch_size): zs = tf.zeros([decoder_batch_size, 1], dtype=py_utils.FPropDtype(p)) return zs - def ComputeContextVectorWithSource(self, - theta, - packed_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + def ComputeContextVectorWithSource( + self, + theta, + packed_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): """Computes the context vector given the current query output. Note: `packed_src.source_vecs` are the vectors that are used to compute the @@ -801,10 +876,12 @@ def ComputeContextVectorWithSource(self, query_batch_size = py_utils.GetShape(query_vec)[0] source_length = py_utils.GetShape(source_padding)[0] if per_step_source_padding is None: - per_step_source_padding = tf.zeros([query_batch_size, source_length], - source_padding.dtype) + per_step_source_padding = tf.zeros( + [query_batch_size, source_length], source_padding.dtype + ) per_step_source_padding = py_utils.HasShape( - per_step_source_padding, [query_batch_size, source_length]) + per_step_source_padding, [query_batch_size, source_length] + ) hidden = self.AddVN(theta.hidden_var, per_step=True) query = self.AddVN(theta.query_var, per_step=True) @@ -812,12 +889,20 @@ def ComputeContextVectorWithSource(self, source_segment_id = tf.zeros_like(source_padding) if query_segment_id is None: query_segment_id = tf.zeros( - tf.shape(query_vec)[0], dtype=source_padding.dtype) + tf.shape(query_vec)[0], dtype=source_padding.dtype + ) - ctx_vec, prob = self._ctx_vec(hidden, query, source_padding, - source_segment_id, concated_source_vecs, - concated_source_contexts, query_vec, - query_segment_id, per_step_source_padding) + ctx_vec, prob = self._ctx_vec( + hidden, + query, + source_padding, + source_segment_id, + concated_source_vecs, + concated_source_contexts, + query_vec, + query_segment_id, + per_step_source_padding, + ) return ctx_vec, prob, attention_state @@ -841,9 +926,12 @@ def Params(cls): p.Define('query_dim', 0, 'Number of query nodes.') p.Define('hidden_dim', 0, 'Number of hidden nodes.') p.Define( - 'use_dim_scale', True, 'Whether or not to use per_dim_scale to scale ' + 'use_dim_scale', + True, + 'Whether or not to use per_dim_scale to scale ' 'the individual dims when calculating attention probabilities. It can ' - 'increase training stability when set to False.') + 'increase training stability when set to False.', + ) p.Define('atten_logit_cap', None, 'Clip softmax logits.') return p @@ -882,13 +970,13 @@ def AttenProbs(inputs): Args: inputs: a NestedMap containing: - - per_dim_scale: [source_dim], a vec to scale individual dims. - - source_padding: [time, source_batch]. - - concated_source_vecs: [source_batch, time, source_dim]. - - query_vec: [target_batch, source_dim]. - - per_step_source_padding: [target_batch, source_length] - - source_segment_id: [time, source_batch]. - - query_segment_id: [target_batch]. + * per_dim_scale: [source_dim], a vec to scale individual dims. + * source_padding: [time, source_batch]. + * concated_source_vecs: [source_batch, time, source_dim]. + * query_vec: [target_batch, source_dim]. + * per_step_source_padding: [target_batch, source_length] + * source_segment_id: [time, source_batch]. + * query_segment_id: [target_batch]. Returns: logits [target_batch, source_time]. @@ -900,7 +988,10 @@ def AttenProbs(inputs): tf.math.rsqrt( tf.cast( py_utils.GetShape(inputs.query_vec)[1], - dtype=py_utils.FPropDtype(p)))) + dtype=py_utils.FPropDtype(p), + ) + ) + ) source_batch, _, source_dim = py_utils.GetShape(concated_source_vecs) target_batch = py_utils.GetShape(inputs.query_vec)[0] query_vec = inputs.query_vec * inputs.per_dim_scale @@ -911,8 +1002,9 @@ def AttenProbs(inputs): query_vec = tf.transpose(query_vec, [1, 2, 0]) source_length = py_utils.GetShape(inputs.per_step_source_padding)[1] # => [n, source_batch, source_length] - per_step_source_padding = tf.reshape(inputs.per_step_source_padding, - [-1, source_batch, source_length]) + per_step_source_padding = tf.reshape( + inputs.per_step_source_padding, [-1, source_batch, source_length] + ) # => [source_batch, source_length, n] per_step_source_padding = tf.transpose(per_step_source_padding, [1, 2, 0]) # Dot-product part. @@ -923,7 +1015,8 @@ def AttenProbs(inputs): act_lhs=concated_source_vecs, act_rhs=query_vec, act_lhs_distribution=quant_utils.QDistribution.SYMMETRIC, - act_rhs_distribution=quant_utils.QDistribution.SYMMETRIC) + act_rhs_distribution=quant_utils.QDistribution.SYMMETRIC, + ) logits = self.QMatmul(concated_source_vecs, query_vec) logits = self.FromAqtActActMatmul(logits) @@ -939,7 +1032,8 @@ def AttenProbs(inputs): if p.packed_input: source_padding = tf.transpose(source_padding, [1, 2, 0]) source_padding = self._UpdatePaddingWithPackedInputMask( - source_padding, inputs.source_segment_id, inputs.query_segment_id) + source_padding, inputs.source_segment_id, inputs.query_segment_id + ) source_padding = tf.transpose(source_padding, [1, 2, 0]) else: source_padding = tf.transpose(source_padding, [2, 0, 1]) @@ -956,9 +1050,16 @@ def AttenProbs(inputs): probs = self._PaddedSoftmax(logits, source_padding) return probs - def Atten(per_dim_scale, source_padding, source_segment_id, - concated_source_vecs, concated_source_contexts, query_vec, - query_segment_id, per_step_source_padding): + def Atten( + per_dim_scale, + source_padding, + source_segment_id, + concated_source_vecs, + concated_source_contexts, + query_vec, + query_segment_id, + per_step_source_padding, + ): """Main attention function. Args: @@ -970,6 +1071,7 @@ def Atten(per_dim_scale, source_padding, source_segment_id, query_vec: [target_batch, source_dim]. query_segment_id: [target_batch]. per_step_source_padding: [target_batch, source_length] + Note: concated_source_vecs are the vectors that are used to compute the attention score between the query_vec and each concated_source_vec. The concated_source_contexts are the vectors that compose the result. The @@ -983,15 +1085,20 @@ def Atten(per_dim_scale, source_padding, source_segment_id, - context_vector: [target_batch, context_dim]. - probs: [target_batch, time]. """ - py_utils.assert_shape_match([py_utils.GetShape(concated_source_vecs)[2]], - [py_utils.GetShape(query_vec)[1]]) - py_utils.assert_shape_match([py_utils.GetShape(concated_source_vecs)[2]], - [symbolic.ToStatic(p.source_dim)]) + py_utils.assert_shape_match( + [py_utils.GetShape(concated_source_vecs)[2]], + [py_utils.GetShape(query_vec)[1]], + ) + py_utils.assert_shape_match( + [py_utils.GetShape(concated_source_vecs)[2]], + [symbolic.ToStatic(p.source_dim)], + ) time, source_batch = py_utils.GetShape(concated_source_vecs, 2) target_batch = py_utils.GetShape(query_vec)[0] concated_source_vecs = tf.transpose(concated_source_vecs, [1, 0, 2]) concated_source_vecs = tf.identity( - concated_source_vecs, name='concated_source_vecs') + concated_source_vecs, name='concated_source_vecs' + ) returned_probs = py_utils.CallDefun( AttenProbs, py_utils.NestedMap( @@ -1001,7 +1108,9 @@ def Atten(per_dim_scale, source_padding, source_segment_id, query_vec=query_vec, per_step_source_padding=per_step_source_padding, source_segment_id=source_segment_id, - query_segment_id=query_segment_id)) + query_segment_id=query_segment_id, + ), + ) returned_probs.set_shape(per_step_source_padding.shape) # => [n, source_batch, time] where n = target_batch // source_batch @@ -1017,19 +1126,22 @@ def Atten(per_dim_scale, source_padding, source_segment_id, # [source_batch, n, time] * [source_batch, time, context_dim] # => [source_batch, n, context_dim]. concated_source_contexts = tf.identity( - concated_source_contexts, name='concated_source_contexts') + concated_source_contexts, name='concated_source_contexts' + ) probs, concated_source_contexts = self.ToAqtActActInputs( act_lhs=probs, act_rhs=concated_source_contexts, act_lhs_distribution=quant_utils.QDistribution.POSITIVE, - act_rhs_distribution=quant_utils.QDistribution.SYMMETRIC) + act_rhs_distribution=quant_utils.QDistribution.SYMMETRIC, + ) context_vector = self.QMatmul(probs, concated_source_contexts) context_vector = self.FromAqtActActMatmul(context_vector) # => [n, source_batch, context_dim]. context_vector = tf.transpose(context_vector, [1, 0, 2]) - context_vector = gshard_utils.MeshSplit(context_vector, p.device_mesh, - p.activation_split_dims_mapping) + context_vector = gshard_utils.MeshSplit( + context_vector, p.device_mesh, p.activation_split_dims_mapping + ) # => [n * source_batch, context_dim]. context_vector = tf.reshape(context_vector, [target_batch, -1]) @@ -1046,16 +1158,19 @@ def _CreateLayerVariables(self): shape=[p.hidden_dim], init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, - collections=['DotProductAttention_vars']) + collections=['DotProductAttention_vars'], + ) self.CreateVariable('per_dim_scale', pc) - def PackSource(self, - theta, - source_vecs, - source_contexts, - source_padding, - source_segment_id=None): + def PackSource( + self, + theta, + source_vecs, + source_contexts, + source_padding, + source_segment_id=None, + ): """Packs source vectors. Does not change attention state. @@ -1090,20 +1205,23 @@ def PackSource(self, # [time, batch_size]. source_padding=source_padding, # [time, batch_size]. - source_segment_id=source_segment_id) + source_segment_id=source_segment_id, + ) def ZeroAttentionState(self, source_length, decoder_batch_size): p = self.params # No states to keep track of currently. return tf.zeros([decoder_batch_size, 1], dtype=py_utils.FPropDtype(p)) - def ComputeContextVectorWithSource(self, - theta, - packed_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + def ComputeContextVectorWithSource( + self, + theta, + packed_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): """Computes the context vector given the current query output. Args: @@ -1138,15 +1256,17 @@ def ComputeContextVectorWithSource(self, source_sequence_length = py_utils.GetShape(source_padding)[0] if per_step_source_padding is None: per_step_source_padding = tf.zeros( - [query_batch_size, source_sequence_length], - dtype=source_padding.dtype) + [query_batch_size, source_sequence_length], dtype=source_padding.dtype + ) per_step_source_padding = py_utils.HasShape( - per_step_source_padding, [query_batch_size, source_sequence_length]) + per_step_source_padding, [query_batch_size, source_sequence_length] + ) if source_segment_id is None: source_segment_id = tf.zeros_like(source_padding) if query_segment_id is None: query_segment_id = tf.zeros( - py_utils.GetShape(query_vec)[0], dtype=source_padding.dtype) + py_utils.GetShape(query_vec)[0], dtype=source_padding.dtype + ) def ScaleFn(x): return tf.nn.softplus(x) / tf.nn.softplus(tf.constant(0.0, dtype=x.dtype)) @@ -1157,9 +1277,15 @@ def ScaleFn(x): per_dim_scale_var = tf.constant(0.0, dtype=query_vec.dtype) ctx_vec, prob = self._ctx_vec( - ScaleFn(per_dim_scale_var), source_padding, source_segment_id, - concated_source_vecs, concated_source_contexts, query_vec, - query_segment_id, per_step_source_padding) + ScaleFn(per_dim_scale_var), + source_padding, + source_segment_id, + concated_source_vecs, + concated_source_contexts, + query_vec, + query_segment_id, + per_step_source_padding, + ) return ctx_vec, prob, attention_state @@ -1201,48 +1327,74 @@ def Params(cls): p.Define('hidden_dim', 0, 'Number of hidden nodes.') p.Define('num_attention_heads', 2, 'Num of attention heads.') p.Define( - 'use_source_vec_as_attention_value', True, + 'use_source_vec_as_attention_value', + True, 'Whether or not to use source_vec as the attention value as well.' - ' If True, we expect source_vec and source_contexts are the same.') - p.Define('enable_source_proj', True, - 'If False, source side linear projection is disabled.') - p.Define('enable_query_proj', True, - 'If False, query side linear projection is disabled.') - p.Define('inner_atten_params', DotProductAttention.Params(), - 'Params for underlying attention mechanism.') + ' If True, we expect source_vec and source_contexts are the same.', + ) + p.Define( + 'enable_source_proj', + True, + 'If False, source side linear projection is disabled.', + ) + p.Define( + 'enable_query_proj', + True, + 'If False, query side linear projection is disabled.', + ) p.Define( - 'enable_ctx_pre_proj', False, - 'If True, context is pre-projected before processing into' - ' hidden_dim.') + 'inner_atten_params', + DotProductAttention.Params(), + 'Params for underlying attention mechanism.', + ) + p.Define( + 'enable_ctx_pre_proj', + False, + 'If True, context is pre-projected before processing into hidden_dim.', + ) p.Define( - 'enable_ctx_post_proj', False, - 'If True, computed context is post projected into' - ' ctx_post_proj_dim.') + 'enable_ctx_post_proj', + False, + 'If True, computed context is post projected into ctx_post_proj_dim.', + ) p.Define('ctx_post_proj_dim', 0, 'Number of post projection nodes.') p.Define( - 'num_post_proj', 1, 'Number of post projections, usually the same as ' + 'num_post_proj', + 1, + 'Number of post projections, usually the same as ' 'number of tasks. Each task may choose to use one of the post ' - 'projection layers.') + 'projection layers.', + ) p.Define( - 'proj_init', 'default', 'Initialization approach for projection ' + 'proj_init', + 'default', + 'Initialization approach for projection ' 'layers:' 'uniform: Use uniform initialization. ' - 'default: Use the default Xavier initialization.') + 'default: Use the default Xavier initialization.', + ) p.Define( - 'attention_head_prob_index', -1, 'If > 0, instead of averaging ' + 'attention_head_prob_index', + -1, + 'If > 0, instead of averaging ' 'the probabilities of all attention heads when returning the ' - 'attention probability, instead return the selected index prob.') + 'attention probability, instead return the selected index prob.', + ) p.Define('use_bias', True, 'Whether to use bias for projection layer.') - p.Define('enable_per_dim_scale', True, - 'Whether to use per_dim_scale in inner_atten.') + p.Define( + 'enable_per_dim_scale', + True, + 'Whether to use per_dim_scale in inner_atten.', + ) # Often the attention context output needs to be concated # with tensors from another layer. This allows them to share # quantization parameters. By convention, all attention layers # need to include their context output vectors in this domain. - p.qdomain.Define('atten_context', None, - 'Quantization domain for attention context.') + p.qdomain.Define( + 'atten_context', None, 'Quantization domain for attention context.' + ) p.params_init = py_utils.WeightInit.Xavier(scale=1.0) @@ -1252,9 +1404,12 @@ def __init__(self, params): """Constructs a MultiHeadedAttention object.""" super().__init__(params) p = self.params - assert symbolic.ToStatic( - p.hidden_dim) % p.num_attention_heads == 0, '%s mod %s != 0' % ( - symbolic.ToStatic(p.hidden_dim), p.num_attention_heads) + assert ( + symbolic.ToStatic(p.hidden_dim) % p.num_attention_heads == 0 + ), '%s mod %s != 0' % ( + symbolic.ToStatic(p.hidden_dim), + p.num_attention_heads, + ) if p.proj_init not in ('uniform', 'default'): raise ValueError('Unknown proj_init: %s!' % p.proj_init) @@ -1268,7 +1423,8 @@ def __init__(self, params): dtype=p.dtype, atten_dropout_prob=p.atten_dropout_prob, atten_dropout_deterministic=p.atten_dropout_deterministic, - packed_input=p.packed_input) + packed_input=p.packed_input, + ) if att_p.cls == DotProductAttention: att_p.use_dim_scale = p.enable_per_dim_scale @@ -1283,22 +1439,26 @@ def __init__(self, params): 'query_proj', shape=[p.query_dim, p.hidden_dim], feature_axis=-1, - legacy_aqt_weight_name='query_proj_aqt') + legacy_aqt_weight_name='query_proj_aqt', + ) self.TrackQWeight( 'source_proj', shape=[p.source_dim, p.hidden_dim], feature_axis=-1, - legacy_aqt_weight_name='source_proj_aqt') + legacy_aqt_weight_name='source_proj_aqt', + ) self.TrackQWeight( 'ctx_proj', shape=[p.context_dim, p.hidden_dim], feature_axis=-1, - legacy_aqt_weight_name='ctx_pre_proj_aqt') + legacy_aqt_weight_name='ctx_pre_proj_aqt', + ) self.TrackQWeight( 'ctx_post_proj', shape=[p.hidden_dim, p.ctx_post_proj_dim], feature_axis=-1, - legacy_aqt_weight_name='ctx_post_proj_aqt') + legacy_aqt_weight_name='ctx_post_proj_aqt', + ) def _CreateLayerVariables(self): super()._CreateLayerVariables() @@ -1321,7 +1481,8 @@ def InitProj(layer_dim, bias=False): shape=[p.hidden_dim], init=InitProj(p.hidden_dim, bias=True), dtype=p.dtype, - collections=[self.__class__.__name__ + '_vars']) + collections=[self.__class__.__name__ + '_vars'], + ) else: pc_bias = None @@ -1332,7 +1493,8 @@ def InitProj(layer_dim, bias=False): dtype=p.dtype, device_mesh=p.device_mesh, tensor_split_dims_mapping=p.weight_split_dims_mapping, - collections=[self.__class__.__name__ + '_vars']) + collections=[self.__class__.__name__ + '_vars'], + ) self.CreateVariable('source_proj', pc) if p.use_bias: self.CreateVariable('source_proj_b', pc_bias) @@ -1346,7 +1508,8 @@ def InitProj(layer_dim, bias=False): dtype=p.dtype, device_mesh=p.device_mesh, tensor_split_dims_mapping=p.weight_split_dims_mapping, - collections=[self.__class__.__name__ + '_vars']) + collections=[self.__class__.__name__ + '_vars'], + ) self.CreateVariable('query_proj', pc) if p.use_bias: self.CreateVariable('query_proj_b', pc_bias) @@ -1361,7 +1524,8 @@ def InitProj(layer_dim, bias=False): dtype=p.dtype, device_mesh=p.device_mesh, tensor_split_dims_mapping=p.weight_split_dims_mapping, - collections=[self.__class__.__name__ + '_vars']) + collections=[self.__class__.__name__ + '_vars'], + ) self.CreateVariable('ctx_proj', pc) if p.use_bias: self.CreateVariable('ctx_proj_b', pc_bias) @@ -1385,38 +1549,47 @@ def InitProj(layer_dim, bias=False): dtype=p.dtype, device_mesh=p.device_mesh, tensor_split_dims_mapping=weight_split_dims_mapping, - collections=[self.__class__.__name__ + '_vars']) + collections=[self.__class__.__name__ + '_vars'], + ) self.CreateVariable('ctx_post_proj', pc) if p.use_bias: pc_bias_post_proj = py_utils.WeightParams( shape=pc_b_shape, init=InitProj(p.ctx_post_proj_dim, bias=True), dtype=p.dtype, - collections=[self.__class__.__name__ + '_vars']) + collections=[self.__class__.__name__ + '_vars'], + ) self.CreateVariable('ctx_post_proj_b', pc_bias_post_proj) - self.TrackQActs('source_proj_matmul', 'source_proj_add', - 'query_proj_matmul', 'query_proj_add', - 'ctx_pre_proj_matmul', 'ctx_pre_proj_add') + self.TrackQActs( + 'source_proj_matmul', + 'source_proj_add', + 'query_proj_matmul', + 'query_proj_add', + 'ctx_pre_proj_matmul', + 'ctx_pre_proj_add', + ) # TODO(suderman): Remove the self.do_eval check below once brop quant within # defun is fixed on the training side. This is less than ideal as-is because # training will just trend to match downstream quant constraints vs force # alignment. self.TrackQActs( - 'ctx_post_proj_matmul', 'ctx_post_proj_add', domain='atten_context') + 'ctx_post_proj_matmul', 'ctx_post_proj_add', domain='atten_context' + ) @classmethod def SetOutputContextDim(cls, p, out_dim): p.ctx_post_proj_dim = out_dim @py_utils.NameScopeDecorator('MultiHeadedAttention/PackSource') - def PackSource(self, - theta: py_utils.NestedMap, - source_vecs: tf.Tensor, - source_contexts: Optional[tf.Tensor], - source_padding: tf.Tensor, - source_segment_id: Optional[tf.Tensor] = None - ) -> py_utils.NestedMap: + def PackSource( + self, + theta: py_utils.NestedMap, + source_vecs: tf.Tensor, + source_contexts: Optional[tf.Tensor], + source_padding: tf.Tensor, + source_segment_id: Optional[tf.Tensor] = None, + ) -> py_utils.NestedMap: """Packs source vectors. Does not change attention state. @@ -1449,13 +1622,14 @@ def PackSource(self, if p.use_source_vec_as_attention_value: assert source_contexts is not None # [time_steps, batch_size, context_dim] - source_contexts = py_utils.HasShape(source_contexts, - [time_steps, batch_size, -1]) - source_padding = py_utils.HasShape(source_padding, - [time_steps, batch_size]) + source_contexts = py_utils.HasShape( + source_contexts, [time_steps, batch_size, -1] + ) + source_padding = py_utils.HasShape(source_padding, [time_steps, batch_size]) if source_segment_id is not None: - source_segment_id = py_utils.HasShape(source_segment_id, - [time_steps, batch_size]) + source_segment_id = py_utils.HasShape( + source_segment_id, [time_steps, batch_size] + ) with tf.name_scope('vecs'): if p.enable_source_proj: @@ -1463,7 +1637,8 @@ def PackSource(self, 'source_proj', act=source_vecs, weight=theta.source_proj, - w_feature_axis=-1) + w_feature_axis=-1, + ) w_source_proj = self.QWeight(w_source_proj) source_vecs = self.QMatmul(source_vecs, w_source_proj) source_vecs = self.QAct('source_proj_matmul', source_vecs) @@ -1472,24 +1647,27 @@ def PackSource(self, source_vecs = fns.qadd( source_vecs, self.QWeight(theta.source_proj_b), - qout_name='source_proj_add') + qout_name='source_proj_add', + ) source_vecs = gshard_utils.MeshSplit( - source_vecs, - p.device_mesh, - p.activation_split_dims_mapping) + source_vecs, p.device_mesh, p.activation_split_dims_mapping + ) # => [time_steps, batch_size, hidden_dim] - source_vecs = py_utils.HasShape(source_vecs, - [time_steps, - batch_size, - symbolic.ToStatic(p.hidden_dim)]) + source_vecs = py_utils.HasShape( + source_vecs, [time_steps, batch_size, symbolic.ToStatic(p.hidden_dim)] + ) # => [time_steps, batch_size * num_heads, hidden_dim / num_heads] - source_vecs = tf.reshape(source_vecs, - [-1, - batch_size * num_heads, - symbolic.ToStatic(p.hidden_dim // num_heads)]) - source_vecs = gshard_utils.MeshSplit(source_vecs, - p.device_mesh, - p.activation_split_dims_mapping) + source_vecs = tf.reshape( + source_vecs, + [ + -1, + batch_size * num_heads, + symbolic.ToStatic(p.hidden_dim // num_heads), + ], + ) + source_vecs = gshard_utils.MeshSplit( + source_vecs, p.device_mesh, p.activation_split_dims_mapping + ) source_vecs = self.ProcessProjectionVec(theta, source_vecs, 'source') with tf.name_scope('contexts'): @@ -1501,7 +1679,8 @@ def PackSource(self, 'ctx_proj', act=source_contexts, weight=theta.ctx_proj, - w_feature_axis=-1) + w_feature_axis=-1, + ) w_ctx_proj = self.QWeight(w_ctx_proj) source_contexts = self.QMatmul(source_contexts, w_ctx_proj) source_contexts = self.QAct('ctx_pre_proj_matmul', source_contexts) @@ -1510,26 +1689,27 @@ def PackSource(self, source_contexts = fns.qadd( source_contexts, self.QWeight(theta.ctx_proj_b), - qout_name='ctx_pre_proj_add') + qout_name='ctx_pre_proj_add', + ) source_contexts = gshard_utils.MeshSplit( - source_contexts, - p.device_mesh, - p.activation_split_dims_mapping) + source_contexts, p.device_mesh, p.activation_split_dims_mapping + ) # => [time_steps, batch_size, context_dim] - source_contexts = py_utils.HasShape(source_contexts, - [time_steps, batch_size, -1]) + source_contexts = py_utils.HasShape( + source_contexts, [time_steps, batch_size, -1] + ) context_dim = py_utils.GetShape(source_contexts)[-1] # => [time_steps, batch_size * num_heads, context_dim / num_heads] source_contexts = tf.reshape( source_contexts, - [-1, batch_size * num_heads, context_dim // num_heads]) + [-1, batch_size * num_heads, context_dim // num_heads], + ) source_contexts = gshard_utils.MeshSplit( - source_contexts, - p.device_mesh, - p.activation_split_dims_mapping) - source_contexts = self.ProcessProjectionVec(theta, - source_contexts, - 'ctx') + source_contexts, p.device_mesh, p.activation_split_dims_mapping + ) + source_contexts = self.ProcessProjectionVec( + theta, source_contexts, 'ctx' + ) with tf.name_scope('padding'): # => [time_steps, batch_size, 1] @@ -1548,24 +1728,29 @@ def PackSource(self, # => [time_steps, batch_size, num_heads] source_segment_id = tf.tile(source_segment_id, [1, 1, num_heads]) # => [time_steps, batch_size * num_heads] - source_segment_id = tf.reshape(source_segment_id, - [-1, batch_size * num_heads]) + source_segment_id = tf.reshape( + source_segment_id, [-1, batch_size * num_heads] + ) - return self.atten.PackSource(theta.atten, - source_vecs, - source_contexts, - source_padding, - source_segment_id) + return self.atten.PackSource( + theta.atten, + source_vecs, + source_contexts, + source_padding, + source_segment_id, + ) @py_utils.NameScopeDecorator('MultiHeadedAttention/ExtendSourcePacked') - def ExtendSourcePacked(self, - theta, - new_source_vecs, - new_source_contexts, - new_source_paddings, - new_source_segment_ids, - cached_packed_src, - t=None): + def ExtendSourcePacked( + self, + theta, + new_source_vecs, + new_source_contexts, + new_source_paddings, + new_source_segment_ids, + cached_packed_src, + t=None, + ): """Extend cached source_vecs and source_contexts by one more timestep. Args: @@ -1583,21 +1768,21 @@ def ExtendSourcePacked(self, source_vecs and source_contexts for the previous t-1 steps. To support tf.while_loop on TPU (satisfying static shape requirement), instead of using tf.concat to update the cached vectors, the time dimension of each - cached vector is fixed as the max_sequence_length and inplace - update op is used to update the information for each time step: + cached vector is fixed as the max_sequence_length and inplace update op + is used to update the information for each time step: * source_vecs: A tensor of shape [max_sequence_length, source_batch, hidden_dim]. [:t, :, :] contains valid preprocessed source_vecs in the - previous t - 1 timesteps, the rests are invalid data. + previous t - 1 timesteps, the rests are invalid data. * source_contexts: A tensor of shape [max_sequence_length, source_batch, hidden_dim]. [:t, :, :] contains valid preprocessed source_contexts in - the previous t - 1 timesteps, the rests are invalid data. + the previous t - 1 timesteps, the rests are invalid data. * source_padding: If not None, a tensor of shape [max_sequence_length, source_batch, num_heads]. [:t, :, :] contains cached source padding - for the previous t - 1 timesteps, the rests are invalid data. + for the previous t - 1 timesteps, the rests are invalid data. * source_segment_id: If not None, a tensor of shape [max_sequence_length, source_batch, num_heads]. [:t, :, :] contains - cached source segment id for the previous t - 1 timesteps, the rests - are invalid data. + cached source segment id for the previous t - 1 timesteps, the rests + are invalid data. When t is None (not running on TPU or the while loop is unrolled): * source_vecs: A tensor of shape [t - 1, source_batch, hidden_dim]. * source_contexts: A tensor of shape [t - 1, source_batch, hidden_dim]. @@ -1624,21 +1809,28 @@ def ExtendSourcePacked(self, 'extended_source_context' is of shape [t, batch_size, num_heads * dim]; 'source_padding' is of shape [t, batch_size, num_heads]; 'source_segment_id' is of shape [t, batch_size, num_heads]. - """ + """ # pyformat: disable batch_size = py_utils.GetShape(new_source_vecs)[0] if new_source_paddings is None: new_source_paddings = tf.zeros([batch_size], dtype=new_source_vecs.dtype) if new_source_segment_ids is None: - new_source_segment_ids = tf.zeros([batch_size], - dtype=new_source_vecs.dtype) + new_source_segment_ids = tf.zeros( + [batch_size], dtype=new_source_vecs.dtype + ) processed_packed_src = self.InitForSourcePacked( - theta, tf.expand_dims(new_source_vecs, 0), + theta, + tf.expand_dims(new_source_vecs, 0), tf.expand_dims(new_source_contexts, 0), tf.expand_dims(new_source_paddings, 0), - tf.expand_dims(new_source_segment_ids, 0)) + tf.expand_dims(new_source_segment_ids, 0), + ) extended_packed_src = py_utils.NestedMap() - for key in ('source_vecs', 'source_contexts', 'source_padding', - 'source_segment_id'): + for key in ( + 'source_vecs', + 'source_contexts', + 'source_padding', + 'source_segment_id', + ): if cached_packed_src.get(key, None) is None: extended_packed_src[key] = None else: @@ -1653,28 +1845,35 @@ def ExtendSourcePacked(self, # replacing it with e.g. `tf.tensor_scatter_nd_update` if py_utils.ReplaceAliasInplaceUpdateInAttention(): extended_packed_src[key] = tf.tensor_scatter_nd_update( - cached_packed_src[key], [[t]], [processed]) + cached_packed_src[key], [[t]], [processed] + ) else: extended_packed_src[key] = inplace_ops.alias_inplace_update( - cached_packed_src[key], t, processed) + cached_packed_src[key], t, processed + ) else: processed = tf.reshape(processed_packed_src[key], [1, batch_size, -1]) extended_packed_src[key] = tf.concat( - [cached_packed_src[key], processed], axis=0) + [cached_packed_src[key], processed], axis=0 + ) return extended_packed_src @py_utils.NameScopeDecorator('MultiHeadedAttention/ZeroAttentionState') def ZeroAttentionState(self, source_length, decoder_batch_size): zero_att_state = self.atten.ZeroAttentionState( - source_length, decoder_batch_size * self.params.num_attention_heads) + source_length, decoder_batch_size * self.params.num_attention_heads + ) # [batch * num_heads, length] => [batch, num_heads * length]. zero_att_state = _RecursiveReshape(zero_att_state, [decoder_batch_size, -1]) nested_map_zero_att_state = py_utils.NestedMap(inner=zero_att_state) if self.params.attention_head_prob_index >= 0: - selected_prob_head = tf.zeros([decoder_batch_size, source_length], - dtype=py_utils.FPropDtype(self.params)) - nested_map_zero_att_state[ - 'selected_attention_head_probs'] = selected_prob_head + selected_prob_head = tf.zeros( + [decoder_batch_size, source_length], + dtype=py_utils.FPropDtype(self.params), + ) + nested_map_zero_att_state['selected_attention_head_probs'] = ( + selected_prob_head + ) return nested_map_zero_att_state def ProcessProjectionVec(self, theta, projection_vec, projection_type): @@ -1683,15 +1882,18 @@ def ProcessProjectionVec(self, theta, projection_vec, projection_type): return projection_vec @py_utils.NameScopeDecorator( - 'MultiHeadedAttention/ComputeContextVectorWithSource') - def ComputeContextVectorWithSource(self, - theta, - packed_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None, - atten_idx=None): + 'MultiHeadedAttention/ComputeContextVectorWithSource' + ) + def ComputeContextVectorWithSource( + self, + theta, + packed_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + atten_idx=None, + ): """Computes the context vector given the current query output. Args: @@ -1700,9 +1902,9 @@ def ComputeContextVectorWithSource(self, packed_src: A `.NestedMap` object returned by PackSource or InitForSourcePacked. query_vec: a tensor of shape [target_batch, query_dim]. - attention_state: A NestedMap. 'inner' contains the inner attention - state. It is not used in AdditiveAttention, and is simply passed - through. Optionally, if attention_head_prob_index >= 0, then + attention_state: A NestedMap. 'inner' contains the inner attention state. + It is not used in AdditiveAttention, and is simply passed through. + Optionally, if attention_head_prob_index >= 0, then 'selected_attention_head_probs' contains the selected attention probability head. per_step_source_padding: Source sequence padding to apply at this step. If @@ -1712,6 +1914,7 @@ def ComputeContextVectorWithSource(self, different samples in a batch, each of which may come from different tasks. This is usually used in multi-task setting. A tensor of shape [target_batch]. + Note: concated_source_vecs are the vectors that are used to compute the attention score between the query_vec and each concated_source_vec. The concated_source_contexts are the vectors that compose the result. The @@ -1742,47 +1945,59 @@ def ComputeContextVectorWithSource(self, 'query_proj', act=query_vec, weight=theta.query_proj, - w_feature_axis=-1) + w_feature_axis=-1, + ) w_query_proj = self.QWeight(w_query_proj) query_vec_projected = self.QMatmul(query_vec, w_query_proj) query_vec_projected = self.QAct('query_proj_matmul', query_vec_projected) - query_vec_projected = self.FromAqtMatmul('query_proj', - query_vec_projected) + query_vec_projected = self.FromAqtMatmul( + 'query_proj', query_vec_projected + ) if p.use_bias: query_vec_projected = fns.qadd( query_vec_projected, self.QWeight(theta.query_proj_b), - qout_name='query_proj_add') - query_vec_projected = tf.reshape(query_vec_projected, - query_vec_projected_shape) - query_vec_projected = self.ProcessProjectionVec(theta, - query_vec_projected, - 'query') + qout_name='query_proj_add', + ) + query_vec_projected = tf.reshape( + query_vec_projected, query_vec_projected_shape + ) + query_vec_projected = self.ProcessProjectionVec( + theta, query_vec_projected, 'query' + ) else: query_vec_projected = tf.reshape(query_vec, query_vec_projected_shape) if p.activation_split_dims_mapping: query_vec_projected = gshard_utils.MeshSplit( - query_vec_projected, p.device_mesh, - p.activation_split_dims_mapping[1:]) + query_vec_projected, + p.device_mesh, + p.activation_split_dims_mapping[1:], + ) query_batch_size = py_utils.GetShape(query_vec)[0] if query_segment_id is None: query_segment_id = tf.zeros( - query_batch_size * num_heads, dtype=source_padding.dtype) + query_batch_size * num_heads, dtype=source_padding.dtype + ) else: query_segment_id_repl = tf.tile( - tf.expand_dims(query_segment_id, 1), [1, num_heads]) + tf.expand_dims(query_segment_id, 1), [1, num_heads] + ) query_segment_id = tf.reshape(query_segment_id_repl, [-1]) if per_step_source_padding is None: - per_step_source_padding = tf.zeros([query_batch_size, source_seq_len], - dtype=source_padding.dtype) + per_step_source_padding = tf.zeros( + [query_batch_size, source_seq_len], dtype=source_padding.dtype + ) per_step_source_padding = py_utils.HasShape( - per_step_source_padding, [query_batch_size, source_seq_len]) + per_step_source_padding, [query_batch_size, source_seq_len] + ) per_step_source_padding = tf.reshape( - tf.tile(per_step_source_padding, [1, num_heads]), [-1, source_seq_len]) - attention_state = _RecursiveReshape(attention_state, - [batch_size * num_heads, -1]) + tf.tile(per_step_source_padding, [1, num_heads]), [-1, source_seq_len] + ) + attention_state = _RecursiveReshape( + attention_state, [batch_size * num_heads, -1] + ) if isinstance(attention_state, py_utils.NestedMap): if 'emit_probs' in attention_state: inner_state = attention_state @@ -1793,23 +2008,32 @@ def ComputeContextVectorWithSource(self, else: inner_state = attention_state ctx_vec, prob, new_inner_state = self.atten.ComputeContextVectorWithSource( - theta.atten, packed_src, query_vec_projected, inner_state, - per_step_source_padding, query_segment_id) + theta.atten, + packed_src, + query_vec_projected, + inner_state, + per_step_source_padding, + query_segment_id, + ) ctx_vec = tf.reshape(ctx_vec, [batch_size, -1]) if p.activation_split_dims_mapping: - ctx_vec = gshard_utils.MeshSplit(ctx_vec, p.device_mesh, - p.activation_split_dims_mapping[1:]) + ctx_vec = gshard_utils.MeshSplit( + ctx_vec, p.device_mesh, p.activation_split_dims_mapping[1:] + ) if p.enable_ctx_post_proj: if atten_idx is None: assert p.num_post_proj == 1, ( 'atten_idx is None, this means there is no need to select ' 'different post projections, and p.num_post_proj is supposed to be ' - '1. However you set p.num_post_proj=%s .' % p.num_post_proj) + '1. However you set p.num_post_proj=%s .' + % p.num_post_proj + ) ctx_vec, w_ctx_post_proj = self.ToAqtInputs( 'ctx_post_proj', act=ctx_vec, weight=theta.ctx_post_proj, - w_feature_axis=-1) + w_feature_axis=-1, + ) w_ctx_post_proj = self.QWeight(w_ctx_post_proj) ctx_vec = self.QMatmul(ctx_vec, w_ctx_post_proj) ctx_vec = self.QAct('ctx_post_proj_matmul', ctx_vec) @@ -1818,12 +2042,15 @@ def ComputeContextVectorWithSource(self, ctx_vec = fns.qadd( ctx_vec, self.QWeight(theta.ctx_post_proj_b), - qout_name='ctx_post_proj_add') + qout_name='ctx_post_proj_add', + ) else: assert p.num_post_proj > 1, ( 'atten_idx is not None, this means there are multiple post ' 'projections, and p.num_post_proj is supposed to be > 1. However ' - 'you set p.num_post_proj=%s .' % p.num_post_proj) + 'you set p.num_post_proj=%s .' + % p.num_post_proj + ) bs_range = [tf.range(batch_size)] select = tf.transpose(tf.concat([bs_range, [atten_idx]], axis=0)) # => [batch, dim, num_langs] @@ -1838,11 +2065,13 @@ def ComputeContextVectorWithSource(self, # explicitly name this tensor for potential future reference multi_headed_atten_prob = tf.reshape( - prob, [batch_size, num_heads, -1], name='multi_headed_atten_prob') + prob, [batch_size, num_heads, -1], name='multi_headed_atten_prob' + ) prob = self.QRAct( tf.reduce_mean(multi_headed_atten_prob, 1), quant_utils.QDistribution.SOFTMAX, - domain='softmax') + domain='softmax', + ) if isinstance(attention_state, py_utils.NestedMap): att_state = attention_state if 'emit_probs' in attention_state: @@ -1852,16 +2081,19 @@ def ComputeContextVectorWithSource(self, else: att_state = new_inner_state if p.attention_head_prob_index >= 0: - selected_prob_head = multi_headed_atten_prob[:, p. - attention_head_prob_index, :] + selected_prob_head = multi_headed_atten_prob[ + :, p.attention_head_prob_index, : + ] att_state.selected_attention_head_probs = selected_prob_head att_state = _RecursiveReshape(att_state, [batch_size, -1]) return ctx_vec, prob, att_state @py_utils.NameScopeDecorator( - 'MultiHeadedAttention/ComputeContextVectorWithAttenProbs') - def ComputeContextVectorWithAttenProbs(self, theta, packed_context, - atten_probs): + 'MultiHeadedAttention/ComputeContextVectorWithAttenProbs' + ) + def ComputeContextVectorWithAttenProbs( + self, theta, packed_context, atten_probs + ): """Computes the context vector given the attention probailities. Args: @@ -1882,13 +2114,19 @@ def ComputeContextVectorWithAttenProbs(self, theta, packed_context, # packed_context: [batch_size * num_head, num_style, # hidden_dim / num_head] # inp: [batch_size * num_head, num_style] - packed_context = py_utils.with_dependencies([ - py_utils.assert_shape_match([py_utils.GetShape(packed_context)[0]], - [py_utils.GetShape(atten_probs)[0]]) - ], packed_context) + packed_context = py_utils.with_dependencies( + [ + py_utils.assert_shape_match( + [py_utils.GetShape(packed_context)[0]], + [py_utils.GetShape(atten_probs)[0]], + ) + ], + packed_context, + ) b_size = py_utils.GetShape(packed_context)[0] // num_heads ctx_vec = tf.reshape( - tf.matmul(tf.expand_dims(atten_probs, 1), packed_context), [b_size, -1]) + tf.matmul(tf.expand_dims(atten_probs, 1), packed_context), [b_size, -1] + ) if p.enable_ctx_post_proj: ctx_vec_proj = tf.matmul(ctx_vec, theta.ctx_post_proj) ctx_vec_proj += theta.ctx_post_proj_b @@ -1908,35 +2146,46 @@ def PackCachedSource(self, cached_src): num_heads = p.num_attention_heads packed_src = py_utils.NestedMap() packed_src.source_vecs = tf.reshape( - concated_source_vecs, [src_seq_len, batch_size * num_heads, -1]) + concated_source_vecs, [src_seq_len, batch_size * num_heads, -1] + ) # TODO(yonghui): Rewrite the following with just one transpose. packed_src.source_contexts = tf.transpose( - tf.reshape(concated_source_contexts, - [src_seq_len, batch_size * num_heads, -1]), [1, 0, 2]) + tf.reshape( + concated_source_contexts, [src_seq_len, batch_size * num_heads, -1] + ), + [1, 0, 2], + ) if source_padding is not None: packed_src.source_padding = tf.reshape( - source_padding, [src_seq_len, batch_size * num_heads]) + source_padding, [src_seq_len, batch_size * num_heads] + ) else: packed_src.source_padding = tf.zeros( - [src_seq_len, batch_size * num_heads], dtype=py_utils.FPropDtype(p)) + [src_seq_len, batch_size * num_heads], dtype=py_utils.FPropDtype(p) + ) if source_segment_id is None: packed_src.source_segment_id = tf.zeros( [src_seq_len, batch_size * num_heads], - dtype=packed_src.source_padding.dtype) + dtype=packed_src.source_padding.dtype, + ) else: packed_src.source_segment_id = tf.reshape( - source_segment_id, [src_seq_len, batch_size * num_heads]) + source_segment_id, [src_seq_len, batch_size * num_heads] + ) return packed_src @py_utils.NameScopeDecorator( - 'MultiHeadedAttention/ComputeContextVectorWithCachedSource') - def ComputeContextVectorWithCachedSource(self, - theta, - cached_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + 'MultiHeadedAttention/ComputeContextVectorWithCachedSource' + ) + def ComputeContextVectorWithCachedSource( + self, + theta, + cached_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): """Same as the ComputeContextVectorWithSource api above, except values ... in source_vecs, source_contexts and source_padding are ordered differently. @@ -1961,8 +2210,13 @@ def ComputeContextVectorWithCachedSource(self, dimensions [target_batch....] """ return self.ComputeContextVectorWithSource( - theta, self.PackCachedSource(cached_src), query_vec, attention_state, - per_step_source_padding, query_segment_id) + theta, + self.PackCachedSource(cached_src), + query_vec, + attention_state, + per_step_source_padding, + query_segment_id, + ) class LocationSensitiveAttention(BaseAttentionLayer): @@ -1977,25 +2231,33 @@ def Params(cls): """Params for this LocationSensitiveAttention class.""" p = super().Params() p.Define('source_dim', 0, 'Number of source nodes.') - p.Define('location_filter_size', 0, - 'Location filter size, should be an odd number e.g. 31.') + p.Define( + 'location_filter_size', + 0, + 'Location filter size, should be an odd number e.g. 31.', + ) p.Define('location_num_filters', 0, 'Number of location filters, e.g. 32.') p.Define('query_dim', 0, 'Number of query nodes.') p.Define('hidden_dim', 0, 'Number of hidden nodes.') p.Define( - 'same_batch_size', False, - 'True iff the source and target sequence has the same batch size.') + 'same_batch_size', + False, + 'True iff the source and target sequence has the same batch size.', + ) p.Define( - 'location_features', ['PREV_PROBS'], + 'location_features', + ['PREV_PROBS'], 'List signals to run the convolutions on. Possible options are: ' - 'PREV_PROBS, CUMULATIVE_PROBS.') + 'PREV_PROBS, CUMULATIVE_PROBS.', + ) # Often the attention context output needs to be concated # with tensors from another layer. This allows them to share # quantization parameters. By convention, all attention layers # need to include their context output vectors in this domain. - p.qdomain.Define('atten_context', None, - 'Quantization domain for attention context.') + p.qdomain.Define( + 'atten_context', None, 'Quantization domain for attention context.' + ) # Fill in reasonable default for params init p.params_init = py_utils.WeightInit.GaussianSqrtDim() @@ -2006,8 +2268,9 @@ def __init__(self, params): super().__init__(params) p = self.params self._is_quantized = p.qdomain.default is not None - assert not p.packed_input, ('Packed input is not supported yet for ' - 'LocationSensitiveAttention.') + assert ( + not p.packed_input + ), 'Packed input is not supported yet for LocationSensitiveAttention.' if p.atten_dropout_prob != 0: raise NotImplementedError('dropout is not supported') @@ -2022,7 +2285,8 @@ def CollapseOutDim(x): # => [sl, sb, hd] location_feats = tf.transpose(inputs.location_feats, [2, 0, 1]) location_hidden = py_utils.Matmul( - CollapseOutDim(location_feats), inputs.location_var) + CollapseOutDim(location_feats), inputs.location_var + ) location_hidden = self.QAct('logits_mul', location_hidden) sl = py_utils.GetShape(location_feats)[0] @@ -2037,7 +2301,8 @@ def CollapseOutDim(x): summed = fns.qadd( inputs.concated_source_vecs, inputs.query_vec_reshaped, - qout_name='logits_add') + qout_name='logits_add', + ) summed = fns.qadd(summed, location_hidden, qout_name='logits_bias') summed = fns.qtanh(summed) # logits is of shape [sl * tb/sb * sb, 1]. Computes dot product @@ -2045,7 +2310,8 @@ def CollapseOutDim(x): # result to be of shape [sl, tb/sb, sb]. logits = py_utils.Matmul( tf.reshape(summed, [-1, p.hidden_dim]), - tf.reshape(inputs.hidden_var, [p.hidden_dim, 1])) + tf.reshape(inputs.hidden_var, [p.hidden_dim, 1]), + ) logits = self.QAct('logits', logits) logits = tf.reshape(logits, py_utils.GetShape(summed)[:3]) return logits @@ -2058,11 +2324,11 @@ def AttenLogitsSameBatchSize(inputs): Args: inputs: a NestedMap containing: - - concated_source_vecs: Tensor of shape [sl, batch, dim] - - query_vec_transformed: Tensor of shape [batch, dim] - - hidden_var: Tensor of shape [dim] - - location_feats: Tensor of shape [batch, location_feature_dim, sl] - - location_var: Tensor of shape [location_feature_dim, dim] + * concated_source_vecs: Tensor of shape [sl, batch, dim] + * query_vec_transformed: Tensor of shape [batch, dim] + * hidden_var: Tensor of shape [dim] + * location_feats: Tensor of shape [batch, location_feature_dim, sl] + * location_var: Tensor of shape [location_feature_dim, dim] Returns: logits in the shape [sl, batch_size]. @@ -2075,7 +2341,8 @@ def CollapseOutDim(x): # => [sl, sb, hd] location_feats = tf.transpose(inputs.location_feats, [2, 0, 1]) location_hidden = py_utils.Matmul( - CollapseOutDim(location_feats), inputs.location_var) + CollapseOutDim(location_feats), inputs.location_var + ) location_hidden = self.QAct('logits_mul', location_hidden) sl = tf.shape(location_feats)[0] tb = tf.shape(location_feats)[1] @@ -2086,7 +2353,8 @@ def CollapseOutDim(x): summed = fns.qadd( inputs.concated_source_vecs, tf.expand_dims(inputs.query_vec_transformed, 0), - qout_name='logits_add') + qout_name='logits_add', + ) summed = fns.qadd(summed, location_hidden, qout_name='logits_bias') summed = fns.qtanh(summed) @@ -2096,20 +2364,31 @@ def CollapseOutDim(x): # result to be of shape [sl, tb]. logits = py_utils.Matmul( tf.reshape(summed, [-1, p.hidden_dim]), - tf.reshape(inputs.hidden_var, [p.hidden_dim, 1])) + tf.reshape(inputs.hidden_var, [p.hidden_dim, 1]), + ) logits = self.QAct('logits', logits) logits = tf.reshape(logits, py_utils.GetShape(summed)[:2]) return logits - def Atten(hidden_var, query_var, source_padding, concated_source_vecs, - concated_source_contexts, query_vec, attention_state, - location_filter_var, location_var, per_step_source_padding): + def Atten( + hidden_var, + query_var, + source_padding, + concated_source_vecs, + concated_source_contexts, + query_vec, + attention_state, + location_filter_var, + location_var, + per_step_source_padding, + ): """Computes the attention context vector.""" p = self.params # attention_state shape [batch, len(p.location_features), slen] # it contains previous and accumulated attention probabilites. - attention_state = py_utils.HasShape(attention_state, - [-1, len(p.location_features), -1]) + attention_state = py_utils.HasShape( + attention_state, [-1, len(p.location_features), -1] + ) location_feats = self._ApplyConv(attention_state, location_filter_var) @@ -2124,23 +2403,28 @@ def Atten(hidden_var, query_var, source_padding, concated_source_vecs, query_vec_transformed = py_utils.Matmul(query_vec, query_var) query_vec_transformed = self.QAct('atten_matmul', query_vec_transformed) # query_vec is reshaped to [1, tb/sb, sb, hidden_dims]. - query_vec_reshaped = tf.reshape(query_vec_transformed, - [1, multiplier, sb, p.hidden_dim]) + query_vec_reshaped = tf.reshape( + query_vec_transformed, [1, multiplier, sb, p.hidden_dim] + ) # logits is of shape [sl, tb/sb, sb] logits = _ConditionalCallDefun( - self._is_quantized, AttenLogits, + self._is_quantized, + AttenLogits, py_utils.NestedMap( concated_source_vecs=concated_source_vecs, query_vec_reshaped=query_vec_reshaped, hidden_var=hidden_var, location_feats=location_feats, - location_var=location_var)) + location_var=location_var, + ), + ) # Take out the padding states. # _source_padding is of shape [sl, sb]. # reshaped to [sl, 1, sb]. source_padding = tf.expand_dims(source_padding, 1) per_step_source_padding = tf.reshape( - tf.transpose(per_step_source_padding), [-1, multiplier, sb]) + tf.transpose(per_step_source_padding), [-1, multiplier, sb] + ) if source_padding.dtype != tf.bool: source_padding = source_padding > 0 if per_step_source_padding.dtype != tf.bool: @@ -2159,16 +2443,25 @@ def Atten(hidden_var, query_var, source_padding, concated_source_vecs, # [sb, tb/sb, sl] * [sb, sl, context_dim] = [sb, tb/sb, context_dim] summed = tf.matmul( tf.cast(probs_reshaped, concated_source_contexts.dtype), - concated_source_contexts) + concated_source_contexts, + ) summed = self.QAct('atten_context', summed) # summed is of shape [tb/sb, sb, context_dim] summed = tf.transpose(summed, [1, 0, 2]) return tf.reshape(summed, [tb, -1]), probs - def AttenSameBatchSize(hidden_var, query_var, source_padding, - concated_source_vecs, concated_source_contexts, - query_vec, attention_state, location_filter_var, - location_var, per_step_source_padding): + def AttenSameBatchSize( + hidden_var, + query_var, + source_padding, + concated_source_vecs, + concated_source_contexts, + query_vec, + attention_state, + location_filter_var, + location_var, + per_step_source_padding, + ): """Computes the attention context vector. Optimized code path for when source and target have the same batch size. @@ -2177,21 +2470,25 @@ def AttenSameBatchSize(hidden_var, query_var, source_padding, p = self.params # attention_state shape [batch, len(p.location_features), slen] # it contains previous and accumulated attention probabilites. - attention_state = py_utils.HasShape(attention_state, - [-1, len(p.location_features), -1]) + attention_state = py_utils.HasShape( + attention_state, [-1, len(p.location_features), -1] + ) location_feats = self._ApplyConv(attention_state, location_filter_var) query_vec_transformed = py_utils.Matmul(query_vec, query_var) query_vec_transformed = self.QAct('atten_matmul', query_vec_transformed) # logits is of shape [sl, sb] logits = _ConditionalCallDefun( - not self._is_quantized, AttenLogitsSameBatchSize, + not self._is_quantized, + AttenLogitsSameBatchSize, py_utils.NestedMap( concated_source_vecs=concated_source_vecs, query_vec_transformed=query_vec_transformed, hidden_var=hidden_var, location_feats=location_feats, - location_var=location_var)) + location_var=location_var, + ), + ) # => [sl, tb] logits.set_shape(source_padding.shape) # Reshape logits to a matrix of shape [tb, sl] and takes the @@ -2201,7 +2498,8 @@ def AttenSameBatchSize(hidden_var, query_var, source_padding, probs = self._PaddedSoftmax(logits, source_padding) summed = tf.matmul( tf.cast(tf.expand_dims(probs, 1), concated_source_contexts.dtype), - concated_source_contexts) + concated_source_contexts, + ) summed = self.QAct('atten_context', summed) return tf.squeeze(summed, 1), probs @@ -2214,7 +2512,8 @@ def EncodeSource(theta, vecs, ctxs): time, batch = py_utils.GetShape(vecs, 2) ctxs = py_utils.HasShape(ctxs, [time, batch, -1]) transformed_vecs = py_utils.Matmul( - tf.reshape(vecs, [-1, p.source_dim]), self.QWeight(theta.source_var)) + tf.reshape(vecs, [-1, p.source_dim]), self.QWeight(theta.source_var) + ) transformed_vecs = tf.reshape(transformed_vecs, [time, batch, -1]) transformed_vecs = self.QAct('encode_matmul', transformed_vecs) transposed_ctxs = tf.transpose(ctxs, [1, 0, 2]) @@ -2230,21 +2529,24 @@ def _CreateLayerVariables(self): shape=[p.source_dim, p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['LocationSensitiveAttention_vars']) + collections=['LocationSensitiveAttention_vars'], + ) self.CreateVariable('source_var', pc) pc = py_utils.WeightParams( shape=[p.query_dim, p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['LocationSensitiveAttention_vars']) + collections=['LocationSensitiveAttention_vars'], + ) self.CreateVariable('query_var', pc) pc = py_utils.WeightParams( shape=[p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['LocationSensitiveAttention_vars']) + collections=['LocationSensitiveAttention_vars'], + ) self.CreateVariable('hidden_var', pc) assert p.location_filter_size % 2 == 1 @@ -2252,21 +2554,24 @@ def _CreateLayerVariables(self): location_filter_shape = [ p.location_filter_size, - len(p.location_features), p.location_num_filters + len(p.location_features), + p.location_num_filters, ] # TODO(yonghui): Don't hard code how params are initialized. location_filter_pc = py_utils.WeightParams( shape=location_filter_shape, init=py_utils.WeightInit.Uniform(0.05), dtype=p.dtype, - collections=['LocationSensitiveAttention_vars']) + collections=['LocationSensitiveAttention_vars'], + ) self.CreateVariable('location_filter_var', location_filter_pc) location_var_shape = [p.location_num_filters, p.hidden_dim] location_pc = py_utils.WeightParams( shape=location_var_shape, init=py_utils.WeightInit.Uniform(0.05), dtype=p.dtype, - collections=['LocationSensitiveAttention_vars']) + collections=['LocationSensitiveAttention_vars'], + ) self.CreateVariable('location_var', location_pc) self.TrackQActs('atten_conv') @@ -2277,7 +2582,8 @@ def _CreateLayerVariables(self): 'encode_matmul', 'logits_mul', 'logits_bias', - domain='fullyconnected') + domain='fullyconnected', + ) def AddGlobalVN(self, theta): theta = super().AddGlobalVN(theta) @@ -2308,7 +2614,8 @@ def _ApplyConv(self, attention_state, location_filter_var): 1, 'SAME', data_format=data_format, - qout_name='atten_conv') + qout_name='atten_conv', + ) if py_utils.use_xla() in ('', 'cpu'): location_feats = tf.transpose(location_feats, [0, 2, 1]) if p.dtype != tf.float32: @@ -2316,17 +2623,20 @@ def _ApplyConv(self, attention_state, location_filter_var): # [sb, hd, sl] return location_feats - def PackSource(self, - theta, - source_vecs, - source_contexts, - source_padding, - source_segment_id=None): + def PackSource( + self, + theta, + source_vecs, + source_contexts, + source_padding, + source_segment_id=None, + ): with tf.name_scope(self.params.name): if source_segment_id is None: source_segment_id = tf.zeros_like(source_padding) concated_source_vecs, concated_source_contexts = self._encode_source( - theta, source_vecs, source_contexts) + theta, source_vecs, source_contexts + ) return py_utils.NestedMap( # [time, batch_size, hidden_dim]. source_vecs=concated_source_vecs, @@ -2338,30 +2648,39 @@ def PackSource(self, # [time, batch_size]. source_padding=source_padding, # [time, batch_size]. - source_segment_id=source_segment_id) + source_segment_id=source_segment_id, + ) def ZeroAttentionState(self, source_length, decoder_batch_size): p = self.params dtype = p.dtype.real_dtype num_features = len(p.location_features) with tf.name_scope(p.name): - state = tf.concat([ - tf.ones([decoder_batch_size, num_features, 1], dtype=dtype), - tf.zeros([decoder_batch_size, num_features, source_length - 1], - dtype=dtype) - ], 2) + state = tf.concat( + [ + tf.ones([decoder_batch_size, num_features, 1], dtype=dtype), + tf.zeros( + [decoder_batch_size, num_features, source_length - 1], + dtype=dtype, + ), + ], + 2, + ) state = self.QRAct( - state, quant_utils.QDistribution.SOFTMAX, domain='softmax') + state, quant_utils.QDistribution.SOFTMAX, domain='softmax' + ) return state - def ComputeContextVectorWithSource(self, - theta, - packed_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + def ComputeContextVectorWithSource( + self, + theta, + packed_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): """Computes the context vector given the current query output. Args: @@ -2396,7 +2715,7 @@ def ComputeContextVectorWithSource(self, - The attention probability vector: [batch_size, time] - The new attention mechanism state: possibly nested tuple of tensors with dimensions [target_batch, ...] - """ + """ # pyformat: disable del query_segment_id p = self.params concated_source_vecs = packed_src.source_vecs @@ -2407,20 +2726,29 @@ def ComputeContextVectorWithSource(self, query_batch_size = py_utils.GetShape(query_vec)[0] source_length = py_utils.GetShape(source_padding)[0] if per_step_source_padding is None: - per_step_source_padding = tf.zeros([query_batch_size, source_length], - dtype=source_padding.dtype) + per_step_source_padding = tf.zeros( + [query_batch_size, source_length], dtype=source_padding.dtype + ) per_step_source_padding = py_utils.HasShape( - per_step_source_padding, [query_batch_size, source_length]) + per_step_source_padding, [query_batch_size, source_length] + ) hidden_var = self.AddVN(theta.hidden_var, per_step=True) query_var = self.AddVN(theta.query_var, per_step=True) location_filter_var = self.AddVN(theta.location_filter_var, per_step=True) location_var = self.AddVN(theta.location_var, per_step=True) - ctx_vec, prob = self._ctx_vec(hidden_var, query_var, source_padding, - concated_source_vecs, - concated_source_contexts, query_vec, - attention_state, location_filter_var, - location_var, per_step_source_padding) + ctx_vec, prob = self._ctx_vec( + hidden_var, + query_var, + source_padding, + concated_source_vecs, + concated_source_contexts, + query_vec, + attention_state, + location_filter_var, + location_var, + per_step_source_padding, + ) new_feats = {'PREV_PROBS': prob} if 'CUMULATIVE_PROBS' in p.location_features: @@ -2428,15 +2756,18 @@ def ComputeContextVectorWithSource(self, cum_prob_index = p.location_features.index('CUMULATIVE_PROBS') cum_probs = tf.add(prob, attention_state[:, cum_prob_index, :]) cum_probs = self.QRAct( - cum_probs, quant_utils.QDistribution.SOFTMAX, domain='softmax') + cum_probs, quant_utils.QDistribution.SOFTMAX, domain='softmax' + ) new_feats['CUMULATIVE_PROBS'] = cum_probs - new_attention_state = tf.stack([new_feats[f] for f in p.location_features], - axis=1) + new_attention_state = tf.stack( + [new_feats[f] for f in p.location_features], axis=1 + ) return ctx_vec, prob, new_attention_state -def MergeSourcePaddingWithPerStepSourcePadding(source_padding, - per_step_source_padding, tb): +def MergeSourcePaddingWithPerStepSourcePadding( + source_padding, per_step_source_padding, tb +): """Merges source padding with per-step source padding. Args: @@ -2458,8 +2789,9 @@ def MergeSourcePaddingWithPerStepSourcePadding(source_padding, # Transpose and reshape source_padding to [1, sb, sl]. source_padding = tf.expand_dims(tf.transpose(source_padding), 0) # Merge source_padding and per_step_source_padding. - source_padding = tf.maximum(source_padding, - tf.reshape(per_step_source_padding, [-1, sb, sl])) + source_padding = tf.maximum( + source_padding, tf.reshape(per_step_source_padding, [-1, sb, sl]) + ) return tf.reshape(source_padding, [tb, -1]) @@ -2515,22 +2847,24 @@ def __init__(self, params): """Constructs an MonotonicAttention object.""" super().__init__(params) p = self.params - assert not p.packed_input, ('Packed input not supported for Monotonic ' - 'Attention.') + assert ( + not p.packed_input + ), 'Packed input not supported for Monotonic Attention.' if p.atten_dropout_prob != 0: raise NotImplementedError('dropout is not supported') # When running eval, don't add pre-sigmoid noise, and use a hard sigmoid to # match behavior of online decoding. if self.do_eval: - p.pre_sigmoid_noise = 0. + p.pre_sigmoid_noise = 0.0 p.hard_sigmoid = True def EncodeSource(theta, vecs, ctxs): time, batch = py_utils.GetShape(vecs, 2) ctxs = py_utils.HasShape(ctxs, [time, batch, -1]) transformed_vecs = py_utils.Matmul( - tf.reshape(vecs, [-1, p.source_dim]), theta.source_var) + tf.reshape(vecs, [-1, p.source_dim]), theta.source_var + ) transformed_vecs = tf.reshape(transformed_vecs, [time, batch, -1]) transposed_ctxs = tf.transpose(ctxs, [1, 0, 2]) return transformed_vecs, transposed_ctxs @@ -2546,7 +2880,8 @@ def _CreateLayerVariables(self): shape=[p.source_dim, p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['MonotonicAttention_vars']) + collections=['MonotonicAttention_vars'], + ) self.CreateVariable('source_var', pc) # query is the weight matrix for the query/decoder RNN state @@ -2554,7 +2889,8 @@ def _CreateLayerVariables(self): shape=[p.query_dim, p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['MonotonicAttention_vars']) + collections=['MonotonicAttention_vars'], + ) self.CreateVariable('query_var', pc) # hidden is the pre-softmax vector which converts from tanh to scalar @@ -2562,7 +2898,8 @@ def _CreateLayerVariables(self): shape=[p.hidden_dim], init=p.params_init, dtype=p.dtype, - collections=['MonotonicAttention_vars']) + collections=['MonotonicAttention_vars'], + ) self.CreateVariable('hidden_var', pc) # energy_bias is the bias vector which appears inside of tanh @@ -2571,7 +2908,8 @@ def _CreateLayerVariables(self): shape=[p.hidden_dim], init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, - collections=['MonotonicAttention_vars']) + collections=['MonotonicAttention_vars'], + ) self.CreateVariable('energy_bias_var', pc) # hidden_scale is the weight normalization scale for hidden @@ -2580,7 +2918,8 @@ def _CreateLayerVariables(self): shape=[], init=py_utils.WeightInit.Constant(1 / np.sqrt(p.hidden_dim)), dtype=p.dtype, - collections=['MonotonicAttention_vars']) + collections=['MonotonicAttention_vars'], + ) self.CreateVariable('hidden_scale_var', pc) # hidden_bias is the bias scalar applied before the sigmoid @@ -2589,7 +2928,8 @@ def _CreateLayerVariables(self): shape=[], init=py_utils.WeightInit.Constant(p.hidden_bias_init), dtype=p.dtype, - collections=['MonotonicAttention_vars']) + collections=['MonotonicAttention_vars'], + ) self.CreateVariable('hidden_bias_var', pc) def AddGlobalVN(self, theta): @@ -2599,17 +2939,20 @@ def AddGlobalVN(self, theta): theta.query_var = self.AddVN(theta.query_var) return theta - def PackSource(self, - theta, - source_vecs, - source_contexts, - source_padding, - source_segment_id=None): + def PackSource( + self, + theta, + source_vecs, + source_contexts, + source_padding, + source_segment_id=None, + ): with tf.name_scope(self.params.name): if source_segment_id is None: source_segment_id = tf.zeros_like(source_padding) concated_source_vecs, concated_source_contexts = self._encode_source( - theta, source_vecs, source_contexts) + theta, source_vecs, source_contexts + ) return py_utils.NestedMap( # [time, batch_size, hidden_dim]. source_vecs=concated_source_vecs, @@ -2621,7 +2964,8 @@ def PackSource(self, # [time, batch_size]. source_padding=source_padding, # [time, batch_size]. - source_segment_id=source_segment_id) + source_segment_id=source_segment_id, + ) def ZeroAttentionState(self, source_length, decoder_batch_size): p = self.params @@ -2631,11 +2975,18 @@ def ZeroAttentionState(self, source_length, decoder_batch_size): emit_probs = tf.one_hot( tf.zeros((decoder_batch_size,), dtype=tf.int32), source_length, - dtype=dtype) + dtype=dtype, + ) return py_utils.NestedMap(emit_probs=emit_probs) - def ComputeProbabilities(self, theta, concated_source_vecs, - merged_source_padding, query_vec, attention_state): + def ComputeProbabilities( + self, + theta, + concated_source_vecs, + merged_source_padding, + query_vec, + attention_state, + ): """Computes probabilities of emissions.""" # concated_source_contexts is of shape [sb, sl, context_dim] @@ -2651,23 +3002,25 @@ def AttenLogits(inputs): Args: inputs: a NestedMap containing: - - concated_source_vecs: [sl, sb, hidden_dims]. - - query_vec: [tb, query_dim]. - - query_v: [query_dim, hidden_dim] - - energy_b: [hidden_dim]. - - hidden_v: [hidden_dim]. - - hidden_g: []. - - hidden_b: []. + * concated_source_vecs: [sl, sb, hidden_dims]. + * query_vec: [tb, query_dim]. + * query_v: [query_dim, hidden_dim] + * energy_b: [hidden_dim]. + * hidden_v: [hidden_dim]. + * hidden_g: []. + * hidden_b: []. Returns: logits shaped [tb, sl]. """ # Apply query matrix to query. Becomes [tb, hidden_dim]. query_vec_transformed = py_utils.Matmul( - inputs.query_vec, inputs.query_v, name='query_transformation') + inputs.query_vec, inputs.query_v, name='query_transformation' + ) # query_vec is reshaped to [1, tb/sb, sb, hidden_dim]. - query_vec_reshaped = tf.reshape(query_vec_transformed, - [1, multiplier, sb, p.hidden_dim]) + query_vec_reshaped = tf.reshape( + query_vec_transformed, [1, multiplier, sb, p.hidden_dim] + ) # [sl, 1, sb, hidden_dim]. concated_source_vecs = tf.expand_dims(inputs.concated_source_vecs, 1) @@ -2684,7 +3037,8 @@ def AttenLogits(inputs): # tf.reshape(v, [1, 1, 1, hidden_dim]), 3) logits = py_utils.Matmul( tf.reshape(summed, [-1, p.hidden_dim]), - tf.reshape(hidden_v, [p.hidden_dim, 1])) + tf.reshape(hidden_v, [p.hidden_dim, 1]), + ) logits += inputs.hidden_b # [tb, sl]. logits = tf.transpose(tf.reshape(logits, [-1, tb]), [1, 0]) @@ -2700,7 +3054,9 @@ def AttenLogits(inputs): energy_b=theta.energy_bias_var, hidden_v=theta.hidden_var, hidden_g=theta.hidden_scale_var, - hidden_b=theta.hidden_bias_var)) + hidden_b=theta.hidden_bias_var, + ), + ) previous_attention = attention_state.emit_probs with tf.name_scope('prob'): @@ -2716,26 +3072,31 @@ def AttenLogits(inputs): activation_noise = tf.random.stateless_normal( py_utils.GetShape(logits), py_utils.GenerateStepSeedPair(p), - dtype=logits.dtype) + dtype=logits.dtype, + ) # Compute sigmoid probabilities. - p_choose_i = tf.nn.sigmoid(logits + self.params.pre_sigmoid_noise * - activation_noise) + p_choose_i = tf.nn.sigmoid( + logits + self.params.pre_sigmoid_noise * activation_noise + ) # Never choose padded values. p_choose_i = py_utils.ApplyPadding(merged_source_padding, p_choose_i) # Compute attention distribution - probs = MonotonicAttentionProb(p_choose_i, previous_attention, - 'parallel') + probs = MonotonicAttentionProb( + p_choose_i, previous_attention, 'parallel' + ) # [tb, sl]. return probs, py_utils.NestedMap(emit_probs=probs) - def ComputeContextVectorWithSource(self, - theta, - packed_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + def ComputeContextVectorWithSource( + self, + theta, + packed_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): """Computes the context vector given the current query output. Args: @@ -2748,6 +3109,7 @@ def ComputeContextVectorWithSource(self, per_step_source_padding: Source sequence padding to apply at this step. If not None, it should be of shape [target_batch_size, source_length]. query_segment_id: a tensor of shape [batch_size]. + Note: concated_source_vecs are the vectors that are used to compute the attention score between the query_vec and each concated_source_vec. The concated_source_contexts are the vectors that compose the result. The @@ -2770,11 +3132,16 @@ def ComputeContextVectorWithSource(self, tb = tf.shape(query_vec)[0] multiplier = tb // sb merged_source_padding = MergeSourcePaddingWithPerStepSourcePadding( - source_padding, per_step_source_padding, tb) + source_padding, per_step_source_padding, tb + ) - probs, new_state = self.ComputeProbabilities(theta, concated_source_vecs, - merged_source_padding, - query_vec, attention_state) + probs, new_state = self.ComputeProbabilities( + theta, + concated_source_vecs, + merged_source_padding, + query_vec, + attention_state, + ) with tf.name_scope('sum'): # Reshape probs to be of shape [tb/sb, sb, sl] @@ -2804,16 +3171,24 @@ def Params(cls): p = super().Params() p.Define('source_dim', 0, 'Number of source nodes.') p.Define('query_dim', 0, 'Number of query nodes.') - p.Define('hidden_dim', 128, - 'Number of hidden units for the MLP that predicts GMM params.') - p.Define('max_offset', -1, - 'Max offset to move attention pointer, Enabled only when > 0.') + p.Define( + 'hidden_dim', + 128, + 'Number of hidden units for the MLP that predicts GMM params.', + ) + p.Define( + 'max_offset', + -1, + 'Max offset to move attention pointer, Enabled only when > 0.', + ) p.Define('num_mixtures', 5, 'Number of location GMM components.') p.Define( - 'normalize_probs', False, + 'normalize_probs', + False, 'Whether to normalize probabilities computed by GMM. Otherwise, ' 'the attention weights (i.e. probabilities) may not add up to ' - '1.0.') + '1.0.', + ) # TODO(ngyuzh): find a good initialize for both TTS and ASR. Consider split # the layer if it's very sensitive to the initialization @@ -2833,7 +3208,8 @@ def __init__(self, params): input_dim=p.query_dim, hidden_layer_dims=[p.hidden_dim, p.num_mixtures * 3], activation=['SIGMOID', 'NONE'], - params_init=p.params_init.Copy()) + params_init=p.params_init.Copy(), + ) self.CreateChild('GMM', ff_params) def ComputeProbs(encoder_positions, priors, means, variances): @@ -2863,16 +3239,26 @@ def ComputeProbs(encoder_positions, priors, means, variances): encoder_positions = tf.expand_dims(encoder_positions, 2) # [multiplier, source_batch, source_length, num_mixtures] - probs = ((priors * tf.math.rsqrt(2 * np.pi * variances + epsilon)) * - tf.exp(-(encoder_positions - means)**2 / - (2 * variances + epsilon))) + probs = ( + priors * tf.math.rsqrt(2 * np.pi * variances + epsilon) + ) * tf.exp( + -((encoder_positions - means) ** 2) / (2 * variances + epsilon) + ) # [multiplier, source_batch, source_length] return tf.reduce_sum(probs, axis=3) - def Atten(source_padding, concated_source_vecs, concated_source_contexts, - query_vec, priors, means, variances, encoder_positions, - per_step_source_padding): + def Atten( + source_padding, + concated_source_vecs, + concated_source_contexts, + query_vec, + priors, + means, + variances, + encoder_positions, + per_step_source_padding, + ): """Computes the attention context vector. Args: @@ -2902,8 +3288,9 @@ def Atten(source_padding, concated_source_vecs, concated_source_contexts, # [multiplier, source_batch, num_mixtures] priors = tf.reshape(priors, [multiplier, source_batch, p.num_mixtures]) means = tf.reshape(means, [multiplier, source_batch, p.num_mixtures]) - variances = tf.reshape(variances, - [multiplier, source_batch, p.num_mixtures]) + variances = tf.reshape( + variances, [multiplier, source_batch, p.num_mixtures] + ) # [multiplier, source_batch, source_length] probs = ComputeProbs(encoder_positions, priors, means, variances) @@ -2912,13 +3299,14 @@ def Atten(source_padding, concated_source_vecs, concated_source_contexts, source_padding = tf.transpose(source_padding) # [multiplier, source_batch, source_length] - per_step_source_padding = tf.reshape(per_step_source_padding, - [multiplier, source_batch, -1]) + per_step_source_padding = tf.reshape( + per_step_source_padding, [multiplier, source_batch, -1] + ) source_padding += per_step_source_padding source_padding = tf.minimum(source_padding, 1.0) # [multiplier, source_batch, source_length] - probs *= (1.0 - source_padding) + probs *= 1.0 - source_padding if p.normalize_probs: probs /= tf.maximum(tf.reduce_sum(probs, axis=2, keepdims=True), 1e-12) @@ -2929,15 +3317,18 @@ def Atten(source_padding, concated_source_vecs, concated_source_contexts, # [source_batch, multiplier, source_length] # @ [source_batch, source_length, context_dim] # -> [source_batch, multiplier, context_dim] - context_vector_transposed = tf.matmul(probs_transposed, - concated_source_contexts) + context_vector_transposed = tf.matmul( + probs_transposed, concated_source_contexts + ) # [multiplier, source_batch, context_dim] context_vector = tf.transpose(context_vector_transposed, [1, 0, 2]) # [target_batch, context_dim], [target_batch, source_length] - return (tf.reshape(context_vector, [target_batch, -1]), - tf.reshape(probs, [target_batch, -1])) + return ( + tf.reshape(context_vector, [target_batch, -1]), + tf.reshape(probs, [target_batch, -1]), + ) self._ctx_vec = Atten @@ -2950,17 +3341,20 @@ def EncodeSource(vecs, ctxs): self._encode_source = EncodeSource - def PackSource(self, - theta, - source_vecs, - source_contexts, - source_padding, - source_segment_id=None): + def PackSource( + self, + theta, + source_vecs, + source_contexts, + source_padding, + source_segment_id=None, + ): with tf.name_scope(self.params.name): if source_segment_id is None: source_segment_id = tf.zeros_like(source_padding) concated_source_vecs, concated_source_contexts = self._encode_source( - source_vecs, source_contexts) + source_vecs, source_contexts + ) return py_utils.NestedMap( # [source_length, source_batch, hidden_dim]. source_vecs=concated_source_vecs, @@ -2972,28 +3366,32 @@ def PackSource(self, # [source_length, source_batch]. source_padding=source_padding, # [source_length, source_batch]. - source_segment_id=source_segment_id) + source_segment_id=source_segment_id, + ) def ZeroAttentionState(self, source_length, decoder_batch_size): p = self.params # [target_batch, num_mixtures] position = tf.zeros([decoder_batch_size, p.num_mixtures], dtype=p.dtype) - position_offsets = tf.zeros([decoder_batch_size, p.num_mixtures], - dtype=p.dtype) + position_offsets = tf.zeros( + [decoder_batch_size, p.num_mixtures], dtype=p.dtype + ) variances = tf.ones([decoder_batch_size, p.num_mixtures], dtype=p.dtype) priors = tf.zeros([decoder_batch_size, p.num_mixtures], dtype=p.dtype) # [target_batch, num_mixtures, 4] return tf.stack([position, position_offsets, variances, priors], axis=2) - def ComputeContextVectorWithSource(self, - theta, - packed_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + def ComputeContextVectorWithSource( + self, + theta, + packed_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): """Computes the context vector given the current query output. Args: @@ -3011,6 +3409,7 @@ def ComputeContextVectorWithSource(self, per_step_source_padding: Source sequence padding to apply at this step. If not None, it should be of shape [target_batch, source_length]. query_segment_id: a tensor of shape [target_batch]. + Note: concated_source_vecs are the vectors that are used to compute the attention score between the query_vec and each concated_source_vec. The concated_source_contexts are the vectors that compose the result. The @@ -3024,7 +3423,7 @@ def ComputeContextVectorWithSource(self, - The attention context vector: [target_batch, context_dim] - The attention probability vector: [target_batch, source_length] - The new attention state vector: [target_batch, num_mixtures, 4] - """ + """ # pyformat: disable del query_segment_id p = self.params concated_source_vecs = packed_src.source_vecs @@ -3037,17 +3436,20 @@ def ComputeContextVectorWithSource(self, # [target_batch, source_length] if per_step_source_padding is None: - per_step_source_padding = tf.zeros([target_batch, source_length], - dtype=source_padding.dtype) - per_step_source_padding = py_utils.HasShape(per_step_source_padding, - [target_batch, source_length]) + per_step_source_padding = tf.zeros( + [target_batch, source_length], dtype=source_padding.dtype + ) + per_step_source_padding = py_utils.HasShape( + per_step_source_padding, [target_batch, source_length] + ) # [target_batch, num_mixtures * 3] out = self.GMM.FProp(theta.GMM, query_vec) # [target_batch, num_mixtures] priors_logits, position_offset_logits, log_variances = tf.split( - out, 3, axis=1, name='GMM') + out, 3, axis=1, name='GMM' + ) log_variances = tf.minimum(log_variances, layers.LOG_SCALE_CLAMP_BOUND) variances = tf.exp(log_variances) @@ -3065,18 +3467,27 @@ def ComputeContextVectorWithSource(self, # Tile and reshape encoder_positions to [source_batch, source_length] # so that it can be evaluated by locations GMMs in a vectorized way. encoder_positions = tf.expand_dims( - tf.cast(tf.range(source_length), tf.float32), 0) + tf.cast(tf.range(source_length), tf.float32), 0 + ) encoder_positions = tf.tile(encoder_positions, [source_batch, 1]) # [target_batch, context_dim], [target_batch, source_length] - ctx_vec, prob = self._ctx_vec(source_padding, concated_source_vecs, - concated_source_contexts, query_vec, priors, - new_position, variances, encoder_positions, - per_step_source_padding) + ctx_vec, prob = self._ctx_vec( + source_padding, + concated_source_vecs, + concated_source_contexts, + query_vec, + priors, + new_position, + variances, + encoder_positions, + per_step_source_padding, + ) # [target_batch, num_mixtures, 4] new_atten_states = tf.stack( - [new_position, position_offset, variances, priors], axis=2) + [new_position, position_offset, variances, priors], axis=2 + ) return ctx_vec, prob, new_atten_states @@ -3105,29 +3516,45 @@ def Params(cls): p.Define('source_dim', 0, 'Number of source nodes.') p.Define('query_dim', 0, 'Number of query nodes.') p.Define('hidden_dim', 0, 'Number of hidden nodes.') - p.Define('attention_tpl', AdditiveAttention.Params(), - 'Attention used by the merger layer when merger_op is atten.') p.Define( - 'pre_proj_input_dims', None, + 'attention_tpl', + AdditiveAttention.Params(), + 'Attention used by the merger layer when merger_op is atten.', + ) + p.Define( + 'pre_proj_input_dims', + None, 'If set, should be a list of depths for the tensors to be merged.' ' Setting this will result in a pre-projection to source_dim' - ' before the merger.') + ' before the merger.', + ) p.Define( - 'pre_proj_output_dims', None, + 'pre_proj_output_dims', + None, 'Should be a list of depths which the input tensors specified in ' 'pre_proj_input_dims need to be projected to. Should match the length ' - 'of pre_proj_input_dims.') + 'of pre_proj_input_dims.', + ) p.Define( 'proj_tpl', layers.ProjectionLayer.Params().Set( - batch_norm=False, weight_norm=True, has_bias=True), - 'Configs template for the projection layer.') - p.Define('gated_avg_tpl', layers.GatedAverageLayer.Params(), - 'Configs template for the gated average layer.') - p.Define('num_sources', 0, 'If merger_op=weighted_sum, then must specify ' - 'num of sources.') - p.Define('post_proj', None, - 'Post projection for the merged context vector.') + batch_norm=False, weight_norm=True, has_bias=True + ), + 'Configs template for the projection layer.', + ) + p.Define( + 'gated_avg_tpl', + layers.GatedAverageLayer.Params(), + 'Configs template for the gated average layer.', + ) + p.Define( + 'num_sources', + 0, + 'If merger_op=weighted_sum, then must specify num of sources.', + ) + p.Define( + 'post_proj', None, 'Post projection for the merged context vector.' + ) return p # Merging operation keys supported by this layer. @@ -3149,8 +3576,9 @@ def __init__(self, params): atten_params.dtype = p.dtype if atten_params.params_init is None: atten_params.params_init = py_utils.WeightInit.Gaussian( - 1. / math.sqrt(atten_params.source_dim + atten_params.query_dim), - seed=p.random_seed) + 1.0 / math.sqrt(atten_params.source_dim + atten_params.query_dim), + seed=p.random_seed, + ) self.CreateChild('atten', atten_params) if p.pre_proj_input_dims: @@ -3159,11 +3587,13 @@ def __init__(self, params): if len(p.pre_proj_input_dims) != len(p.pre_proj_output_dims): raise ValueError( 'Output dims should be the same length as input dims. ' - 'Expected: %s obtained: %s' % - (len(p.pre_proj_input_dims), len(p.pre_proj_output_dims))) + 'Expected: %s obtained: %s' + % (len(p.pre_proj_input_dims), len(p.pre_proj_output_dims)) + ) pre_proj_params = [] for i, (pre_proj_input_dim, pre_proj_output_dim) in enumerate( - zip(p.pre_proj_input_dims, p.pre_proj_output_dims)): + zip(p.pre_proj_input_dims, p.pre_proj_output_dims) + ): proj_p = p.proj_tpl.Copy() proj_p.name = 'merger_pre_proj_%d' % i proj_p.input_dim = pre_proj_input_dim @@ -3172,8 +3602,9 @@ def __init__(self, params): self.CreateChildren('pre_proj', pre_proj_params) if p.merger_op == 'gated_avg': - assert p.num_sources > 0, ('For merger_op=gated_avg, must specify ' - 'num_sources > 0.') + assert ( + p.num_sources > 0 + ), 'For merger_op=gated_avg, must specify num_sources > 0.' params = p.gated_avg_tpl.Copy() params.name = 'g_avg_merger' params.num_nodes = p.source_dim @@ -3188,22 +3619,25 @@ def _CreateLayerVariables(self): p = self.params if p.merger_op == 'weighted_sum': - assert p.num_sources > 0, ('For merger_op=weighted_sum, must specify ' - 'num_sources > 0.') + assert ( + p.num_sources > 0 + ), 'For merger_op=weighted_sum, must specify num_sources > 0.' params_init = py_utils.WeightInit.Constant(1.0 / p.num_sources) # Weights to be learned. pw = py_utils.WeightParams( shape=[p.num_sources], init=params_init, dtype=p.dtype, - collections=[self.__class__.__name__ + '_vars']) + collections=[self.__class__.__name__ + '_vars'], + ) self.CreateVariable('sum_weight', pw) def _child_variable_scope_override(self): return { - **super()._child_variable_scope_override(), 'atten': [], + **super()._child_variable_scope_override(), + 'atten': [], 'gated_average': [], - 'pre_proj': [] + 'pre_proj': [], } def FProp(self, theta, inputs, query_vec=None): @@ -3264,13 +3698,12 @@ def FProp(self, theta, inputs, query_vec=None): for t1, t2 in tensor_pairs ]): w = tf.expand_dims( - tf.expand_dims(tf.expand_dims(theta.sum_weight, 1), 1), 1) + tf.expand_dims(tf.expand_dims(theta.sum_weight, 1), 1), 1 + ) w = tf.tile( w, - [1, - tf.shape(inputs)[1], - tf.shape(inputs)[2], - tf.shape(inputs)[3]]) + [1, tf.shape(inputs)[1], tf.shape(inputs)[2], tf.shape(inputs)[3]], + ) output = tf.reduce_sum(inputs * w, axis=0) elif p.merger_op == 'atten': @@ -3284,13 +3717,14 @@ def FProp(self, theta, inputs, query_vec=None): paddings = tf.zeros([n_sources, batch_size], dtype=inputs.dtype) self.atten.InitForSourcePacked(theta.atten, inputs, inputs, paddings) output, _, _ = self.atten.ComputeContextVector( - theta.atten, tf.reshape(query_vec, [-1, p.query_dim])) + theta.atten, tf.reshape(query_vec, [-1, p.query_dim]) + ) elif p.merger_op == 'concat': # Concatenate over the last dim, all dims but last must match. with tf.control_dependencies([ - py_utils.assert_equal(tf.shape(t1)[:-1], - tf.shape(t2)[:-1]) for t1, t2 in tensor_pairs + py_utils.assert_equal(tf.shape(t1)[:-1], tf.shape(t2)[:-1]) + for t1, t2 in tensor_pairs ]): output = tf.concat(inputs, axis=-1) @@ -3318,21 +3752,31 @@ class MultiSourceAttention(BaseAttentionLayer): @classmethod def Params(cls): p = super().Params() - p.Define('source_atten_tpls', None, - 'A list of (source_key, attention_param) ' - 'pairs.') + p.Define( + 'source_atten_tpls', + None, + 'A list of (source_key, attention_param) pairs.', + ) p.Define('source_dim', 0, 'Default source dimension.') p.Define( - 'query_dim', 0, 'Number of query nodes. Child attention params ' - 'must have query_dim less or equal than 0 or equal to this value.') + 'query_dim', + 0, + 'Number of query nodes. Child attention params ' + 'must have query_dim less or equal than 0 or equal to this value.', + ) p.Define( - 'primary_source_key', 'source_0', 'Key for the primary source ' - 'whose attention probabilities will be used as an output.') + 'primary_source_key', + 'source_0', + 'Key for the primary source ' + 'whose attention probabilities will be used as an output.', + ) p.Define( 'atten_merger_tpl', MergerLayer.Params().Set( - params_init=py_utils.WeightInit.Uniform(0.04), merger_op='sum'), - 'Params to specify how to merge source attention vectors.') + params_init=py_utils.WeightInit.Uniform(0.04), merger_op='sum' + ), + 'Params to specify how to merge source attention vectors.', + ) return p def __init__(self, params): @@ -3356,21 +3800,27 @@ def __init__(self, params): merger_p.query_dim = p.query_dim self.CreateChild('atten_merger', merger_p) - def PackSource(self, - theta, - source_vecs, - source_contexts, - source_padding, - source_segment_id=None): + def PackSource( + self, + theta, + source_vecs, + source_contexts, + source_padding, + source_segment_id=None, + ): p = self.params with tf.name_scope(self.params.name): packed_src = py_utils.NestedMap() for source_key, _ in p.source_atten_tpls: - packed_src[source_key] = ( - self.children['atten_%s' % source_key].InitForSourcePacked( - theta['atten_%s' % source_key], source_vecs[source_key], - source_contexts[source_key], source_padding[source_key], - source_segment_id[source_key] if source_segment_id else None)) + packed_src[source_key] = self.children[ + 'atten_%s' % source_key + ].InitForSourcePacked( + theta['atten_%s' % source_key], + source_vecs[source_key], + source_contexts[source_key], + source_padding[source_key], + source_segment_id[source_key] if source_segment_id else None, + ) return packed_src def ZeroAttentionState(self, source_seq_length, decoder_batch_size): @@ -3378,37 +3828,42 @@ def ZeroAttentionState(self, source_seq_length, decoder_batch_size): with tf.name_scope(self.params.name): return py_utils.NestedMap({ source_key: getattr(self, 'atten_%s' % source_key).ZeroAttentionState( - source_seq_length[source_key], decoder_batch_size) + source_seq_length[source_key], decoder_batch_size + ) for source_key, _ in p.source_atten_tpls }) - def ComputeContextVectorWithSource(self, - theta, - packed_src, - query_vec, - attention_state=None, - per_step_source_padding=None, - query_segment_id=None): + def ComputeContextVectorWithSource( + self, + theta, + packed_src, + query_vec, + attention_state=None, + per_step_source_padding=None, + query_segment_id=None, + ): p = self.params assert per_step_source_padding is None with tf.name_scope(self.params.name): result_map = py_utils.NestedMap() for source_key, _ in p.source_atten_tpls: - result_map[source_key] = ( - self.children['atten_%s' % - source_key].ComputeContextVectorWithSource( - theta.get('atten_%s' % source_key), - packed_src[source_key], query_vec, - attention_state[source_key] - if attention_state else None, - per_step_source_padding, query_segment_id)) + result_map[source_key] = self.children[ + 'atten_%s' % source_key + ].ComputeContextVectorWithSource( + theta.get('atten_%s' % source_key), + packed_src[source_key], + query_vec, + attention_state[source_key] if attention_state else None, + per_step_source_padding, + query_segment_id, + ) return self._CombineContext(theta, result_map, query_vec) def _CombineContext(self, theta, context_map, query_vec): ctxs = context_map.Flatten() - combined_context = self.atten_merger.FProp(theta.atten_merger, - [ctx for ctx, _, _ in ctxs], - query_vec) + combined_context = self.atten_merger.FProp( + theta.atten_merger, [ctx for ctx, _, _ in ctxs], query_vec + ) return ( combined_context, # Return atten_probs of the primary source. @@ -3417,4 +3872,5 @@ def _CombineContext(self, theta, context_map, query_vec): py_utils.NestedMap({ src_key: context_map[src_key][2] for src_key, _ in self.params.source_atten_tpls - })) + }), + )