1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines helper routines for XLA compilation.
17
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19
20 #include "absl/synchronization/notification.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/tf2xla/lib/util.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/common_runtime/device_mgr.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/device.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/stream_executor/stream.h"
37
38 namespace tensorflow {
39
Zero(xla::XlaBuilder * b,DataType data_type)40 xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
41 xla::PrimitiveType type;
42 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
43 return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
44 }
45
One(xla::XlaBuilder * b,DataType data_type)46 xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
47 xla::PrimitiveType type;
48 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
49 return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
50 }
51
IntegerLiteral(xla::XlaBuilder * b,DataType data_type,int64_t value)52 xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
53 int64_t value) {
54 xla::PrimitiveType type;
55 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
56 return ::tensorflow::IntegerLiteral(b, type, value);
57 }
58
FloatLiteral(xla::XlaBuilder * b,DataType data_type,double value)59 xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
60 double value) {
61 xla::PrimitiveType type;
62 TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
63 return ::tensorflow::FloatLiteral(b, type, value);
64 }
65
ReshapeLiteral(const xla::Literal & input,absl::Span<const int64> dimensions,xla::Literal * output)66 /* static */ Status XlaHelpers::ReshapeLiteral(
67 const xla::Literal& input, absl::Span<const int64> dimensions,
68 xla::Literal* output) {
69 if (input.shape().IsTuple()) {
70 return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
71 }
72 xla::Shape shape =
73 xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
74 int64_t elements_before = xla::ShapeUtil::ElementsIn(input.shape());
75 int64_t elements_after = xla::ShapeUtil::ElementsIn(shape);
76 if (elements_before != elements_after) {
77 return errors::InvalidArgument(
78 "Shapes before and after ReshapeLiteral have different numbers of "
79 "elements.");
80 }
81
82 *output = input.Clone();
83 output->mutable_shape_do_not_use()->Swap(&shape);
84 return Status::OK();
85 }
86
OneHot(xla::XlaBuilder * builder,int64_t depth,int axis,DataType index_type,const TensorShape & indices_shape,const xla::XlaOp & indices,const xla::XlaOp & on_value,const xla::XlaOp & off_value,xla::XlaOp * one_hot)87 Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis,
88 DataType index_type, const TensorShape& indices_shape,
89 const xla::XlaOp& indices, const xla::XlaOp& on_value,
90 const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
91 // Broadcast the linspace constant across the indices along the new axis,
92 // and test equality at each position.
93 std::vector<int64> broadcast_dims(indices_shape.dims());
94 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
95 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
96
97 TensorShape output_shape = indices_shape;
98 output_shape.InsertDim(axis, depth);
99 xla::Shape iota_shape;
100 TF_RETURN_IF_ERROR(
101 TensorShapeToXLAShape(index_type, output_shape, &iota_shape));
102
103 // Selects the user-provided off_value and on_value values.
104 *one_hot = xla::Select(
105 xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
106 xla::Broadcast(on_value, output_shape.dim_sizes()),
107 xla::Broadcast(off_value, output_shape.dim_sizes()));
108 return Status::OK();
109 }
110
SumAccumulationType(const DataType & dtype)111 DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
112 // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
113 // repeated floating point additions.
114 if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
115 return DT_FLOAT;
116 }
117 // Upcast small integer types to 32 bit to avoid overflow.
118 if (dtype == DT_INT8 || dtype == DT_INT16) {
119 return DT_INT32;
120 }
121 if (dtype == DT_UINT8 || dtype == DT_UINT16) {
122 return DT_UINT32;
123 }
124 return dtype;
125 }
126
ConvertElementType(const xla::XlaOp & operand,const DataType new_element_type)127 xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
128 const DataType new_element_type) {
129 xla::PrimitiveType convert_to;
130 TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
131 return xla::ConvertElementType(operand, convert_to);
132 }
133
IdentityShapeRepresentationFn()134 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
135 return [](const TensorShape& shape, DataType dtype,
136 bool use_fast_memory) -> StatusOr<xla::Shape> {
137 xla::Shape xla_shape;
138 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
139 return xla_shape;
140 };
141 }
142
143 // Rewrites the layout of xla_shape if there is tiled sharding.
RewriteLayoutWithShardedShape(const absl::optional<xla::HloSharding> & sharding,bool use_fast_memory,XlaHelpers::ShapeRepresentationFn shape_representation_fn,xla::Shape * xla_shape)144 Status RewriteLayoutWithShardedShape(
145 const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
146 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
147 xla::Shape* xla_shape) {
148 if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
149 // After sharding, per core shape might have different layout. For example,
150 // before sharding, a shape [128, 128] will be assigned default
151 // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
152 // the sharded shapes will have minor-to-major {0, 1}.
153 //
154 // As a result, for sharded shapes, we set their layout to per core shape's
155 // layout.
156 //
157 // TODO(endlessroad): for variable input & update, we might have
158 // different layouts which will prevent input output aliasing and
159 // increase memory usage. Investigate such cases.
160 int64_t device = *sharding->tile_assignment().begin();
161 std::vector<int64> offset =
162 sharding->TileOffsetForDevice(*xla_shape, device);
163 std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
164 std::vector<int64> dimensions(xla_shape->rank());
165 for (int64_t i = 0; i < xla_shape->rank(); ++i) {
166 dimensions[i] = limit[i] - offset[i];
167 }
168 xla::Shape per_device_xla_shape =
169 xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
170 TensorShape per_device_tensor_shape;
171 TF_RETURN_IF_ERROR(
172 XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
173 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
174 xla_shape->element_type()));
175 TF_ASSIGN_OR_RETURN(per_device_xla_shape,
176 shape_representation_fn(per_device_tensor_shape, dtype,
177 use_fast_memory));
178 *xla_shape->mutable_layout() = per_device_xla_shape.layout();
179 }
180 return Status::OK();
181 }
182
183 // There is a shape_representation_fn or sharding for an output, this function
184 // uses a reshape to fix the layout.
ReshapeWithCorrectRepresentationAndSharding(xla::XlaBuilder * builder,xla::XlaOp original,xla::Shape original_shape,XlaHelpers::ShapeRepresentationFn shape_representation_fn,absl::optional<xla::OpSharding> sharding,bool fast_mem)185 StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
186 xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
187 XlaHelpers::ShapeRepresentationFn shape_representation_fn,
188 absl::optional<xla::OpSharding> sharding, bool fast_mem) {
189 if (original_shape.IsTuple()) {
190 std::vector<xla::XlaOp> elements;
191 for (int64_t i = 0; i < original_shape.tuple_shapes_size(); ++i) {
192 auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
193 TF_ASSIGN_OR_RETURN(auto element,
194 ReshapeWithCorrectRepresentationAndSharding(
195 builder, xla::GetTupleElement(original, i),
196 original_shape.tuple_shapes(i),
197 shape_representation_fn, subsharding, fast_mem));
198 elements.push_back(element);
199 }
200 return xla::Tuple(builder, elements);
201 }
202 if (!original_shape.IsArray()) return original;
203 TensorShape shape;
204 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
205 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
206 original_shape.element_type()));
207 TF_ASSIGN_OR_RETURN(auto to_shape,
208 shape_representation_fn(shape, dtype, fast_mem));
209 if (sharding) {
210 TF_ASSIGN_OR_RETURN(auto hlo_sharding,
211 xla::HloSharding::FromProto(*sharding));
212 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
213 hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
214 }
215 if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
216 for (int64_t i = 0; i < original_shape.rank(); ++i) {
217 to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
218 }
219 }
220 return xla::Reshape(to_shape, original);
221 }
222
ResolveDeviceAssignment(OpKernelContext * ctx,const absl::optional<XlaCompilationResult::CollectiveReduceV2OpInfo> & collective_reduce_info)223 StatusOr<absl::optional<xla::DeviceAssignment>> ResolveDeviceAssignment(
224 OpKernelContext* ctx,
225 const absl::optional<XlaCompilationResult::CollectiveReduceV2OpInfo>&
226 collective_reduce_info) {
227 static const int kTimeoutSeconds = 30;
228 if (!collective_reduce_info) {
229 // An empty device assignment is sufficient for the case where no
230 // collectives are present.
231 return {{absl::nullopt}};
232 }
233 if (ctx->collective_executor() == nullptr) {
234 return errors::InvalidArgument(
235 "CollectiveExecutor is required but not available");
236 }
237
238 auto params = core::RefCountPtr<CollectiveParams>(new CollectiveParams());
239 params->name = "xla-reduction-compilation";
240 params->group.device_type =
241 DeviceType{static_cast<Device*>(ctx->device())->device_type()};
242 params->group.group_size = collective_reduce_info->group_size;
243 params->group.group_key = collective_reduce_info->group_key;
244 params->instance.type = REDUCTION_COLLECTIVE;
245 params->instance.impl_details.communication_hint = "nccl";
246 params->instance.impl_details.timeout_seconds = kTimeoutSeconds;
247 params->instance.impl_details.collective_name = "NcclReduce";
248 // TODO(cheshire): Avoid passing a dummy shape, TF runtime does not resolve
249 // devices otherwise.
250 params->instance.shape = TensorShape({1});
251
252 Status st;
253 absl::Notification n;
254 ctx->collective_executor()->CompleteParamsAsync(
255 ctx->device()->attributes(), params.get(), ctx->cancellation_manager(),
256 [&](const Status& s) {
257 st = s;
258 n.Notify();
259 });
260 if (!n.WaitForNotificationWithTimeout(absl::Seconds(kTimeoutSeconds))) {
261 return errors::InvalidArgument("Timeout reached");
262 }
263 TF_RETURN_IF_ERROR(st);
264
265 xla::DeviceAssignment out(params->group.group_size, 1);
266 for (int device_idx = 0; device_idx < params->group.group_size;
267 device_idx++) {
268 const std::string& device_name = params->group.devices[device_idx].name();
269 Device* resolved_device = nullptr;
270 TF_RETURN_IF_ERROR(ctx->function_library()->device_mgr()->LookupDevice(
271 device_name, &resolved_device));
272
273 // TODO(cheshire): CPU support.
274 // Both GPU and TPU uses GpuDeviceInfo, see DeviceBase::GpuDeviceInfo.
275 const DeviceBase::GpuDeviceInfo* gpu_device_info =
276 resolved_device->tensorflow_gpu_device_info();
277 if (!gpu_device_info || !gpu_device_info->stream) {
278 return errors::Internal(
279 "CollectiveReduceV2Op compilation is only supported on GPUs");
280 }
281
282 out(device_idx, 0) = gpu_device_info->stream->parent()->device_ordinal();
283 }
284
285 return {{out}};
286 }
287
DefinitionLocationMsg(const absl::optional<ManagedStackTrace> & stack_trace)288 std::string DefinitionLocationMsg(
289 const absl::optional<ManagedStackTrace>& stack_trace) {
290 if (stack_trace) {
291 std::vector<StackFrame> stack_frames =
292 stack_trace->ToStackFrames({}, IsInternalFrameForFilename,
293 /*reverse_traversal=*/true,
294 /*limit=*/1);
295 if (!stack_frames.empty()) {
296 const StackFrame& last_frame = stack_frames[0];
297 return absl::StrCat(" (defined @ ", last_frame.file_name, ":",
298 last_frame.line_number, ")");
299 }
300 }
301 return "";
302 }
303
304 } // end namespace tensorflow
305