• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/platform/logging.h"
31 
32 namespace xla {
33 
ShapedBuffer(Shape on_host_shape,Shape on_device_shape,const se::Platform * platform,int device_ordinal)34 ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
35                            const se::Platform* platform, int device_ordinal)
36     : on_host_shape_(std::move(on_host_shape)),
37       on_device_shape_(std::move(on_device_shape)),
38       platform_(platform),
39       device_ordinal_(device_ordinal),
40       buffers_(&on_device_shape_) {}
41 
ShapedBuffer(ShapedBuffer && s)42 ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
43     : on_host_shape_(std::move(s.on_host_shape_)),
44       on_device_shape_(std::move(s.on_device_shape_)),
45       platform_(s.platform_),
46       device_ordinal_(s.device_ordinal_),
47       buffers_(std::move(s.buffers_)) {
48   // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
49   // into buffers_, we also need to update this pointer so that buffers_ doesn't
50   // point into s.
51   buffers_.replace_shape_ptr(&on_device_shape_);
52 }
53 
operator =(ShapedBuffer && s)54 ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
55   on_host_shape_ = std::move(s.on_host_shape_);
56   on_device_shape_ = std::move(s.on_device_shape_);
57   platform_ = s.platform_;
58   device_ordinal_ = s.device_ordinal_;
59   buffers_ = std::move(s.buffers_);
60   // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
61   // into buffers_, we also need to update this pointer so that buffers_ doesn't
62   // point into s.
63   buffers_.replace_shape_ptr(&on_device_shape_);
64   return *this;
65 }
66 
~ShapedBuffer()67 ShapedBuffer::~ShapedBuffer() {}
68 
SubShapedBuffer(const ShapeIndex & index) const69 StatusOr<ShapedBuffer> ShapedBuffer::SubShapedBuffer(
70     const ShapeIndex& index) const {
71   TF_ASSIGN_OR_RETURN(const Shape* host_sub_shape,
72                       ShapeUtil::TryGetSubshape(on_host_shape(), index));
73   TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape,
74                       ShapeUtil::TryGetSubshape(on_device_shape(), index));
75   ShapedBuffer sub_shaped_buffer(*host_sub_shape, *device_sub_shape, platform_,
76                                  device_ordinal_);
77   TF_ASSIGN_OR_RETURN(ShapeTree<se::DeviceMemoryBase> sub_buffers,
78                       buffers_.SubShapeTree(index));
79   sub_shaped_buffer.set_buffers(std::move(sub_buffers));
80   return std::move(sub_shaped_buffer);
81 }
82 
clear()83 void ShapedBuffer::clear() {
84   for (auto& pair : buffers_) {
85     // A default constructed DeviceMemoryBase is a null pointer.
86     pair.second = se::DeviceMemoryBase();
87   }
88 }
89 
ToString() const90 string ShapedBuffer::ToString() const {
91   string s = absl::StrCat(
92       "ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
93       "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
94       ", on-device shape=" +
95           ShapeUtil::HumanStringWithLayout(on_device_shape()),
96       ":\n");
97   ShapeUtil::ForEachSubshape(
98       on_device_shape(),
99       [this, &s](const Shape& subshape, const ShapeIndex& index) {
100         string shape_str;
101         if (subshape.IsTuple()) {
102           shape_str = "tuple";
103         } else {
104           shape_str = ShapeUtil::HumanStringWithLayout(subshape);
105         }
106         const se::DeviceMemoryBase& memory = buffer(index);
107         absl::StrAppendFormat(&s, "  %s%p (%d bytes) : %s\n",
108                               string(index.size() * 2, ' '), memory.opaque(),
109                               memory.size(), shape_str);
110       });
111   return s;
112 }
113 
operator <<(std::ostream & out,const ShapedBuffer & buffer)114 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
115   out << buffer.ToString();
116   return out;
117 }
118 
ScopedShapedBuffer(Shape on_host_shape,Shape on_device_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal)119 ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
120                                        Shape on_device_shape,
121                                        se::DeviceMemoryAllocator* allocator,
122                                        int device_ordinal)
123     : ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape),
124                    allocator->platform(), device_ordinal),
125       allocator_(allocator) {}
126 
ScopedShapedBuffer(ShapedBuffer shaped_buffer,se::DeviceMemoryAllocator * allocator)127 ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
128                                        se::DeviceMemoryAllocator* allocator)
129     : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
130 
ScopedShapedBuffer(ScopedShapedBuffer && s)131 ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s)
132     : ShapedBuffer(static_cast<ShapedBuffer&&>(s)), allocator_(s.allocator_) {
133   // Null out s.allocator_ so it doesn't try to free anything in its destructor.
134   s.allocator_ = nullptr;
135 }
136 
operator =(ScopedShapedBuffer && s)137 ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) {
138   Deallocate();
139 
140   *static_cast<ShapedBuffer*>(this) = std::move(static_cast<ShapedBuffer&>(s));
141   allocator_ = s.allocator_;
142   // Null out s.allocator_ so it doesn't try to free anything in its destructor.
143   s.allocator_ = nullptr;
144   return *this;
145 }
146 
~ScopedShapedBuffer()147 ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); }
148 
release()149 ShapedBuffer ScopedShapedBuffer::release() {
150   ShapedBuffer shaped_buffer(static_cast<ShapedBuffer&&>(*this));
151   buffers_ = ShapeTree<se::DeviceMemoryBase>();
152   return shaped_buffer;
153 }
154 
Deallocate()155 void ScopedShapedBuffer::Deallocate() {
156   // allocator_ will be null if we were moved-from.
157   if (allocator_ == nullptr) {
158     return;
159   }
160   // Deallocate all non-null buffers. A buffer may appear in more than one spot
161   // in the shape (eg, a tuple with a repeated element) so keep track of what
162   // has been deallocated.
163   absl::flat_hash_set<void*> deallocated_ptrs;
164   for (auto& pair : buffers_) {
165     se::DeviceMemoryBase& memory_base = pair.second;
166     if (!memory_base.is_null() &&
167         deallocated_ptrs.insert(memory_base.opaque()).second) {
168       TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base));
169     }
170   }
171 }
172 
TakeSubTree(ShapeIndexView index)173 ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) {
174   const xla::Shape& sub_on_host_shape =
175       xla::ShapeUtil::GetSubshape(on_host_shape(), {index});
176   const xla::Shape& sub_on_device_shape =
177       xla::ShapeUtil::GetSubshape(on_device_shape(), {index});
178 
179   ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape,
180                             memory_allocator(), device_ordinal());
181   auto src_it = buffers().find(index);
182   auto dst_it = output.buffers().begin();
183   while (dst_it != output.buffers().end()) {
184     dst_it->second = src_it->second;
185     src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0);
186     ++src_it;
187     ++dst_it;
188   }
189   return output;
190 }
191 
192 }  // namespace xla
193