1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras-based einsum dense layer.""" 16# pylint: disable=g-classes-have-attributes 17 18import re 19 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.keras import activations 22from tensorflow.python.keras import constraints 23from tensorflow.python.keras import initializers 24from tensorflow.python.keras import regularizers 25from tensorflow.python.keras.engine.base_layer import Layer 26from tensorflow.python.ops import special_math_ops 27from tensorflow.python.util.tf_export import keras_export 28 29 30@keras_export("keras.layers.experimental.EinsumDense") 31class EinsumDense(Layer): 32 """A layer that uses tf.einsum as the backing computation. 33 34 This layer can perform einsum calculations of arbitrary dimensionality. 35 36 Args: 37 equation: An equation describing the einsum to perform. This equation must 38 be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or 39 `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis 40 expression sequence. 41 output_shape: The expected shape of the output tensor (excluding the batch 42 dimension and any dimensions represented by ellipses). You can specify 43 None for any dimension that is unknown or can be inferred from the input 44 shape. 45 activation: Activation function to use. If you don't specify anything, no 46 activation is applied (that is, a "linear" activation: `a(x) = x`). 47 bias_axes: A string containing the output dimension(s) to apply a bias to. 48 Each character in the `bias_axes` string should correspond to a character 49 in the output portion of the `equation` string. 50 kernel_initializer: Initializer for the `kernel` weights matrix. 51 bias_initializer: Initializer for the bias vector. 52 kernel_regularizer: Regularizer function applied to the `kernel` weights 53 matrix. 54 bias_regularizer: Regularizer function applied to the bias vector. 55 activity_regularizer: Regularizer function applied to the output of the 56 layer (its "activation").. 57 kernel_constraint: Constraint function applied to the `kernel` weights 58 matrix. 59 bias_constraint: Constraint function applied to the bias vector. 60 61 Examples: 62 63 **Biased dense layer with einsums** 64 65 This example shows how to instantiate a standard Keras dense layer using 66 einsum operations. This example is equivalent to 67 `tf.keras.layers.Dense(64, use_bias=True)`. 68 69 >>> layer = EinsumDense("ab,bc->ac", output_shape=64, bias_axes="c") 70 >>> input_tensor = tf.keras.Input(shape=[32]) 71 >>> output_tensor = layer(input_tensor) 72 >>> output_tensor 73 <... shape=(None, 64) dtype=...> 74 75 **Applying a dense layer to a sequence** 76 77 This example shows how to instantiate a layer that applies the same dense 78 operation to every element in a sequence. Here, the 'output_shape' has two 79 values (since there are two non-batch dimensions in the output); the first 80 dimension in the output_shape is `None`, because the sequence dimension `b` 81 has an unknown shape. 82 83 >>> layer = EinsumDense("abc,cd->abd", 84 ... output_shape=(None, 64), 85 ... bias_axes="d") 86 >>> input_tensor = tf.keras.Input(shape=[32, 128]) 87 >>> output_tensor = layer(input_tensor) 88 >>> output_tensor 89 <... shape=(None, 32, 64) dtype=...> 90 91 **Applying a dense layer to a sequence using ellipses** 92 93 This example shows how to instantiate a layer that applies the same dense 94 operation to every element in a sequence, but uses the ellipsis notation 95 instead of specifying the batch and sequence dimensions. 96 97 Because we are using ellipsis notation and have specified only one axis, the 98 output_shape arg is a single value. When instantiated in this way, the layer 99 can handle any number of sequence dimensions - including the case where no 100 sequence dimension exists. 101 102 >>> layer = EinsumDense("...x,xy->...y", output_shape=64, bias_axes="y") 103 >>> input_tensor = tf.keras.Input(shape=[32, 128]) 104 >>> output_tensor = layer(input_tensor) 105 >>> output_tensor 106 <... shape=(None, 32, 64) dtype=...> 107 """ 108 109 def __init__(self, 110 equation, 111 output_shape, 112 activation=None, 113 bias_axes=None, 114 kernel_initializer="glorot_uniform", 115 bias_initializer="zeros", 116 kernel_regularizer=None, 117 bias_regularizer=None, 118 activity_regularizer=None, 119 kernel_constraint=None, 120 bias_constraint=None, 121 **kwargs): 122 super(EinsumDense, self).__init__(**kwargs) 123 self.equation = equation 124 if isinstance(output_shape, int): 125 self.partial_output_shape = [output_shape] 126 else: 127 self.partial_output_shape = list(output_shape) 128 self.bias_axes = bias_axes 129 self.activation = activations.get(activation) 130 self.kernel_initializer = initializers.get(kernel_initializer) 131 self.bias_initializer = initializers.get(bias_initializer) 132 self.kernel_regularizer = regularizers.get(kernel_regularizer) 133 self.bias_regularizer = regularizers.get(bias_regularizer) 134 self.kernel_constraint = constraints.get(kernel_constraint) 135 self.bias_constraint = constraints.get(bias_constraint) 136 137 def build(self, input_shape): 138 input_shape = tensor_shape.TensorShape(input_shape) 139 shape_data = _analyze_einsum_string(self.equation, 140 self.bias_axes, 141 input_shape, 142 self.partial_output_shape) 143 kernel_shape, bias_shape, self.full_output_shape = shape_data 144 self.kernel = self.add_weight( 145 "kernel", 146 shape=kernel_shape, 147 initializer=self.kernel_initializer, 148 regularizer=self.kernel_regularizer, 149 constraint=self.kernel_constraint, 150 dtype=self.dtype, 151 trainable=True) 152 153 if bias_shape is not None: 154 self.bias = self.add_weight( 155 "bias", 156 shape=bias_shape, 157 initializer=self.bias_initializer, 158 regularizer=self.bias_regularizer, 159 constraint=self.bias_constraint, 160 dtype=self.dtype, 161 trainable=True) 162 else: 163 self.bias = None 164 super(EinsumDense, self).build(input_shape) 165 166 def compute_output_shape(self, _): 167 return tensor_shape.TensorShape(self.full_output_shape) 168 169 def get_config(self): 170 config = { 171 "output_shape": 172 self.partial_output_shape, 173 "equation": 174 self.equation, 175 "activation": 176 activations.serialize(self.activation), 177 "bias_axes": 178 self.bias_axes, 179 "kernel_initializer": 180 initializers.serialize(self.kernel_initializer), 181 "bias_initializer": 182 initializers.serialize(self.bias_initializer), 183 "kernel_regularizer": 184 regularizers.serialize(self.kernel_regularizer), 185 "bias_regularizer": 186 regularizers.serialize(self.bias_regularizer), 187 "activity_regularizer": 188 regularizers.serialize(self.activity_regularizer), 189 "kernel_constraint": 190 constraints.serialize(self.kernel_constraint), 191 "bias_constraint": 192 constraints.serialize(self.bias_constraint), 193 } 194 base_config = super(EinsumDense, self).get_config() 195 return dict(list(base_config.items()) + list(config.items())) 196 197 def call(self, inputs): 198 ret = special_math_ops.einsum(self.equation, inputs, self.kernel) 199 if self.bias is not None: 200 ret += self.bias 201 if self.activation is not None: 202 ret = self.activation(ret) 203 return ret 204 205 206def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape): 207 """Analyzes an einsum string to determine the required weight shape.""" 208 209 dot_replaced_string = re.sub(r"\.\.\.", "0", equation) 210 211 # This is the case where no ellipses are present in the string. 212 split_string = re.match("([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", 213 dot_replaced_string) 214 if split_string: 215 return _analyze_split_string(split_string, bias_axes, input_shape, 216 output_shape) 217 218 # This is the case where ellipses are present on the left. 219 split_string = re.match("0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", 220 dot_replaced_string) 221 if split_string: 222 return _analyze_split_string( 223 split_string, bias_axes, input_shape, output_shape, left_elided=True) 224 225 # This is the case where ellipses are present on the right. 226 split_string = re.match("([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", 227 dot_replaced_string) 228 if split_string: 229 return _analyze_split_string(split_string, bias_axes, input_shape, 230 output_shape) 231 232 raise ValueError( 233 "Invalid einsum equation '%s'. Equations must be in the form " 234 "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." % equation) 235 236 237def _analyze_split_string(split_string, 238 bias_axes, 239 input_shape, 240 output_shape, 241 left_elided=False): 242 """Analyze an pre-split einsum string to find the weight shape.""" 243 input_spec = split_string.group(1) 244 weight_spec = split_string.group(2) 245 output_spec = split_string.group(3) 246 elided = len(input_shape) - len(input_spec) 247 248 if isinstance(output_shape, int): 249 output_shape = [output_shape] 250 else: 251 output_shape = list(output_shape) 252 253 output_shape.insert(0, input_shape[0]) 254 255 if elided > 0 and left_elided: 256 for i in range(1, elided): 257 # We already inserted the 0th input dimension at dim 0, so we need to 258 # start at location 1 here. 259 output_shape.insert(1, input_shape[i]) 260 elif elided > 0 and not left_elided: 261 for i in range(len(input_shape) - elided, len(input_shape)): 262 output_shape.append(input_shape[i]) 263 264 if left_elided: 265 # If we have beginning dimensions elided, we need to use negative indexing 266 # to determine where in the input dimension our values are. 267 input_dim_map = { 268 dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec) 269 } 270 # Because we've constructed the full output shape already, we don't need 271 # to do negative indexing. 272 output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)} 273 else: 274 input_dim_map = {dim: i for i, dim in enumerate(input_spec)} 275 output_dim_map = {dim: i for i, dim in enumerate(output_spec)} 276 277 for i, dim in enumerate(input_spec): 278 input_shape_at_dim = input_shape[i] 279 if dim in output_dim_map: 280 output_shape_at_dim = output_shape[output_dim_map[dim]] 281 if (output_shape_at_dim is not None and 282 output_shape_at_dim != input_shape_at_dim): 283 raise ValueError( 284 "Input shape and output shape do not match at shared " 285 "dimension '%s'. Input shape is %s, and output shape " 286 "is %s." % 287 (dim, input_shape_at_dim, output_shape[output_dim_map[dim]])) 288 289 for dim in output_spec: 290 if dim not in input_spec and dim not in weight_spec: 291 raise ValueError("Dimension '%s' was specified in the output '%s' but " 292 "has no corresponding dim in the input spec '%s' or " 293 "weight spec '%s.'" % (dim, output_spec, input_spec, 294 output_spec)) 295 296 weight_shape = [] 297 for dim in weight_spec: 298 if dim in input_dim_map: 299 weight_shape.append(input_shape[input_dim_map[dim]]) 300 elif dim in output_dim_map: 301 weight_shape.append(output_shape[output_dim_map[dim]]) 302 else: 303 raise ValueError("Weight dimension '%s' did not have a match in either " 304 "the input spec '%s' or the output spec '%s'. For this " 305 "layer, the weight must be fully specified." % 306 (dim, input_spec, output_spec)) 307 308 if bias_axes is not None: 309 num_left_elided = elided if left_elided else 0 310 idx_map = { 311 char: output_shape[i + num_left_elided] 312 for i, char in enumerate(output_spec) 313 } 314 315 for char in bias_axes: 316 if char not in output_spec: 317 raise ValueError("Bias dimension '%s' was requested, but is not a part " 318 "of the output specification '%s'" % 319 (char, output_spec)) 320 321 first_bias_location = min([output_spec.find(char) for char in bias_axes]) 322 bias_output_spec = output_spec[first_bias_location:] 323 324 bias_shape = [ 325 idx_map[char] if char in bias_axes else 1 for char in bias_output_spec 326 ] 327 328 if not left_elided: 329 for _ in range(elided): 330 bias_shape.append(1) 331 else: 332 bias_shape = None 333 334 return weight_shape, bias_shape, output_shape 335