1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for interpolate_spline.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21from scipy import interpolate as sc_interpolate 22 23from tensorflow.contrib.image.python.ops import interpolate_spline 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import clip_ops 31from tensorflow.python.ops import gradients 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import googletest 35 36from tensorflow.python.training import momentum 37 38 39class _InterpolationProblem(object): 40 """Abstract class for interpolation problem descriptions.""" 41 42 def get_problem(self, optimizable=False, extrapolate=True, dtype='float32'): 43 """Make data for an interpolation problem where all x vectors are n-d. 44 45 Args: 46 optimizable: If True, then make train_points a tf.Variable. 47 extrapolate: If False, then clamp the query_points values to be within 48 the max and min of train_points. 49 dtype: The data type to use. 50 51 Returns: 52 query_points, query_values, train_points, train_values: training and 53 test tensors for interpolation problem 54 """ 55 56 # The values generated here depend on a seed of 0. 57 np.random.seed(0) 58 59 batch_size = 1 60 num_training_points = 10 61 num_query_points = 4 62 63 init_points = np.random.uniform( 64 size=[batch_size, num_training_points, self.DATA_DIM]) 65 66 init_points = init_points.astype(dtype) 67 train_points = ( 68 variables.Variable(init_points) 69 if optimizable else constant_op.constant(init_points)) 70 train_values = self.tf_function(train_points) 71 72 query_points_np = np.random.uniform( 73 size=[batch_size, num_query_points, self.DATA_DIM]) 74 query_points_np = query_points_np.astype(dtype) 75 if not extrapolate: 76 query_points_np = np.clip(query_points_np, np.min(init_points), 77 np.max(init_points)) 78 79 query_points = constant_op.constant(query_points_np) 80 query_values = self.np_function(query_points_np) 81 82 return query_points, query_values, train_points, train_values 83 84 85class _QuadraticPlusSinProblem1D(_InterpolationProblem): 86 """1D interpolation problem used for regression testing.""" 87 DATA_DIM = 1 88 HARDCODED_QUERY_VALUES = { 89 (1.0, 0.0): [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387], 90 (1.0, 91 0.01): [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693], 92 (2.0, 93 0.0): [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059], 94 (2.0, 95 0.01): [6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741], 96 (3.0, 97 0.0): [9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122], 98 (3.0, 99 0.01): [4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857], 100 (4.0, 101 0.0): [9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256], 102 (4.0, 0.01): [ 103 -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725 104 ] 105 } 106 107 def np_function(self, x): 108 """Takes np array, evaluates the test function, and returns np array.""" 109 return np.sum( 110 np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10), 111 axis=2, 112 keepdims=True) 113 114 def tf_function(self, x): 115 """Takes tf tensor, evaluates the test function, and returns tf tensor.""" 116 return math_ops.reduce_mean( 117 math_ops.pow((x - 0.5), 3) - 0.25 * x + 10 * math_ops.sin(x * 10), 118 2, 119 keepdims=True) 120 121 122class _QuadraticPlusSinProblemND(_InterpolationProblem): 123 """3D interpolation problem used for regression testing.""" 124 125 DATA_DIM = 3 126 HARDCODED_QUERY_VALUES = { 127 (1.0, 0.0): [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885], 128 (1.0, 0.01): [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569], 129 (2.0, 0.0): [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215], 130 (2.0, 0.01): [0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312], 131 (3.0, 0.0): [0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491], 132 (3.0, 133 0.01): [0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866], 134 (4.0, 135 0.0): [0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833], 136 (4.0, 137 0.01): [0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382], 138 } 139 140 def np_function(self, x): 141 """Takes np array, evaluates the test function, and returns np array.""" 142 return np.sum( 143 np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15), 144 axis=2, 145 keepdims=True) 146 147 def tf_function(self, x): 148 """Takes tf tensor, evaluates the test function, and returns tf tensor.""" 149 return math_ops.reduce_sum( 150 math_ops.square(x - 0.5) + 0.25 * x + 1 * math_ops.sin(x * 15), 151 2, 152 keepdims=True) 153 154 155class InterpolateSplineTest(test_util.TensorFlowTestCase): 156 157 def test_1d_linear_interpolation(self): 158 """For 1d linear interpolation, we can compare directly to scipy.""" 159 160 tp = _QuadraticPlusSinProblem1D() 161 (query_points, _, train_points, train_values) = tp.get_problem( 162 extrapolate=False, dtype='float64') 163 interpolation_order = 1 164 165 with ops.name_scope('interpolator'): 166 interpolator = interpolate_spline.interpolate_spline( 167 train_points, train_values, query_points, interpolation_order) 168 with self.cached_session() as sess: 169 fetches = [query_points, train_points, train_values, interpolator] 170 query_points_, train_points_, train_values_, interp_ = sess.run(fetches) 171 172 # Just look at the first element of the minibatch. 173 # Also, trim the final singleton dimension. 174 interp_ = interp_[0, :, 0] 175 query_points_ = query_points_[0, :, 0] 176 train_points_ = train_points_[0, :, 0] 177 train_values_ = train_values_[0, :, 0] 178 179 # Compute scipy interpolation. 180 scipy_interp_function = sc_interpolate.interp1d( 181 train_points_, train_values_, kind='linear') 182 183 scipy_interpolation = scipy_interp_function(query_points_) 184 scipy_interpolation_on_train = scipy_interp_function(train_points_) 185 186 # Even with float64 precision, the interpolants disagree with scipy a 187 # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc. 188 tol = 1e-3 189 190 self.assertAllClose( 191 train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol) 192 self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol) 193 194 def test_1d_interpolation(self): 195 """Regression test for interpolation with 1-D points.""" 196 197 tp = _QuadraticPlusSinProblem1D() 198 (query_points, _, train_points, 199 train_values) = tp.get_problem(dtype='float64') 200 201 for order in (1, 2, 3): 202 for reg_weight in (0, 0.01): 203 interpolator = interpolate_spline.interpolate_spline( 204 train_points, train_values, query_points, order, reg_weight) 205 206 target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] 207 target_interpolation = np.array(target_interpolation) 208 with self.cached_session() as sess: 209 interp_val = sess.run(interpolator) 210 self.assertAllClose(interp_val[0, :, 0], target_interpolation) 211 212 def test_nd_linear_interpolation(self): 213 """Regression test for interpolation with N-D points.""" 214 215 tp = _QuadraticPlusSinProblemND() 216 (query_points, _, train_points, 217 train_values) = tp.get_problem(dtype='float64') 218 219 for order in (1, 2, 3): 220 for reg_weight in (0, 0.01): 221 interpolator = interpolate_spline.interpolate_spline( 222 train_points, train_values, query_points, order, reg_weight) 223 224 target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] 225 target_interpolation = np.array(target_interpolation) 226 with self.cached_session() as sess: 227 interp_val = sess.run(interpolator) 228 self.assertAllClose(interp_val[0, :, 0], target_interpolation) 229 230 def test_nd_linear_interpolation_unspecified_shape(self): 231 """Ensure that interpolation supports dynamic batch_size and num_points.""" 232 233 tp = _QuadraticPlusSinProblemND() 234 (query_points, _, train_points, 235 train_values) = tp.get_problem(dtype='float64') 236 237 # Construct placeholders such that the batch size, number of train points, 238 # and number of query points are not known at graph construction time. 239 feature_dim = query_points.shape[-1] 240 value_dim = train_values.shape[-1] 241 train_points_ph = array_ops.placeholder( 242 dtype=train_points.dtype, shape=[None, None, feature_dim]) 243 train_values_ph = array_ops.placeholder( 244 dtype=train_values.dtype, shape=[None, None, value_dim]) 245 query_points_ph = array_ops.placeholder( 246 dtype=query_points.dtype, shape=[None, None, feature_dim]) 247 248 order = 1 249 reg_weight = 0.01 250 251 interpolator = interpolate_spline.interpolate_spline( 252 train_points_ph, train_values_ph, query_points_ph, order, reg_weight) 253 254 target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] 255 target_interpolation = np.array(target_interpolation) 256 with self.cached_session() as sess: 257 258 (train_points_value, train_values_value, query_points_value) = sess.run( 259 [train_points, train_values, query_points]) 260 261 interp_val = sess.run( 262 interpolator, 263 feed_dict={ 264 train_points_ph: train_points_value, 265 train_values_ph: train_values_value, 266 query_points_ph: query_points_value 267 }) 268 self.assertAllClose(interp_val[0, :, 0], target_interpolation) 269 270 def test_fully_unspecified_shape(self): 271 """Ensure that erreor is thrown when input/output dim unspecified.""" 272 273 tp = _QuadraticPlusSinProblemND() 274 (query_points, _, train_points, 275 train_values) = tp.get_problem(dtype='float64') 276 277 # Construct placeholders such that the batch size, number of train points, 278 # and number of query points are not known at graph construction time. 279 feature_dim = query_points.shape[-1] 280 value_dim = train_values.shape[-1] 281 train_points_ph = array_ops.placeholder( 282 dtype=train_points.dtype, shape=[None, None, feature_dim]) 283 train_points_ph_invalid = array_ops.placeholder( 284 dtype=train_points.dtype, shape=[None, None, None]) 285 train_values_ph = array_ops.placeholder( 286 dtype=train_values.dtype, shape=[None, None, value_dim]) 287 train_values_ph_invalid = array_ops.placeholder( 288 dtype=train_values.dtype, shape=[None, None, None]) 289 query_points_ph = array_ops.placeholder( 290 dtype=query_points.dtype, shape=[None, None, feature_dim]) 291 292 order = 1 293 reg_weight = 0.01 294 295 with self.assertRaises(ValueError): 296 _ = interpolate_spline.interpolate_spline( 297 train_points_ph_invalid, train_values_ph, query_points_ph, order, 298 reg_weight) 299 300 with self.assertRaises(ValueError): 301 _ = interpolate_spline.interpolate_spline( 302 train_points_ph, train_values_ph_invalid, query_points_ph, order, 303 reg_weight) 304 305 def test_interpolation_gradient(self): 306 """Make sure that backprop can run. Correctness of gradients is assumed. 307 308 Here, we create a use a small 'training' set and a more densely-sampled 309 set of query points, for which we know the true value in advance. The goal 310 is to choose x locations for the training data such that interpolating using 311 this training data yields the best reconstruction for the function 312 values at the query points. The training data locations are optimized 313 iteratively using gradient descent. 314 """ 315 tp = _QuadraticPlusSinProblemND() 316 (query_points, query_values, train_points, 317 train_values) = tp.get_problem(optimizable=True) 318 319 regularization = 0.001 320 for interpolation_order in (1, 2, 3, 4): 321 interpolator = interpolate_spline.interpolate_spline( 322 train_points, train_values, query_points, interpolation_order, 323 regularization) 324 325 loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator)) 326 327 optimizer = momentum.MomentumOptimizer(0.001, 0.9) 328 grad = gradients.gradients(loss, [train_points]) 329 grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) 330 opt_func = optimizer.apply_gradients(zip(grad, [train_points])) 331 init_op = variables.global_variables_initializer() 332 333 with self.cached_session() as sess: 334 sess.run(init_op) 335 for _ in range(100): 336 sess.run([loss, opt_func]) 337 338 339if __name__ == '__main__': 340 googletest.main() 341