• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for interpolate_spline."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21from scipy import interpolate as sc_interpolate
22
23from tensorflow.contrib.image.python.ops import interpolate_spline
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import clip_ops
31from tensorflow.python.ops import gradients
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import googletest
35
36from tensorflow.python.training import momentum
37
38
39class _InterpolationProblem(object):
40  """Abstract class for interpolation problem descriptions."""
41
42  def get_problem(self, optimizable=False, extrapolate=True, dtype='float32'):
43    """Make data for an interpolation problem where all x vectors are n-d.
44
45    Args:
46      optimizable: If True, then make train_points a tf.Variable.
47      extrapolate: If False, then clamp the query_points values to be within
48      the max and min of train_points.
49      dtype: The data type to use.
50
51    Returns:
52      query_points, query_values, train_points, train_values: training and
53      test tensors for interpolation problem
54    """
55
56    # The values generated here depend on a seed of 0.
57    np.random.seed(0)
58
59    batch_size = 1
60    num_training_points = 10
61    num_query_points = 4
62
63    init_points = np.random.uniform(
64        size=[batch_size, num_training_points, self.DATA_DIM])
65
66    init_points = init_points.astype(dtype)
67    train_points = (
68        variables.Variable(init_points)
69        if optimizable else constant_op.constant(init_points))
70    train_values = self.tf_function(train_points)
71
72    query_points_np = np.random.uniform(
73        size=[batch_size, num_query_points, self.DATA_DIM])
74    query_points_np = query_points_np.astype(dtype)
75    if not extrapolate:
76      query_points_np = np.clip(query_points_np, np.min(init_points),
77                                np.max(init_points))
78
79    query_points = constant_op.constant(query_points_np)
80    query_values = self.np_function(query_points_np)
81
82    return query_points, query_values, train_points, train_values
83
84
85class _QuadraticPlusSinProblem1D(_InterpolationProblem):
86  """1D interpolation problem used for regression testing."""
87  DATA_DIM = 1
88  HARDCODED_QUERY_VALUES = {
89      (1.0, 0.0): [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387],
90      (1.0,
91       0.01): [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693],
92      (2.0,
93       0.0): [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059],
94      (2.0,
95       0.01): [6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741],
96      (3.0,
97       0.0): [9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122],
98      (3.0,
99       0.01): [4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857],
100      (4.0,
101       0.0): [9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256],
102      (4.0, 0.01): [
103          -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725
104      ]
105  }
106
107  def np_function(self, x):
108    """Takes np array, evaluates the test function, and returns np array."""
109    return np.sum(
110        np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10),
111        axis=2,
112        keepdims=True)
113
114  def tf_function(self, x):
115    """Takes tf tensor, evaluates the test function,  and returns tf tensor."""
116    return math_ops.reduce_mean(
117        math_ops.pow((x - 0.5), 3) - 0.25 * x + 10 * math_ops.sin(x * 10),
118        2,
119        keepdims=True)
120
121
122class _QuadraticPlusSinProblemND(_InterpolationProblem):
123  """3D interpolation problem used for regression testing."""
124
125  DATA_DIM = 3
126  HARDCODED_QUERY_VALUES = {
127      (1.0, 0.0): [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885],
128      (1.0, 0.01): [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569],
129      (2.0, 0.0): [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215],
130      (2.0, 0.01): [0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312],
131      (3.0, 0.0): [0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491],
132      (3.0,
133       0.01): [0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866],
134      (4.0,
135       0.0): [0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833],
136      (4.0,
137       0.01): [0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382],
138  }
139
140  def np_function(self, x):
141    """Takes np array, evaluates the test function, and returns np array."""
142    return np.sum(
143        np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15),
144        axis=2,
145        keepdims=True)
146
147  def tf_function(self, x):
148    """Takes tf tensor, evaluates the test function,  and returns tf tensor."""
149    return math_ops.reduce_sum(
150        math_ops.square(x - 0.5) + 0.25 * x + 1 * math_ops.sin(x * 15),
151        2,
152        keepdims=True)
153
154
155class InterpolateSplineTest(test_util.TensorFlowTestCase):
156
157  def test_1d_linear_interpolation(self):
158    """For 1d linear interpolation, we can compare directly to scipy."""
159
160    tp = _QuadraticPlusSinProblem1D()
161    (query_points, _, train_points, train_values) = tp.get_problem(
162        extrapolate=False, dtype='float64')
163    interpolation_order = 1
164
165    with ops.name_scope('interpolator'):
166      interpolator = interpolate_spline.interpolate_spline(
167          train_points, train_values, query_points, interpolation_order)
168      with self.cached_session() as sess:
169        fetches = [query_points, train_points, train_values, interpolator]
170        query_points_, train_points_, train_values_, interp_ = sess.run(fetches)
171
172        # Just look at the first element of the minibatch.
173        # Also, trim the final singleton dimension.
174        interp_ = interp_[0, :, 0]
175        query_points_ = query_points_[0, :, 0]
176        train_points_ = train_points_[0, :, 0]
177        train_values_ = train_values_[0, :, 0]
178
179        # Compute scipy interpolation.
180        scipy_interp_function = sc_interpolate.interp1d(
181            train_points_, train_values_, kind='linear')
182
183        scipy_interpolation = scipy_interp_function(query_points_)
184        scipy_interpolation_on_train = scipy_interp_function(train_points_)
185
186        # Even with float64 precision, the interpolants disagree with scipy a
187        # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc.
188        tol = 1e-3
189
190        self.assertAllClose(
191            train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol)
192        self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol)
193
194  def test_1d_interpolation(self):
195    """Regression test for interpolation with 1-D points."""
196
197    tp = _QuadraticPlusSinProblem1D()
198    (query_points, _, train_points,
199     train_values) = tp.get_problem(dtype='float64')
200
201    for order in (1, 2, 3):
202      for reg_weight in (0, 0.01):
203        interpolator = interpolate_spline.interpolate_spline(
204            train_points, train_values, query_points, order, reg_weight)
205
206        target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
207        target_interpolation = np.array(target_interpolation)
208        with self.cached_session() as sess:
209          interp_val = sess.run(interpolator)
210          self.assertAllClose(interp_val[0, :, 0], target_interpolation)
211
212  def test_nd_linear_interpolation(self):
213    """Regression test for interpolation with N-D points."""
214
215    tp = _QuadraticPlusSinProblemND()
216    (query_points, _, train_points,
217     train_values) = tp.get_problem(dtype='float64')
218
219    for order in (1, 2, 3):
220      for reg_weight in (0, 0.01):
221        interpolator = interpolate_spline.interpolate_spline(
222            train_points, train_values, query_points, order, reg_weight)
223
224        target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
225        target_interpolation = np.array(target_interpolation)
226        with self.cached_session() as sess:
227          interp_val = sess.run(interpolator)
228          self.assertAllClose(interp_val[0, :, 0], target_interpolation)
229
230  def test_nd_linear_interpolation_unspecified_shape(self):
231    """Ensure that interpolation supports dynamic batch_size and num_points."""
232
233    tp = _QuadraticPlusSinProblemND()
234    (query_points, _, train_points,
235     train_values) = tp.get_problem(dtype='float64')
236
237    # Construct placeholders such that the batch size, number of train points,
238    # and number of query points are not known at graph construction time.
239    feature_dim = query_points.shape[-1]
240    value_dim = train_values.shape[-1]
241    train_points_ph = array_ops.placeholder(
242        dtype=train_points.dtype, shape=[None, None, feature_dim])
243    train_values_ph = array_ops.placeholder(
244        dtype=train_values.dtype, shape=[None, None, value_dim])
245    query_points_ph = array_ops.placeholder(
246        dtype=query_points.dtype, shape=[None, None, feature_dim])
247
248    order = 1
249    reg_weight = 0.01
250
251    interpolator = interpolate_spline.interpolate_spline(
252        train_points_ph, train_values_ph, query_points_ph, order, reg_weight)
253
254    target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
255    target_interpolation = np.array(target_interpolation)
256    with self.cached_session() as sess:
257
258      (train_points_value, train_values_value, query_points_value) = sess.run(
259          [train_points, train_values, query_points])
260
261      interp_val = sess.run(
262          interpolator,
263          feed_dict={
264              train_points_ph: train_points_value,
265              train_values_ph: train_values_value,
266              query_points_ph: query_points_value
267          })
268      self.assertAllClose(interp_val[0, :, 0], target_interpolation)
269
270  def test_fully_unspecified_shape(self):
271    """Ensure that erreor is thrown when input/output dim unspecified."""
272
273    tp = _QuadraticPlusSinProblemND()
274    (query_points, _, train_points,
275     train_values) = tp.get_problem(dtype='float64')
276
277    # Construct placeholders such that the batch size, number of train points,
278    # and number of query points are not known at graph construction time.
279    feature_dim = query_points.shape[-1]
280    value_dim = train_values.shape[-1]
281    train_points_ph = array_ops.placeholder(
282        dtype=train_points.dtype, shape=[None, None, feature_dim])
283    train_points_ph_invalid = array_ops.placeholder(
284        dtype=train_points.dtype, shape=[None, None, None])
285    train_values_ph = array_ops.placeholder(
286        dtype=train_values.dtype, shape=[None, None, value_dim])
287    train_values_ph_invalid = array_ops.placeholder(
288        dtype=train_values.dtype, shape=[None, None, None])
289    query_points_ph = array_ops.placeholder(
290        dtype=query_points.dtype, shape=[None, None, feature_dim])
291
292    order = 1
293    reg_weight = 0.01
294
295    with self.assertRaises(ValueError):
296      _ = interpolate_spline.interpolate_spline(
297          train_points_ph_invalid, train_values_ph, query_points_ph, order,
298          reg_weight)
299
300    with self.assertRaises(ValueError):
301      _ = interpolate_spline.interpolate_spline(
302          train_points_ph, train_values_ph_invalid, query_points_ph, order,
303          reg_weight)
304
305  def test_interpolation_gradient(self):
306    """Make sure that backprop can run. Correctness of gradients is assumed.
307
308    Here, we create a use a small 'training' set and a more densely-sampled
309    set of query points, for which we know the true value in advance. The goal
310    is to choose x locations for the training data such that interpolating using
311    this training data yields the best reconstruction for the function
312    values at the query points. The training data locations are optimized
313    iteratively using gradient descent.
314    """
315    tp = _QuadraticPlusSinProblemND()
316    (query_points, query_values, train_points,
317     train_values) = tp.get_problem(optimizable=True)
318
319    regularization = 0.001
320    for interpolation_order in (1, 2, 3, 4):
321      interpolator = interpolate_spline.interpolate_spline(
322          train_points, train_values, query_points, interpolation_order,
323          regularization)
324
325      loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator))
326
327      optimizer = momentum.MomentumOptimizer(0.001, 0.9)
328      grad = gradients.gradients(loss, [train_points])
329      grad, _ = clip_ops.clip_by_global_norm(grad, 1.0)
330      opt_func = optimizer.apply_gradients(zip(grad, [train_points]))
331      init_op = variables.global_variables_initializer()
332
333      with self.cached_session() as sess:
334        sess.run(init_op)
335        for _ in range(100):
336          sess.run([loss, opt_func])
337
338
339if __name__ == '__main__':
340  googletest.main()
341