1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Linear regression tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib.learn.python import learn 24from tensorflow.python.platform import test 25 26 27class RegressionTest(test.TestCase): 28 """Linear regression tests.""" 29 30 def testLinearRegression(self): 31 rng = np.random.RandomState(67) 32 n = 1000 33 n_weights = 10 34 bias = 2 35 x = rng.uniform(-1, 1, (n, n_weights)) 36 weights = 10 * rng.randn(n_weights) 37 y = np.dot(x, weights) 38 y += rng.randn(len(x)) * 0.05 + rng.normal(bias, 0.01) 39 regressor = learn.LinearRegressor( 40 feature_columns=learn.infer_real_valued_columns_from_input(x), 41 optimizer="SGD") 42 regressor.fit(x, y, steps=200) 43 self.assertIn("linear//weight", regressor.get_variable_names()) 44 regressor_weights = regressor.get_variable_value("linear//weight") 45 # Have to flatten weights since they come in (x, 1) shape. 46 self.assertAllClose(weights, regressor_weights.flatten(), rtol=0.01) 47 # TODO(ispir): Disable centered_bias. 48 # assert abs(bias - regressor.bias_) < 0.1 49 50 51if __name__ == "__main__": 52 test.main() 53