• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helper routines for XLA compilation.
17 
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 
20 #include "absl/synchronization/notification.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/tf2xla/lib/util.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/common_runtime/device_mgr.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/device.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/stream_executor/stream.h"
37 
38 namespace tensorflow {
39 
Zero(xla::XlaBuilder * b,DataType data_type)40 xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
41   xla::PrimitiveType type;
42   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
43   return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
44 }
45 
One(xla::XlaBuilder * b,DataType data_type)46 xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
47   xla::PrimitiveType type;
48   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
49   return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
50 }
51 
IntegerLiteral(xla::XlaBuilder * b,DataType data_type,int64_t value)52 xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
53                                       int64_t value) {
54   xla::PrimitiveType type;
55   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
56   return ::tensorflow::IntegerLiteral(b, type, value);
57 }
58 
FloatLiteral(xla::XlaBuilder * b,DataType data_type,double value)59 xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
60                                     double value) {
61   xla::PrimitiveType type;
62   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
63   return ::tensorflow::FloatLiteral(b, type, value);
64 }
65 
ReshapeLiteral(const xla::Literal & input,absl::Span<const int64> dimensions,xla::Literal * output)66 /* static */ Status XlaHelpers::ReshapeLiteral(
67     const xla::Literal& input, absl::Span<const int64> dimensions,
68     xla::Literal* output) {
69   if (input.shape().IsTuple()) {
70     return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
71   }
72   xla::Shape shape =
73       xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
74   int64_t elements_before = xla::ShapeUtil::ElementsIn(input.shape());
75   int64_t elements_after = xla::ShapeUtil::ElementsIn(shape);
76   if (elements_before != elements_after) {
77     return errors::InvalidArgument(
78         "Shapes before and after ReshapeLiteral have different numbers of "
79         "elements.");
80   }
81 
82   *output = input.Clone();
83   output->mutable_shape_do_not_use()->Swap(&shape);
84   return Status::OK();
85 }
86 
OneHot(xla::XlaBuilder * builder,int64_t depth,int axis,DataType index_type,const TensorShape & indices_shape,const xla::XlaOp & indices,const xla::XlaOp & on_value,const xla::XlaOp & off_value,xla::XlaOp * one_hot)87 Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis,
88                           DataType index_type, const TensorShape& indices_shape,
89                           const xla::XlaOp& indices, const xla::XlaOp& on_value,
90                           const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
91   // Broadcast the linspace constant across the indices along the new axis,
92   // and test equality at each position.
93   std::vector<int64> broadcast_dims(indices_shape.dims());
94   std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
95   std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
96 
97   TensorShape output_shape = indices_shape;
98   output_shape.InsertDim(axis, depth);
99   xla::Shape iota_shape;
100   TF_RETURN_IF_ERROR(
101       TensorShapeToXLAShape(index_type, output_shape, &iota_shape));
102 
103   // Selects the user-provided off_value and on_value values.
104   *one_hot = xla::Select(
105       xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
106       xla::Broadcast(on_value, output_shape.dim_sizes()),
107       xla::Broadcast(off_value, output_shape.dim_sizes()));
108   return Status::OK();
109 }
110 
SumAccumulationType(const DataType & dtype)111 DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
112   // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
113   // repeated floating point additions.
114   if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
115     return DT_FLOAT;
116   }
117   // Upcast small integer types to 32 bit to avoid overflow.
118   if (dtype == DT_INT8 || dtype == DT_INT16) {
119     return DT_INT32;
120   }
121   if (dtype == DT_UINT8 || dtype == DT_UINT16) {
122     return DT_UINT32;
123   }
124   return dtype;
125 }
126 
ConvertElementType(const xla::XlaOp & operand,const DataType new_element_type)127 xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
128                                           const DataType new_element_type) {
129   xla::PrimitiveType convert_to;
130   TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
131   return xla::ConvertElementType(operand, convert_to);
132 }
133 
IdentityShapeRepresentationFn()134 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
135   return [](const TensorShape& shape, DataType dtype,
136             bool use_fast_memory) -> StatusOr<xla::Shape> {
137     xla::Shape xla_shape;
138     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
139     return xla_shape;
140   };
141 }
142 
143 // Rewrites the layout of xla_shape if there is tiled sharding.
RewriteLayoutWithShardedShape(const absl::optional<xla::HloSharding> & sharding,bool use_fast_memory,XlaHelpers::ShapeRepresentationFn shape_representation_fn,xla::Shape * xla_shape)144 Status RewriteLayoutWithShardedShape(
145     const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
146     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
147     xla::Shape* xla_shape) {
148   if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
149     // After sharding, per core shape might have different layout. For example,
150     // before sharding, a shape [128, 128] will be assigned default
151     // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
152     // the sharded shapes will have minor-to-major {0, 1}.
153     //
154     // As a result, for sharded shapes, we set their layout to per core shape's
155     // layout.
156     //
157     // TODO(endlessroad): for variable input & update, we might have
158     // different layouts which will prevent input output aliasing and
159     // increase memory usage. Investigate such cases.
160     int64_t device = *sharding->tile_assignment().begin();
161     std::vector<int64> offset =
162         sharding->TileOffsetForDevice(*xla_shape, device);
163     std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
164     std::vector<int64> dimensions(xla_shape->rank());
165     for (int64_t i = 0; i < xla_shape->rank(); ++i) {
166       dimensions[i] = limit[i] - offset[i];
167     }
168     xla::Shape per_device_xla_shape =
169         xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
170     TensorShape per_device_tensor_shape;
171     TF_RETURN_IF_ERROR(
172         XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
173     TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
174                                             xla_shape->element_type()));
175     TF_ASSIGN_OR_RETURN(per_device_xla_shape,
176                         shape_representation_fn(per_device_tensor_shape, dtype,
177                                                 use_fast_memory));
178     *xla_shape->mutable_layout() = per_device_xla_shape.layout();
179   }
180   return Status::OK();
181 }
182 
183 // There is a shape_representation_fn or sharding for an output, this function
184 // uses a reshape to fix the layout.
ReshapeWithCorrectRepresentationAndSharding(xla::XlaBuilder * builder,xla::XlaOp original,xla::Shape original_shape,XlaHelpers::ShapeRepresentationFn shape_representation_fn,absl::optional<xla::OpSharding> sharding,bool fast_mem)185 StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
186     xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
187     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
188     absl::optional<xla::OpSharding> sharding, bool fast_mem) {
189   if (original_shape.IsTuple()) {
190     std::vector<xla::XlaOp> elements;
191     for (int64_t i = 0; i < original_shape.tuple_shapes_size(); ++i) {
192       auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
193       TF_ASSIGN_OR_RETURN(auto element,
194                           ReshapeWithCorrectRepresentationAndSharding(
195                               builder, xla::GetTupleElement(original, i),
196                               original_shape.tuple_shapes(i),
197                               shape_representation_fn, subsharding, fast_mem));
198       elements.push_back(element);
199     }
200     return xla::Tuple(builder, elements);
201   }
202   if (!original_shape.IsArray()) return original;
203   TensorShape shape;
204   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
205   TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
206                                           original_shape.element_type()));
207   TF_ASSIGN_OR_RETURN(auto to_shape,
208                       shape_representation_fn(shape, dtype, fast_mem));
209   if (sharding) {
210     TF_ASSIGN_OR_RETURN(auto hlo_sharding,
211                         xla::HloSharding::FromProto(*sharding));
212     TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
213         hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
214   }
215   if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
216     for (int64_t i = 0; i < original_shape.rank(); ++i) {
217       to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
218     }
219   }
220   return xla::Reshape(to_shape, original);
221 }
222 
ResolveDeviceAssignment(OpKernelContext * ctx,const absl::optional<XlaCompilationResult::CollectiveReduceV2OpInfo> & collective_reduce_info)223 StatusOr<absl::optional<xla::DeviceAssignment>> ResolveDeviceAssignment(
224     OpKernelContext* ctx,
225     const absl::optional<XlaCompilationResult::CollectiveReduceV2OpInfo>&
226         collective_reduce_info) {
227   static const int kTimeoutSeconds = 30;
228   if (!collective_reduce_info) {
229     // An empty device assignment is sufficient for the case where no
230     // collectives are present.
231     return {{absl::nullopt}};
232   }
233   if (ctx->collective_executor() == nullptr) {
234     return errors::InvalidArgument(
235         "CollectiveExecutor is required but not available");
236   }
237 
238   auto params = core::RefCountPtr<CollectiveParams>(new CollectiveParams());
239   params->name = "xla-reduction-compilation";
240   params->group.device_type =
241       DeviceType{static_cast<Device*>(ctx->device())->device_type()};
242   params->group.group_size = collective_reduce_info->group_size;
243   params->group.group_key = collective_reduce_info->group_key;
244   params->instance.type = REDUCTION_COLLECTIVE;
245   params->instance.impl_details.communication_hint = "nccl";
246   params->instance.impl_details.timeout_seconds = kTimeoutSeconds;
247   params->instance.impl_details.collective_name = "NcclReduce";
248   // TODO(cheshire): Avoid passing a dummy shape, TF runtime does not resolve
249   // devices otherwise.
250   params->instance.shape = TensorShape({1});
251 
252   Status st;
253   absl::Notification n;
254   ctx->collective_executor()->CompleteParamsAsync(
255       ctx->device()->attributes(), params.get(), ctx->cancellation_manager(),
256       [&](const Status& s) {
257         st = s;
258         n.Notify();
259       });
260   if (!n.WaitForNotificationWithTimeout(absl::Seconds(kTimeoutSeconds))) {
261     return errors::InvalidArgument("Timeout reached");
262   }
263   TF_RETURN_IF_ERROR(st);
264 
265   xla::DeviceAssignment out(params->group.group_size, 1);
266   for (int device_idx = 0; device_idx < params->group.group_size;
267        device_idx++) {
268     const std::string& device_name = params->group.devices[device_idx].name();
269     Device* resolved_device = nullptr;
270     TF_RETURN_IF_ERROR(ctx->function_library()->device_mgr()->LookupDevice(
271         device_name, &resolved_device));
272 
273     // TODO(cheshire): CPU support.
274     // Both GPU and TPU uses GpuDeviceInfo, see DeviceBase::GpuDeviceInfo.
275     const DeviceBase::GpuDeviceInfo* gpu_device_info =
276         resolved_device->tensorflow_gpu_device_info();
277     if (!gpu_device_info || !gpu_device_info->stream) {
278       return errors::Internal(
279           "CollectiveReduceV2Op compilation is only supported on GPUs");
280     }
281 
282     out(device_idx, 0) = gpu_device_info->stream->parent()->device_ordinal();
283   }
284 
285   return {{out}};
286 }
287 
DefinitionLocationMsg(const absl::optional<ManagedStackTrace> & stack_trace)288 std::string DefinitionLocationMsg(
289     const absl::optional<ManagedStackTrace>& stack_trace) {
290   if (stack_trace) {
291     std::vector<StackFrame> stack_frames =
292         stack_trace->ToStackFrames({}, IsInternalFrameForFilename,
293                                    /*reverse_traversal=*/true,
294                                    /*limit=*/1);
295     if (!stack_frames.empty()) {
296       const StackFrame& last_frame = stack_frames[0];
297       return absl::StrCat(" (defined @ ", last_frame.file_name, ":",
298                           last_frame.line_number, ")");
299     }
300   }
301   return "";
302 }
303 
304 }  // end namespace tensorflow
305