1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for ftrl ("follow the regularized leader") operations.""" 16 17import numpy as np 18 19from tensorflow.compiler.tests import xla_test 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import test_util 22from tensorflow.python.ops import resource_variable_ops 23from tensorflow.python.platform import googletest 24from tensorflow.python.training import training_ops 25 26 27class ResourceApplyFtrlTest(xla_test.XLATestCase): 28 """Test cases for ftrl ops.""" 29 30 def setUp(self): 31 super().setUp() 32 self.rewrite_ops_for_tpu = ("TPU" in self.device and 33 test_util.is_mlir_bridge_enabled()) 34 35 def _eval(self, var, accum, linear, grad, lr, l1, l2, l2_shrinkage=0, 36 lr_power=1, multiply_linear_by_lr=False): 37 dtype = np.float32 38 var = np.array(var, dtype=dtype) 39 accum = np.array(accum, dtype=dtype) 40 linear = np.array(linear, dtype=dtype) 41 grad = np.array(grad, dtype=dtype) 42 use_v2 = bool(l2_shrinkage) 43 with self.session() as session: 44 with self.test_scope(): 45 lr = constant_op.constant(lr, dtype=dtype) 46 l1 = constant_op.constant(l1, dtype=dtype) 47 l2 = constant_op.constant(l2, dtype=dtype) 48 l2_shrinkage = constant_op.constant(l2_shrinkage, dtype=dtype) 49 lr_power = constant_op.constant(lr_power, dtype=dtype) 50 v_var = resource_variable_ops.ResourceVariable(var, dtype=dtype) 51 v_accum = resource_variable_ops.ResourceVariable(accum, dtype=dtype) 52 v_linear = resource_variable_ops.ResourceVariable(linear, dtype=dtype) 53 session.run(v_var.create) 54 session.run(v_accum.create) 55 session.run(v_linear.create) 56 assert not (use_v2 and multiply_linear_by_lr) 57 if use_v2: 58 session.run(training_ops.resource_apply_ftrl_v2( 59 v_var.handle, v_accum.handle, v_linear.handle, 60 grad, lr, l1, l2, l2_shrinkage, lr_power, 61 multiply_linear_by_lr=multiply_linear_by_lr)) 62 else: 63 session.run(training_ops.resource_apply_ftrl( 64 v_var.handle, v_accum.handle, v_linear.handle, 65 grad, lr, l1, l2, lr_power, 66 multiply_linear_by_lr=multiply_linear_by_lr)) 67 return (v_var.read_value().eval().reshape(var.shape), 68 v_accum.read_value().eval().reshape(accum.shape), 69 v_linear.read_value().eval().reshape(linear.shape)) 70 71 def testAccum(self): 72 """Test that accum is updated with grad^2.""" 73 accum = np.array([[[1, 3], [2, 5], [6, 8]]]) 74 grad = np.array([[[1, 3], [2, 5], [6, 8]]]) 75 _, new_accum, _ = self._eval( 76 var=np.zeros((1, 3, 2)), 77 accum=accum, 78 linear=np.zeros((1, 3, 2)), 79 grad=grad, 80 lr=7, l1=3, l2=7, lr_power=2) 81 self.assertAllClose(accum + grad*grad, new_accum) 82 83 def testLinearNoGradient(self): 84 """Test that if accum_new == accum, linear doesn't change.""" 85 _, _, linear = self._eval( 86 var=np.ones((1, 3, 2)), 87 accum=[[[1, 3], [2, 5], [6, 8]]], 88 linear=[[[1, 2], [3, 4], [5, 6]]], 89 grad=np.zeros((1, 3, 2)), # make accum_new == acum 90 lr=1, l1=3, l2=7, lr_power=2) 91 self.assertAllClose([[[1, 2], [3, 4], [5, 6]]], linear) 92 93 def testLinear(self): 94 """Test the linear update for new_linear=2 and linear=1.""" 95 _, _, linear = self._eval( 96 var=np.ones((1, 3, 2)), 97 accum=np.ones((1, 3, 2)), 98 linear=np.zeros((1, 3, 2)), 99 grad=np.ones((1, 3, 2)), 100 lr=1, l1=3, l2=7, lr_power=2) 101 self.assertAllClose(1.75 * np.ones((1, 3, 2)), linear) 102 103 def testLR(self): 104 """Test that the linear update is divided by lr.""" 105 _, _, linear = self._eval( 106 var=np.ones((1, 3, 2)), 107 accum=np.ones((1, 3, 2)), 108 linear=np.zeros((1, 3, 2)), 109 grad=np.ones((1, 3, 2)), 110 lr=5, l1=3, l2=7, lr_power=-1) 111 self.assertAllClose(0.8 * np.ones((1, 3, 2)), linear) 112 113 def testVar(self): 114 """Test computation of var with linear=1.5, quadratic=1.""" 115 var, _, _ = self._eval( 116 var=np.ones((1, 3, 2)), 117 accum=np.ones((1, 3, 2)), 118 linear=np.zeros((1, 3, 2)), 119 grad=np.ones((1, 3, 2)), 120 lr=1, l1=1, l2=0.25, lr_power=1) 121 self.assertAllClose(-0.5 * np.ones((1, 3, 2)), var) 122 123 def testVarClipped(self): 124 """Test that var becomes 0 if |linear| < l1.""" 125 var, _, _ = self._eval( 126 var=np.ones((1, 3, 2)), 127 accum=np.ones((1, 3, 2)), 128 linear=np.zeros((1, 3, 2)), 129 grad=np.ones((1, 3, 2)), 130 lr=1, l1=1.6, l2=0.25, lr_power=1) 131 self.assertAllClose(np.zeros((1, 3, 2)), var) 132 133 def testQuadratic(self): 134 """Test that quadratic (here: -2) is the divisor of var.""" 135 var, _, _ = self._eval( 136 var=np.ones((1, 3, 2)), 137 accum=np.ones((1, 3, 2)), 138 linear=np.zeros((1, 3, 2)), 139 grad=np.ones((1, 3, 2)), 140 lr=1, l1=1, l2=-1.25, lr_power=1) 141 self.assertAllClose(0.25 * np.ones((1, 3, 2)), var) 142 143 def testL2Shrinkage(self): 144 """Test that 2 * l2_shrinkage * var is *not* added to the gradient.""" 145 _, accum, _ = self._eval( 146 var=np.ones((1, 3, 2)), 147 accum=np.zeros((1, 3, 2)), 148 linear=np.zeros((1, 3, 2)), 149 grad=np.zeros((1, 3, 2)), 150 lr=7, l1=3, l2=7, lr_power=2, l2_shrinkage=0.5) 151 self.assertAllClose(np.zeros((1, 3, 2)), accum) 152 153 def testL2ShrinkageOnLinear(self): 154 """Test that 2 * l2_shrinkage * var is added to linear.""" 155 _, _, linear = self._eval( 156 var=np.ones((1, 3, 2)), 157 accum=np.zeros((1, 3, 2)), 158 linear=np.zeros((1, 3, 2)), 159 grad=np.zeros((1, 3, 2)), 160 lr=2, l1=3, l2=7, lr_power=0, l2_shrinkage=11) 161 self.assertAllClose(22 * np.ones((1, 3, 2)), linear) 162 163 def testMultiplyLinearByLR(self): 164 """Test multiply_linear_by_lr = true for the linear variable.""" 165 _, _, linear = self._eval( 166 var=np.zeros((1, 3, 2)), 167 accum=np.zeros((1, 3, 2)), 168 linear=np.ones((1, 3, 2)), 169 grad=np.ones((1, 3, 2)), 170 lr=6, l1=1, l2=-1.25, lr_power=0, 171 multiply_linear_by_lr=True) 172 self.assertAllClose(7 * np.ones((1, 3, 2)), linear) 173 174 def testMultiplyLinearByLRClipping(self): 175 """Test that multiply_linear_by_lr = true scales the clip margins.""" 176 var, _, _ = self._eval( 177 var=np.ones((1, 3, 2)), 178 accum=np.ones((1, 3, 2)), 179 linear=np.zeros((1, 3, 2)), 180 grad=np.ones((1, 3, 2)), 181 lr=3, l1=1.0, l2=0.25, lr_power=1, 182 multiply_linear_by_lr=True) 183 self.assertAllClose(-0.25 * np.ones((1, 3, 2)), var) 184 185 def testMultiplyLinearByLRClipZero(self): 186 """Test that multiply_linear_by_lr = true still clips to 0.""" 187 var, _, _ = self._eval( 188 var=np.ones((1, 3, 2)), 189 accum=np.ones((1, 3, 2)), 190 linear=np.zeros((1, 3, 2)), 191 grad=np.ones((1, 3, 2)), 192 lr=3, l1=1.2, l2=0.25, lr_power=1, 193 multiply_linear_by_lr=True) 194 self.assertAllClose(np.zeros((1, 3, 2)), var) 195 196 197if __name__ == "__main__": 198 googletest.main() 199