• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/lib/util.h"
17 
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 
27 namespace tensorflow {
28 
Zeros(xla::XlaBuilder * builder,const xla::Shape & shape)29 xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
30   return xla::Broadcast(
31       xla::ConstantLiteral(builder,
32                            xla::LiteralUtil::Zero(shape.element_type())),
33       shape.dimensions());
34 }
35 
FloatLiteral(xla::XlaBuilder * builder,xla::PrimitiveType type,double value)36 xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
37                         double value) {
38   switch (type) {
39     case xla::F16:
40       return xla::ConstantR0<xla::half>(builder, static_cast<xla::half>(value));
41       break;
42     case xla::BF16:
43       return xla::ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
44       break;
45     case xla::F32:
46       return xla::ConstantR0<float>(builder, static_cast<float>(value));
47       break;
48     case xla::F64:
49       return xla::ConstantR0<double>(builder, value);
50       break;
51     case xla::C64:
52       return xla::ConstantR0<xla::complex64>(builder, value);
53       break;
54     case xla::C128:
55       return xla::ConstantR0<xla::complex128>(builder, value);
56       break;
57     default:
58       LOG(FATAL) << "unhandled element type " << type;
59   }
60 }
61 
IntegerLiteral(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t value)62 xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
63                           int64_t value) {
64   xla::Literal literal;
65   switch (type) {
66     case xla::U8:
67       literal = xla::LiteralUtil::CreateR0<uint8_t>(value);
68       break;
69     case xla::U16:
70       literal = xla::LiteralUtil::CreateR0<uint16_t>(value);
71       break;
72     case xla::U32:
73       literal = xla::LiteralUtil::CreateR0<uint32_t>(value);
74       break;
75     case xla::U64:
76       literal = xla::LiteralUtil::CreateR0<uint64_t>(value);
77       break;
78     case xla::S8:
79       literal = xla::LiteralUtil::CreateR0<int8_t>(value);
80       break;
81     case xla::S16:
82       literal = xla::LiteralUtil::CreateR0<int16_t>(value);
83       break;
84     case xla::S32:
85       literal = xla::LiteralUtil::CreateR0<int32_t>(value);
86       break;
87     case xla::S64:
88       literal = xla::LiteralUtil::CreateR0<int64_t>(value);
89       break;
90     case xla::F32:
91       literal = xla::LiteralUtil::CreateR0<float>(value);
92       break;
93     case xla::F64:
94       literal = xla::LiteralUtil::CreateR0<double>(value);
95       break;
96     case xla::C64:
97       literal = xla::LiteralUtil::CreateR0<xla::complex64>(value);
98       break;
99     case xla::C128:
100       literal = xla::LiteralUtil::CreateR0<xla::complex128>(value);
101       break;
102     case xla::PRED:
103       LOG(FATAL) << "pred element type is not integral";
104     case xla::BF16:
105       literal = xla::LiteralUtil::CreateR0<xla::bfloat16>(
106           static_cast<xla::bfloat16>(value));
107       break;
108     case xla::F16:
109       literal =
110           xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
111       break;
112     case xla::TUPLE:
113       LOG(FATAL) << "tuple element type is not integral";
114     case xla::OPAQUE_TYPE:
115       LOG(FATAL) << "opaque element type is not integral";
116     default:
117       LOG(FATAL) << "unhandled element type " << type;
118   }
119   return xla::ConstantLiteral(builder, literal);
120 }
121 
122 }  // namespace tensorflow
123