1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/constants.h"
17
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/util.h"
20
21 namespace xla {
22
Zero(XlaBuilder * builder,PrimitiveType type)23 XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
24 return ConstantLiteral(builder, LiteralUtil::Zero(type));
25 }
26
Zeros(XlaBuilder * builder,const Shape & shape)27 XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
28 return Broadcast(Zero(builder, shape.element_type()),
29 AsInt64Slice(shape.dimensions()));
30 }
31
ZerosLike(XlaOp prototype)32 XlaOp ZerosLike(XlaOp prototype) {
33 XlaBuilder* builder = prototype.builder();
34 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
35 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
36 return Zeros(builder, shape);
37 });
38 }
39
One(XlaBuilder * builder,PrimitiveType type)40 XlaOp One(XlaBuilder* builder, PrimitiveType type) {
41 return ConstantLiteral(builder, LiteralUtil::One(type));
42 }
43
Epsilon(XlaBuilder * builder,PrimitiveType type)44 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
45 switch (type) {
46 case F16:
47 return ConstantR0<Eigen::half>(
48 builder,
49 static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
50 case BF16:
51 return ConstantR0<Eigen::bfloat16>(
52 builder, static_cast<Eigen::bfloat16>(
53 Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
54 case F32:
55 return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
56 case F64:
57 return ConstantR0<double>(builder,
58 std::numeric_limits<double>::epsilon());
59 default:
60 return builder->ReportError(InvalidArgument(
61 "Invalid type for Epsilon (%s).", PrimitiveType_Name(type)));
62 }
63 }
64
MinValue(XlaBuilder * builder,PrimitiveType type)65 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
66 return ConstantLiteral(builder, LiteralUtil::MinValue(type));
67 }
68
MinFiniteValue(XlaBuilder * builder,PrimitiveType type)69 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
70 switch (type) {
71 case F16:
72 return ConstantR0<Eigen::half>(builder,
73 Eigen::NumTraits<Eigen::half>::lowest());
74 case BF16:
75 return ConstantR0<Eigen::bfloat16>(
76 builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
77 case F32:
78 return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
79 case F64:
80 return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
81 default:
82 return MinValue(builder, type);
83 }
84 }
85
MinPositiveNormalValue(XlaBuilder * builder,PrimitiveType type)86 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
87 switch (type) {
88 case F16:
89 return ConstantR0<Eigen::half>(builder,
90 std::numeric_limits<Eigen::half>::min());
91 case BF16:
92 return ConstantR0<Eigen::bfloat16>(
93 builder, std::numeric_limits<Eigen::bfloat16>::min());
94 case F32:
95 return ConstantR0<float>(builder, std::numeric_limits<float>::min());
96 case F64:
97 return ConstantR0<double>(builder, std::numeric_limits<double>::min());
98 default:
99 return builder->ReportError(
100 InvalidArgument("Invalid type for MinPositiveNormalValue (%s).",
101 PrimitiveType_Name(type)));
102 }
103 }
104
MaxValue(XlaBuilder * builder,PrimitiveType type)105 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
106 return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
107 }
108
MaxFiniteValue(XlaBuilder * builder,PrimitiveType type)109 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
110 switch (type) {
111 case F16:
112 return ConstantR0<Eigen::half>(builder,
113 Eigen::NumTraits<Eigen::half>::highest());
114 case BF16:
115 return ConstantR0<Eigen::bfloat16>(
116 builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
117 case F32:
118 return ConstantR0<float>(builder, std::numeric_limits<float>::max());
119 case F64:
120 return ConstantR0<double>(builder, std::numeric_limits<double>::max());
121 default:
122 return MaxValue(builder, type);
123 }
124 }
125
NanValue(XlaBuilder * builder,PrimitiveType type)126 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
127 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
128 switch (type) {
129 case F16:
130 return ConstantR0<Eigen::half>(
131 builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
132 case BF16:
133 return ConstantR0<Eigen::bfloat16>(
134 builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
135 case F32:
136 return ConstantR0<float>(builder,
137 std::numeric_limits<float>::quiet_NaN());
138 case F64:
139 return ConstantR0<double>(builder,
140 std::numeric_limits<double>::quiet_NaN());
141 default:
142 return InvalidArgument(
143 "Operand to NanValue was %s, but must be a real-valued "
144 "floating-point type.",
145 PrimitiveType_Name(type));
146 }
147 });
148 }
149
150 } // namespace xla
151