1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31
32 namespace xla {
33
ShapedBuffer(Shape on_host_shape,Shape on_device_shape,const se::Platform * platform,int device_ordinal)34 ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
35 const se::Platform* platform, int device_ordinal)
36 : on_host_shape_(std::move(on_host_shape)),
37 on_device_shape_(std::move(on_device_shape)),
38 platform_(platform),
39 device_ordinal_(device_ordinal),
40 buffers_(&on_device_shape_) {}
41
ShapedBuffer(ShapedBuffer && s)42 ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
43 : on_host_shape_(std::move(s.on_host_shape_)),
44 on_device_shape_(std::move(s.on_device_shape_)),
45 platform_(s.platform_),
46 device_ordinal_(s.device_ordinal_),
47 buffers_(std::move(s.buffers_)) {
48 // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
49 // into buffers_, we also need to update this pointer so that buffers_ doesn't
50 // point into s.
51 buffers_.replace_shape_ptr(&on_device_shape_);
52 }
53
operator =(ShapedBuffer && s)54 ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
55 on_host_shape_ = std::move(s.on_host_shape_);
56 on_device_shape_ = std::move(s.on_device_shape_);
57 platform_ = s.platform_;
58 device_ordinal_ = s.device_ordinal_;
59 buffers_ = std::move(s.buffers_);
60 // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
61 // into buffers_, we also need to update this pointer so that buffers_ doesn't
62 // point into s.
63 buffers_.replace_shape_ptr(&on_device_shape_);
64 return *this;
65 }
66
~ShapedBuffer()67 ShapedBuffer::~ShapedBuffer() {}
68
SubShapedBuffer(const ShapeIndex & index) const69 StatusOr<ShapedBuffer> ShapedBuffer::SubShapedBuffer(
70 const ShapeIndex& index) const {
71 TF_ASSIGN_OR_RETURN(const Shape* host_sub_shape,
72 ShapeUtil::TryGetSubshape(on_host_shape(), index));
73 TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape,
74 ShapeUtil::TryGetSubshape(on_device_shape(), index));
75 ShapedBuffer sub_shaped_buffer(*host_sub_shape, *device_sub_shape, platform_,
76 device_ordinal_);
77 TF_ASSIGN_OR_RETURN(ShapeTree<se::DeviceMemoryBase> sub_buffers,
78 buffers_.SubShapeTree(index));
79 sub_shaped_buffer.set_buffers(std::move(sub_buffers));
80 return std::move(sub_shaped_buffer);
81 }
82
clear()83 void ShapedBuffer::clear() {
84 for (auto& pair : buffers_) {
85 // A default constructed DeviceMemoryBase is a null pointer.
86 pair.second = se::DeviceMemoryBase();
87 }
88 }
89
ToString() const90 string ShapedBuffer::ToString() const {
91 string s = absl::StrCat(
92 "ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
93 "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
94 ", on-device shape=" +
95 ShapeUtil::HumanStringWithLayout(on_device_shape()),
96 ":\n");
97 ShapeUtil::ForEachSubshape(
98 on_device_shape(),
99 [this, &s](const Shape& subshape, const ShapeIndex& index) {
100 string shape_str;
101 if (subshape.IsTuple()) {
102 shape_str = "tuple";
103 } else {
104 shape_str = ShapeUtil::HumanStringWithLayout(subshape);
105 }
106 const se::DeviceMemoryBase& memory = buffer(index);
107 absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n",
108 string(index.size() * 2, ' '), memory.opaque(),
109 memory.size(), shape_str);
110 });
111 return s;
112 }
113
operator <<(std::ostream & out,const ShapedBuffer & buffer)114 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
115 out << buffer.ToString();
116 return out;
117 }
118
ScopedShapedBuffer(Shape on_host_shape,Shape on_device_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal)119 ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
120 Shape on_device_shape,
121 se::DeviceMemoryAllocator* allocator,
122 int device_ordinal)
123 : ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape),
124 allocator->platform(), device_ordinal),
125 allocator_(allocator) {}
126
ScopedShapedBuffer(ShapedBuffer shaped_buffer,se::DeviceMemoryAllocator * allocator)127 ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
128 se::DeviceMemoryAllocator* allocator)
129 : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
130
ScopedShapedBuffer(ScopedShapedBuffer && s)131 ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s)
132 : ShapedBuffer(static_cast<ShapedBuffer&&>(s)), allocator_(s.allocator_) {
133 // Null out s.allocator_ so it doesn't try to free anything in its destructor.
134 s.allocator_ = nullptr;
135 }
136
operator =(ScopedShapedBuffer && s)137 ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) {
138 Deallocate();
139
140 *static_cast<ShapedBuffer*>(this) = std::move(static_cast<ShapedBuffer&>(s));
141 allocator_ = s.allocator_;
142 // Null out s.allocator_ so it doesn't try to free anything in its destructor.
143 s.allocator_ = nullptr;
144 return *this;
145 }
146
~ScopedShapedBuffer()147 ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); }
148
release()149 ShapedBuffer ScopedShapedBuffer::release() {
150 ShapedBuffer shaped_buffer(static_cast<ShapedBuffer&&>(*this));
151 buffers_ = ShapeTree<se::DeviceMemoryBase>();
152 return shaped_buffer;
153 }
154
Deallocate()155 void ScopedShapedBuffer::Deallocate() {
156 // allocator_ will be null if we were moved-from.
157 if (allocator_ == nullptr) {
158 return;
159 }
160 // Deallocate all non-null buffers. A buffer may appear in more than one spot
161 // in the shape (eg, a tuple with a repeated element) so keep track of what
162 // has been deallocated.
163 absl::flat_hash_set<void*> deallocated_ptrs;
164 for (auto& pair : buffers_) {
165 se::DeviceMemoryBase& memory_base = pair.second;
166 if (!memory_base.is_null() &&
167 deallocated_ptrs.insert(memory_base.opaque()).second) {
168 TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base));
169 }
170 }
171 }
172
TakeSubTree(ShapeIndexView index)173 ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) {
174 const xla::Shape& sub_on_host_shape =
175 xla::ShapeUtil::GetSubshape(on_host_shape(), {index});
176 const xla::Shape& sub_on_device_shape =
177 xla::ShapeUtil::GetSubshape(on_device_shape(), {index});
178
179 ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape,
180 memory_allocator(), device_ordinal());
181 auto src_it = buffers().find(index);
182 auto dst_it = output.buffers().begin();
183 while (dst_it != output.buffers().end()) {
184 dst_it->second = src_it->second;
185 src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0);
186 ++src_it;
187 ++dst_it;
188 }
189 return output;
190 }
191
192 } // namespace xla
193