1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines helper routines for XLA compilation.
17
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/lib/util.h"
20
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/lib/core/status.h"
32
33 namespace tensorflow {
34
Zero(xla::XlaBuilder * b,DataType data_type)35 xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
36 xla::PrimitiveType type;
37 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
38 return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
39 }
40
One(xla::XlaBuilder * b,DataType data_type)41 xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
42 xla::PrimitiveType type;
43 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
44 return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
45 }
46
IntegerLiteral(xla::XlaBuilder * b,DataType data_type,int64 value)47 xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
48 int64 value) {
49 xla::PrimitiveType type;
50 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
51 return ::tensorflow::IntegerLiteral(b, type, value);
52 }
53
FloatLiteral(xla::XlaBuilder * b,DataType data_type,double value)54 xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
55 double value) {
56 xla::PrimitiveType type;
57 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
58 return ::tensorflow::FloatLiteral(b, type, value);
59 }
60
ReshapeLiteral(const xla::Literal & input,absl::Span<const int64> dimensions,xla::Literal * output)61 /* static */ Status XlaHelpers::ReshapeLiteral(
62 const xla::Literal& input, absl::Span<const int64> dimensions,
63 xla::Literal* output) {
64 if (input.shape().IsTuple()) {
65 return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
66 }
67 xla::Shape shape =
68 xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
69 int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape());
70 int64 elements_after = xla::ShapeUtil::ElementsIn(shape);
71 if (elements_before != elements_after) {
72 return errors::InvalidArgument(
73 "Shapes before and after ReshapeLiteral have different numbers of "
74 "elements.");
75 }
76
77 *output = input.Clone();
78 output->mutable_shape_do_not_use()->Swap(&shape);
79 return Status::OK();
80 }
81
OneHot(xla::XlaBuilder * builder,int64 depth,int axis,DataType index_type,const TensorShape & indices_shape,const xla::XlaOp & indices,const xla::XlaOp & on_value,const xla::XlaOp & off_value,xla::XlaOp * one_hot)82 Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
83 DataType index_type, const TensorShape& indices_shape,
84 const xla::XlaOp& indices, const xla::XlaOp& on_value,
85 const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
86 // Broadcast the linspace constant across the indices along the new axis,
87 // and test equality at each position.
88 std::vector<int64> broadcast_dims(indices_shape.dims());
89 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
90 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
91
92 TensorShape output_shape = indices_shape;
93 output_shape.InsertDim(axis, depth);
94 xla::Shape iota_shape;
95 TF_RETURN_IF_ERROR(
96 TensorShapeToXLAShape(index_type, output_shape, &iota_shape));
97
98 // Selects the user-provided off_value and on_value values.
99 *one_hot = xla::Select(
100 xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
101 xla::Broadcast(on_value, output_shape.dim_sizes()),
102 xla::Broadcast(off_value, output_shape.dim_sizes()));
103 return Status::OK();
104 }
105
SumAccumulationType(const DataType & dtype)106 DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
107 // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
108 // repeated floating point additions.
109 if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
110 return DT_FLOAT;
111 }
112 // Upcast small integer types to 32 bit to avoid overflow.
113 if (dtype == DT_INT8 || dtype == DT_INT16) {
114 return DT_INT32;
115 }
116 if (dtype == DT_UINT8 || dtype == DT_UINT16) {
117 return DT_UINT32;
118 }
119 return dtype;
120 }
121
ConvertElementType(const xla::XlaOp & operand,const DataType new_element_type)122 xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
123 const DataType new_element_type) {
124 xla::PrimitiveType convert_to;
125 TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
126 return xla::ConvertElementType(operand, convert_to);
127 }
128
IdentityShapeRepresentationFn()129 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
130 return [](const TensorShape& shape, DataType dtype,
131 bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
132 xla::Shape xla_shape;
133 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
134 return xla_shape;
135 };
136 }
137
138 // Rewrites the layout of xla_shape if there is tiled sharding.
RewriteLayoutWithShardedShape(const absl::optional<xla::HloSharding> & sharding,bool use_fast_memory,XlaHelpers::ShapeRepresentationFn shape_representation_fn,xla::Shape * xla_shape)139 Status RewriteLayoutWithShardedShape(
140 const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
141 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
142 xla::Shape* xla_shape) {
143 if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
144 // After sharding, per core shape might have different layout. For example,
145 // before sharding, a shape [128, 128] will be assigned default
146 // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
147 // the sharded shapes will have minor-to-major {0, 1}.
148 //
149 // As a result, for sharded shapes, we set their layout to per core shape's
150 // layout.
151 //
152 // TODO(endlessroad): for variable input & update, we might have
153 // different layouts which will prevent input output aliasing and
154 // increase memory usage. Investigate such cases.
155 int64 device = *sharding->tile_assignment().begin();
156 std::vector<int64> offset =
157 sharding->TileOffsetForDevice(*xla_shape, device);
158 std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
159 std::vector<int64> dimensions(xla_shape->rank());
160 for (int64 i = 0; i < xla_shape->rank(); ++i) {
161 dimensions[i] = limit[i] - offset[i];
162 }
163 xla::Shape per_device_xla_shape =
164 xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
165 TensorShape per_device_tensor_shape;
166 TF_RETURN_IF_ERROR(
167 XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
168 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
169 xla_shape->element_type()));
170 TF_ASSIGN_OR_RETURN(per_device_xla_shape,
171 shape_representation_fn(per_device_tensor_shape, dtype,
172 use_fast_memory));
173 *xla_shape->mutable_layout() = per_device_xla_shape.layout();
174 }
175 return Status::OK();
176 }
177
178 // There is a shape_representation_fn or sharding for an output, this function
179 // uses a reshape to fix the layout.
ReshapeWithCorrectRepresentationAndSharding(xla::XlaBuilder * builder,xla::XlaOp original,xla::Shape original_shape,XlaHelpers::ShapeRepresentationFn shape_representation_fn,absl::optional<xla::OpSharding> sharding,bool fast_mem)180 xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
181 xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
182 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
183 absl::optional<xla::OpSharding> sharding, bool fast_mem) {
184 if (original_shape.IsTuple()) {
185 std::vector<xla::XlaOp> elements;
186 for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
187 auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
188 TF_ASSIGN_OR_RETURN(auto element,
189 ReshapeWithCorrectRepresentationAndSharding(
190 builder, xla::GetTupleElement(original, i),
191 original_shape.tuple_shapes(i),
192 shape_representation_fn, subsharding, fast_mem));
193 elements.push_back(element);
194 }
195 return xla::Tuple(builder, elements);
196 }
197 if (!original_shape.IsArray()) return original;
198 TensorShape shape;
199 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
200 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
201 original_shape.element_type()));
202 TF_ASSIGN_OR_RETURN(auto to_shape,
203 shape_representation_fn(shape, dtype, fast_mem));
204 if (sharding) {
205 TF_ASSIGN_OR_RETURN(auto hlo_sharding,
206 xla::HloSharding::FromProto(*sharding));
207 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
208 hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
209 }
210 if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
211 for (int64 i = 0; i < original_shape.rank(); ++i) {
212 to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
213 }
214 }
215 return xla::Reshape(to_shape, original);
216 }
217
218 } // end namespace tensorflow
219