1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/constants.h"
17
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/util.h"
20
21 namespace xla {
22
Zero(XlaBuilder * builder,PrimitiveType type)23 XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
24 return ConstantLiteral(builder, LiteralUtil::Zero(type));
25 }
26
Zeros(XlaBuilder * builder,const Shape & shape)27 XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
28 return Broadcast(Zero(builder, shape.element_type()),
29 AsInt64Slice(shape.dimensions()));
30 }
31
ZerosLike(XlaOp prototype)32 XlaOp ZerosLike(XlaOp prototype) {
33 XlaBuilder* builder = prototype.builder();
34 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
35 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
36 return Zeros(builder, shape);
37 });
38 }
39
One(XlaBuilder * builder,PrimitiveType type)40 XlaOp One(XlaBuilder* builder, PrimitiveType type) {
41 return ConstantLiteral(builder, LiteralUtil::One(type));
42 }
43
Epsilon(XlaBuilder * builder,PrimitiveType type)44 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
45 switch (type) {
46 case F16:
47 return ConstantR0<Eigen::half>(
48 builder,
49 static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
50 case BF16:
51 return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
52 case F32:
53 return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
54 case F64:
55 return ConstantR0<double>(builder,
56 std::numeric_limits<double>::epsilon());
57 default:
58 return builder->ReportError(InvalidArgument(
59 "Invalid type for Epsilon (%s).", PrimitiveType_Name(type)));
60 }
61 }
62
MinValue(XlaBuilder * builder,PrimitiveType type)63 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
64 return ConstantLiteral(builder, LiteralUtil::MinValue(type));
65 }
66
MinFiniteValue(XlaBuilder * builder,PrimitiveType type)67 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
68 switch (type) {
69 case F16:
70 return ConstantR0<Eigen::half>(builder,
71 Eigen::NumTraits<Eigen::half>::lowest());
72 case BF16:
73 return ConstantR0<bfloat16>(builder, bfloat16::lowest());
74 case F32:
75 return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
76 case F64:
77 return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
78 default:
79 return MinValue(builder, type);
80 }
81 }
82
MinPositiveNormalValue(XlaBuilder * builder,PrimitiveType type)83 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
84 switch (type) {
85 case F16:
86 return ConstantR0<Eigen::half>(builder,
87 std::numeric_limits<Eigen::half>::min());
88 case BF16:
89 return ConstantR0<bfloat16>(builder, bfloat16::min_positive_normal());
90 case F32:
91 return ConstantR0<float>(builder, std::numeric_limits<float>::min());
92 case F64:
93 return ConstantR0<double>(builder, std::numeric_limits<double>::min());
94 default:
95 return builder->ReportError(
96 InvalidArgument("Invalid type for MinPositiveNormalValue (%s).",
97 PrimitiveType_Name(type)));
98 }
99 }
100
MaxValue(XlaBuilder * builder,PrimitiveType type)101 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
102 return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
103 }
104
MaxFiniteValue(XlaBuilder * builder,PrimitiveType type)105 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
106 switch (type) {
107 case F16:
108 return ConstantR0<Eigen::half>(builder,
109 Eigen::NumTraits<Eigen::half>::highest());
110 case BF16:
111 return ConstantR0<bfloat16>(builder, bfloat16::highest());
112 case F32:
113 return ConstantR0<float>(builder, std::numeric_limits<float>::max());
114 case F64:
115 return ConstantR0<double>(builder, std::numeric_limits<double>::max());
116 default:
117 return MaxValue(builder, type);
118 }
119 }
120
NanValue(XlaBuilder * builder,PrimitiveType type)121 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
122 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
123 switch (type) {
124 case F16:
125 return ConstantR0<Eigen::half>(
126 builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
127 case BF16:
128 return ConstantR0<bfloat16>(
129 builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
130 case F32:
131 return ConstantR0<float>(builder,
132 std::numeric_limits<float>::quiet_NaN());
133 case F64:
134 return ConstantR0<double>(builder,
135 std::numeric_limits<double>::quiet_NaN());
136 default:
137 return InvalidArgument(
138 "Operand to NanValue was %s, but must be a real-valued "
139 "floating-point type.",
140 PrimitiveType_Name(type));
141 }
142 });
143 }
144
145 } // namespace xla
146