1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Miscellaneous utilities used by time series models.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import math 23 24import numpy as np 25 26from tensorflow.contrib import lookup 27from tensorflow.contrib.layers.python.layers import layers 28 29from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures 30 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import functional_ops 38from tensorflow.python.ops import gen_math_ops 39from tensorflow.python.ops import init_ops 40from tensorflow.python.ops import linalg_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import nn 43from tensorflow.python.ops import state_ops 44from tensorflow.python.ops import variable_scope 45from tensorflow.python.util import nest 46 47 48def normal_log_prob(loc, scale, x): 49 """Computes the Normal log pdf.""" 50 z = (x - loc) / scale 51 return -0.5 * (math_ops.square(z) 52 + np.log(2. * np.pi) + math_ops.log(scale)) 53 54 55def cauchy_log_prob(loc, scale, x): 56 """Computes the Cauchy log pdf.""" 57 z = (x - loc) / scale 58 return (-np.log(np.pi) - math_ops.log(scale) - 59 math_ops.log1p(math_ops.square(z))) 60 61 62def mvn_tril_log_prob(loc, scale_tril, x): 63 """Computes the MVN log pdf under tril scale. Doesn't handle batches.""" 64 x0 = x - loc 65 z = linalg_ops.matrix_triangular_solve( 66 scale_tril, x0[..., array_ops.newaxis])[..., 0] 67 log_det_cov = 2. * math_ops.reduce_sum(math_ops.log( 68 array_ops.matrix_diag_part(scale_tril)), axis=-1) 69 d = math_ops.cast(array_ops.shape(scale_tril)[-1], log_det_cov.dtype) 70 return -0.5 * (math_ops.reduce_sum(math_ops.square(z), axis=-1) 71 + d * np.log(2. * np.pi) + log_det_cov) 72 73 74def clip_covariance( 75 covariance_matrix, maximum_variance_ratio, minimum_variance): 76 """Enforce constraints on a covariance matrix to improve numerical stability. 77 78 Args: 79 covariance_matrix: A [..., N, N] batch of covariance matrices. 80 maximum_variance_ratio: The maximum allowed ratio of two diagonal 81 entries. Any entries lower than the maximum entry divided by this ratio 82 will be set to that value. 83 minimum_variance: A floor for diagonal entries in the returned matrix. 84 Returns: 85 A new covariance matrix with the requested constraints enforced. If the 86 input was positive definite, the output will be too. 87 """ 88 # TODO(allenl): Smarter scaling here so that correlations are preserved when 89 # fiddling with diagonal elements. 90 diagonal = array_ops.matrix_diag_part(covariance_matrix) 91 maximum = math_ops.reduce_max(diagonal, axis=-1, keepdims=True) 92 new_diagonal = gen_math_ops.maximum( 93 diagonal, maximum / maximum_variance_ratio) 94 return array_ops.matrix_set_diag( 95 covariance_matrix, math_ops.maximum(new_diagonal, minimum_variance)) 96 97 98def block_diagonal(matrices, dtype=dtypes.float32, name="block_diagonal"): 99 r"""Constructs block-diagonal matrices from a list of batched 2D tensors. 100 101 Args: 102 matrices: A list of Tensors with shape [..., N_i, M_i] (i.e. a list of 103 matrices with the same batch dimension). 104 dtype: Data type to use. The Tensors in `matrices` must match this dtype. 105 name: A name for the returned op. 106 Returns: 107 A matrix with the input matrices stacked along its main diagonal, having 108 shape [..., \sum_i N_i, \sum_i M_i]. 109 """ 110 matrices = [ops.convert_to_tensor(matrix, dtype=dtype) for matrix in matrices] 111 blocked_rows = tensor_shape.Dimension(0) 112 blocked_cols = tensor_shape.Dimension(0) 113 batch_shape = tensor_shape.TensorShape(None) 114 for matrix in matrices: 115 full_matrix_shape = matrix.get_shape().with_rank_at_least(2) 116 batch_shape = batch_shape.merge_with(full_matrix_shape[:-2]) 117 blocked_rows += full_matrix_shape[-2] 118 blocked_cols += full_matrix_shape[-1] 119 ret_columns_list = [] 120 for matrix in matrices: 121 matrix_shape = array_ops.shape(matrix) 122 ret_columns_list.append(matrix_shape[-1]) 123 ret_columns = math_ops.add_n(ret_columns_list) 124 row_blocks = [] 125 current_column = 0 126 for matrix in matrices: 127 matrix_shape = array_ops.shape(matrix) 128 row_before_length = current_column 129 current_column += matrix_shape[-1] 130 row_after_length = ret_columns - current_column 131 row_blocks.append( 132 array_ops.pad( 133 tensor=matrix, 134 paddings=array_ops.concat( 135 [ 136 array_ops.zeros( 137 [array_ops.rank(matrix) - 1, 2], dtype=dtypes.int32), [( 138 row_before_length, row_after_length)] 139 ], 140 axis=0))) 141 blocked = array_ops.concat(row_blocks, -2, name=name) 142 blocked.set_shape(batch_shape.concatenate((blocked_rows, blocked_cols))) 143 return blocked 144 145 146def power_sums_tensor(array_size, power_matrix, multiplier): 147 r"""Computes \sum_{i=0}^{N-1} A^i B (A^i)^T for N=0..(array_size + 1). 148 149 Args: 150 array_size: The number of non-trivial sums to pre-compute. 151 power_matrix: The "A" matrix above. 152 multiplier: The "B" matrix above 153 Returns: 154 A Tensor with S[N] = \sum_{i=0}^{N-1} A^i B (A^i)^T 155 S[0] is the zero matrix 156 S[1] is B 157 S[2] is A B A^T + B 158 ...and so on 159 """ 160 array_size = math_ops.cast(array_size, dtypes.int32) 161 power_matrix = ops.convert_to_tensor(power_matrix) 162 identity_like_power_matrix = linalg_ops.eye( 163 array_ops.shape(power_matrix)[0], dtype=power_matrix.dtype) 164 identity_like_power_matrix.set_shape( 165 ops.convert_to_tensor(power_matrix).get_shape()) 166 transition_powers = functional_ops.scan( 167 lambda previous_power, _: math_ops.matmul(previous_power, power_matrix), 168 math_ops.range(array_size - 1), 169 initializer=identity_like_power_matrix) 170 summed = math_ops.cumsum( 171 array_ops.concat([ 172 array_ops.expand_dims(multiplier, 0), math_ops.matmul( 173 batch_times_matrix(transition_powers, multiplier), 174 transition_powers, 175 adjoint_b=True) 176 ], 0)) 177 return array_ops.concat( 178 [array_ops.expand_dims(array_ops.zeros_like(multiplier), 0), summed], 0) 179 180 181def matrix_to_powers(matrix, powers): 182 """Raise a single matrix to multiple powers.""" 183 matrix_tiled = array_ops.tile( 184 array_ops.expand_dims(matrix, 0), [array_ops.size(powers), 1, 1]) 185 return batch_matrix_pow(matrix_tiled, powers) 186 187 188def batch_matrix_pow(matrices, powers): 189 """Compute powers of matrices, e.g. A^3 = matmul(matmul(A, A), A). 190 191 Uses exponentiation by squaring, with O(log(p)) matrix multiplications to 192 compute A^p. 193 194 Args: 195 matrices: [batch size x N x N] 196 powers: Which integer power to raise each matrix to [batch size] 197 Returns: 198 The matrices raised to their respective powers, same dimensions as the 199 "matrices" argument. 200 """ 201 202 def terminate_when_all_zero(current_argument, residual_powers, accumulator): 203 del current_argument, accumulator # not used for condition 204 do_exit = math_ops.reduce_any( 205 math_ops.greater(residual_powers, array_ops.ones_like(residual_powers))) 206 return do_exit 207 208 def do_iteration(current_argument, residual_powers, accumulator): 209 """Compute one step of iterative exponentiation by squaring. 210 211 The recursive form is: 212 power(A, p) = { power(matmul(A, A), p / 2) for even p 213 { matmul(A, power(matmul(A, A), (p - 1) / 2)) for odd p 214 power(A, 0) = I 215 216 The power(A, 0) = I case is handled by starting with accumulator set to the 217 identity matrix; matrices with zero residual powers are passed through 218 unchanged. 219 220 Args: 221 current_argument: On this step, what is the first argument (A^2..^2) to 222 the (unrolled) recursive function? [batch size x N x N] 223 residual_powers: On this step, what is the second argument (residual p)? 224 [batch_size] 225 accumulator: Accumulates the exterior multiplications from the odd 226 powers (initially the identity matrix). [batch_size x N x N] 227 Returns: 228 Updated versions of each argument for one step of the unrolled 229 computation. Does not change parts of the batch which have a residual 230 power of zero. 231 """ 232 is_even = math_ops.equal(residual_powers % 2, 233 array_ops.zeros( 234 array_ops.shape(residual_powers), 235 dtype=dtypes.int32)) 236 new_accumulator = array_ops.where(is_even, accumulator, 237 math_ops.matmul(accumulator, 238 current_argument)) 239 new_argument = math_ops.matmul(current_argument, current_argument) 240 do_update = math_ops.greater(residual_powers, 1) 241 new_residual_powers = residual_powers - residual_powers % 2 242 new_residual_powers //= 2 243 # Stop updating if we've reached our base case; some batch elements may 244 # finish sooner than others 245 accumulator = array_ops.where(do_update, new_accumulator, accumulator) 246 current_argument = array_ops.where(do_update, new_argument, 247 current_argument) 248 residual_powers = array_ops.where(do_update, new_residual_powers, 249 residual_powers) 250 return (current_argument, residual_powers, accumulator) 251 252 matrices = ops.convert_to_tensor(matrices) 253 powers = math_ops.cast(powers, dtype=dtypes.int32) 254 ident = array_ops.expand_dims( 255 array_ops.diag( 256 array_ops.ones([array_ops.shape(matrices)[1]], dtype=matrices.dtype)), 257 0) 258 ident_tiled = array_ops.tile(ident, [array_ops.shape(matrices)[0], 1, 1]) 259 (final_argument, 260 final_residual_power, final_accumulator) = control_flow_ops.while_loop( 261 terminate_when_all_zero, do_iteration, [matrices, powers, ident_tiled]) 262 return array_ops.where( 263 math_ops.equal(final_residual_power, 264 array_ops.zeros_like( 265 final_residual_power, dtype=dtypes.int32)), 266 ident_tiled, math_ops.matmul(final_argument, final_accumulator)) 267 268 269# TODO(allenl): would be useful if this was built into batch_matmul 270def batch_times_matrix(batch, matrix, adj_x=False, adj_y=False): 271 """Multiply a batch of matrices by a single matrix. 272 273 Functionally equivalent to: 274 tf.matmul(batch, array_ops.tile(gen_math_ops.expand_dims(matrix, 0), 275 [array_ops.shape(batch)[0], 1, 1]), 276 adjoint_a=adj_x, adjoint_b=adj_y) 277 278 Args: 279 batch: [batch_size x N x M] after optional transpose 280 matrix: [M x P] after optional transpose 281 adj_x: If true, transpose the second two dimensions of "batch" before 282 multiplying. 283 adj_y: If true, transpose "matrix" before multiplying. 284 Returns: 285 [batch_size x N x P] 286 """ 287 batch = ops.convert_to_tensor(batch) 288 matrix = ops.convert_to_tensor(matrix) 289 assert batch.get_shape().ndims == 3 290 assert matrix.get_shape().ndims == 2 291 if adj_x: 292 batch = array_ops.transpose(batch, [0, 2, 1]) 293 batch_dimension = batch.get_shape().dims[0].value 294 first_dimension = batch.get_shape().dims[1].value 295 tensor_batch_shape = array_ops.shape(batch) 296 if batch_dimension is None: 297 batch_dimension = tensor_batch_shape[0] 298 if first_dimension is None: 299 first_dimension = tensor_batch_shape[1] 300 matrix_first_dimension, matrix_second_dimension = matrix.get_shape().as_list() 301 batch_reshaped = array_ops.reshape(batch, [-1, tensor_batch_shape[2]]) 302 if adj_y: 303 if matrix_first_dimension is None: 304 matrix_first_dimension = array_ops.shape(matrix)[0] 305 result_shape = [batch_dimension, first_dimension, matrix_first_dimension] 306 else: 307 if matrix_second_dimension is None: 308 matrix_second_dimension = array_ops.shape(matrix)[1] 309 result_shape = [batch_dimension, first_dimension, matrix_second_dimension] 310 return array_ops.reshape( 311 math_ops.matmul(batch_reshaped, matrix, adjoint_b=adj_y), result_shape) 312 313 314def matrix_times_batch(matrix, batch, adj_x=False, adj_y=False): 315 """Like batch_times_matrix, but with the multiplication order swapped.""" 316 return array_ops.transpose( 317 batch_times_matrix( 318 batch=batch, matrix=matrix, adj_x=not adj_y, adj_y=not adj_x), 319 [0, 2, 1]) 320 321 322def make_toeplitz_matrix(inputs, name=None): 323 """Make a symmetric Toeplitz matrix from input array of values. 324 325 Args: 326 inputs: a 3-D tensor of shape [num_blocks, block_size, block_size]. 327 name: the name of the operation. 328 329 Returns: 330 a symmetric Toeplitz matrix of shape 331 [num_blocks*block_size, num_blocks*block_size]. 332 """ 333 num_blocks = array_ops.shape(inputs)[0] 334 block_size = array_ops.shape(inputs)[1] 335 output_size = block_size * num_blocks 336 lags = array_ops.reshape(math_ops.range(num_blocks), shape=[1, -1]) 337 indices = math_ops.abs(lags - array_ops.transpose(lags)) 338 output = array_ops.gather(inputs, indices) 339 output = array_ops.reshape( 340 array_ops.transpose(output, [0, 2, 1, 3]), [output_size, output_size]) 341 return array_ops.identity(output, name=name) 342 343 344# TODO(allenl): Investigate alternative parameterizations. 345def sign_magnitude_positive_definite( 346 raw, off_diagonal_scale=0., overall_scale=0.): 347 """Constructs a positive definite matrix from an unconstrained input matrix. 348 349 We want to keep the whole matrix on a log scale, but also allow off-diagonal 350 elements to be negative, so the sign of off-diagonal elements is modeled 351 separately from their magnitude (using the lower and upper triangles 352 respectively). Specifically: 353 354 for i < j, we have: 355 output_cholesky[i, j] = raw[j, i] / (abs(raw[j, i]) + 1) * 356 exp((off_diagonal_scale + overall_scale + raw[i, j]) / 2) 357 358 output_cholesky[i, i] = exp((raw[i, i] + overall_scale) / 2) 359 360 output = output_cholesky^T * output_cholesky 361 362 where raw, off_diagonal_scale, and overall_scale are 363 un-constrained real-valued variables. The resulting values are stable 364 around zero due to the exponential (and the softsign keeps the function 365 smooth). 366 367 Args: 368 raw: A [..., M, M] Tensor. 369 off_diagonal_scale: A scalar or [...] shaped Tensor controlling the relative 370 scale of off-diagonal values in the output matrix. 371 overall_scale: A scalar or [...] shaped Tensor controlling the overall scale 372 of the output matrix. 373 Returns: 374 The `output` matrix described above, a [..., M, M] positive definite matrix. 375 376 """ 377 raw = ops.convert_to_tensor(raw) 378 diagonal = array_ops.matrix_diag_part(raw) 379 def _right_pad_with_ones(tensor, target_rank): 380 # Allow broadcasting even if overall_scale and off_diagonal_scale have batch 381 # dimensions 382 tensor = ops.convert_to_tensor(tensor, dtype=raw.dtype.base_dtype) 383 return array_ops.reshape(tensor, 384 array_ops.concat( 385 [ 386 array_ops.shape(tensor), array_ops.ones( 387 [target_rank - array_ops.rank(tensor)], 388 dtype=target_rank.dtype) 389 ], 390 axis=0)) 391 # We divide the log values by 2 to compensate for the squaring that happens 392 # when transforming Cholesky factors into positive definite matrices. 393 sign_magnitude = (gen_math_ops.exp( 394 (raw + _right_pad_with_ones(off_diagonal_scale, array_ops.rank(raw)) + 395 _right_pad_with_ones(overall_scale, array_ops.rank(raw))) / 2.) * 396 nn.softsign(array_ops.matrix_transpose(raw))) 397 sign_magnitude.set_shape(raw.get_shape()) 398 cholesky_factor = array_ops.matrix_set_diag( 399 input=array_ops.matrix_band_part(sign_magnitude, 0, -1), 400 diagonal=gen_math_ops.exp((diagonal + _right_pad_with_ones( 401 overall_scale, array_ops.rank(diagonal))) / 2.)) 402 return math_ops.matmul(cholesky_factor, cholesky_factor, transpose_a=True) 403 404 405def transform_to_covariance_matrices(input_vectors, matrix_size): 406 """Construct covariance matrices via transformations from input_vectors. 407 408 Args: 409 input_vectors: A [batch size x input size] batch of vectors to transform. 410 matrix_size: An integer indicating one dimension of the (square) output 411 matrix. 412 Returns: 413 A [batch size x matrix_size x matrix_size] batch of covariance matrices. 414 """ 415 combined_values = layers.fully_connected( 416 input_vectors, matrix_size**2 + 2, activation_fn=None) 417 return sign_magnitude_positive_definite( 418 raw=array_ops.reshape(combined_values[..., :-2], 419 array_ops.concat([ 420 array_ops.shape(combined_values)[:-1], 421 [matrix_size, matrix_size] 422 ], 0)), 423 off_diagonal_scale=combined_values[..., -2], 424 overall_scale=combined_values[..., -1]) 425 426 427def variable_covariance_matrix( 428 size, name, dtype, initial_diagonal_values=None, 429 initial_overall_scale_log=0.): 430 """Construct a Variable-parameterized positive definite matrix. 431 432 Useful for parameterizing covariance matrices. 433 434 Args: 435 size: The size of the main diagonal, the returned matrix having shape [size 436 x size]. 437 name: The name to use when defining variables and ops. 438 dtype: The floating point data type to use. 439 initial_diagonal_values: A Tensor with shape [size] with initial values for 440 the diagonal values of the returned matrix. Must be positive. 441 initial_overall_scale_log: Initial value of the bias term for every element 442 of the matrix in log space. 443 Returns: 444 A Variable-parameterized covariance matrix with shape [size x size]. 445 """ 446 raw_values = variable_scope.get_variable( 447 name + "_pre_transform", 448 dtype=dtype, 449 shape=[size, size], 450 initializer=init_ops.zeros_initializer()) 451 if initial_diagonal_values is not None: 452 raw_values += array_ops.matrix_diag(math_ops.log(initial_diagonal_values)) 453 return array_ops.identity( 454 sign_magnitude_positive_definite( 455 raw=raw_values, 456 off_diagonal_scale=variable_scope.get_variable( 457 name + "_off_diagonal_scale", 458 dtype=dtype, 459 initializer=constant_op.constant(-5., dtype=dtype)), 460 overall_scale=ops.convert_to_tensor( 461 initial_overall_scale_log, dtype=dtype) + 462 variable_scope.get_variable( 463 name + "_overall_scale", 464 dtype=dtype, 465 shape=[], 466 initializer=init_ops.zeros_initializer())), 467 name=name) 468 469 470def batch_start_time(times): 471 return times[:, 0] 472 473 474def batch_end_time(times): 475 return times[:, -1] 476 477 478def log_noninformative_covariance_prior(covariance): 479 """Compute a relatively uninformative prior for noise parameters. 480 481 Helpful for avoiding noise over-estimation, where noise otherwise decreases 482 very slowly during optimization. 483 484 See: 485 Villegas, C. On the A Priori Distribution of the Covariance Matrix. 486 Ann. Math. Statist. 40 (1969), no. 3, 1098--1099. 487 488 Args: 489 covariance: A covariance matrix. 490 Returns: 491 For a [p x p] matrix: 492 log(det(covariance)^(-(p + 1) / 2)) 493 """ 494 # Avoid zero/negative determinants due to numerical errors 495 covariance += array_ops.diag(1e-8 * array_ops.ones( 496 shape=[array_ops.shape(covariance)[0]], dtype=covariance.dtype)) 497 power = -(math_ops.cast(array_ops.shape(covariance)[0] + 1, 498 covariance.dtype) / 2.) 499 return power * math_ops.log(linalg_ops.matrix_determinant(covariance)) 500 501 502def entropy_matched_cauchy_scale(covariance): 503 """Approximates a similar Cauchy distribution given a covariance matrix. 504 505 Since Cauchy distributions do not have moments, entropy matching provides one 506 way to set a Cauchy's scale parameter in a way that provides a similar 507 distribution. The effect is dividing the standard deviation of an independent 508 Gaussian by a constant very near 3. 509 510 To set the scale of the Cauchy distribution, we first select the diagonals of 511 `covariance`. Since this ignores cross terms, it overestimates the entropy of 512 the Gaussian. For each of these variances, we solve for the Cauchy scale 513 parameter which gives the same entropy as the Gaussian with that 514 variance. This means setting the (univariate) Gaussian entropy 515 0.5 * ln(2 * variance * pi * e) 516 equal to the Cauchy entropy 517 ln(4 * pi * scale) 518 Solving, we get scale = sqrt(variance * (e / (8 pi))). 519 520 Args: 521 covariance: A [batch size x N x N] batch of covariance matrices to produce 522 Cauchy scales for. 523 Returns: 524 A [batch size x N] set of Cauchy scale parameters for each part of the batch 525 and each dimension of the input Gaussians. 526 """ 527 return math_ops.sqrt(math.e / (8. * math.pi) * 528 array_ops.matrix_diag_part(covariance)) 529 530 531class TensorValuedMutableDenseHashTable(lookup.MutableDenseHashTable): 532 """A version of MutableDenseHashTable which stores arbitrary Tensor shapes. 533 534 Since MutableDenseHashTable only allows vectors right now, simply adds reshape 535 ops on both ends. 536 """ 537 538 def __init__(self, key_dtype, value_dtype, default_value, *args, **kwargs): 539 self._non_vector_value_shape = array_ops.shape(default_value) 540 super(TensorValuedMutableDenseHashTable, self).__init__( 541 key_dtype=key_dtype, 542 value_dtype=value_dtype, 543 default_value=array_ops.reshape(default_value, [-1]), 544 *args, 545 **kwargs) 546 547 def insert(self, keys, values, name=None): 548 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype) 549 keys_flat = array_ops.reshape(keys, [-1]) 550 return super(TensorValuedMutableDenseHashTable, self).insert( 551 keys=keys_flat, 552 # Each key has one corresponding value, so the shape of the tensor of 553 # values for every key is key_shape + value_shape 554 values=array_ops.reshape(values, [array_ops.shape(keys_flat)[0], -1]), 555 name=name) 556 557 def lookup(self, keys, name=None): 558 keys_flat = array_ops.reshape( 559 ops.convert_to_tensor(keys, dtype=self._key_dtype), [-1]) 560 return array_ops.reshape( 561 super(TensorValuedMutableDenseHashTable, self).lookup( 562 keys=keys_flat, name=name), 563 array_ops.concat([array_ops.shape(keys), self._non_vector_value_shape], 564 0)) 565 566 567class TupleOfTensorsLookup(lookup.LookupInterface): 568 """A LookupInterface with nested tuples of Tensors as values. 569 570 Creates one MutableDenseHashTable per value Tensor, which has some unnecessary 571 overhead. 572 """ 573 574 def __init__(self, 575 key_dtype, 576 default_values, 577 empty_key, 578 deleted_key, 579 name, 580 checkpoint=True): 581 default_values_flat = nest.flatten(default_values) 582 self._hash_tables = nest.pack_sequence_as(default_values, [ 583 TensorValuedMutableDenseHashTable( 584 key_dtype=key_dtype, 585 value_dtype=default_value.dtype.base_dtype, 586 default_value=default_value, 587 empty_key=empty_key, 588 deleted_key=deleted_key, 589 name=name + "_{}".format(table_number), 590 checkpoint=checkpoint) 591 for table_number, default_value in enumerate(default_values_flat) 592 ]) 593 self._name = name 594 595 def lookup(self, keys): 596 return nest.pack_sequence_as( 597 self._hash_tables, 598 [hash_table.lookup(keys) 599 for hash_table in nest.flatten(self._hash_tables)]) 600 601 def insert(self, keys, values): 602 nest.assert_same_structure(self._hash_tables, values) 603 # Avoid race conditions by requiring that all inputs are computed before any 604 # inserts happen (an issue if one key's update relies on another's value). 605 values_flat = [array_ops.identity(value) for value in nest.flatten(values)] 606 with ops.control_dependencies(values_flat): 607 insert_ops = [hash_table.insert(keys, value) 608 for hash_table, value 609 in zip(nest.flatten(self._hash_tables), 610 values_flat)] 611 return control_flow_ops.group(*insert_ops) 612 613 def check_table_dtypes(self, key_dtype, value_dtype): 614 # dtype checking is done in the objects in self._hash_tables 615 pass 616 617 618def replicate_state(start_state, batch_size): 619 """Create batch versions of state. 620 621 Takes a list of Tensors, adds a batch dimension, and replicates 622 batch_size times across that batch dimension. Used to replicate the 623 non-batch state returned by get_start_state in define_loss. 624 625 Args: 626 start_state: Model-defined state to replicate. 627 batch_size: Batch dimension for data. 628 Returns: 629 Replicated versions of the state. 630 """ 631 flattened_state = nest.flatten(start_state) 632 replicated_state = [ 633 array_ops.tile( 634 array_ops.expand_dims(state_nonbatch, 0), 635 array_ops.concat([[batch_size], array_ops.ones( 636 [array_ops.rank(state_nonbatch)], dtype=dtypes.int32)], 0)) 637 for state_nonbatch in flattened_state 638 ] 639 return nest.pack_sequence_as(start_state, replicated_state) 640 641 642Moments = collections.namedtuple("Moments", ["mean", "variance"]) 643 644 645# Currently all of these statistics are computed incrementally (i.e. are updated 646# every time a new mini-batch of training data is presented) when this object is 647# created in InputStatisticsFromMiniBatch. 648InputStatistics = collections.namedtuple( 649 "InputStatistics", 650 ["series_start_moments", # The mean and variance of each feature in a chunk 651 # (with a size configured in the statistics 652 # object) at the start of the series. A tuple of 653 # (mean, variance), each with shape [number of 654 # features], floating point. One use is in state 655 # space models, to keep priors calibrated even as 656 # earlier parts of the series are presented. If 657 # this object was created by 658 # InputStatisticsFromMiniBatch, these moments are 659 # computed based on the earliest chunk of data 660 # presented so far. However, there is a race 661 # condition in the update, so these may reflect 662 # statistics later in the series, but should 663 # eventually reflect statistics in a chunk at the 664 # series start. 665 "overall_feature_moments", # The mean and variance of each feature over 666 # the entire series. A tuple of (mean, 667 # variance), each with shape [number of 668 # features]. If this object was created by 669 # InputStatisticsFromMiniBatch, these moments 670 # are estimates based on the data seen so far. 671 "start_time", # The first (lowest) time in the series, a scalar 672 # integer. If this object was created by 673 # InputStatisticsFromMiniBatch, this is the lowest time seen 674 # so far rather than the lowest time that will ever be seen 675 # (guaranteed to be at least as low as the lowest time 676 # presented in the current minibatch). 677 "total_observation_count", # Count of data points, a scalar integer. If 678 # this object was created by 679 # InputStatisticsFromMiniBatch, this is an 680 # estimate of the total number of observations 681 # in the whole dataset computed based on the 682 # density of the series and the minimum and 683 # maximum times seen. 684 ]) 685 686 687# TODO(allenl): It would be nice to do something with full series statistics 688# when the user provides that. 689class InputStatisticsFromMiniBatch(object): 690 """Generate statistics from mini-batch input.""" 691 692 def __init__(self, num_features, dtype, starting_variance_window_size=16): 693 """Configure the input statistics object. 694 695 Args: 696 num_features: Number of features for the time series 697 dtype: The floating point data type to use. 698 starting_variance_window_size: The number of datapoints to use when 699 computing the mean and variance at the start of the series. 700 """ 701 self._starting_variance_window_size = starting_variance_window_size 702 self._num_features = num_features 703 self._dtype = dtype 704 705 def initialize_graph(self, features, update_statistics=True): 706 """Create any ops needed to provide input statistics. 707 708 Should be called before statistics are requested. 709 710 Args: 711 features: A dictionary, the output of a `TimeSeriesInputFn` (with keys 712 TrainEvalFeatures.TIMES and TrainEvalFeatures.VALUES). 713 update_statistics: Whether `features` should be used to update adaptive 714 statistics. Typically True for training and false for evaluation. 715 Returns: 716 An InputStatistics object composed of Variables, which will be updated 717 based on mini-batches of data if requested. 718 """ 719 if (TrainEvalFeatures.TIMES in features 720 and TrainEvalFeatures.VALUES in features): 721 times = features[TrainEvalFeatures.TIMES] 722 values = features[TrainEvalFeatures.VALUES] 723 else: 724 # times and values may not be available, for example during prediction. We 725 # still need to retrieve our variables so that they can be read from, even 726 # if we're not going to update them. 727 times = None 728 values = None 729 # Create/retrieve variables representing input statistics, initialized 730 # without data to avoid deadlocking if variables are initialized before 731 # queue runners are started. 732 with variable_scope.variable_scope("input_statistics", use_resource=True): 733 statistics = self._create_variable_statistics_object() 734 with variable_scope.variable_scope( 735 "input_statistics_auxiliary", use_resource=True): 736 # Secondary statistics, necessary for the incremental computation of the 737 # primary statistics (e.g. counts and sums for computing a mean 738 # incrementally). 739 auxiliary_variables = self._AdaptiveInputAuxiliaryStatistics( 740 num_features=self._num_features, dtype=self._dtype) 741 if update_statistics and times is not None and values is not None: 742 # If we have times and values from mini-batch input, create update ops to 743 # take the new data into account. 744 assign_op = self._update_statistics_from_mini_batch( 745 statistics, auxiliary_variables, times, values) 746 with ops.control_dependencies([assign_op]): 747 stat_variables = nest.pack_sequence_as(statistics, [ 748 array_ops.identity(tensor) for tensor in nest.flatten(statistics) 749 ]) 750 # Since start time updates have a race condition, ensure that the 751 # reported start time is at least as low as the lowest time in this 752 # mini-batch. The start time should converge on the correct value 753 # eventually even with the race condition, but for example state space 754 # models have an assertion which could fail without this 755 # post-processing. 756 return stat_variables._replace(start_time=gen_math_ops.minimum( 757 stat_variables.start_time, math_ops.reduce_min(times))) 758 else: 759 return statistics 760 761 class _AdaptiveInputAuxiliaryStatistics(collections.namedtuple( 762 "_AdaptiveInputAuxiliaryStatistics", 763 ["max_time_seen", # The maximum time seen (best effort if updated from 764 # multiple workers; see notes about race condition 765 # below). 766 "chunk_count", # The number of chunks seen. 767 "inter_observation_duration_sum", # The sum across chunks of their "time 768 # density" (number of times per 769 # example). 770 "example_count", # The number of examples seen (each example has a 771 # single time associated with it and one or more 772 # real-valued features). 773 "overall_feature_sum", # The sum of values for each feature. Shape 774 # [number of features]. 775 "overall_feature_sum_of_squares", # The sum of squared values for each 776 # feature. Shape [number of features] 777 ])): 778 """Extra statistics used to incrementally update InputStatistics.""" 779 780 def __new__(cls, num_features, dtype): 781 return super( 782 InputStatisticsFromMiniBatch # pylint: disable=protected-access 783 ._AdaptiveInputAuxiliaryStatistics, 784 cls).__new__( 785 cls, 786 max_time_seen=variable_scope.get_variable( 787 name="max_time_seen", 788 initializer=dtypes.int64.min, 789 dtype=dtypes.int64, 790 trainable=False), 791 chunk_count=variable_scope.get_variable( 792 name="chunk_count", 793 initializer=init_ops.zeros_initializer(), 794 shape=[], 795 dtype=dtypes.int64, 796 trainable=False), 797 inter_observation_duration_sum=variable_scope.get_variable( 798 name="inter_observation_duration_sum", 799 initializer=init_ops.zeros_initializer(), 800 shape=[], 801 dtype=dtype, 802 trainable=False), 803 example_count=variable_scope.get_variable( 804 name="example_count", 805 shape=[], 806 dtype=dtypes.int64, 807 trainable=False), 808 overall_feature_sum=variable_scope.get_variable( 809 name="overall_feature_sum", 810 shape=[num_features], 811 dtype=dtype, 812 initializer=init_ops.zeros_initializer(), 813 trainable=False), 814 overall_feature_sum_of_squares=variable_scope.get_variable( 815 name="overall_feature_sum_of_squares", 816 shape=[num_features], 817 dtype=dtype, 818 initializer=init_ops.zeros_initializer(), 819 trainable=False)) 820 821 def _update_statistics_from_mini_batch( 822 self, statistics, auxiliary_variables, times, values): 823 """Given mini-batch input, update `statistics` and `auxiliary_variables`.""" 824 values = math_ops.cast(values, self._dtype) 825 # The density (measured in times per observation) that we see in each part 826 # of the mini-batch. 827 batch_inter_observation_duration = (math_ops.cast( 828 math_ops.reduce_max(times, axis=1) - math_ops.reduce_min(times, axis=1), 829 self._dtype) / math_ops.cast( 830 array_ops.shape(times)[1] - 1, self._dtype)) 831 # Co-locate updates with their variables to minimize race conditions when 832 # updating statistics. 833 with ops.device(auxiliary_variables.max_time_seen.device): 834 # There is a race condition if this value is being updated from multiple 835 # workers. However, it should eventually reach the correct value if the 836 # last chunk is presented enough times. 837 max_time_seen_assign = state_ops.assign( 838 auxiliary_variables.max_time_seen, 839 gen_math_ops.maximum(auxiliary_variables.max_time_seen, 840 math_ops.reduce_max(times))) 841 with ops.device(auxiliary_variables.chunk_count.device): 842 chunk_count_assign = state_ops.assign_add(auxiliary_variables.chunk_count, 843 array_ops.shape( 844 times, 845 out_type=dtypes.int64)[0]) 846 with ops.device(auxiliary_variables.inter_observation_duration_sum.device): 847 inter_observation_duration_assign = state_ops.assign_add( 848 auxiliary_variables.inter_observation_duration_sum, 849 math_ops.reduce_sum(batch_inter_observation_duration)) 850 with ops.device(auxiliary_variables.example_count.device): 851 example_count_assign = state_ops.assign_add( 852 auxiliary_variables.example_count, 853 array_ops.size(times, out_type=dtypes.int64)) 854 # Note: These mean/variance updates assume that all points are equally 855 # likely, which is not true if _chunks_ are sampled uniformly from the space 856 # of all possible contiguous chunks, since points at the start and end of 857 # the series are then members of fewer chunks. For series which are much 858 # longer than the chunk size (the usual/expected case), this effect becomes 859 # irrelevant. 860 with ops.device(auxiliary_variables.overall_feature_sum.device): 861 overall_feature_sum_assign = state_ops.assign_add( 862 auxiliary_variables.overall_feature_sum, 863 math_ops.reduce_sum(values, axis=[0, 1])) 864 with ops.device(auxiliary_variables.overall_feature_sum_of_squares.device): 865 overall_feature_sum_of_squares_assign = state_ops.assign_add( 866 auxiliary_variables.overall_feature_sum_of_squares, 867 math_ops.reduce_sum(values**2, axis=[0, 1])) 868 per_chunk_aux_updates = control_flow_ops.group( 869 max_time_seen_assign, chunk_count_assign, 870 inter_observation_duration_assign, example_count_assign, 871 overall_feature_sum_assign, overall_feature_sum_of_squares_assign) 872 with ops.control_dependencies([per_chunk_aux_updates]): 873 example_count_float = math_ops.cast(auxiliary_variables.example_count, 874 self._dtype) 875 new_feature_mean = (auxiliary_variables.overall_feature_sum / 876 example_count_float) 877 overall_feature_mean_update = state_ops.assign( 878 statistics.overall_feature_moments.mean, new_feature_mean) 879 overall_feature_var_update = state_ops.assign( 880 statistics.overall_feature_moments.variance, 881 # De-biased n / (n - 1) variance correction 882 example_count_float / (example_count_float - 1.) * 883 (auxiliary_variables.overall_feature_sum_of_squares / 884 example_count_float - new_feature_mean**2)) 885 # TODO(b/35675805): Remove this cast 886 min_time_batch = math_ops.cast(math_ops.argmin(times[:, 0]), dtypes.int32) 887 def series_start_updates(): 888 # If this is the lowest-time chunk that we have seen so far, update 889 # series start moments to reflect that. Note that these statistics are 890 # "best effort", as there are race conditions in the update (however, 891 # they should eventually converge if the start of the series is 892 # presented enough times). 893 mean, variance = nn.moments( 894 values[min_time_batch, :self._starting_variance_window_size], 895 axes=[0]) 896 return control_flow_ops.group( 897 state_ops.assign(statistics.series_start_moments.mean, mean), 898 state_ops.assign(statistics.series_start_moments.variance, 899 variance)) 900 with ops.device(statistics.start_time.device): 901 series_start_update = control_flow_ops.cond( 902 # Update moments whenever we even match the lowest time seen so far, 903 # to ensure that series start statistics are eventually updated to 904 # their correct values, despite race conditions (i.e. eventually 905 # statistics.start_time will reflect the global lowest time, and 906 # given that we will eventually update the series start moments to 907 # their correct values). 908 math_ops.less_equal(times[min_time_batch, 0], 909 statistics.start_time), 910 series_start_updates, 911 control_flow_ops.no_op) 912 with ops.control_dependencies([series_start_update]): 913 # There is a race condition if this update is performed in parallel on 914 # multiple workers. Since models may be sensitive to being presented 915 # with times before the putative start time, the value of this 916 # variable is post-processed above to guarantee that each worker is 917 # presented with a start time which is at least as low as the lowest 918 # time in its current mini-batch. 919 start_time_update = state_ops.assign(statistics.start_time, 920 gen_math_ops.minimum( 921 statistics.start_time, 922 math_ops.reduce_min(times))) 923 inter_observation_duration_estimate = ( 924 auxiliary_variables.inter_observation_duration_sum / math_ops.cast( 925 auxiliary_variables.chunk_count, self._dtype)) 926 # Estimate the total number of observations as: 927 # (end time - start time + 1) * average intra-chunk time density 928 total_observation_count_update = state_ops.assign( 929 statistics.total_observation_count, 930 math_ops.cast( 931 gen_math_ops.round( 932 math_ops.cast(max_time_seen_assign - 933 start_time_update + 1, self._dtype) / 934 inter_observation_duration_estimate), dtypes.int64)) 935 per_chunk_stat_updates = control_flow_ops.group( 936 overall_feature_mean_update, overall_feature_var_update, 937 series_start_update, start_time_update, 938 total_observation_count_update) 939 return per_chunk_stat_updates 940 941 def _create_variable_statistics_object(self): 942 """Creates non-trainable variables representing input statistics.""" 943 series_start_moments = Moments( 944 mean=variable_scope.get_variable( 945 name="series_start_mean", 946 shape=[self._num_features], 947 dtype=self._dtype, 948 initializer=init_ops.zeros_initializer(), 949 trainable=False), 950 variance=variable_scope.get_variable( 951 name="series_start_variance", 952 shape=[self._num_features], 953 dtype=self._dtype, 954 initializer=init_ops.ones_initializer(), 955 trainable=False)) 956 overall_feature_moments = Moments( 957 mean=variable_scope.get_variable( 958 name="overall_feature_mean", 959 shape=[self._num_features], 960 dtype=self._dtype, 961 initializer=init_ops.zeros_initializer(), 962 trainable=False), 963 variance=variable_scope.get_variable( 964 name="overall_feature_var", 965 shape=[self._num_features], 966 dtype=self._dtype, 967 initializer=init_ops.ones_initializer(), 968 trainable=False)) 969 start_time = variable_scope.get_variable( 970 name="start_time", 971 dtype=dtypes.int64, 972 initializer=dtypes.int64.max, 973 trainable=False) 974 total_observation_count = variable_scope.get_variable( 975 name="total_observation_count", 976 shape=[], 977 dtype=dtypes.int64, 978 initializer=init_ops.ones_initializer(), 979 trainable=False) 980 return InputStatistics( 981 series_start_moments=series_start_moments, 982 overall_feature_moments=overall_feature_moments, 983 start_time=start_time, 984 total_observation_count=total_observation_count) 985