• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for ftrl ("follow the regularized leader") operations."""
16
17import numpy as np
18
19from tensorflow.compiler.tests import xla_test
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import test_util
22from tensorflow.python.ops import resource_variable_ops
23from tensorflow.python.platform import googletest
24from tensorflow.python.training import training_ops
25
26
27class ResourceApplyFtrlTest(xla_test.XLATestCase):
28  """Test cases for ftrl ops."""
29
30  def setUp(self):
31    super().setUp()
32    self.rewrite_ops_for_tpu = ("TPU" in self.device and
33                                test_util.is_mlir_bridge_enabled())
34
35  def _eval(self, var, accum, linear, grad, lr, l1, l2, l2_shrinkage=0,
36            lr_power=1, multiply_linear_by_lr=False):
37    dtype = np.float32
38    var = np.array(var, dtype=dtype)
39    accum = np.array(accum, dtype=dtype)
40    linear = np.array(linear, dtype=dtype)
41    grad = np.array(grad, dtype=dtype)
42    use_v2 = bool(l2_shrinkage)
43    with self.session() as session:
44      with self.test_scope():
45        lr = constant_op.constant(lr, dtype=dtype)
46        l1 = constant_op.constant(l1, dtype=dtype)
47        l2 = constant_op.constant(l2, dtype=dtype)
48        l2_shrinkage = constant_op.constant(l2_shrinkage, dtype=dtype)
49        lr_power = constant_op.constant(lr_power, dtype=dtype)
50        v_var = resource_variable_ops.ResourceVariable(var, dtype=dtype)
51        v_accum = resource_variable_ops.ResourceVariable(accum, dtype=dtype)
52        v_linear = resource_variable_ops.ResourceVariable(linear, dtype=dtype)
53        session.run(v_var.create)
54        session.run(v_accum.create)
55        session.run(v_linear.create)
56        assert not (use_v2 and multiply_linear_by_lr)
57        if use_v2:
58          session.run(training_ops.resource_apply_ftrl_v2(
59              v_var.handle, v_accum.handle, v_linear.handle,
60              grad, lr, l1, l2, l2_shrinkage, lr_power,
61              multiply_linear_by_lr=multiply_linear_by_lr))
62        else:
63          session.run(training_ops.resource_apply_ftrl(
64              v_var.handle, v_accum.handle, v_linear.handle,
65              grad, lr, l1, l2, lr_power,
66              multiply_linear_by_lr=multiply_linear_by_lr))
67        return (v_var.read_value().eval().reshape(var.shape),
68                v_accum.read_value().eval().reshape(accum.shape),
69                v_linear.read_value().eval().reshape(linear.shape))
70
71  def testAccum(self):
72    """Test that accum is updated with grad^2."""
73    accum = np.array([[[1, 3], [2, 5], [6, 8]]])
74    grad = np.array([[[1, 3], [2, 5], [6, 8]]])
75    _, new_accum, _ = self._eval(
76        var=np.zeros((1, 3, 2)),
77        accum=accum,
78        linear=np.zeros((1, 3, 2)),
79        grad=grad,
80        lr=7, l1=3, l2=7, lr_power=2)
81    self.assertAllClose(accum + grad*grad, new_accum)
82
83  def testLinearNoGradient(self):
84    """Test that if accum_new == accum, linear doesn't change."""
85    _, _, linear = self._eval(
86        var=np.ones((1, 3, 2)),
87        accum=[[[1, 3], [2, 5], [6, 8]]],
88        linear=[[[1, 2], [3, 4], [5, 6]]],
89        grad=np.zeros((1, 3, 2)),  # make accum_new == acum
90        lr=1, l1=3, l2=7, lr_power=2)
91    self.assertAllClose([[[1, 2], [3, 4], [5, 6]]], linear)
92
93  def testLinear(self):
94    """Test the linear update for new_linear=2 and linear=1."""
95    _, _, linear = self._eval(
96        var=np.ones((1, 3, 2)),
97        accum=np.ones((1, 3, 2)),
98        linear=np.zeros((1, 3, 2)),
99        grad=np.ones((1, 3, 2)),
100        lr=1, l1=3, l2=7, lr_power=2)
101    self.assertAllClose(1.75 * np.ones((1, 3, 2)), linear)
102
103  def testLR(self):
104    """Test that the linear update is divided by lr."""
105    _, _, linear = self._eval(
106        var=np.ones((1, 3, 2)),
107        accum=np.ones((1, 3, 2)),
108        linear=np.zeros((1, 3, 2)),
109        grad=np.ones((1, 3, 2)),
110        lr=5, l1=3, l2=7, lr_power=-1)
111    self.assertAllClose(0.8 * np.ones((1, 3, 2)), linear)
112
113  def testVar(self):
114    """Test computation of var with linear=1.5, quadratic=1."""
115    var, _, _ = self._eval(
116        var=np.ones((1, 3, 2)),
117        accum=np.ones((1, 3, 2)),
118        linear=np.zeros((1, 3, 2)),
119        grad=np.ones((1, 3, 2)),
120        lr=1, l1=1, l2=0.25, lr_power=1)
121    self.assertAllClose(-0.5 * np.ones((1, 3, 2)), var)
122
123  def testVarClipped(self):
124    """Test that var becomes 0 if |linear| < l1."""
125    var, _, _ = self._eval(
126        var=np.ones((1, 3, 2)),
127        accum=np.ones((1, 3, 2)),
128        linear=np.zeros((1, 3, 2)),
129        grad=np.ones((1, 3, 2)),
130        lr=1, l1=1.6, l2=0.25, lr_power=1)
131    self.assertAllClose(np.zeros((1, 3, 2)), var)
132
133  def testQuadratic(self):
134    """Test that quadratic (here: -2) is the divisor of var."""
135    var, _, _ = self._eval(
136        var=np.ones((1, 3, 2)),
137        accum=np.ones((1, 3, 2)),
138        linear=np.zeros((1, 3, 2)),
139        grad=np.ones((1, 3, 2)),
140        lr=1, l1=1, l2=-1.25, lr_power=1)
141    self.assertAllClose(0.25 * np.ones((1, 3, 2)), var)
142
143  def testL2Shrinkage(self):
144    """Test that 2 * l2_shrinkage * var is *not* added to the gradient."""
145    _, accum, _ = self._eval(
146        var=np.ones((1, 3, 2)),
147        accum=np.zeros((1, 3, 2)),
148        linear=np.zeros((1, 3, 2)),
149        grad=np.zeros((1, 3, 2)),
150        lr=7, l1=3, l2=7, lr_power=2, l2_shrinkage=0.5)
151    self.assertAllClose(np.zeros((1, 3, 2)), accum)
152
153  def testL2ShrinkageOnLinear(self):
154    """Test that 2 * l2_shrinkage * var is added to linear."""
155    _, _, linear = self._eval(
156        var=np.ones((1, 3, 2)),
157        accum=np.zeros((1, 3, 2)),
158        linear=np.zeros((1, 3, 2)),
159        grad=np.zeros((1, 3, 2)),
160        lr=2, l1=3, l2=7, lr_power=0, l2_shrinkage=11)
161    self.assertAllClose(22 * np.ones((1, 3, 2)), linear)
162
163  def testMultiplyLinearByLR(self):
164    """Test multiply_linear_by_lr = true for the linear variable."""
165    _, _, linear = self._eval(
166        var=np.zeros((1, 3, 2)),
167        accum=np.zeros((1, 3, 2)),
168        linear=np.ones((1, 3, 2)),
169        grad=np.ones((1, 3, 2)),
170        lr=6, l1=1, l2=-1.25, lr_power=0,
171        multiply_linear_by_lr=True)
172    self.assertAllClose(7 * np.ones((1, 3, 2)), linear)
173
174  def testMultiplyLinearByLRClipping(self):
175    """Test that multiply_linear_by_lr = true scales the clip margins."""
176    var, _, _ = self._eval(
177        var=np.ones((1, 3, 2)),
178        accum=np.ones((1, 3, 2)),
179        linear=np.zeros((1, 3, 2)),
180        grad=np.ones((1, 3, 2)),
181        lr=3, l1=1.0, l2=0.25, lr_power=1,
182        multiply_linear_by_lr=True)
183    self.assertAllClose(-0.25 * np.ones((1, 3, 2)), var)
184
185  def testMultiplyLinearByLRClipZero(self):
186    """Test that multiply_linear_by_lr = true still clips to 0."""
187    var, _, _ = self._eval(
188        var=np.ones((1, 3, 2)),
189        accum=np.ones((1, 3, 2)),
190        linear=np.zeros((1, 3, 2)),
191        grad=np.ones((1, 3, 2)),
192        lr=3, l1=1.2, l2=0.25, lr_power=1,
193        multiply_linear_by_lr=True)
194    self.assertAllClose(np.zeros((1, 3, 2)), var)
195
196
197if __name__ == "__main__":
198  googletest.main()
199