1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests lowering of tf.bitcast""" 16 17from tensorflow.compiler.tests import xla_test 18from tensorflow.python.eager import def_function 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.ops import random_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import image_ops 26from tensorflow.python.ops import io_ops 27from tensorflow.python.platform import test 28 29 30class CastOpsTest(xla_test.XLATestCase): 31 32 def testBitcastToLarger(self): 33 with ops.device('device:{}:0'.format(self.device)): 34 35 def f(x): 36 t = array_ops.bitcast(x, dtypes.float32) 37 return math_ops.reduce_sum(t, axis=1) 38 39 compiled_f = def_function.function(f, jit_compile=True) 40 41 x = random_ops.random_normal([10, 10, 2], dtype=dtypes.float16) 42 with ops.device(self.device): 43 out = f(x) 44 compiled_out = compiled_f(x) 45 self.assertAllClose(out, compiled_out) 46 # 10,10,2--(bitcast-convert)-->10,10--(reduce)-->10 47 self.assertEqual(out.shape[0], 10) 48 49 hlo = compiled_f.experimental_get_compiler_ir(x)(stage='hlo') 50 self.assertIn('f32[10,10]{1,0} bitcast-convert(f16[10,10,2]{2,1,0}', hlo) 51 52 def testBitcastToSmaller(self): 53 pass 54 55 56if __name__ == '__main__': 57 ops.enable_eager_execution() 58 test.main() 59