1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Arithmetic Operations that don't fit into math_ops due to dependencies. 16 17To avoid circular dependencies, some math_ops should go here. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import collections 25import functools 26import re 27import string 28 29import numpy as np 30import opt_einsum 31import six 32 33from six.moves import xrange # pylint: disable=redefined-builtin 34 35from tensorflow.compiler.tf2xla.ops import gen_xla_ops 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_ops 40from tensorflow.python.ops import gen_linalg_ops 41from tensorflow.python.ops import gen_special_math_ops 42from tensorflow.python.ops import math_ops 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.util import deprecation 45from tensorflow.python.util import dispatch 46from tensorflow.python.util.tf_export import tf_export 47 48 49# TODO(b/27419586) Change docstring for required dtype of x once int allowed 50@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta']) 51@dispatch.add_dispatch_support 52@deprecation.deprecated_endpoints('lbeta') 53def lbeta(x, name=None): 54 r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. 55 56 Given one-dimensional $z = [z_1,...,z_K]$, we define 57 58 $$Beta(z) = \frac{\prod_j \Gamma(z_j)}{\Gamma(\sum_j z_j)},$$ 59 60 where $\Gamma$ is the gamma function. 61 62 And for $n + 1$ dimensional $x$ with shape $[N_1, ..., N_n, K]$, we define 63 64 $$lbeta(x)[i_1, ..., i_n] = \log{|Beta(x[i_1, ..., i_n, :])|}.$$ 65 66 In other words, the last dimension is treated as the $z$ vector. 67 68 Note that if $z = [u, v]$, then 69 70 $$Beta(z) = \frac{\Gamma(u)\Gamma(v)}{\Gamma(u + v)} 71 = \int_0^1 t^{u-1} (1 - t)^{v-1} \mathrm{d}t,$$ 72 73 which defines the traditional bivariate beta function. 74 75 If the last dimension is empty, we follow the convention that the sum over 76 the empty set is zero, and the product is one. 77 78 Args: 79 x: A rank `n + 1` `Tensor`, `n >= 0` with type `float`, or `double`. 80 name: A name for the operation (optional). 81 82 Returns: 83 The logarithm of \\(|Beta(x)|\\) reducing along the last dimension. 84 """ 85 # In the event that the last dimension has zero entries, we return -inf. 86 # This is consistent with a convention that the sum over the empty set 0, and 87 # the product is 1. 88 # This is standard. See https://en.wikipedia.org/wiki/Empty_set. 89 with ops.name_scope(name, 'lbeta', [x]): 90 x = ops.convert_to_tensor(x, name='x') 91 92 # Note reduce_sum([]) = 0. 93 log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1]) 94 95 # Note lgamma(0) = infinity, so if x = [] 96 # log_gamma_sum_x = lgamma(0) = infinity, and 97 # log_prod_gamma_x = lgamma(1) = 0, 98 # so result = -infinity 99 sum_x = math_ops.reduce_sum(x, axis=[-1]) 100 log_gamma_sum_x = math_ops.lgamma(sum_x) 101 result = log_prod_gamma_x - log_gamma_sum_x 102 103 return result 104 105 106@tf_export('math.special.dawsn') 107@dispatch.add_dispatch_support 108def dawsn(x, name=None): 109 """Computes Dawson's integral of `x` element-wise. 110 111 Dawson's integral is defined as `exp(-x**2)` times the integral of 112 `exp(t**2)` from `0` to `x`, with the domain of definition all real numbers. 113 114 Dawson's function is odd. 115 >>> tf.math.special.dawsn([-1., -0.5, 0.5, 1.]).numpy() 116 array([-0.5380795, -0.4244364, 0.4244364, 0.5380795], dtype=float32) 117 118 This implementation is based off of the Cephes math library. 119 120 Args: 121 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 122 `float32`, `float64`. 123 name: A name for the operation (optional). 124 125 Returns: 126 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 127 128 @compatibility(scipy) 129 Equivalent to scipy.special.dawsn 130 @end_compatibility 131 """ 132 with ops.name_scope(name, 'dawsn', [x]): 133 return gen_special_math_ops.dawsn(x) 134 135 136@tf_export('math.special.expint') 137@dispatch.add_dispatch_support 138def expint(x, name=None): 139 """Computes the Exponential integral of `x` element-wise. 140 141 The Exponential integral is defined as the integral of `exp(t) / t` from 142 `-inf` to `x`, with the domain of definition all positive real numbers. 143 144 >>> tf.math.special.expint([1., 1.1, 2.1, 4.1]).numpy() 145 array([ 1.8951179, 2.1673784, 5.3332353, 21.048464], dtype=float32) 146 147 This implementation is based off of the Cephes math library. 148 149 Args: 150 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 151 `float32`, `float64`. 152 name: A name for the operation (optional). 153 154 Returns: 155 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 156 157 @compatibility(scipy) 158 Equivalent to scipy.special.expi 159 @end_compatibility 160 """ 161 with ops.name_scope(name, 'expint', [x]): 162 return gen_special_math_ops.expint(x) 163 164 165@tf_export('math.special.fresnel_cos') 166@dispatch.add_dispatch_support 167def fresnel_cos(x, name=None): 168 """Computes Fresnel's cosine integral of `x` element-wise. 169 170 The Fresnel cosine integral is defined as the integral of `cos(t^2)` from 171 `0` to `x`, with the domain of definition all real numbers. 172 173 The Fresnel cosine integral is odd. 174 >>> tf.math.special.fresnel_cos([-1., -0.1, 0.1, 1.]).numpy() 175 array([-0.7798934 , -0.09999753, 0.09999753, 0.7798934 ], dtype=float32) 176 177 This implementation is based off of the Cephes math library. 178 179 Args: 180 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 181 `float32`, `float64`. 182 name: A name for the operation (optional). 183 184 Returns: 185 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 186 187 @compatibility(scipy) 188 Equivalent to scipy.special.fresnel second output. 189 @end_compatibility 190 """ 191 with ops.name_scope(name, 'fresnel_cos', [x]): 192 return gen_special_math_ops.fresnel_cos(x) 193 194 195@tf_export('math.special.fresnel_sin') 196@dispatch.add_dispatch_support 197def fresnel_sin(x, name=None): 198 """Computes Fresnel's sine integral of `x` element-wise. 199 200 The Fresnel sine integral is defined as the integral of `sin(t^2)` from 201 `0` to `x`, with the domain of definition all real numbers. 202 203 >>> tf.math.special.fresnel_sin([-1., -0.1, 0.1, 1.]).numpy() 204 array([-0.43825912, -0.00052359, 0.00052359, 0.43825912], dtype=float32) 205 206 This implementation is based off of the Cephes math library. 207 208 Args: 209 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 210 `float32`, `float64`. 211 name: A name for the operation (optional). 212 213 Returns: 214 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 215 216 @compatibility(scipy) 217 Equivalent to scipy.special.fresnel first output. 218 @end_compatibility 219 """ 220 with ops.name_scope(name, 'fresnel_sin', [x]): 221 return gen_special_math_ops.fresnel_sin(x) 222 223 224@tf_export('math.special.spence') 225@dispatch.add_dispatch_support 226def spence(x, name=None): 227 """Computes Spence's integral of `x` element-wise. 228 229 Spence's integral is defined as the integral of `log(t) / (1 - t)` from 230 `1` to `x`, with the domain of definition all non-negative real numbers. 231 232 >>> tf.math.special.spence([0.5, 1., 2., 3.]).numpy() 233 array([ 0.58224034, 0. , -0.82246685, -1.4367464], dtype=float32) 234 235 This implementation is based off of the Cephes math library. 236 237 Args: 238 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 239 `float32`, `float64`. 240 name: A name for the operation (optional). 241 242 Returns: 243 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 244 245 @compatibility(scipy) 246 Equivalent to scipy.special.spence 247 @end_compatibility 248 """ 249 with ops.name_scope(name, 'spence', [x]): 250 return gen_special_math_ops.spence(x) 251 252 253@tf_export('math.bessel_i0', 'math.special.bessel_i0') 254@dispatch.add_dispatch_support 255def bessel_i0(x, name=None): 256 """Computes the Bessel i0 function of `x` element-wise. 257 258 Modified Bessel function of order 0. 259 260 It is preferable to use the numerically stabler function `i0e(x)` instead. 261 262 >>> tf.math.special.bessel_i0([-1., -0.5, 0.5, 1.]).numpy() 263 array([1.26606588, 1.06348337, 1.06348337, 1.26606588], dtype=float32) 264 265 Args: 266 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 267 `float32`, `float64`. 268 name: A name for the operation (optional). 269 270 Returns: 271 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 272 273 @compatibility(scipy) 274 Equivalent to scipy.special.i0 275 @end_compatibility 276 """ 277 with ops.name_scope(name, 'bessel_i0', [x]): 278 return gen_special_math_ops.bessel_i0(x) 279 280 281@tf_export('math.bessel_i0e', 'math.special.bessel_i0e') 282@dispatch.add_dispatch_support 283def bessel_i0e(x, name=None): 284 """Computes the Bessel i0e function of `x` element-wise. 285 286 Modified Bessel function of order 0. 287 288 >>> tf.math.special.bessel_i0e([-1., -0.5, 0.5, 1.]).numpy() 289 array([0.46575961, 0.64503527, 0.64503527, 0.46575961], dtype=float32) 290 291 Args: 292 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 293 `float32`, `float64`. 294 name: A name for the operation (optional). 295 296 Returns: 297 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 298 299 @compatibility(scipy) 300 Equivalent to scipy.special.i0e 301 @end_compatibility 302 """ 303 with ops.name_scope(name, 'bessel_i0e', [x]): 304 return gen_special_math_ops.bessel_i0e(x) 305 306 307@tf_export('math.bessel_i1', 'math.special.bessel_i1') 308@dispatch.add_dispatch_support 309def bessel_i1(x, name=None): 310 """Computes the Bessel i1 function of `x` element-wise. 311 312 Modified Bessel function of order 1. 313 314 It is preferable to use the numerically stabler function `i1e(x)` instead. 315 316 >>> tf.math.special.bessel_i1([-1., -0.5, 0.5, 1.]).numpy() 317 array([-0.5651591 , -0.25789431, 0.25789431, 0.5651591 ], dtype=float32) 318 319 Args: 320 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 321 `float32`, `float64`. 322 name: A name for the operation (optional). 323 324 Returns: 325 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 326 327 @compatibility(scipy) 328 Equivalent to scipy.special.i1 329 @end_compatibility 330 """ 331 with ops.name_scope(name, 'bessel_i1', [x]): 332 return gen_special_math_ops.bessel_i1(x) 333 334 335@tf_export('math.bessel_i1e', 'math.special.bessel_i1e') 336@dispatch.add_dispatch_support 337def bessel_i1e(x, name=None): 338 """Computes the Bessel i1e function of `x` element-wise. 339 340 Modified Bessel function of order 1. 341 342 >>> tf.math.special.bessel_i1e([-1., -0.5, 0.5, 1.]).numpy() 343 array([-0.20791042, -0.15642083, 0.15642083, 0.20791042], dtype=float32) 344 345 Args: 346 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 347 `float32`, `float64`. 348 name: A name for the operation (optional). 349 350 Returns: 351 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 352 353 @compatibility(scipy) 354 Equivalent to scipy.special.i1e 355 @end_compatibility 356 """ 357 with ops.name_scope(name, 'bessel_i1e', [x]): 358 return gen_special_math_ops.bessel_i1e(x) 359 360 361@tf_export('math.special.bessel_k0') 362@dispatch.add_dispatch_support 363def bessel_k0(x, name=None): 364 """Computes the Bessel k0 function of `x` element-wise. 365 366 Modified Bessel function of order 0. 367 368 It is preferable to use the numerically stabler function `k0e(x)` instead. 369 370 >>> tf.math.special.bessel_k0([0.5, 1., 2., 4.]).numpy() 371 array([0.92441907, 0.42102444, 0.11389387, 0.01115968], dtype=float32) 372 373 Args: 374 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 375 `float32`, `float64`. 376 name: A name for the operation (optional). 377 378 Returns: 379 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 380 381 @compatibility(scipy) 382 Equivalent to scipy.special.k0 383 @end_compatibility 384 """ 385 with ops.name_scope(name, 'bessel_k0', [x]): 386 return gen_special_math_ops.bessel_k0(x) 387 388 389@tf_export('math.special.bessel_k0e') 390@dispatch.add_dispatch_support 391def bessel_k0e(x, name=None): 392 """Computes the Bessel k0e function of `x` element-wise. 393 394 Modified Bessel function of order 0. 395 396 >>> tf.math.special.bessel_k0e([0.5, 1., 2., 4.]).numpy() 397 array([1.52410939, 1.14446308, 0.84156822, 0.60929767], dtype=float32) 398 399 Args: 400 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 401 `float32`, `float64`. 402 name: A name for the operation (optional). 403 404 Returns: 405 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 406 407 @compatibility(scipy) 408 Equivalent to scipy.special.k0e 409 @end_compatibility 410 """ 411 with ops.name_scope(name, 'bessel_k0e', [x]): 412 return gen_special_math_ops.bessel_k0e(x) 413 414 415@tf_export('math.special.bessel_k1') 416@dispatch.add_dispatch_support 417def bessel_k1(x, name=None): 418 """Computes the Bessel k1 function of `x` element-wise. 419 420 Modified Bessel function of order 1. 421 422 It is preferable to use the numerically stabler function `k1e(x)` instead. 423 424 >>> tf.math.special.bessel_k1([0.5, 1., 2., 4.]).numpy() 425 array([1.65644112, 0.60190723, 0.13986588, 0.0124835 ], dtype=float32) 426 427 Args: 428 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 429 `float32`, `float64`. 430 name: A name for the operation (optional). 431 432 Returns: 433 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 434 435 @compatibility(scipy) 436 Equivalent to scipy.special.k1 437 @end_compatibility 438 """ 439 with ops.name_scope(name, 'bessel_k1', [x]): 440 return gen_special_math_ops.bessel_k1(x) 441 442 443@tf_export('math.special.bessel_k1e') 444@dispatch.add_dispatch_support 445def bessel_k1e(x, name=None): 446 """Computes the Bessel k1e function of `x` element-wise. 447 448 Modified Bessel function of order 1. 449 450 >>> tf.math.special.bessel_k1e([0.5, 1., 2., 4.]).numpy() 451 array([2.73100971, 1.63615349, 1.03347685, 0.68157595], dtype=float32) 452 453 Args: 454 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 455 `float32`, `float64`. 456 name: A name for the operation (optional). 457 458 Returns: 459 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 460 461 @compatibility(scipy) 462 Equivalent to scipy.special.k1e 463 @end_compatibility 464 """ 465 with ops.name_scope(name, 'bessel_k1e', [x]): 466 return gen_special_math_ops.bessel_k1e(x) 467 468 469@tf_export('math.special.bessel_j0') 470@dispatch.add_dispatch_support 471def bessel_j0(x, name=None): 472 """Computes the Bessel j0 function of `x` element-wise. 473 474 Modified Bessel function of order 0. 475 476 >>> tf.math.special.bessel_j0([0.5, 1., 2., 4.]).numpy() 477 array([ 0.93846981, 0.76519769, 0.22389078, -0.39714981], dtype=float32) 478 479 Args: 480 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 481 `float32`, `float64`. 482 name: A name for the operation (optional). 483 484 Returns: 485 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 486 487 @compatibility(scipy) 488 Equivalent to scipy.special.j0 489 @end_compatibility 490 """ 491 with ops.name_scope(name, 'bessel_j0', [x]): 492 return gen_special_math_ops.bessel_j0(x) 493 494 495@tf_export('math.special.bessel_j1') 496@dispatch.add_dispatch_support 497def bessel_j1(x, name=None): 498 """Computes the Bessel j1 function of `x` element-wise. 499 500 Modified Bessel function of order 1. 501 502 >>> tf.math.special.bessel_j1([0.5, 1., 2., 4.]).numpy() 503 array([ 0.24226846, 0.44005059, 0.57672481, -0.06604333], dtype=float32) 504 505 Args: 506 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 507 `float32`, `float64`. 508 name: A name for the operation (optional). 509 510 Returns: 511 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 512 513 @compatibility(scipy) 514 Equivalent to scipy.special.j1 515 @end_compatibility 516 """ 517 with ops.name_scope(name, 'bessel_j1', [x]): 518 return gen_special_math_ops.bessel_j1(x) 519 520 521@tf_export('math.special.bessel_y0') 522@dispatch.add_dispatch_support 523def bessel_y0(x, name=None): 524 """Computes the Bessel y0 function of `x` element-wise. 525 526 Modified Bessel function of order 0. 527 528 >>> tf.math.special.bessel_y0([0.5, 1., 2., 4.]).numpy() 529 array([-0.44451873, 0.08825696, 0.51037567, -0.01694074], dtype=float32) 530 531 Args: 532 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 533 `float32`, `float64`. 534 name: A name for the operation (optional). 535 536 Returns: 537 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 538 539 @compatibility(scipy) 540 Equivalent to scipy.special.y0 541 @end_compatibility 542 """ 543 with ops.name_scope(name, 'bessel_y0', [x]): 544 return gen_special_math_ops.bessel_y0(x) 545 546 547@tf_export('math.special.bessel_y1') 548@dispatch.add_dispatch_support 549def bessel_y1(x, name=None): 550 """Computes the Bessel y1 function of `x` element-wise. 551 552 Modified Bessel function of order 1. 553 554 >>> tf.math.special.bessel_y1([0.5, 1., 2., 4.]).numpy() 555 array([-1.47147239, -0.78121282, -0.10703243, 0.39792571], dtype=float32) 556 557 Args: 558 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 559 `float32`, `float64`. 560 name: A name for the operation (optional). 561 562 Returns: 563 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 564 565 @compatibility(scipy) 566 Equivalent to scipy.special.y1 567 @end_compatibility 568 """ 569 with ops.name_scope(name, 'bessel_y1', [x]): 570 return gen_special_math_ops.bessel_y1(x) 571 572 573@ops.RegisterGradient('XlaEinsum') 574def _einsum_grad(op, grad): 575 equation = op.get_attr('equation') 576 if isinstance(equation, bytes): 577 equation = equation.decode() 578 579 inputs, output = equation.split('->') 580 left, right = inputs.split(',') 581 582 return [ 583 gen_xla_ops.xla_einsum( 584 grad, 585 op.inputs[1], 586 equation='{},{}->{}'.format(output, right, left), 587 name=None), 588 gen_xla_ops.xla_einsum( 589 grad, 590 op.inputs[0], 591 equation='{},{}->{}'.format(output, left, right), 592 name=None) 593 ] 594 595 596def _enclosing_tpu_context(): 597 # pylint: disable=protected-access 598 context = ops.get_default_graph()._get_control_flow_context() 599 # pylint: enable=protected-access 600 while context is not None and not isinstance( 601 context, control_flow_ops.XLAControlFlowContext): 602 context = context.outer_context 603 return context 604 605 606@tf_export('einsum', 'linalg.einsum') 607@dispatch.add_dispatch_support 608def einsum(equation, *inputs, **kwargs): 609 r"""Tensor contraction over specified indices and outer product. 610 611 Einsum allows defining Tensors by defining their element-wise computation. 612 This computation is defined by `equation`, a shorthand form based on Einstein 613 summation. As an example, consider multiplying two matrices A and B to form a 614 matrix C. The elements of C are given by: 615 616 $$ C_{i,k} = \sum_j A_{i,j} B_{j,k} $$ 617 618 or 619 620 ``` 621 C[i,k] = sum_j A[i,j] * B[j,k] 622 ``` 623 624 The corresponding einsum `equation` is: 625 626 ``` 627 ij,jk->ik 628 ``` 629 630 In general, to convert the element-wise equation into the `equation` string, 631 use the following procedure (intermediate strings for matrix multiplication 632 example provided in parentheses): 633 634 1. remove variable names, brackets, and commas, (`ik = sum_j ij * jk`) 635 2. replace "*" with ",", (`ik = sum_j ij , jk`) 636 3. drop summation signs, and (`ik = ij, jk`) 637 4. move the output to the right, while replacing "=" with "->". (`ij,jk->ik`) 638 639 Note: If the output indices are not specified repeated indices are summed. 640 So `ij,jk->ik` can be simplified to `ij,jk`. 641 642 Many common operations can be expressed in this way. For example: 643 644 **Matrix multiplication** 645 646 >>> m0 = tf.random.normal(shape=[2, 3]) 647 >>> m1 = tf.random.normal(shape=[3, 5]) 648 >>> e = tf.einsum('ij,jk->ik', m0, m1) 649 >>> # output[i,k] = sum_j m0[i,j] * m1[j, k] 650 >>> print(e.shape) 651 (2, 5) 652 653 Repeated indices are summed if the output indices are not specified. 654 655 >>> e = tf.einsum('ij,jk', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] 656 >>> print(e.shape) 657 (2, 5) 658 659 660 **Dot product** 661 662 >>> u = tf.random.normal(shape=[5]) 663 >>> v = tf.random.normal(shape=[5]) 664 >>> e = tf.einsum('i,i->', u, v) # output = sum_i u[i]*v[i] 665 >>> print(e.shape) 666 () 667 668 **Outer product** 669 670 >>> u = tf.random.normal(shape=[3]) 671 >>> v = tf.random.normal(shape=[5]) 672 >>> e = tf.einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] 673 >>> print(e.shape) 674 (3, 5) 675 676 **Transpose** 677 678 >>> m = tf.ones(2,3) 679 >>> e = tf.einsum('ij->ji', m0) # output[j,i] = m0[i,j] 680 >>> print(e.shape) 681 (3, 2) 682 683 **Diag** 684 685 >>> m = tf.reshape(tf.range(9), [3,3]) 686 >>> diag = tf.einsum('ii->i', m) 687 >>> print(diag.shape) 688 (3,) 689 690 **Trace** 691 692 >>> # Repeated indices are summed. 693 >>> trace = tf.einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] 694 >>> assert trace == sum(diag) 695 >>> print(trace.shape) 696 () 697 698 **Batch matrix multiplication** 699 700 >>> s = tf.random.normal(shape=[7,5,3]) 701 >>> t = tf.random.normal(shape=[7,3,2]) 702 >>> e = tf.einsum('bij,bjk->bik', s, t) 703 >>> # output[a,i,k] = sum_j s[a,i,j] * t[a, j, k] 704 >>> print(e.shape) 705 (7, 5, 2) 706 707 This method does not support broadcasting on named-axes. All axes with 708 matching labels should have the same length. If you have length-1 axes, 709 use `tf.squeseze` or `tf.reshape` to eliminate them. 710 711 To write code that is agnostic to the number of indices in the input 712 use an ellipsis. The ellipsis is a placeholder for "whatever other indices 713 fit here". 714 715 For example, to perform a NumPy-style broadcasting-batch-matrix multiplication 716 where the matrix multiply acts on the last two axes of the input, use: 717 718 >>> s = tf.random.normal(shape=[11, 7, 5, 3]) 719 >>> t = tf.random.normal(shape=[11, 7, 3, 2]) 720 >>> e = tf.einsum('...ij,...jk->...ik', s, t) 721 >>> print(e.shape) 722 (11, 7, 5, 2) 723 724 Einsum **will** broadcast over axes covered by the ellipsis. 725 726 >>> s = tf.random.normal(shape=[11, 1, 5, 3]) 727 >>> t = tf.random.normal(shape=[1, 7, 3, 2]) 728 >>> e = tf.einsum('...ij,...jk->...ik', s, t) 729 >>> print(e.shape) 730 (11, 7, 5, 2) 731 732 Args: 733 equation: a `str` describing the contraction, in the same format as 734 `numpy.einsum`. 735 *inputs: the inputs to contract (each one a `Tensor`), whose shapes should 736 be consistent with `equation`. 737 **kwargs: 738 - optimize: Optimization strategy to use to find contraction path using 739 opt_einsum. Must be 'greedy', 'optimal', 'branch-2', 'branch-all' or 740 'auto'. (optional, default: 'greedy'). 741 - name: A name for the operation (optional). 742 743 Returns: 744 The contracted `Tensor`, with shape determined by `equation`. 745 746 Raises: 747 ValueError: If 748 - the format of `equation` is incorrect, 749 - number of inputs or their shapes are inconsistent with `equation`. 750 """ 751 return _einsum_v2(equation, *inputs, **kwargs) 752 753 754def _einsum_v1(equation, *inputs, **kwargs): 755 """Legacy implementation of einsum without using EinsumOp.""" 756 name = kwargs.pop('name', None) 757 if kwargs: 758 raise TypeError('invalid keyword arguments for this function: ' + ', '.join( 759 [format(key) for key in sorted(list(kwargs.keys()))])) 760 with ops.name_scope(name, 'einsum', [equation, inputs]) as name: 761 inputs = list(inputs) 762 input_shapes = [x.shape for x in inputs] 763 input_axis_labels, output_axis_labels = ( 764 _einsum_v1_parse_and_resolve_equation(equation, input_shapes)) 765 766 axis_labels = set(''.join(input_axis_labels) + output_axis_labels) 767 768 for a in axis_labels: 769 for input_labels in input_axis_labels: 770 if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and 771 input_labels == input_labels[::-1] and '->' not in equation): 772 return math_ops.trace(inputs[0]) 773 if input_labels.count(a) > 1: 774 raise ValueError( 775 'Subscript not supported: an axis appears more than once: %s' % 776 input_labels) 777 for a in axis_labels: 778 input_count = sum(1 for s in input_axis_labels if a in s) 779 if input_count > 2 and a not in output_axis_labels: 780 logging.warn( 781 'Falling back to exponential-space implementation of einsum()' 782 ' because index "%s" is summed over more than two inputs.', a) 783 return _exponential_space_einsum_v1(equation, *inputs) 784 785 # Use xla_einsum if executing on TPU and if the operation is a 2 input 786 # einsum supported by XlaEinsumOp. 787 if _enclosing_tpu_context() is not None and len(inputs) == 2: 788 return gen_xla_ops.xla_einsum( 789 inputs[0], inputs[1], input_axis_labels[0] + ',' + 790 input_axis_labels[1] + '->' + output_axis_labels) 791 temp = inputs[0] 792 temp_axis_labels = input_axis_labels[0] 793 for i in xrange(len(inputs) - 1): 794 axes_to_sum = ( 795 set(temp_axis_labels) & 796 set(input_axis_labels[i + 1]) - set(output_axis_labels)) 797 temp, temp_axis_labels = _einsum_v1_reduction(temp, temp_axis_labels, 798 inputs[i + 1], 799 input_axis_labels[i + 1], 800 axes_to_sum) 801 802 missing_indices = set(temp_axis_labels) - set(output_axis_labels) 803 if missing_indices: 804 axis = [ 805 i for i, a in enumerate(temp_axis_labels) 806 if a not in output_axis_labels 807 ] 808 temp = math_ops.reduce_sum(temp, axis=axis) 809 temp_axis_labels = ''.join( 810 a for a in temp_axis_labels if a in output_axis_labels) 811 if sorted(temp_axis_labels) != sorted(output_axis_labels): 812 raise ValueError('Invalid equation: %s' % equation) 813 814 perm = [temp_axis_labels.index(a) for a in output_axis_labels] 815 return _transpose_if_necessary(temp, perm) 816 817 818def _einsum_v1_parse_and_resolve_equation(equation, input_shapes): 819 """Helper for einsum() that splits/resolves inputs & outputs. 820 821 Args: 822 equation: Equation string given as argument to einsum(). 823 input_shapes: List of the shapes of all inputs given to einsum() 824 825 Returns: 826 input_axis_labels, output_axis_labels where: 827 input_axis_labels: List of length len(input_shapes) of strings 828 representing the character label for each dimension of each given input, 829 resolving any broadcast (...) axes, 830 output_axis_labels: A string of character labels for each axes of output 831 tensor, filling in missing output subscripts and broadcast axes. 832 833 Raises: 834 ValueError: If equation is in the uncorrect format, incorrect number of 835 inputs given or broadcast axes "..." or output axes could not be resolved. 836 """ 837 equation = equation.replace(' ', '') 838 match = re.match('^([a-zA-Z,.]+)(->[a-zA-Z.]*)?$', equation) 839 if not match: 840 raise ValueError('Indices have incorrect format: %s' % equation) 841 842 input_axis_labels = match.group(1).split(',') 843 output_axis_labels = match.group(2)[2:] if match.group(2) else None 844 845 if len(input_shapes) != len(input_axis_labels): 846 raise ValueError('Got %d arguments for equation "%s", expecting %d' % 847 (len(input_shapes), equation, len(input_axis_labels))) 848 849 # Resolve Ellipsis 850 # Assign axes labels for unspecified dimensions in inputs. Labels taken 851 # from unused labels. Follow numpy einsum broadcasting conventions for 852 # tensors of different length and unlabeled output. 853 ellipsis_axes = '' 854 if '...' in equation: 855 unused = ''.join( 856 c for c in string.ascii_letters if c not in ''.join(input_axis_labels)) 857 for i, ax in enumerate(input_axis_labels): 858 if '...' in ax: 859 parts = ax.split('...') 860 if len(parts) != 2: 861 raise ValueError('Unable to resolve ellipsis. Excess number found.') 862 if input_shapes[i].ndims is None: 863 raise ValueError('Unable to statically infer ellipsis axes.') 864 n = input_shapes[i].ndims - len(''.join(parts)) 865 if n < 0: 866 raise ValueError('Ellipses lengths do not match.') 867 if len(unused) < n: 868 raise ValueError( 869 'Unable to resolve ellipsis, too many distinct labels.') 870 replace_axes = unused[-n:] if n > 0 else '' 871 input_axis_labels[i] = input_axis_labels[i].replace('...', 872 replace_axes) 873 if len(replace_axes) > len(ellipsis_axes): 874 ellipsis_axes = replace_axes 875 876 if any('.' in ax for ax in input_axis_labels): 877 raise ValueError('period "." found outside of ellipsis') 878 879 if output_axis_labels is not None: 880 output_axis_labels = output_axis_labels.replace('...', ellipsis_axes) 881 if '.' in output_axis_labels: 882 raise ValueError('period "." found outside of ellipsis') 883 884 if output_axis_labels is None: 885 # infer the output subscripts if not given, assume alphabetical order, 886 # but always place ellipsis axes before given. 887 axis_labels = set(''.join(input_axis_labels)) - set(ellipsis_axes) 888 indices = ''.join(sorted(axis_labels)) 889 counts = {ax: 0 for ax in indices} 890 for axes_ in input_axis_labels: 891 for ax in axes_: 892 if ax not in ellipsis_axes: 893 counts[ax] += 1 894 895 output_axis_labels = ellipsis_axes + ''.join( 896 sorted(ax for ax in axis_labels if counts[ax] == 1)) 897 898 return input_axis_labels, output_axis_labels 899 900 901def _einsum_v1_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum): 902 """Helper for einsum() that computes the result of a two-argument einsum(). 903 904 Args: 905 t0: a `Tensor` 906 t0_axis_labels: a string of axis labels. This string's length must equal 907 the rank of t0. 908 t1: a `Tensor` 909 t1_axis_labels: a string to axis labels. This string's length must equal 910 the rank of t1. 911 axes_to_sum: set of labels of axes to be summed over 912 913 Returns: 914 A `Tensor` whose elements are obtained by summing, over all axes in 915 `axes_to_sum`, the corresponding elements of `t0` and `t1`. 916 917 For example, if t0_axis_labels == 'abijk', t1_axis_labels == 'acjkl', and 918 axes_to_sum == {j,k}, this will return a tensor x where 919 920 out[a,b,c,i,l] = sum_j sum_k t0[a,b,i,j,k] * t1[a,c,j,k,l] 921 922 Raises: 923 ValueError: if the rank of `t0` does not match the length of 924 `t0_axis_labels`, or that of `t1` does not match the length of 925 `t1_axis_labels`. 926 """ 927 if len(t0_axis_labels) != len(t0.shape): 928 raise ValueError( 929 'Tensor t0 of rank %d does not match einsum reduction of length %d' % 930 (len(t0.shape), len(t0_axis_labels))) 931 if len(t1_axis_labels) != len(t1.shape): 932 raise ValueError( 933 'Tensor t1 of rank %d does not match einsum reduction of length %d' % 934 (len(t1.shape), len(t1_axis_labels))) 935 936 # This function computes the result of a two-argument einsum() using batch 937 # matrix multiplication. This involves 938 # 1. transposing t0 and t1 so that axes are in the correct order for 939 # batch matrix multiplication, and 940 # 2. reshaping t0 and t1 so that they are both of rank 3. 941 942 # First, we divide axes into three groups: 943 # * "preserved" axes are present in both inputs and the output 944 # * "summed" axes are present in both inputs but not the output 945 # * "broadcast" axes are present in exactly one input and the output 946 # 947 # As an example, if the einsum is abijk,acjkl->abcil, then "a" is a 948 # preserved axis, "b" and "c" are broadcast axes, and "j" and "k" are 949 # summed axes. 950 assert all(a in t0_axis_labels and a in t1_axis_labels for a in axes_to_sum) 951 preserved_axes = (set(t0_axis_labels) & set(t1_axis_labels)) - axes_to_sum 952 broadcast_axes = {} 953 for i, sym_list in enumerate([t0_axis_labels, t1_axis_labels]): 954 broadcast_axes[i] = set(sym_list) - preserved_axes - axes_to_sum 955 956 # Reorder the axes so that: 957 # 1. preserved axes come first in both inputs 958 # 2. in input 0, broadcast axes come next, followed by summed axes 959 # 3. in input 1, summed axes come next, followed by broadcast axes 960 def sort_key(input_index, a): 961 if a in preserved_axes: 962 return (-1, a) 963 elif ((input_index == 0 and a in broadcast_axes[0]) or 964 (input_index == 1 and a in axes_to_sum)): 965 return (0, a) 966 else: 967 return (1, a) 968 969 axis_labels = [t0_axis_labels, t1_axis_labels] 970 sorted_axes = [ 971 sorted(sym_list, key=lambda a: sort_key(i, a)) 972 for i, sym_list in enumerate(axis_labels) 973 ] 974 inputs = [t0, t1] 975 for i, axes_str in enumerate(axis_labels): 976 perm = [axes_str.find(a) for a in sorted_axes[i]] 977 inputs[i] = _transpose_if_necessary(inputs[i], perm) 978 t0, t1 = inputs 979 980 if not axes_to_sum: 981 # In the special case where there are no axes to sum over, reduce to mul() 982 # rather than to batch matrix multiplication. 983 for _ in broadcast_axes[1]: 984 t0 = array_ops.expand_dims(t0, -1) 985 for _ in broadcast_axes[0]: 986 t1 = array_ops.expand_dims(t1, len(preserved_axes)) 987 product = math_ops.multiply(t0, t1) 988 product_axes = sorted_axes[0] + sorted_axes[1][len(preserved_axes):] 989 return product, ''.join(product_axes) 990 else: 991 # Reduce to matmul(). 992 993 # Reshape both inputs so as to combine multiple broadcast axes 994 # into a single axis, and combine multiple summed axes into a 995 # single axis. 996 997 t0_shape = _get_shape(t0) 998 num_broadcast_elements_t0 = _total_size( 999 t0_shape[len(preserved_axes):-len(axes_to_sum)]) 1000 num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):]) 1001 new_shape = ( 1002 t0_shape[:len(preserved_axes)] + 1003 [num_broadcast_elements_t0, num_summed_elements]) 1004 t0 = _reshape_if_necessary(t0, new_shape) 1005 1006 t1_shape = _get_shape(t1) 1007 num_broadcast_elements_t1 = _total_size( 1008 t1_shape[len(preserved_axes) + len(axes_to_sum):]) 1009 new_shape = ( 1010 t1_shape[:len(preserved_axes)] + 1011 [num_summed_elements, num_broadcast_elements_t1]) 1012 t1 = _reshape_if_necessary(t1, new_shape) 1013 1014 product = math_ops.matmul(t0, t1) 1015 1016 # Undo compaction of broadcast axes 1017 uncompacted_shape = ( 1018 t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] + 1019 t1_shape[len(t1_shape) - len(broadcast_axes[1]):]) 1020 product = _reshape_if_necessary(product, uncompacted_shape) 1021 1022 product_axes = ( 1023 sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] + 1024 sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):]) 1025 1026 return product, ''.join(product_axes) 1027 1028 1029def _transpose_if_necessary(tensor, perm): 1030 """Like transpose(), but avoids creating a new tensor if possible.""" 1031 if perm != list(range(len(perm))): 1032 return array_ops.transpose(tensor, perm=perm) 1033 else: 1034 return tensor 1035 1036 1037def _reshape_if_necessary(tensor, new_shape): 1038 """Like reshape(), but avoids creating a new tensor if possible.""" 1039 # Accept None as an alias for -1 in new_shape. 1040 new_shape = tuple(-1 if x is None else x for x in new_shape) 1041 cur_shape = tuple(x.value for x in tensor.shape.dims) 1042 if (len(new_shape) == len(cur_shape) and 1043 all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1) 1044 for d0, d1 in zip(cur_shape, new_shape))): 1045 return tensor 1046 else: 1047 return array_ops.reshape(tensor, new_shape) 1048 1049 1050def _get_shape(tensor): 1051 """Like get_shape().as_list(), but explicitly queries the shape of a tensor 1052 if necessary to ensure that the returned value contains no unknown value.""" 1053 1054 shape = tensor.shape.as_list() 1055 none_indices = [i for i, d in enumerate(shape) if d is None] 1056 if none_indices: 1057 # Query the shape if shape contains None values 1058 shape_tensor = array_ops.shape(tensor) 1059 for i in none_indices: 1060 shape[i] = shape_tensor[i] 1061 return shape 1062 1063 1064def _total_size(shape_values): 1065 """Given list of tensor shape values, returns total size. 1066 If shape_values contains tensor values (which are results of 1067 array_ops.shape), then it returns a scalar tensor. 1068 If not, it returns an integer.""" 1069 1070 result = 1 1071 for val in shape_values: 1072 result *= val 1073 return result 1074 1075 1076def _exponential_space_einsum_v1(equation, *inputs): 1077 """Fallback implementation that supports summing an index over > 2 inputs.""" 1078 inputs = list(inputs) 1079 input_shapes = [x.shape for x in inputs] 1080 idx_in, idx_out = _einsum_v1_parse_and_resolve_equation( 1081 equation, input_shapes) 1082 1083 idx_all = set(''.join(idx_in) + idx_out) 1084 indices = ''.join(sorted(idx_all)) 1085 1086 missing_idx = set(idx_out).difference(idx_all) 1087 if missing_idx: 1088 raise ValueError('Unknown output axes: %s' % missing_idx) 1089 1090 axis_order = {} 1091 for ax in indices: 1092 if ax not in idx_out: 1093 axis_order[ax] = len(axis_order) 1094 for ax in idx_out: 1095 axis_order[ax] = len(axis_order) 1096 1097 # transpose inputs so axes are in order 1098 for i, (input_, axes_) in enumerate(zip(inputs, idx_in)): 1099 if input_.shape.ndims != len(axes_): 1100 raise ValueError( 1101 'Input %d with axes %s has incorrect' \ 1102 ' number of dimensions (expected %d, got %d)' % ( 1103 i, axes_, len(axes_), input_.shape.ndims 1104 ) 1105 ) 1106 1107 sorted_idx = sorted(axes_, key=axis_order.get) 1108 1109 if len(set(axes_)) != len(axes_): 1110 raise ValueError( 1111 'Subscript not supported: an axis appears more than once: %s' % axes_) 1112 1113 if list(axes_) != sorted_idx: 1114 permuted = [axes_.find(ax) for ax in sorted_idx] 1115 inputs[i] = array_ops.transpose(input_, permuted) 1116 idx_in[i] = sorted_idx 1117 1118 reduction_idx = [] 1119 shapes = [[dim if dim else -1 1120 for dim in tensor.shape.as_list()] 1121 for tensor in inputs] 1122 1123 # validate shapes for broadcasting 1124 for j, ax in enumerate(sorted(idx_all, key=axis_order.get)): 1125 dims = [] 1126 for i, idx in enumerate(idx_in): 1127 if ax not in idx: 1128 shapes[i].insert(j, 1) 1129 else: 1130 dim = shapes[i][j] 1131 if isinstance(dim, int) and dim > 1: 1132 dims.append(dim) 1133 1134 if len(set(dims)) > 1: 1135 raise ValueError('Dimension mismatch on axis: %s' % ax) 1136 1137 if ax not in idx_out: 1138 reduction_idx.append(j) 1139 1140 # reshape, multiply 1141 expanded_inputs = [ 1142 array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes) 1143 ] 1144 expanded_output = 1 1145 for input_ in expanded_inputs: 1146 expanded_output *= input_ 1147 1148 # contract 1149 return math_ops.reduce_sum(expanded_output, reduction_idx) 1150 1151 1152def _einsum_v2(equation, *inputs, **kwargs): 1153 """Implementation of einsum utilizing opt_einsum and EinsumOp.""" 1154 name = kwargs.pop('name', None) 1155 optimize = kwargs.pop('optimize', 'greedy') 1156 if kwargs: 1157 msg = 'Invalid keyword arguments for einsum: {}' 1158 raise TypeError(msg.format(', '.join(kwargs))) 1159 1160 with ops.name_scope(name, 'einsum', [equation, inputs]) as name: 1161 inputs = list(inputs) 1162 input_shapes = [] 1163 for operand in inputs: 1164 if isinstance(operand.shape, tensor_shape.TensorShape): 1165 input_shapes.append(operand.shape.as_list() if operand.shape else None) 1166 else: 1167 input_shapes.append(list(operand.shape)) 1168 # Validate and sanitize the equation and resolve static input shapes, as 1169 # opt_einsum requires that all shapes be a tuple of positive integers. 1170 # Also remove ellipsis from the equation as opt_einsum will replace them 1171 # with named labels. Then broadcasting between different shapes or ranks 1172 # wouldn't work. (E.g. [1, 1, 2] wouldn't broadcast with [3, 1]). 1173 resolved_equation, resolved_input_shapes, ellipsis_label = ( 1174 _einsum_v2_parse_and_resolve_equation(equation, input_shapes)) 1175 1176 if len(inputs) <= 2: # No need to call opt_einsum. 1177 # Replace back ellipses that were removed for opt_einsum. 1178 if ellipsis_label: 1179 resolved_equation = resolved_equation.replace(ellipsis_label, '...') 1180 return gen_linalg_ops.einsum(inputs, resolved_equation) 1181 1182 # Send fully specified shapes to opt_einsum, since it cannot handle unknown 1183 # dimensions. For unknown dimensions, we guess that the dimension equals 1. 1184 # Instead of creating Tensors or NumPy arrays with the specified shape, 1185 # create a dummy `shaped` object with a `shape` property. 1186 shaped = collections.namedtuple('shaped', ['shape']) 1187 shaped_inputs = tuple( 1188 [shaped(tuple(shape)) for shape in resolved_input_shapes]) 1189 # opt_einsum breaks down an n-ary einsum operation into n-1 binary einsums. 1190 # Obtain the sequence of equations and the indices of operands involved in 1191 # each einsum operation. 1192 indices_and_equations = _get_opt_einsum_contract_path( 1193 resolved_equation, shaped_inputs, optimize) 1194 for operand_indices, binary_equation in indices_and_equations: 1195 if ellipsis_label: 1196 # Replace back ellipses that were removed for opt_einsum. 1197 binary_equation = binary_equation.replace(ellipsis_label, '...') 1198 operands = list(map(inputs.pop, operand_indices)) 1199 inputs.append(gen_linalg_ops.einsum(operands, binary_equation)) 1200 return inputs[0] 1201 1202 1203def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize): 1204 """Returns the (memoized) result of opt_einsum.contract_path.""" 1205 # Note: We use einsum_call=True, which is an internal api for opt_einsum, 1206 # to get the contraction path without having opt_einsum perform the actual 1207 # contractions. 1208 _, contractions = opt_einsum.contract_path( 1209 equation, 1210 *shaped_inputs_tuple, 1211 optimize=optimize, 1212 einsum_call=True, 1213 use_blas=True) 1214 # Return a tuple so that the cached value is not mutable. 1215 indices_and_equations = tuple([(expr[0], expr[2]) for expr in contractions]) 1216 return indices_and_equations 1217 1218 1219# Cache the possibly expensive opt_einsum.contract_path call using lru_cache 1220# from the Python3+ standard library. 1221if not six.PY2: 1222 _get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)( 1223 _get_opt_einsum_contract_path) 1224 1225 1226def _einsum_v2_parse_and_resolve_equation(equation, input_shapes): 1227 """Helper which validates einsum equation and resolves input shapes.""" 1228 resolved_equation = equation.replace(' ', '') 1229 ellipsis_label = None 1230 if '...' in equation: 1231 # Replace ellipsis ('...') with '0' for (a) ease of parsing and (b) to 1232 # prevent opt_einsum from resolving them into named labels; as it doesn't 1233 # support broadcasting. 1234 ellipsis_label = '0' 1235 if ellipsis_label in resolved_equation: 1236 raise ValueError('Invalid character "0" in equation: {}'.format(equation)) 1237 resolved_equation = resolved_equation.replace('...', ellipsis_label) 1238 1239 # Ensure there are no non-alphanumeric characters in the equation, including 1240 # periods (`.`) outside of ellipses, in the equation. This is not a hard 1241 # requirement; except we use a special character '0' for ellipsis. 1242 allowed_labels = 'a-zA-Z' 1243 if ellipsis_label: 1244 allowed_labels += ellipsis_label 1245 match = re.match('^([{0},]*)(->[{0}]*)?$'.format(allowed_labels), 1246 resolved_equation) 1247 if not match: 1248 raise ValueError( 1249 'Subscripts have incorrect format: {}'.format(resolved_equation)) 1250 input_labels = match.group(1).split(',') 1251 output_labels = match.group(2)[2:] if match.group(2) else None 1252 1253 if len(input_shapes) != len(input_labels): 1254 raise ValueError('Got {} inputs for equation "{}", expecting {}'.format( 1255 len(input_shapes), equation, len(input_labels))) 1256 1257 # Special case: if there are no '->', then we create output subscripts from 1258 # labels appearing only once. 1259 if '->' not in resolved_equation: 1260 label_counts = collections.Counter(match.group(1)) 1261 output_labels = ''.join([ 1262 x for x in sorted(list(label_counts)) 1263 if x != ',' and label_counts[x] == 1 1264 ]) 1265 resolved_equation += '->' + output_labels 1266 # Validate output_labels. 1267 if output_labels and len(set(output_labels)) != len(output_labels): 1268 raise ValueError( 1269 'Output subscripts contain a label appearing more than once: {}'.format( 1270 equation)) 1271 input_label_set = set(match.group(1)) 1272 for label in output_labels: 1273 if label != ellipsis_label and label not in input_label_set: 1274 raise ValueError('Output subscripts contain the label {} not present ' 1275 'in the input subscripts.'.format(label)) 1276 if ellipsis_label and output_labels: 1277 num_output_ellipses = output_labels.count(ellipsis_label) 1278 if num_output_ellipses > 1: 1279 raise ValueError( 1280 'Output subscripts contain multiple ellipsis: {}'.format(equation)) 1281 1282 # Early return if <= 2 inputs. Resolved shapes are not needed. 1283 if len(input_shapes) <= 2: 1284 return resolved_equation, None, ellipsis_label 1285 1286 # Create a map from axis labels to known dimensions. This is used to infer 1287 # unknown dimensions if a known dimension also has the same label. 1288 label_to_dim = collections.defaultdict(lambda: 1) 1289 for i, (labels, shape) in enumerate(zip(input_labels, input_shapes)): 1290 if shape is None: 1291 continue 1292 ellipsis_start = labels.find(ellipsis_label) if ellipsis_label else -1 1293 if ellipsis_start != -1: # This input contains an ellipsis. 1294 if ellipsis_start != labels.rfind(ellipsis_label): 1295 raise ValueError('Too many ellipsis') 1296 if len(labels) > len(shape) + 1: 1297 raise ValueError('Too many named labels in {}th subscript string of' 1298 ' equation {} for input shape {} '.format( 1299 i, equation, shape)) 1300 ellipsis_end = ellipsis_start + len(shape) + 1 - len(labels) 1301 shape[ellipsis_start:ellipsis_end] = ([ 1302 np.prod( 1303 list(filter(None, shape[ellipsis_start:ellipsis_end])), 1304 dtype=np.int64) 1305 ]) 1306 else: 1307 # This input does not contain an ellipsis. 1308 if len(labels) != len(shape): 1309 raise ValueError( 1310 'Number of named labels in input #{} of equation {} ' 1311 'must be equal to the number of dimensions in shape {}'.format( 1312 i, equation, shape)) 1313 for dim, label in zip(shape, labels): 1314 if dim is not None: 1315 label_to_dim[label] = max(label_to_dim[label], dim) 1316 1317 resolved_shapes = [] 1318 for labels in input_labels: 1319 resolved_shapes.append([label_to_dim[label] for label in labels]) 1320 return resolved_equation, resolved_shapes, ellipsis_label 1321