• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Polyharmonic spline interpolation."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import linalg_ops
24from tensorflow.python.ops import math_ops
25
26EPSILON = 0.0000000001
27
28
29def _cross_squared_distance_matrix(x, y):
30  """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
31
32  Computes the pairwise distances between rows of x and rows of y
33  Args:
34    x: [batch_size, n, d] float `Tensor`
35    y: [batch_size, m, d] float `Tensor`
36
37  Returns:
38    squared_dists: [batch_size, n, m] float `Tensor`, where
39    squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
40  """
41  x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2)
42  y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2)
43
44  # Expand so that we can broadcast.
45  x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)
46  y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1)
47
48  x_y_transpose = math_ops.matmul(x, y, adjoint_b=True)
49
50  # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
51  squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile
52
53  return squared_dists
54
55
56def _pairwise_squared_distance_matrix(x):
57  """Pairwise squared distance among a (batch) matrix's rows (2nd dim).
58
59  This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x)
60
61  Args:
62    x: `[batch_size, n, d]` float `Tensor`
63
64  Returns:
65    squared_dists: `[batch_size, n, n]` float `Tensor`, where
66    squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2
67  """
68
69  x_x_transpose = math_ops.matmul(x, x, adjoint_b=True)
70  x_norm_squared = array_ops.matrix_diag_part(x_x_transpose)
71  x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)
72
73  # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
74  squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose(
75      x_norm_squared_tile, [0, 2, 1])
76
77  return squared_dists
78
79
80def _solve_interpolation(train_points, train_values, order,
81                         regularization_weight):
82  """Solve for interpolation coefficients.
83
84  Computes the coefficients of the polyharmonic interpolant for the 'training'
85  data defined by (train_points, train_values) using the kernel phi.
86
87  Args:
88    train_points: `[b, n, d]` interpolation centers
89    train_values: `[b, n, k]` function values
90    order: order of the interpolation
91    regularization_weight: weight to place on smoothness regularization term
92
93  Returns:
94    w: `[b, n, k]` weights on each interpolation center
95    v: `[b, d, k]` weights on each input dimension
96  Raises:
97    ValueError: if d or k is not fully specified.
98  """
99
100  # These dimensions are set dynamically at runtime.
101  b, n, _ = array_ops.unstack(array_ops.shape(train_points), num=3)
102
103  d = train_points.shape[-1]
104  if tensor_shape.dimension_value(d) is None:
105    raise ValueError('The dimensionality of the input points (d) must be '
106                     'statically-inferrable.')
107
108  k = train_values.shape[-1]
109  if tensor_shape.dimension_value(k) is None:
110    raise ValueError('The dimensionality of the output values (k) must be '
111                     'statically-inferrable.')
112
113  # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
114  # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
115  # To account for python style guidelines we use
116  # matrix_a for A and matrix_b for B.
117
118  c = train_points
119  f = train_values
120
121  # Next, construct the linear system.
122  with ops.name_scope('construct_linear_system'):
123
124    matrix_a = _phi(_pairwise_squared_distance_matrix(c), order)  # [b, n, n]
125    if regularization_weight > 0:
126      batch_identity_matrix = array_ops.expand_dims(
127          linalg_ops.eye(n, dtype=c.dtype), 0)
128      matrix_a += regularization_weight * batch_identity_matrix
129
130    # Append ones to the feature values for the bias term in the linear model.
131    ones = array_ops.ones_like(c[..., :1], dtype=c.dtype)
132    matrix_b = array_ops.concat([c, ones], 2)  # [b, n, d + 1]
133
134    # [b, n + d + 1, n]
135    left_block = array_ops.concat(
136        [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1)
137
138    num_b_cols = matrix_b.get_shape()[2]  # d + 1
139    lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype)
140    right_block = array_ops.concat([matrix_b, lhs_zeros],
141                                   1)  # [b, n + d + 1, d + 1]
142    lhs = array_ops.concat([left_block, right_block],
143                           2)  # [b, n + d + 1, n + d + 1]
144
145    rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype)
146    rhs = array_ops.concat([f, rhs_zeros], 1)  # [b, n + d + 1, k]
147
148  # Then, solve the linear system and unpack the results.
149  with ops.name_scope('solve_linear_system'):
150    w_v = linalg_ops.matrix_solve(lhs, rhs)
151    w = w_v[:, :n, :]
152    v = w_v[:, n:, :]
153
154  return w, v
155
156
157def _apply_interpolation(query_points, train_points, w, v, order):
158  """Apply polyharmonic interpolation model to data.
159
160  Given coefficients w and v for the interpolation model, we evaluate
161  interpolated function values at query_points.
162
163  Args:
164    query_points: `[b, m, d]` x values to evaluate the interpolation at
165    train_points: `[b, n, d]` x values that act as the interpolation centers
166                    ( the c variables in the wikipedia article)
167    w: `[b, n, k]` weights on each interpolation center
168    v: `[b, d, k]` weights on each input dimension
169    order: order of the interpolation
170
171  Returns:
172    Polyharmonic interpolation evaluated at points defined in query_points.
173  """
174
175  # First, compute the contribution from the rbf term.
176  pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
177  phi_pairwise_dists = _phi(pairwise_dists, order)
178
179  rbf_term = math_ops.matmul(phi_pairwise_dists, w)
180
181  # Then, compute the contribution from the linear term.
182  # Pad query_points with ones, for the bias term in the linear model.
183  query_points_pad = array_ops.concat([
184      query_points,
185      array_ops.ones_like(query_points[..., :1], train_points.dtype)
186  ], 2)
187  linear_term = math_ops.matmul(query_points_pad, v)
188
189  return rbf_term + linear_term
190
191
192def _phi(r, order):
193  """Coordinate-wise nonlinearity used to define the order of the interpolation.
194
195  See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
196
197  Args:
198    r: input op
199    order: interpolation order
200
201  Returns:
202    phi_k evaluated coordinate-wise on r, for k = r
203  """
204
205  # using EPSILON prevents log(0), sqrt0), etc.
206  # sqrt(0) is well-defined, but its gradient is not
207  with ops.name_scope('phi'):
208    if order == 1:
209      r = math_ops.maximum(r, EPSILON)
210      r = math_ops.sqrt(r)
211      return r
212    elif order == 2:
213      return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON))
214    elif order == 4:
215      return 0.5 * math_ops.square(r) * math_ops.log(
216          math_ops.maximum(r, EPSILON))
217    elif order % 2 == 0:
218      r = math_ops.maximum(r, EPSILON)
219      return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r)
220    else:
221      r = math_ops.maximum(r, EPSILON)
222      return math_ops.pow(r, 0.5 * order)
223
224
225def interpolate_spline(train_points,
226                       train_values,
227                       query_points,
228                       order,
229                       regularization_weight=0.0,
230                       name='interpolate_spline'):
231  r"""Interpolate signal using polyharmonic interpolation.
232
233  The interpolant has the form
234  $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$
235
236  This is a sum of two terms: (1) a weighted sum of radial basis function (RBF)
237  terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias.
238  The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v
239  by appending 1 as a final dimension to x. The coefficients w and v are
240  estimated such that the interpolant exactly fits the value of the function at
241  the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the
242  vector w sums to 0. With these constraints, the coefficients can be obtained
243  by solving a linear system.
244
245  \\(\phi\\) is an RBF, parametrized by an interpolation
246  order. Using order=2 produces the well-known thin-plate spline.
247
248  We also provide the option to perform regularized interpolation. Here, the
249  interpolant is selected to trade off between the squared loss on the training
250  data and a certain measure of its curvature
251  ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)).
252  Using a regularization weight greater than zero has the effect that the
253  interpolant will no longer exactly fit the training data. However, it may be
254  less vulnerable to overfitting, particularly for high-order interpolation.
255
256  Note the interpolation procedure is differentiable with respect to all inputs
257  besides the order parameter.
258
259  We support dynamically-shaped inputs, where batch_size, n, and m are None
260  at graph construction time. However, d and k must be known.
261
262  Args:
263    train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
264      locations. These do not need to be regularly-spaced.
265    train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values
266      evaluated at train_points.
267    query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations
268      where we will output the interpolant's values.
269    order: order of the interpolation. Common values are 1 for
270      \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline),
271       or 3 for \\(\phi(r) = r^3\\).
272    regularization_weight: weight placed on the regularization term.
273      This will depend substantially on the problem, and it should always be
274      tuned. For many problems, it is reasonable to use no regularization.
275      If using a non-zero value, we recommend a small value like 0.001.
276    name: name prefix for ops created by this function
277
278  Returns:
279    `[b, m, k]` float `Tensor` of query values. We use train_points and
280    train_values to perform polyharmonic interpolation. The query values are
281    the values of the interpolant evaluated at the locations specified in
282    query_points.
283  """
284  with ops.name_scope(name):
285    train_points = ops.convert_to_tensor(train_points)
286    train_values = ops.convert_to_tensor(train_values)
287    query_points = ops.convert_to_tensor(query_points)
288
289    # First, fit the spline to the observed data.
290    with ops.name_scope('solve'):
291      w, v = _solve_interpolation(train_points, train_values, order,
292                                  regularization_weight)
293
294    # Then, evaluate the spline at the query locations.
295    with ops.name_scope('predict'):
296      query_values = _apply_interpolation(query_points, train_points, w, v,
297                                          order)
298
299  return query_values
300