• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
23 
24 namespace xla {
25 namespace gpu {
26 
OutfeedThunk(ThunkInfo thunk_info,std::vector<ShapedSlice> source_slices)27 OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info,
28                            std::vector<ShapedSlice> source_slices)
29     : Thunk(Kind::kOutfeed, thunk_info),
30       source_slices_(std::move(source_slices)) {}
31 
ExecuteOnStream(const ExecuteParams & params)32 Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
33   se::Stream& stream = *params.stream;
34   const BufferAllocations& buffer_allocations = *params.buffer_allocations;
35 
36   VLOG(2) << "Outfeeding from GPU";
37 
38   OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(stream.parent());
39   ShapeTree<std::unique_ptr<OutfeedBuffer>>* output_buffers =
40       outfeed_manager->BlockingGetNextDestination();
41 
42   // Nothing to be done for an outfeed with no inputs.
43   // Note: Cannot do this before `BlockingGetNextDestination` above to dequeue
44   // an entry from the outfeed manager.
45   if (source_slices_.empty()) {
46     return OkStatus();
47   }
48 
49   const int64_t leaf_count = output_buffers->leaf_count();
50   TF_RET_CHECK(source_slices_.size() == leaf_count)
51       << "Mismatch between number of outfeed inputs (" << source_slices_.size()
52       << ") and outputs (" << leaf_count << ")";
53 
54   auto output_leaf_it = output_buffers->leaf_begin();
55   for (int64_t index = 0; index < leaf_count; ++index) {
56     // Assert that the shapes are compatible.
57     const ShapeIndex& shape_index = output_leaf_it->first;
58     std::unique_ptr<OutfeedBuffer>& buffer = output_leaf_it->second;
59 
60     // NOTE: This code needs deal with the `output_buffers` object getting
61     // deleted when its executing. Specifically, objects in the outfeed queue
62     // are pointers to instance of stack allocated objects in
63     // `GpuTransferManager::TransferLiteralFromOutfeed`. When all leaf node
64     // buffers are notified via "buffer->Done()" below in the stream host
65     // callback, `TransferLiteralFromOutfeed` deletes this stack allocated
66     // object when it returns. This means that its possible that during the last
67     // iteration, after the call to "buffer->Done()" is scheduled onto the
68     // stream, the `output_buffers` object might get deleted, so we should avoid
69     // accessing the object after that.
70     //
71     // To achieve that, increment the leaf iterator here before the last "Done"
72     // is enqueued, instead of in the loop increment, which would be after the
73     // "Done" is scheduled.
74     ++output_leaf_it;
75     const Shape& output_shape =
76         ShapeUtil::GetSubshape(output_buffers->shape(), shape_index);
77     TF_RET_CHECK(ShapeUtil::Equal(source_slices_[index].shape, output_shape))
78         << "Mismatch between outfeed output buffer shape "
79         << ShapeUtil::HumanStringWithLayout(output_shape)
80         << " and outfeed source buffer shape "
81         << ShapeUtil::HumanStringWithLayout(source_slices_[index].shape);
82 
83     BufferAllocation::Slice source_slice = source_slices_[index].slice;
84     if (!source_slice.allocation())
85       return InternalError("outfeed source missing buffer allocation");
86     se::DeviceMemoryBase data_address =
87         buffer_allocations.GetDeviceAddress(source_slice);
88 
89     // TODO(b/111309141): Run this on a separate stream so it doesn't block
90     // the GPU from doing work during the transfer.
91     stream
92         .ThenMemcpy(buffer->destination()->untyped_data(), data_address,
93                     buffer->length())
94         .ThenDoHostCallback([&buffer]() { buffer->Done(); });
95   }
96 
97   Status block_status = stream.BlockHostUntilDone();
98   if (!block_status.ok()) {
99     return InternalError("Failed to complete data transfer on stream %p: %s",
100                          &stream, block_status.error_message());
101   }
102 
103   VLOG(2) << "Outfeeding from GPU complete";
104   return OkStatus();
105 }
106 
107 }  // namespace gpu
108 }  // namespace xla
109