# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests lowering of tf.bitcast""" from tensorflow.compiler.tests import xla_test from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import io_ops from tensorflow.python.platform import test class CastOpsTest(xla_test.XLATestCase): def testBitcastToLarger(self): with ops.device('device:{}:0'.format(self.device)): def f(x): t = array_ops.bitcast(x, dtypes.float32) return math_ops.reduce_sum(t, axis=1) compiled_f = def_function.function(f, jit_compile=True) x = random_ops.random_normal([10, 10, 2], dtype=dtypes.float16) with ops.device(self.device): out = f(x) compiled_out = compiled_f(x) self.assertAllClose(out, compiled_out) # 10,10,2--(bitcast-convert)-->10,10--(reduce)-->10 self.assertEqual(out.shape[0], 10) hlo = compiled_f.experimental_get_compiler_ir(x)(stage='hlo') self.assertIn('f32[10,10]{1,0} bitcast-convert(f16[10,10,2]{2,1,0}', hlo) def testBitcastToSmaller(self): pass if __name__ == '__main__': ops.enable_eager_execution() test.main()