• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31 
32 namespace xla {
33 
ShapedBuffer(const Shape & on_host_shape,const Shape & on_device_shape,const se::Platform * platform,int device_ordinal)34 ShapedBuffer::ShapedBuffer(const Shape& on_host_shape,
35                            const Shape& on_device_shape,
36                            const se::Platform* platform, int device_ordinal)
37     : on_host_shape_(on_host_shape),
38       on_device_shape_(on_device_shape),
39       platform_(platform),
40       device_ordinal_(device_ordinal),
41       buffers_(&on_device_shape_) {}
42 
ShapedBuffer(ShapedBuffer && s)43 ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
44     : on_host_shape_(std::move(s.on_host_shape_)),
45       on_device_shape_(std::move(s.on_device_shape_)),
46       platform_(s.platform_),
47       device_ordinal_(s.device_ordinal_),
48       buffers_(std::move(s.buffers_)) {
49   // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
50   // into buffers_, we also need to update this pointer so that buffers_ doesn't
51   // point into s.
52   buffers_.replace_shape_ptr(&on_device_shape_);
53 }
54 
operator =(ShapedBuffer && s)55 ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
56   on_host_shape_ = std::move(s.on_host_shape_);
57   on_device_shape_ = std::move(s.on_device_shape_);
58   platform_ = s.platform_;
59   device_ordinal_ = s.device_ordinal_;
60   buffers_ = std::move(s.buffers_);
61   // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
62   // into buffers_, we also need to update this pointer so that buffers_ doesn't
63   // point into s.
64   buffers_.replace_shape_ptr(&on_device_shape_);
65   return *this;
66 }
67 
~ShapedBuffer()68 ShapedBuffer::~ShapedBuffer() {}
69 
clear()70 void ShapedBuffer::clear() {
71   for (auto& pair : buffers_) {
72     // A default constructed DeviceMemoryBase is a null pointer.
73     pair.second = se::DeviceMemoryBase();
74   }
75 }
76 
ToString() const77 string ShapedBuffer::ToString() const {
78   string s = absl::StrCat(
79       "ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
80       "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
81       ", on-device shape=" +
82           ShapeUtil::HumanStringWithLayout(on_device_shape()),
83       ":\n");
84   ShapeUtil::ForEachSubshape(
85       on_device_shape(),
86       [this, &s](const Shape& subshape, const ShapeIndex& index) {
87         string shape_str;
88         if (subshape.IsTuple()) {
89           shape_str = "tuple";
90         } else {
91           shape_str = ShapeUtil::HumanStringWithLayout(subshape);
92         }
93         const se::DeviceMemoryBase& memory = buffer(index);
94         absl::StrAppendFormat(&s, "  %s%p (%d bytes) : %s\n",
95                               string(index.size() * 2, ' '), memory.opaque(),
96                               memory.size(), shape_str);
97       });
98   return s;
99 }
100 
operator <<(std::ostream & out,const ShapedBuffer & buffer)101 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
102   out << buffer.ToString();
103   return out;
104 }
105 
ScopedShapedBuffer(const Shape & on_host_shape,const Shape & on_device_shape,DeviceMemoryAllocator * allocator,int device_ordinal)106 ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape,
107                                        const Shape& on_device_shape,
108                                        DeviceMemoryAllocator* allocator,
109                                        int device_ordinal)
110     : ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(),
111                    device_ordinal),
112       allocator_(allocator) {}
113 
ScopedShapedBuffer(ShapedBuffer shaped_buffer,DeviceMemoryAllocator * allocator)114 ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
115                                        DeviceMemoryAllocator* allocator)
116     : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
117 
ScopedShapedBuffer(ScopedShapedBuffer && s)118 ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s)
119     : ShapedBuffer(static_cast<ShapedBuffer&&>(s)), allocator_(s.allocator_) {
120   // Null out s.allocator_ so it doesn't try to free anything in its destructor.
121   s.allocator_ = nullptr;
122 }
123 
operator =(ScopedShapedBuffer && s)124 ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) {
125   Deallocate();
126 
127   *static_cast<ShapedBuffer*>(this) = std::move(static_cast<ShapedBuffer&>(s));
128   allocator_ = s.allocator_;
129   // Null out s.allocator_ so it doesn't try to free anything in its destructor.
130   s.allocator_ = nullptr;
131   return *this;
132 }
133 
~ScopedShapedBuffer()134 ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); }
135 
release()136 ShapedBuffer ScopedShapedBuffer::release() {
137   ShapedBuffer shaped_buffer(static_cast<ShapedBuffer&&>(*this));
138   buffers_ = ShapeTree<se::DeviceMemoryBase>();
139   return shaped_buffer;
140 }
141 
Deallocate()142 void ScopedShapedBuffer::Deallocate() {
143   // allocator_ will be null if we were moved-from.
144   if (allocator_ == nullptr) {
145     return;
146   }
147   // Deallocate all non-null buffers. A buffer may appear in more than one spot
148   // in the shape (eg, a tuple with a repeated element) so keep track of what
149   // has been deallocated.
150   absl::flat_hash_set<void*> deallocated_ptrs;
151   for (auto& pair : buffers_) {
152     se::DeviceMemoryBase& memory_base = pair.second;
153     if (!memory_base.is_null() &&
154         deallocated_ptrs.insert(memory_base.opaque()).second) {
155       TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base));
156     }
157   }
158 }
159 
TakeSubTree(ShapeIndexView index)160 ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) {
161   const xla::Shape& sub_on_host_shape =
162       xla::ShapeUtil::GetSubshape(on_host_shape(), {index});
163   const xla::Shape& sub_on_device_shape =
164       xla::ShapeUtil::GetSubshape(on_device_shape(), {index});
165 
166   ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape,
167                             memory_allocator(), device_ordinal());
168   auto src_it = buffers().find(index);
169   auto dst_it = output.buffers().begin();
170   while (dst_it != output.buffers().end()) {
171     dst_it->second = src_it->second;
172     src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0);
173     ++src_it;
174     ++dst_it;
175   }
176   return output;
177 }
178 
179 }  // namespace xla
180