• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests lowering of tf.bitcast"""
16
17from tensorflow.compiler.tests import xla_test
18from tensorflow.python.eager import def_function
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import random_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import image_ops
26from tensorflow.python.ops import io_ops
27from tensorflow.python.platform import test
28
29
30class CastOpsTest(xla_test.XLATestCase):
31
32  def testBitcastToLarger(self):
33    with ops.device('device:{}:0'.format(self.device)):
34
35      def f(x):
36        t = array_ops.bitcast(x, dtypes.float32)
37        return math_ops.reduce_sum(t, axis=1)
38
39      compiled_f = def_function.function(f, jit_compile=True)
40
41      x = random_ops.random_normal([10, 10, 2], dtype=dtypes.float16)
42      with ops.device(self.device):
43        out = f(x)
44        compiled_out = compiled_f(x)
45        self.assertAllClose(out, compiled_out)
46        # 10,10,2--(bitcast-convert)-->10,10--(reduce)-->10
47        self.assertEqual(out.shape[0], 10)
48
49      hlo = compiled_f.experimental_get_compiler_ir(x)(stage='hlo')
50      self.assertIn('f32[10,10]{1,0} bitcast-convert(f16[10,10,2]{2,1,0}', hlo)
51
52  def testBitcastToSmaller(self):
53    pass
54
55
56if __name__ == '__main__':
57  ops.enable_eager_execution()
58  test.main()
59