1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31
32 namespace xla {
33
ShapedBuffer(const Shape & on_host_shape,const Shape & on_device_shape,const se::Platform * platform,int device_ordinal)34 ShapedBuffer::ShapedBuffer(const Shape& on_host_shape,
35 const Shape& on_device_shape,
36 const se::Platform* platform, int device_ordinal)
37 : on_host_shape_(on_host_shape),
38 on_device_shape_(on_device_shape),
39 platform_(platform),
40 device_ordinal_(device_ordinal),
41 buffers_(&on_device_shape_) {}
42
ShapedBuffer(ShapedBuffer && s)43 ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
44 : on_host_shape_(std::move(s.on_host_shape_)),
45 on_device_shape_(std::move(s.on_device_shape_)),
46 platform_(s.platform_),
47 device_ordinal_(s.device_ordinal_),
48 buffers_(std::move(s.buffers_)) {
49 // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
50 // into buffers_, we also need to update this pointer so that buffers_ doesn't
51 // point into s.
52 buffers_.replace_shape_ptr(&on_device_shape_);
53 }
54
operator =(ShapedBuffer && s)55 ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
56 on_host_shape_ = std::move(s.on_host_shape_);
57 on_device_shape_ = std::move(s.on_device_shape_);
58 platform_ = s.platform_;
59 device_ordinal_ = s.device_ordinal_;
60 buffers_ = std::move(s.buffers_);
61 // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
62 // into buffers_, we also need to update this pointer so that buffers_ doesn't
63 // point into s.
64 buffers_.replace_shape_ptr(&on_device_shape_);
65 return *this;
66 }
67
~ShapedBuffer()68 ShapedBuffer::~ShapedBuffer() {}
69
clear()70 void ShapedBuffer::clear() {
71 for (auto& pair : buffers_) {
72 // A default constructed DeviceMemoryBase is a null pointer.
73 pair.second = se::DeviceMemoryBase();
74 }
75 }
76
ToString() const77 string ShapedBuffer::ToString() const {
78 string s = absl::StrCat(
79 "ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
80 "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
81 ", on-device shape=" +
82 ShapeUtil::HumanStringWithLayout(on_device_shape()),
83 ":\n");
84 ShapeUtil::ForEachSubshape(
85 on_device_shape(),
86 [this, &s](const Shape& subshape, const ShapeIndex& index) {
87 string shape_str;
88 if (subshape.IsTuple()) {
89 shape_str = "tuple";
90 } else {
91 shape_str = ShapeUtil::HumanStringWithLayout(subshape);
92 }
93 const se::DeviceMemoryBase& memory = buffer(index);
94 absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n",
95 string(index.size() * 2, ' '), memory.opaque(),
96 memory.size(), shape_str);
97 });
98 return s;
99 }
100
operator <<(std::ostream & out,const ShapedBuffer & buffer)101 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
102 out << buffer.ToString();
103 return out;
104 }
105
ScopedShapedBuffer(const Shape & on_host_shape,const Shape & on_device_shape,DeviceMemoryAllocator * allocator,int device_ordinal)106 ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape,
107 const Shape& on_device_shape,
108 DeviceMemoryAllocator* allocator,
109 int device_ordinal)
110 : ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(),
111 device_ordinal),
112 allocator_(allocator) {}
113
ScopedShapedBuffer(ShapedBuffer shaped_buffer,DeviceMemoryAllocator * allocator)114 ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
115 DeviceMemoryAllocator* allocator)
116 : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
117
ScopedShapedBuffer(ScopedShapedBuffer && s)118 ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s)
119 : ShapedBuffer(static_cast<ShapedBuffer&&>(s)), allocator_(s.allocator_) {
120 // Null out s.allocator_ so it doesn't try to free anything in its destructor.
121 s.allocator_ = nullptr;
122 }
123
operator =(ScopedShapedBuffer && s)124 ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) {
125 Deallocate();
126
127 *static_cast<ShapedBuffer*>(this) = std::move(static_cast<ShapedBuffer&>(s));
128 allocator_ = s.allocator_;
129 // Null out s.allocator_ so it doesn't try to free anything in its destructor.
130 s.allocator_ = nullptr;
131 return *this;
132 }
133
~ScopedShapedBuffer()134 ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); }
135
release()136 ShapedBuffer ScopedShapedBuffer::release() {
137 ShapedBuffer shaped_buffer(static_cast<ShapedBuffer&&>(*this));
138 buffers_ = ShapeTree<se::DeviceMemoryBase>();
139 return shaped_buffer;
140 }
141
Deallocate()142 void ScopedShapedBuffer::Deallocate() {
143 // allocator_ will be null if we were moved-from.
144 if (allocator_ == nullptr) {
145 return;
146 }
147 // Deallocate all non-null buffers. A buffer may appear in more than one spot
148 // in the shape (eg, a tuple with a repeated element) so keep track of what
149 // has been deallocated.
150 absl::flat_hash_set<void*> deallocated_ptrs;
151 for (auto& pair : buffers_) {
152 se::DeviceMemoryBase& memory_base = pair.second;
153 if (!memory_base.is_null() &&
154 deallocated_ptrs.insert(memory_base.opaque()).second) {
155 TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base));
156 }
157 }
158 }
159
TakeSubTree(ShapeIndexView index)160 ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) {
161 const xla::Shape& sub_on_host_shape =
162 xla::ShapeUtil::GetSubshape(on_host_shape(), {index});
163 const xla::Shape& sub_on_device_shape =
164 xla::ShapeUtil::GetSubshape(on_device_shape(), {index});
165
166 ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape,
167 memory_allocator(), device_ordinal());
168 auto src_it = buffers().find(index);
169 auto dst_it = output.buffers().begin();
170 while (dst_it != output.buffers().end()) {
171 dst_it->second = src_it->second;
172 src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0);
173 ++src_it;
174 ++dst_it;
175 }
176 return output;
177 }
178
179 } // namespace xla
180