1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/lib/util.h"
17
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/lib/core/errors.h"
26
27 namespace tensorflow {
28
Zeros(xla::XlaBuilder * builder,const xla::Shape & shape)29 xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
30 return xla::Broadcast(
31 xla::ConstantLiteral(builder,
32 xla::LiteralUtil::Zero(shape.element_type())),
33 shape.dimensions());
34 }
35
FloatLiteral(xla::XlaBuilder * builder,xla::PrimitiveType type,double value)36 xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
37 double value) {
38 switch (type) {
39 case xla::F16:
40 return xla::ConstantR0<xla::half>(builder, static_cast<xla::half>(value));
41 break;
42 case xla::BF16:
43 return xla::ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
44 break;
45 case xla::F32:
46 return xla::ConstantR0<float>(builder, static_cast<float>(value));
47 break;
48 case xla::F64:
49 return xla::ConstantR0<double>(builder, value);
50 break;
51 case xla::C64:
52 return xla::ConstantR0<xla::complex64>(builder, value);
53 break;
54 case xla::C128:
55 return xla::ConstantR0<xla::complex128>(builder, value);
56 break;
57 default:
58 LOG(FATAL) << "unhandled element type " << type;
59 }
60 }
61
IntegerLiteral(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t value)62 xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
63 int64_t value) {
64 xla::Literal literal;
65 switch (type) {
66 case xla::U8:
67 literal = xla::LiteralUtil::CreateR0<uint8_t>(value);
68 break;
69 case xla::U16:
70 literal = xla::LiteralUtil::CreateR0<uint16_t>(value);
71 break;
72 case xla::U32:
73 literal = xla::LiteralUtil::CreateR0<uint32_t>(value);
74 break;
75 case xla::U64:
76 literal = xla::LiteralUtil::CreateR0<uint64_t>(value);
77 break;
78 case xla::S8:
79 literal = xla::LiteralUtil::CreateR0<int8_t>(value);
80 break;
81 case xla::S16:
82 literal = xla::LiteralUtil::CreateR0<int16_t>(value);
83 break;
84 case xla::S32:
85 literal = xla::LiteralUtil::CreateR0<int32_t>(value);
86 break;
87 case xla::S64:
88 literal = xla::LiteralUtil::CreateR0<int64_t>(value);
89 break;
90 case xla::F32:
91 literal = xla::LiteralUtil::CreateR0<float>(value);
92 break;
93 case xla::F64:
94 literal = xla::LiteralUtil::CreateR0<double>(value);
95 break;
96 case xla::C64:
97 literal = xla::LiteralUtil::CreateR0<xla::complex64>(value);
98 break;
99 case xla::C128:
100 literal = xla::LiteralUtil::CreateR0<xla::complex128>(value);
101 break;
102 case xla::PRED:
103 LOG(FATAL) << "pred element type is not integral";
104 case xla::BF16:
105 literal = xla::LiteralUtil::CreateR0<xla::bfloat16>(
106 static_cast<xla::bfloat16>(value));
107 break;
108 case xla::F16:
109 literal =
110 xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
111 break;
112 case xla::TUPLE:
113 LOG(FATAL) << "tuple element type is not integral";
114 case xla::OPAQUE_TYPE:
115 LOG(FATAL) << "opaque element type is not integral";
116 default:
117 LOG(FATAL) << "unhandled element type " << type;
118 }
119 return xla::ConstantLiteral(builder, literal);
120 }
121
122 } // namespace tensorflow
123