1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.Einsum.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.client import session 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import gen_linalg_ops 31from tensorflow.python.ops import gradient_checker_v2 32from tensorflow.python.ops import special_math_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import benchmark 35from tensorflow.python.platform import test 36 37 38class EinsumOpTest(test.TestCase): 39 40 def _check(self, s, *input_shapes, **kwargs): 41 dtype = kwargs.pop('dtype', np.float32) 42 r = np.random.RandomState(0) 43 inputs = [] 44 for shape in input_shapes: 45 arr = np.array(r.randn(*shape)).astype(dtype) 46 if dtype == np.complex64 or dtype == np.complex128: 47 arr += 1j * np.array(r.randn(*shape)).astype(dtype) 48 inputs.append(arr) 49 input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs] 50 a = np.einsum(s, *inputs) 51 b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s)) 52 self.assertAllClose(a, b, atol=1e-4, rtol=1e-4) 53 54 def testUnary(self): 55 self._check('->', ()) 56 self._check('ab->', (3, 3)) 57 self._check('ab->ab', (3, 3)) 58 self._check('abc->b', (3, 4, 5)) 59 self._check('abc->ca', (3, 4, 5)) 60 self._check('abc->cab', (3, 4, 5)) 61 62 def testUnaryWithRepeatedLabels(self): 63 self._check('aa->', (3, 3)) 64 self._check('aa->a', (3, 3)) 65 self._check('aaa->', (3, 3, 3)) 66 self._check('aaa->a', (3, 3, 3)) 67 self._check('aab->a', (3, 3, 4)) 68 self._check('aabcc->a', (3, 3, 5, 4, 4)) 69 self._check('aabcc->ac', (3, 3, 5, 4, 4)) 70 self._check('aabcd->ad', (3, 3, 5, 4, 4)) 71 72 def testUnaryEllipsis(self): 73 # Unary cases with ellipsis. 74 # Edge cases. 75 self._check('...->...', ()) 76 self._check('...->', ()) 77 self._check('->...', ()) 78 79 # Tests from dask 80 self._check('a...a->a...', (2, 2)) 81 self._check('a...a->', (2, 2)) 82 self._check('a...a->...', (2, 5, 1, 2)) 83 self._check('a...a->a...', (2, 1, 2)) 84 self._check('a...a->a...', (2, 3, 4, 5, 2)) 85 86 # Regular cases. 87 self._check('...ijk->...ki', (3, 4, 5)) 88 self._check('...ijk->...ki', (1, 3, 4, 5)) 89 self._check('...ijk->...ki', (2, 2, 3, 4, 5)) 90 91 # Repeated indices. 92 self._check('i...ii->...i', (3, 2, 3, 3)) 93 94 def testBinarySimple(self): 95 # Binary cases in XLA mode must have either (a) each index appearing exactly 96 # once in both the inputs (batch or contraction index), or (b) appearing 97 # exactly once in an input and in the output (free index). 98 self._check(',->', (), ()) 99 self._check('a,a->', (3,), (3,)) 100 self._check('a,a->a', (3,), (3,)) 101 self._check('ab,b->a', (3, 4), (4,)) 102 self._check('ab,ab->', (3, 4), (3, 4)) 103 self._check('ab,bc->ac', (3, 4), (4, 5)) 104 self._check('nij,jk->nik', (5, 2, 3), (3, 4)) 105 self._check('abc,bad->abcd', (1, 2, 3), (2, 1, 4)) 106 # Based on https://github.com/google/jax/issues/37#issuecomment-448572187 107 self._check('sa,shb->shab', (2, 1), (2, 3, 4)) 108 109 def testReducedIndices(self): 110 self._check('ba,b->', (3, 2), (3,)) 111 self._check('ab,ab->', (3, 4), (3, 4)) 112 self._check('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3)) 113 114 def testRepeatedIndices(self): 115 # Repeated indices. 116 self._check('ijj,k->ik', (2, 3, 3), (4,)) 117 self._check('aba,a->b', (3, 4, 3), (3,)) 118 # From https://github.com/dask/dask/pull/3412#discussion_r182413444 119 self._check('aab,bc->ac', (2, 2, 3), (3, 4)) 120 self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4)) 121 122 def testEllipsis(self): 123 # Batch matmul with ellipsis but without broadcasting. 124 self._check('...mk,...kn->...mn', (5, 1, 2, 3), (5, 1, 3, 4)) 125 # Empty batch dimensions. 126 self._check('...mk,...kn->...mn', (2, 3), (3, 4)) 127 # Tensor contraction with transpose. 128 self._check('...ija,aijb...->ba...ij', (1, 2, 2, 3, 1), (1, 2, 3, 4, 1, 2)) 129 # Output subscripts may omit ellipsis when batch shape is empty. 130 self._check('...mk,...kn->mn', (2, 3), (3, 4)) 131 self._check('...mk,kn->mn', (2, 3), (3, 4)) 132 self._check('mk,...kn->mn', (2, 3), (3, 4)) 133 134 def testBroadcasting(self): 135 # Batch matmul with broadcasting. 136 self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5)) 137 self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5)) 138 self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5)) 139 self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5)) 140 self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5)) 141 self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5)) 142 # Following 2 from https://stackoverflow.com/a/19203475/1611416 143 self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6)) 144 self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,)) 145 self._check('i...j,j...k->i...k', (3, 1, 2, 2), (2, 2, 3, 1, 4)) 146 147 def testBroadcastingWithRepeatedIndices(self): 148 # Broadcasting with repeated indices. 149 self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4)) 150 self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4)) 151 self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4)) 152 self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4)) 153 154 def testDtypes(self): 155 bfloat16 = dtypes.bfloat16.as_numpy_dtype 156 157 def check(dtype): 158 r = np.random.RandomState(0) 159 equation = 'ij,jk->ik' 160 input_shapes = [(2, 2), (2, 2)] 161 inputs = [] 162 for shape in input_shapes: 163 arr = np.array(r.randn(*shape)).astype(dtype) 164 if dtype == np.complex64 or dtype == np.complex128: 165 arr += 1j * np.array(r.randn(*shape)).astype(dtype) 166 inputs.append(arr) 167 input_tensors = [constant_op.constant(x) for x in inputs] 168 if dtype == bfloat16: 169 # np.einsum doesn't support bfloat16. 170 a = np.einsum(equation, 171 *[x.astype(np.float32) for x in inputs]).astype(dtype) 172 else: 173 a = np.einsum(equation, *inputs) 174 175 b = self.evaluate(gen_linalg_ops.einsum(input_tensors, equation)) 176 tol = 1e-2 if dtype == bfloat16 else 1e-4 177 self.assertAllClose(a, b, atol=tol, rtol=tol) 178 179 for dtype in [ 180 bfloat16, np.float32, np.float64, np.complex64, np.complex128, np.int32, 181 np.int64 182 ]: 183 check(dtype) 184 185 @test_util.disable_xla('b/131919749') 186 @test_util.run_in_graph_and_eager_modes 187 def testInvalid(self): 188 r = np.random.RandomState(0) 189 cases = [ 190 # incorrect rank. 191 ('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)), 192 ('...ij,jk->ik', r.randn(3), r.randn(3, 4)), 193 # inconsistent dimensions. 194 ('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)), 195 # broadcasting is invalid 196 ('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)), 197 # output should have ellipsis when broadcasting shape is 198 # non-empty. 199 ('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)), 200 ] 201 for args in cases: 202 with self.assertRaises((ValueError, errors.InvalidArgumentError)): 203 _ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0])) 204 205 placeholders = [ 206 array_ops.placeholder_with_default(x, shape=None) for x in args[1:] 207 ] 208 with self.assertRaises((ValueError, errors.InvalidArgumentError)): 209 _ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0])) 210 211 @test_util.run_in_graph_and_eager_modes 212 def testPlaceholder(self): 213 214 def check(equation, *input_and_placeholder_shapes): 215 r = np.random.RandomState(0) 216 inputs = [] 217 input_placeholders = [] 218 for actual_shape, placeholder_shape in input_and_placeholder_shapes: 219 input_np = np.array(r.randn(*actual_shape)) 220 inputs.append(input_np) 221 input_placeholders.append( 222 array_ops.placeholder_with_default(input_np, placeholder_shape)) 223 224 a = np.einsum(equation, *inputs) 225 b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation)) 226 self.assertAllClose(a, b, atol=1e-4, rtol=1e-4) 227 228 check('bijl,bjkm->bik', ((9, 2, 3, 5), (None, None, None, 5)), 229 ((9, 3, 4, 7), (None, None, 4, None))) 230 check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None)) 231 check('...ij,...->...i', ((4, 3, 1, 2), (None, 3, None, 2)), 232 ((4, 3), (None, 3))) 233 check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None)) 234 235 @test_util.disable_xla('b/131919749') 236 def testOutputRepeatedLabels(self): 237 # This is the reverse operation of generalized traces, to be used for 238 # computing symbolic gradients of einsum. Note: this operation is not 239 # supported by np.einsum as it's only required for gradients. 240 r = np.random.RandomState(0) 241 a = r.randn(2, 2) 242 s = 'a->aa' 243 diag_a = np.diag(np.diag(a)) 244 b = self.evaluate(gen_linalg_ops.einsum([np.diag(a)], s)) 245 self.assertAllClose(diag_a, b, atol=1e-4, rtol=1e-4) 246 247 def testEmpty(self): 248 def check(equation, input_shapes, output_shape): 249 # All these cases result in an output filled with zeros, so we don't call 250 # np.einsum. Also np.einsum doesn't support generalized diagonals which 251 # are needed for EinsumOp gradients. 252 r = np.random.RandomState(0) 253 inputs = [np.array(r.randn(*shape)) for shape in input_shapes] 254 output = self.evaluate(gen_linalg_ops.einsum(inputs, equation)) 255 self.assertAllClose(output, np.zeros(output_shape), atol=1e-4, rtol=1e-4) 256 257 # Contractions along zero-sized dimensons. 258 check('ab,bc->ac', [(0, 10), (10, 10)], (0, 10)) 259 # From transformer xl. 260 check('ibnd,ijbn->jnd', [(1, 0, 5, 10), (1, 1, 0, 5)], (1, 5, 10)) 261 262 @test_util.disable_xla('b/131919749') 263 def testEmptyWithRepeatedLabels(self): 264 265 def check(equation, input_shapes, output_shape): 266 # All these cases result in an output filled with zeros, so we don't call 267 # np.einsum. Also np.einsum doesn't support generalized diagonals which 268 # are needed for EinsumOp gradients. 269 r = np.random.RandomState(0) 270 inputs = [np.array(r.randn(*shape)) for shape in input_shapes] 271 output = self.evaluate(gen_linalg_ops.einsum(inputs, equation)) 272 self.assertAllClose(output, np.zeros(output_shape), atol=1e-4, rtol=1e-4) 273 274 # Generalized traces with zero-sized dimensions. 275 check('aab,bc->ac', [(0, 0, 10), (10, 10)], (0, 10)) 276 check('aaab,bc->c', [(0, 0, 0, 3), (3, 4)], (4,)) 277 # Generalized diagonals along with contraction. 278 check('ab,bc->aaca', [(0, 10), (10, 5)], (0, 0, 5, 0)) 279 check('ab,bc->aaa', [(0, 10), (10, 5)], (0, 0, 0)) 280 check('ab,bc->cc', [(0, 10), (10, 5)], (5, 5)) 281 check('ab,ab->aaa', [(0, 5), (0, 5)], (0, 0, 0)) 282 283 284@test_util.run_all_in_graph_and_eager_modes 285class EinsumGradTest(test.TestCase): 286 287 def _check_gradient(self, s, *input_shapes): 288 with self.cached_session(): 289 r = np.random.RandomState(0) 290 inputs = [np.array(r.randn(*shape), np.float64) for shape in input_shapes] 291 input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs] 292 analytical, numerical = gradient_checker_v2.compute_gradient( 293 lambda *xs: gen_linalg_ops.einsum(xs, s), input_tensors) 294 self.assertLess( 295 gradient_checker_v2.max_error(analytical, numerical), 1e-4) 296 297 @test_util.disable_xla('b/131919749') 298 def testUnary(self): 299 # Unary cases. 300 self._check_gradient('->', ()) 301 self._check_gradient('aaa->a', (3, 3, 3)) 302 self._check_gradient('aabcd->ad', (3, 3, 5, 4, 4)) 303 self._check_gradient('aabcd->add', (3, 3, 5, 4, 4)) 304 self._check_gradient('abcd->da', (3, 5, 4, 2)) 305 306 @test_util.disable_xla('b/131919749') 307 def testUnaryEllipsis(self): 308 self._check_gradient('...->...', ()) 309 self._check_gradient('...->', ()) 310 self._check_gradient('->...', ()) 311 312 # Tests from dask 313 self._check_gradient('a...a->a...', (2, 2)) 314 self._check_gradient('a...a->', (2, 2)) 315 self._check_gradient('a...a->...', (2, 5, 1, 2)) 316 self._check_gradient('a...a->a...', (2, 1, 2)) 317 self._check_gradient('a...a->a...', (2, 3, 4, 5, 2)) 318 319 self._check_gradient('...ijk->...ki', (3, 4, 5)) 320 self._check_gradient('...ijk->...ki', (1, 3, 4, 5)) 321 self._check_gradient('...ijk->...ki', (2, 2, 3, 4, 5)) 322 self._check_gradient('ab...cd->da...', (3, 5, 2, 3, 4, 2)) 323 324 def testBinarySimple(self): 325 # Binary cases in XLA mode must have either (a) each index appearing exactly 326 # once in both the inputs (batch or contraction index), or (b) appearing 327 # exactly once in an input and in the output (free index). 328 self._check_gradient(',->', (), ()) 329 self._check_gradient('a,a->', (3,), (3,)) 330 self._check_gradient('a,a->a', (3,), (3,)) 331 self._check_gradient('ab,b->a', (3, 4), (4,)) 332 self._check_gradient('ab,ab->', (3, 4), (3, 4)) 333 self._check_gradient('ab,bc->ac', (3, 4), (4, 5)) 334 self._check_gradient('nij,jk->nik', (5, 2, 3), (3, 4)) 335 self._check_gradient('abc,bad->abcd', (1, 2, 3), (2, 1, 4)) 336 # Based on https://github.com/google/jax/issues/37#issuecomment-448572187 337 self._check_gradient('sa,shb->shab', (2, 1), (2, 3, 4)) 338 339 def testEmpty(self): 340 # From Transformer XL. 341 self._check_gradient('ibnd,ijbn->jnd', (1, 0, 5, 10), (1, 1, 0, 5)) 342 343 def testReducedIndices(self): 344 self._check_gradient('ba,b->', (3, 2), (3,)) 345 self._check_gradient('ab,ab->', (3, 4), (3, 4)) 346 self._check_gradient('ijkm,ijln->ijmn', (2, 3, 3, 4), (2, 3, 3, 2)) 347 self._check_gradient('abce,badf->abcd', (1, 2, 3, 4), (2, 1, 4, 3)) 348 349 @test_util.disable_xla('b/131919749') 350 def testReducedIndicesWithRepeatedLabels(self): 351 self._check_gradient('abce,badf->bcba', (1, 2, 3, 4), (2, 1, 4, 3)) 352 353 @test_util.disable_xla('b/131919749') 354 def testRepeatedLabels(self): 355 # Repeated indices. 356 self._check_gradient('aba,a->b', (3, 4, 3), (3,)) 357 self._check_gradient('ijj,k->ik', (2, 3, 3), (4,)) 358 self._check_gradient('ill,k->ik', (2, 3, 3), (4,)) 359 # From https://github.com/dask/dask/pull/3412#discussion_r182413444 360 self._check_gradient('aab,bc->ac', (1, 1, 3), (3, 4)) 361 self._check_gradient('aab,bcc->ac', (2, 2, 3), (3, 4, 4)) 362 363 @test_util.disable_xla('b/131919749') 364 def testEmptyWithRepeatedLabels(self): 365 self._check_gradient('aab,bc->ac', (0, 0, 10), (10, 10)) 366 self._check_gradient('aab,bc->ac', (1, 1, 0), (0, 10)) 367 self._check_gradient('aaab,bc->c', (0, 0, 0, 3), (3, 4)) 368 369 def testBroadcasting(self): 370 self._check_gradient('...ij,...jk->...ik', (3, 2), (2, 4)) 371 self._check_gradient('ij...,jk...->ik...', (3, 2, 1), (2, 4)) 372 self._check_gradient('...ij,...jk->...ik', (3, 1, 3, 2), (1, 5, 2, 4)) 373 self._check_gradient('i...j,j...k->i...k', (3, 1, 2, 2), (2, 2, 3, 1, 4)) 374 375 @test_util.disable_xla('b/131919749') 376 def testBroadcastingWithRepeatedLabels(self): 377 self._check_gradient('ij,jk...k->i...', (3, 2), (2, 4, 1, 4)) 378 self._check_gradient('aab,b...c->a...c', (1, 1, 3), (3, 1, 1, 4)) 379 380 381class EinsumBenchmark(test.Benchmark): 382 cases = [ 383 # Unary cases. 384 ['ijk->i', 100], 385 ['ijk->kji', 100], 386 # Regular matmul or batch matmul. 387 ['ij,jk->ik', 1000], 388 ['ji,kj->ik', 1000], 389 ['ab,ab->', 100], 390 ['ab,ba->', 100], 391 ['abc,abc->', 100], 392 ['abc,bac->', 100], 393 ['abc,cba->', 100], 394 ['bij,bjk->bik', 100], 395 ['bji,bjk->bki', 100], 396 ['ikl,kji->kl', 100], 397 ['klj,lki->ij', 100], 398 ['ijk,ilj->kli', 100], 399 ['kij,mkb->ijmb', 100], 400 ['abcd,ad->bc', 40], 401 # Larger binary contractions. 402 ['ijk,jklm->il', 40], 403 ['efabc,eabcd->efd', 30], 404 ['fabec,abcde->fde', 30], 405 ['efabc,edabc->efd', 30], 406 ['eadbf,dfebc->ecfad', 30], 407 ['abcdef,bcdfg->abcdeg', 30], 408 ] 409 410 def benchmarkEinsum(self): 411 for equation, dim in self.cases: 412 with ops.Graph().as_default(), \ 413 session.Session(config=benchmark.benchmark_config()) as sess, \ 414 ops.device('/cpu:0'): 415 r = np.random.RandomState(0) 416 input_subscripts = equation.split('->')[0].split(',') 417 input_vars = [] 418 for subscript in input_subscripts: 419 input_shape = (dim,) * len(subscript) 420 input_vars.append( 421 variables.Variable(np.array(r.randn(*input_shape), np.float32))) 422 variables.global_variables_initializer().run() 423 424 # Call einsum_v1. 425 self.run_op_benchmark( 426 sess, 427 special_math_ops.einsum(equation, *input_vars), 428 min_iters=50, 429 name='einsum_v1_cpu_({})_{}'.format(equation, dim)) 430 431 # Call gen_linalg_ops.einsum. 432 self.run_op_benchmark( 433 sess, 434 gen_linalg_ops.einsum(input_vars, equation), 435 min_iters=50, 436 name='einsum_v2_cpu_({})_{}'.format(equation, dim)) 437 438 439if __name__ == '__main__': 440 test.main() 441