1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Unified APIs' python bindings.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import timeit 22 23from absl.testing import parameterized 24 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework.experimental import _unified_api 31from tensorflow.python.framework.experimental import context_stack as context_lib 32from tensorflow.python.framework.experimental import def_function 33from tensorflow.python.framework.experimental import math_ops as unified_math_ops 34from tensorflow.python.framework.experimental import nn_ops as unified_nn_ops 35from tensorflow.python.framework.experimental import tape as tape_lib 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 39from tensorflow.python.ops import nn_ops 40from tensorflow.python.ops import random_ops 41from tensorflow.python.platform import test 42 43SetTracingImplementation = _unified_api.SetTracingImplementation 44TensorCastHelper = _unified_api.EagerTensorToImmediateExecutionTensorHandle 45 46 47def get_immediate_execution_context(): 48 context._reset_context() 49 context.context().ensure_initialized() 50 return _unified_api.EagerContextToImmediateExecutionContext( 51 context.context()._handle) 52 53 54def maybe_cast(t, perform_cast): 55 if perform_cast: 56 return TensorCastHelper(t) 57 return t 58 59 60class UnifiedApiTest(test.TestCase, parameterized.TestCase): 61 62 @parameterized.named_parameters([ 63 ("Graph", False), 64 ("Mlir", True), 65 ]) 66 def testAdd(self, use_mlir): 67 if use_mlir: 68 SetTracingImplementation("mlir") 69 70 def model(a, b): 71 return unified_math_ops.add(a, b) 72 73 with context_lib.set_default(get_immediate_execution_context()): 74 a = TensorCastHelper(constant_op.constant([1., 2.])) 75 b = TensorCastHelper(constant_op.constant([3., 4.])) 76 77 func_output = def_function.function(model)(a, b) 78 self.assertAllEqual(func_output.numpy(), [4., 6.]) 79 80 eager_output = model(a, b) 81 self.assertAllEqual(eager_output.numpy(), [4., 6.]) 82 83 @parameterized.named_parameters([ 84 ("Graph", False), 85 ("Mlir", True), 86 ]) 87 def testAddGrad(self, use_mlir): 88 if use_mlir: 89 SetTracingImplementation("mlir") 90 91 def model(a, b): 92 with tape_lib.GradientTape() as tape: 93 tape.watch(a) 94 tape.watch(b) 95 result = unified_math_ops.add(a, b) 96 grads = tape.gradient(result, [a, b]) 97 return grads 98 99 with context_lib.set_default(get_immediate_execution_context()): 100 a = TensorCastHelper(constant_op.constant([1., 2.])) 101 b = TensorCastHelper(constant_op.constant([3., 4.])) 102 103 func_outputs = def_function.function(model)(a, b) 104 self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0]) 105 self.assertAllEqual(func_outputs[1].numpy(), [1.0, 1.0]) 106 107 eager_outputs = model(a, b) 108 self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0]) 109 self.assertAllEqual(eager_outputs[1].numpy(), [1.0, 1.0]) 110 111 @parameterized.named_parameters([ 112 ("Graph", False), 113 ("Mlir", True), 114 ]) 115 def testRelu(self, use_mlir): 116 if use_mlir: 117 SetTracingImplementation("mlir") 118 119 def model(t): 120 return unified_nn_ops.relu(t) 121 122 with context_lib.set_default(get_immediate_execution_context()): 123 positive = TensorCastHelper(constant_op.constant([1.])) 124 negative = TensorCastHelper(constant_op.constant([-1.])) 125 126 model_fn = def_function.function(model) 127 func_output = model_fn(positive) 128 self.assertAllEqual(func_output.numpy(), [1.]) 129 func_output = model_fn(negative) 130 self.assertAllEqual(func_output.numpy(), [0.]) 131 132 eager_output = model(positive) 133 self.assertAllEqual(eager_output.numpy(), [1.]) 134 eager_output = model(negative) 135 self.assertAllEqual(eager_output.numpy(), [0.]) 136 137 @parameterized.named_parameters([ 138 ("Graph", False), 139 ("Mlir", True), 140 ]) 141 def testReluGrad(self, use_mlir): 142 if use_mlir: 143 SetTracingImplementation("mlir") 144 145 def model(t): 146 with tape_lib.GradientTape() as tape: 147 tape.watch(t) 148 result = unified_nn_ops.relu(t) 149 grads = tape.gradient(result, t) 150 return grads 151 152 with context_lib.set_default(get_immediate_execution_context()): 153 positive = TensorCastHelper(constant_op.constant([1.])) 154 negative = TensorCastHelper(constant_op.constant([-1.])) 155 156 model_fn = def_function.function(model) 157 func_output = model_fn(positive) 158 self.assertAllEqual(func_output.numpy(), [1.]) 159 func_output = model_fn(negative) 160 self.assertAllEqual(func_output.numpy(), [0.]) 161 162 eager_output = model(positive) 163 self.assertAllEqual(eager_output.numpy(), [1.]) 164 eager_output = model(negative) 165 self.assertAllEqual(eager_output.numpy(), [0.]) 166 167 @parameterized.named_parameters([ 168 ("Graph", False), 169 ("Mlir", True), 170 ]) 171 def testNeg(self, use_mlir): 172 if use_mlir: 173 SetTracingImplementation("mlir") 174 175 def model(a): 176 return unified_math_ops.neg(a) 177 178 with context_lib.set_default(get_immediate_execution_context()): 179 a = TensorCastHelper(constant_op.constant([2.])) 180 181 func_output = def_function.function(model)(a) 182 self.assertAllEqual(func_output.numpy(), [-2.]) 183 184 eager_output = model(a) 185 self.assertAllEqual(eager_output.numpy(), [-2.]) 186 187 @parameterized.named_parameters([ 188 ("Graph", False), 189 ("Mlir", True), 190 ]) 191 def testNegGrad(self, use_mlir): 192 if use_mlir: 193 SetTracingImplementation("mlir") 194 195 def model(a): 196 with tape_lib.GradientTape() as tape: 197 tape.watch(a) 198 result = unified_math_ops.neg(a) 199 grads = tape.gradient(result, a) 200 return grads 201 202 with context_lib.set_default(get_immediate_execution_context()): 203 a = TensorCastHelper(constant_op.constant([2.])) 204 205 func_outputs = def_function.function(model)(a) 206 self.assertAllEqual(func_outputs.numpy(), [-1.0]) 207 208 eager_outputs = model(a) 209 self.assertAllEqual(eager_outputs.numpy(), [-1.0]) 210 211 @parameterized.named_parameters([ 212 ("Graph", False), 213 ("Mlir", True), 214 ]) 215 def testSub(self, use_mlir): 216 if use_mlir: 217 SetTracingImplementation("mlir") 218 219 def model(a, b): 220 return unified_math_ops.sub(a, b) 221 222 with context_lib.set_default(get_immediate_execution_context()): 223 a = TensorCastHelper(constant_op.constant([1., 2.])) 224 b = TensorCastHelper(constant_op.constant([3., 4.])) 225 226 func_output = def_function.function(model)(a, b) 227 self.assertAllEqual(func_output.numpy(), [-2., -2.]) 228 229 eager_output = model(a, b) 230 self.assertAllEqual(eager_output.numpy(), [-2., -2.]) 231 232 @parameterized.named_parameters([ 233 ("Graph", False), 234 ("Mlir", True), 235 ]) 236 def testSubGrad(self, use_mlir): 237 if use_mlir: 238 SetTracingImplementation("mlir") 239 240 def model(a, b): 241 with tape_lib.GradientTape() as tape: 242 tape.watch(a) 243 tape.watch(b) 244 result = unified_math_ops.sub(a, b) 245 grads = tape.gradient(result, [a, b]) 246 return grads 247 248 with context_lib.set_default(get_immediate_execution_context()): 249 a = TensorCastHelper(constant_op.constant([1., 2.])) 250 b = TensorCastHelper(constant_op.constant([3., 4.])) 251 252 func_outputs = def_function.function(model)(a, b) 253 self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0]) 254 self.assertAllEqual(func_outputs[1].numpy(), [-1.0, -1.0]) 255 256 eager_outputs = model(a, b) 257 self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0]) 258 self.assertAllEqual(eager_outputs[1].numpy(), [-1.0, -1.0]) 259 260 @parameterized.named_parameters([ 261 ("Graph", False), 262 ("Mlir", True), 263 ]) 264 def testMul(self, use_mlir): 265 if use_mlir: 266 SetTracingImplementation("mlir") 267 268 def model(a, b): 269 return unified_math_ops.mul(a, b) 270 271 with context_lib.set_default(get_immediate_execution_context()): 272 a = TensorCastHelper(constant_op.constant([1., 2.])) 273 b = TensorCastHelper(constant_op.constant([3., 4.])) 274 275 func_output = def_function.function(model)(a, b) 276 self.assertAllEqual(func_output.numpy(), [3., 8.]) 277 278 eager_output = model(a, b) 279 self.assertAllEqual(eager_output.numpy(), [3., 8.]) 280 281 @parameterized.named_parameters([ 282 ("Graph", False), 283 ("Mlir", True), 284 ]) 285 def testMulGrad(self, use_mlir): 286 if use_mlir: 287 SetTracingImplementation("mlir") 288 289 def model(a, b): 290 with tape_lib.GradientTape() as tape: 291 tape.watch(a) 292 tape.watch(b) 293 result = unified_math_ops.mul(a, b) 294 grads = tape.gradient(result, [a, b]) 295 return grads 296 297 with context_lib.set_default(get_immediate_execution_context()): 298 a = TensorCastHelper(constant_op.constant([1., 2.])) 299 b = TensorCastHelper(constant_op.constant([3., 4.])) 300 301 func_outputs = def_function.function(model)(a, b) 302 self.assertAllEqual(func_outputs[0].numpy(), [3., 4.]) 303 self.assertAllEqual(func_outputs[1].numpy(), [1., 2.]) 304 305 eager_outputs = model(a, b) 306 self.assertAllEqual(eager_outputs[0].numpy(), [3., 4.]) 307 self.assertAllEqual(eager_outputs[1].numpy(), [1., 2.]) 308 309 @parameterized.named_parameters([ 310 ("Graph", False), 311 ("Mlir", True), 312 ]) 313 def testLog1p(self, use_mlir): 314 if use_mlir: 315 SetTracingImplementation("mlir") 316 317 def model(a): 318 return unified_math_ops.log1p(a) 319 320 with context_lib.set_default(get_immediate_execution_context()): 321 a = TensorCastHelper(constant_op.constant([1.])) 322 323 func_output = def_function.function(model)(a) 324 self.assertArrayNear(func_output.numpy(), [0.69314], 0.001) 325 326 eager_output = model(a) 327 self.assertArrayNear(eager_output.numpy(), [0.69314], 0.001) 328 329 @parameterized.named_parameters([ 330 ("Graph", False), 331 ("Mlir", True), 332 ]) 333 def testLog1pGrad(self, use_mlir): 334 if use_mlir: 335 SetTracingImplementation("mlir") 336 337 def model(a): 338 with tape_lib.GradientTape() as tape: 339 tape.watch(a) 340 result = unified_math_ops.log1p(a) 341 grads = tape.gradient(result, a) 342 return grads 343 344 with context_lib.set_default(get_immediate_execution_context()): 345 a = TensorCastHelper(constant_op.constant([1.])) 346 347 func_outputs = def_function.function(model)(a) 348 self.assertArrayNear(func_outputs.numpy(), [0.5], 0.001) 349 350 eager_outputs = model(a) 351 self.assertArrayNear(eager_outputs.numpy(), [0.5], 0.001) 352 353 @parameterized.named_parameters([ 354 ("Graph", False), 355 ("Mlir", True), 356 ]) 357 def testDivNoNan(self, use_mlir): 358 if use_mlir: 359 SetTracingImplementation("mlir") 360 361 def model(a, b): 362 return unified_math_ops.div_no_nan(a, b) 363 364 with context_lib.set_default(get_immediate_execution_context()): 365 a = TensorCastHelper(constant_op.constant([2.])) 366 b = TensorCastHelper(constant_op.constant([4.])) 367 368 func_output = def_function.function(model)(a, b) 369 self.assertArrayNear(func_output.numpy(), [0.5], 0.001) 370 371 eager_output = model(a, b) 372 self.assertArrayNear(eager_output.numpy(), [0.5], 0.001) 373 374 @parameterized.named_parameters([ 375 ("Graph", False), 376 ("Mlir", True), 377 ]) 378 def testDivNoNanGrad(self, use_mlir): 379 if use_mlir: 380 SetTracingImplementation("mlir") 381 382 def model(a, b): 383 with tape_lib.GradientTape() as tape: 384 tape.watch(a) 385 tape.watch(b) 386 result = unified_math_ops.div_no_nan(a, b) 387 grads = tape.gradient(result, [a, b]) 388 return grads 389 390 with context_lib.set_default(get_immediate_execution_context()): 391 a = TensorCastHelper(constant_op.constant([2.])) 392 b = TensorCastHelper(constant_op.constant([4.])) 393 394 func_outputs = def_function.function(model)(a, b) 395 self.assertArrayNear(func_outputs[0].numpy(), [0.25], 0.001) 396 self.assertArrayNear(func_outputs[1].numpy(), [-0.125], 0.001) 397 398 eager_outputs = model(a, b) 399 self.assertArrayNear(eager_outputs[0].numpy(), [0.25], 0.001) 400 self.assertArrayNear(eager_outputs[1].numpy(), [-0.125], 0.001) 401 402 403class UnifiedTapeBenchmark(test.Benchmark): 404 405 def _computeMnistMlpGrads(self, math_ops_lib, nn_ops_lib, backprop_lib, cast, 406 num_iters, hidden_layers, hidden_size, batch_size): 407 batch_size = 1 408 image_size = 28 * 28 409 num_classes = 10 410 411 def model(x, hidden_weights, softmax_weight, labels): 412 with backprop_lib.GradientTape() as tape: 413 for weight in hidden_weights + [softmax_weight]: 414 tape.watch(weight) 415 for hidden_weight in hidden_weights: 416 x = math_ops_lib.mat_mul(x, hidden_weight) 417 x = nn_ops_lib.relu(x) 418 logits = math_ops_lib.mat_mul(x, softmax_weight) 419 loss = nn_ops_lib.sparse_softmax_cross_entropy_with_logits( 420 logits=logits, labels=labels) 421 422 grads = tape.gradient(loss, hidden_weights + [softmax_weight]) 423 return grads 424 425 x = maybe_cast(array_ops.ones([batch_size, image_size]), cast) 426 hidden_weights = [] 427 for i in range(hidden_layers): 428 hidden_weights.append( 429 maybe_cast( 430 random_ops.random_uniform( 431 [hidden_size if i else image_size, hidden_size]), cast)) 432 softmax_weight = maybe_cast( 433 random_ops.random_uniform([hidden_size, num_classes]), cast) 434 labels = maybe_cast(array_ops.zeros([batch_size], dtype=dtypes.int32), cast) 435 436 with context_lib.set_default(get_immediate_execution_context()): 437 # Warm up. 438 for _ in range(10): 439 model(x, hidden_weights, softmax_weight, labels) 440 runtimes = timeit.repeat( 441 lambda: model(x, hidden_weights, softmax_weight, labels), 442 repeat=num_iters, 443 number=10) 444 return min(runtimes) / 10 445 446 def benchmarkTwoHiddenLayerMnistEagerUnified(self): 447 num_iters = 100 448 duration = self._computeMnistMlpGrads( 449 unified_math_ops, 450 unified_nn_ops, 451 tape_lib, 452 True, 453 num_iters, 454 hidden_layers=2, 455 hidden_size=100, 456 batch_size=1) 457 self.report_benchmark( 458 name="TwoHiddenLayerMnistEagerUnified", 459 iters=num_iters, 460 wall_time=duration) 461 462 def benchmarkTwoHiddenLayerMnistEager(self): 463 num_iters = 100 464 duration = self._computeMnistMlpGrads( 465 math_ops, 466 nn_ops, 467 backprop, 468 False, 469 num_iters, 470 hidden_layers=2, 471 hidden_size=100, 472 batch_size=1) 473 self.report_benchmark( 474 name="TwoHiddenLayerMnistEager", iters=num_iters, wall_time=duration) 475 476 def benchmarkTenHiddenLayerMnistEagerUnified(self): 477 num_iters = 100 478 duration = self._computeMnistMlpGrads( 479 unified_math_ops, 480 unified_nn_ops, 481 tape_lib, 482 True, 483 num_iters, 484 hidden_layers=10, 485 hidden_size=100, 486 batch_size=1) 487 self.report_benchmark( 488 name="TenHiddenLayerMnistEagerUnified", 489 iters=num_iters, 490 wall_time=duration) 491 492 def benchmarkTenHiddenLayerMnistEager(self): 493 num_iters = 100 494 duration = self._computeMnistMlpGrads( 495 math_ops, 496 nn_ops, 497 backprop, 498 False, 499 num_iters, 500 hidden_layers=10, 501 hidden_size=100, 502 batch_size=1) 503 self.report_benchmark( 504 name="TenHiddenLayerMnistEager", iters=num_iters, wall_time=duration) 505 506 507if __name__ == "__main__": 508 ops.enable_eager_execution() 509 test.main() 510