• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Linear regression tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.learn.python import learn
24from tensorflow.python.platform import test
25
26
27class RegressionTest(test.TestCase):
28  """Linear regression tests."""
29
30  def testLinearRegression(self):
31    rng = np.random.RandomState(67)
32    n = 1000
33    n_weights = 10
34    bias = 2
35    x = rng.uniform(-1, 1, (n, n_weights))
36    weights = 10 * rng.randn(n_weights)
37    y = np.dot(x, weights)
38    y += rng.randn(len(x)) * 0.05 + rng.normal(bias, 0.01)
39    regressor = learn.LinearRegressor(
40        feature_columns=learn.infer_real_valued_columns_from_input(x),
41        optimizer="SGD")
42    regressor.fit(x, y, steps=200)
43    self.assertIn("linear//weight", regressor.get_variable_names())
44    regressor_weights = regressor.get_variable_value("linear//weight")
45    # Have to flatten weights since they come in (x, 1) shape.
46    self.assertAllClose(weights, regressor_weights.flatten(), rtol=0.01)
47    # TODO(ispir): Disable centered_bias.
48    # assert abs(bias - regressor.bias_) < 0.1
49
50
51if __name__ == "__main__":
52  test.main()
53