1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/constants.h"
17
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/util.h"
20
21 namespace xla {
22
Zero(XlaBuilder * builder,PrimitiveType type)23 XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
24 return ConstantLiteral(builder, LiteralUtil::Zero(type));
25 }
26
Zeros(XlaBuilder * builder,const Shape & shape)27 XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
28 return Broadcast(Zero(builder, shape.element_type()), shape.dimensions());
29 }
30
ZerosLike(XlaOp prototype)31 XlaOp ZerosLike(XlaOp prototype) {
32 XlaBuilder* builder = prototype.builder();
33 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
34 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
35 return Zeros(builder, shape);
36 });
37 }
38
One(XlaBuilder * builder,PrimitiveType type)39 XlaOp One(XlaBuilder* builder, PrimitiveType type) {
40 return ConstantLiteral(builder, LiteralUtil::One(type));
41 }
42
Epsilon(XlaBuilder * builder,PrimitiveType type)43 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
44 switch (type) {
45 case F16:
46 return ConstantR0<Eigen::half>(
47 builder,
48 static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
49 case BF16:
50 return ConstantR0<Eigen::bfloat16>(
51 builder, static_cast<Eigen::bfloat16>(
52 Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
53 case F32:
54 return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
55 case F64:
56 return ConstantR0<double>(builder,
57 std::numeric_limits<double>::epsilon());
58 default:
59 return builder->ReportError(InvalidArgument(
60 "Invalid type for Epsilon (%s).", PrimitiveType_Name(type)));
61 }
62 }
63
MinValue(XlaBuilder * builder,PrimitiveType type)64 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
65 return ConstantLiteral(builder, LiteralUtil::MinValue(type));
66 }
67
MinFiniteValue(XlaBuilder * builder,PrimitiveType type)68 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
69 switch (type) {
70 case F16:
71 return ConstantR0<Eigen::half>(builder,
72 Eigen::NumTraits<Eigen::half>::lowest());
73 case BF16:
74 return ConstantR0<Eigen::bfloat16>(
75 builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
76 case F32:
77 return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
78 case F64:
79 return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
80 default:
81 return MinValue(builder, type);
82 }
83 }
84
MinPositiveNormalValue(XlaBuilder * builder,PrimitiveType type)85 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
86 switch (type) {
87 case F16:
88 return ConstantR0<Eigen::half>(builder,
89 std::numeric_limits<Eigen::half>::min());
90 case BF16:
91 return ConstantR0<Eigen::bfloat16>(
92 builder, std::numeric_limits<Eigen::bfloat16>::min());
93 case F32:
94 return ConstantR0<float>(builder, std::numeric_limits<float>::min());
95 case F64:
96 return ConstantR0<double>(builder, std::numeric_limits<double>::min());
97 default:
98 return builder->ReportError(
99 InvalidArgument("Invalid type for MinPositiveNormalValue (%s).",
100 PrimitiveType_Name(type)));
101 }
102 }
103
MaxValue(XlaBuilder * builder,PrimitiveType type)104 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
105 return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
106 }
107
MaxFiniteValue(XlaBuilder * builder,PrimitiveType type)108 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
109 switch (type) {
110 case F16:
111 return ConstantR0<Eigen::half>(builder,
112 Eigen::NumTraits<Eigen::half>::highest());
113 case BF16:
114 return ConstantR0<Eigen::bfloat16>(
115 builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
116 case F32:
117 return ConstantR0<float>(builder, std::numeric_limits<float>::max());
118 case F64:
119 return ConstantR0<double>(builder, std::numeric_limits<double>::max());
120 default:
121 return MaxValue(builder, type);
122 }
123 }
124
NanValue(XlaBuilder * builder,PrimitiveType type)125 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
126 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
127 switch (type) {
128 case F16:
129 return ConstantR0<Eigen::half>(
130 builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
131 case BF16:
132 return ConstantR0<Eigen::bfloat16>(
133 builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
134 case F32:
135 return ConstantR0<float>(builder,
136 std::numeric_limits<float>::quiet_NaN());
137 case F64:
138 return ConstantR0<double>(builder,
139 std::numeric_limits<double>::quiet_NaN());
140 default:
141 return InvalidArgument(
142 "Operand to NanValue was %s, but must be a real-valued "
143 "floating-point type.",
144 PrimitiveType_Name(type));
145 }
146 });
147 }
148
149 } // namespace xla
150