1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Polyharmonic spline interpolation.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import ops 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import linalg_ops 24from tensorflow.python.ops import math_ops 25 26EPSILON = 0.0000000001 27 28 29def _cross_squared_distance_matrix(x, y): 30 """Pairwise squared distance between two (batch) matrices' rows (2nd dim). 31 32 Computes the pairwise distances between rows of x and rows of y 33 Args: 34 x: [batch_size, n, d] float `Tensor` 35 y: [batch_size, m, d] float `Tensor` 36 37 Returns: 38 squared_dists: [batch_size, n, m] float `Tensor`, where 39 squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 40 """ 41 x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2) 42 y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2) 43 44 # Expand so that we can broadcast. 45 x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) 46 y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1) 47 48 x_y_transpose = math_ops.matmul(x, y, adjoint_b=True) 49 50 # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj 51 squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile 52 53 return squared_dists 54 55 56def _pairwise_squared_distance_matrix(x): 57 """Pairwise squared distance among a (batch) matrix's rows (2nd dim). 58 59 This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x) 60 61 Args: 62 x: `[batch_size, n, d]` float `Tensor` 63 64 Returns: 65 squared_dists: `[batch_size, n, n]` float `Tensor`, where 66 squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2 67 """ 68 69 x_x_transpose = math_ops.matmul(x, x, adjoint_b=True) 70 x_norm_squared = array_ops.matrix_diag_part(x_x_transpose) 71 x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) 72 73 # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj 74 squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose( 75 x_norm_squared_tile, [0, 2, 1]) 76 77 return squared_dists 78 79 80def _solve_interpolation(train_points, train_values, order, 81 regularization_weight): 82 """Solve for interpolation coefficients. 83 84 Computes the coefficients of the polyharmonic interpolant for the 'training' 85 data defined by (train_points, train_values) using the kernel phi. 86 87 Args: 88 train_points: `[b, n, d]` interpolation centers 89 train_values: `[b, n, k]` function values 90 order: order of the interpolation 91 regularization_weight: weight to place on smoothness regularization term 92 93 Returns: 94 w: `[b, n, k]` weights on each interpolation center 95 v: `[b, d, k]` weights on each input dimension 96 Raises: 97 ValueError: if d or k is not fully specified. 98 """ 99 100 # These dimensions are set dynamically at runtime. 101 b, n, _ = array_ops.unstack(array_ops.shape(train_points), num=3) 102 103 d = train_points.shape[-1] 104 if tensor_shape.dimension_value(d) is None: 105 raise ValueError('The dimensionality of the input points (d) must be ' 106 'statically-inferrable.') 107 108 k = train_values.shape[-1] 109 if tensor_shape.dimension_value(k) is None: 110 raise ValueError('The dimensionality of the output values (k) must be ' 111 'statically-inferrable.') 112 113 # First, rename variables so that the notation (c, f, w, v, A, B, etc.) 114 # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. 115 # To account for python style guidelines we use 116 # matrix_a for A and matrix_b for B. 117 118 c = train_points 119 f = train_values 120 121 # Next, construct the linear system. 122 with ops.name_scope('construct_linear_system'): 123 124 matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n] 125 if regularization_weight > 0: 126 batch_identity_matrix = array_ops.expand_dims( 127 linalg_ops.eye(n, dtype=c.dtype), 0) 128 matrix_a += regularization_weight * batch_identity_matrix 129 130 # Append ones to the feature values for the bias term in the linear model. 131 ones = array_ops.ones_like(c[..., :1], dtype=c.dtype) 132 matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1] 133 134 # [b, n + d + 1, n] 135 left_block = array_ops.concat( 136 [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1) 137 138 num_b_cols = matrix_b.get_shape()[2] # d + 1 139 lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype) 140 right_block = array_ops.concat([matrix_b, lhs_zeros], 141 1) # [b, n + d + 1, d + 1] 142 lhs = array_ops.concat([left_block, right_block], 143 2) # [b, n + d + 1, n + d + 1] 144 145 rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype) 146 rhs = array_ops.concat([f, rhs_zeros], 1) # [b, n + d + 1, k] 147 148 # Then, solve the linear system and unpack the results. 149 with ops.name_scope('solve_linear_system'): 150 w_v = linalg_ops.matrix_solve(lhs, rhs) 151 w = w_v[:, :n, :] 152 v = w_v[:, n:, :] 153 154 return w, v 155 156 157def _apply_interpolation(query_points, train_points, w, v, order): 158 """Apply polyharmonic interpolation model to data. 159 160 Given coefficients w and v for the interpolation model, we evaluate 161 interpolated function values at query_points. 162 163 Args: 164 query_points: `[b, m, d]` x values to evaluate the interpolation at 165 train_points: `[b, n, d]` x values that act as the interpolation centers 166 ( the c variables in the wikipedia article) 167 w: `[b, n, k]` weights on each interpolation center 168 v: `[b, d, k]` weights on each input dimension 169 order: order of the interpolation 170 171 Returns: 172 Polyharmonic interpolation evaluated at points defined in query_points. 173 """ 174 175 # First, compute the contribution from the rbf term. 176 pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) 177 phi_pairwise_dists = _phi(pairwise_dists, order) 178 179 rbf_term = math_ops.matmul(phi_pairwise_dists, w) 180 181 # Then, compute the contribution from the linear term. 182 # Pad query_points with ones, for the bias term in the linear model. 183 query_points_pad = array_ops.concat([ 184 query_points, 185 array_ops.ones_like(query_points[..., :1], train_points.dtype) 186 ], 2) 187 linear_term = math_ops.matmul(query_points_pad, v) 188 189 return rbf_term + linear_term 190 191 192def _phi(r, order): 193 """Coordinate-wise nonlinearity used to define the order of the interpolation. 194 195 See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. 196 197 Args: 198 r: input op 199 order: interpolation order 200 201 Returns: 202 phi_k evaluated coordinate-wise on r, for k = r 203 """ 204 205 # using EPSILON prevents log(0), sqrt0), etc. 206 # sqrt(0) is well-defined, but its gradient is not 207 with ops.name_scope('phi'): 208 if order == 1: 209 r = math_ops.maximum(r, EPSILON) 210 r = math_ops.sqrt(r) 211 return r 212 elif order == 2: 213 return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON)) 214 elif order == 4: 215 return 0.5 * math_ops.square(r) * math_ops.log( 216 math_ops.maximum(r, EPSILON)) 217 elif order % 2 == 0: 218 r = math_ops.maximum(r, EPSILON) 219 return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r) 220 else: 221 r = math_ops.maximum(r, EPSILON) 222 return math_ops.pow(r, 0.5 * order) 223 224 225def interpolate_spline(train_points, 226 train_values, 227 query_points, 228 order, 229 regularization_weight=0.0, 230 name='interpolate_spline'): 231 r"""Interpolate signal using polyharmonic interpolation. 232 233 The interpolant has the form 234 $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$ 235 236 This is a sum of two terms: (1) a weighted sum of radial basis function (RBF) 237 terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias. 238 The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v 239 by appending 1 as a final dimension to x. The coefficients w and v are 240 estimated such that the interpolant exactly fits the value of the function at 241 the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the 242 vector w sums to 0. With these constraints, the coefficients can be obtained 243 by solving a linear system. 244 245 \\(\phi\\) is an RBF, parametrized by an interpolation 246 order. Using order=2 produces the well-known thin-plate spline. 247 248 We also provide the option to perform regularized interpolation. Here, the 249 interpolant is selected to trade off between the squared loss on the training 250 data and a certain measure of its curvature 251 ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)). 252 Using a regularization weight greater than zero has the effect that the 253 interpolant will no longer exactly fit the training data. However, it may be 254 less vulnerable to overfitting, particularly for high-order interpolation. 255 256 Note the interpolation procedure is differentiable with respect to all inputs 257 besides the order parameter. 258 259 We support dynamically-shaped inputs, where batch_size, n, and m are None 260 at graph construction time. However, d and k must be known. 261 262 Args: 263 train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional 264 locations. These do not need to be regularly-spaced. 265 train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values 266 evaluated at train_points. 267 query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations 268 where we will output the interpolant's values. 269 order: order of the interpolation. Common values are 1 for 270 \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline), 271 or 3 for \\(\phi(r) = r^3\\). 272 regularization_weight: weight placed on the regularization term. 273 This will depend substantially on the problem, and it should always be 274 tuned. For many problems, it is reasonable to use no regularization. 275 If using a non-zero value, we recommend a small value like 0.001. 276 name: name prefix for ops created by this function 277 278 Returns: 279 `[b, m, k]` float `Tensor` of query values. We use train_points and 280 train_values to perform polyharmonic interpolation. The query values are 281 the values of the interpolant evaluated at the locations specified in 282 query_points. 283 """ 284 with ops.name_scope(name): 285 train_points = ops.convert_to_tensor(train_points) 286 train_values = ops.convert_to_tensor(train_values) 287 query_points = ops.convert_to_tensor(query_points) 288 289 # First, fit the spline to the observed data. 290 with ops.name_scope('solve'): 291 w, v = _solve_interpolation(train_points, train_values, order, 292 regularization_weight) 293 294 # Then, evaluate the spline at the query locations. 295 with ops.name_scope('predict'): 296 query_values = _apply_interpolation(query_points, train_points, w, v, 297 order) 298 299 return query_values 300