• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
18 
19 #include <memory>
20 #include <ostream>
21 #include <string>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/shape_tree.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/stream_executor/device_memory_allocator.h"
30 
31 namespace xla {
32 
33 class ScopedShapedBuffer;
34 
35 // Class which encapsulates a buffer or set of buffers containing data of a
36 // particular XLA shape.
37 class ShapedBuffer {
38  public:
39   // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The
40   // shape of the data on the host and the device may differ because the device
41   // may have a different representation for different data types. Therefore,
42   // both the on-host and on-device shape are required. The on-device shape
43   // determines the number of device allocations (DeviceMemoryBase) held by the
44   // ShapedBuffer.
45   ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
46                const se::Platform* platform, int device_ordinal);
47 
48   // Movable, but not copyable.
49   ShapedBuffer(ShapedBuffer&& s);
50   ShapedBuffer& operator=(ShapedBuffer&&);
51   ShapedBuffer(const ShapedBuffer&) = delete;
52   ShapedBuffer& operator=(const ShapedBuffer&) = delete;
53 
54   // Prevent (some forms of) accidental object slicing.
55   ShapedBuffer(const ScopedShapedBuffer&) = delete;
56   ShapedBuffer& operator=(const ScopedShapedBuffer&) = delete;
57 
58   virtual ~ShapedBuffer();
59 
60   // Returns the shape of the on-host representation of the data held by this
61   // ShapedBuffer.
on_host_shape()62   const Shape& on_host_shape() const { return on_host_shape_; }
63 
64   // Returns the shape of the on-device representation of the data held by this
65   // ShapedBuffer.
on_device_shape()66   const Shape& on_device_shape() const { return on_device_shape_; }
67 
platform()68   const se::Platform* platform() const { return platform_; }
device_ordinal()69   int device_ordinal() const { return device_ordinal_; }
70 
71   // Return the root buffer of the shape (shape index {}).
root_buffer()72   const se::DeviceMemoryBase& root_buffer() const {
73     return buffer(/*index=*/{});
74   }
75 
76   // Returns the buffer at the given shape index where index is defined as in
77   // ShapeUtil::GetSubshape.
buffer(const ShapeIndex & index)78   const se::DeviceMemoryBase& buffer(const ShapeIndex& index) const {
79     return buffers_.element(index);
80   }
81 
82   // Sets the device memory buffer at the given index.
set_buffer(const se::DeviceMemoryBase & buffer,const ShapeIndex & index)83   void set_buffer(const se::DeviceMemoryBase& buffer, const ShapeIndex& index) {
84     *buffers_.mutable_element(index) = buffer;
85   }
86 
87   // Sets all buffers.
88   //
89   // Precondition: buffers.shape == on_device_shape_
set_buffers(ShapeTree<se::DeviceMemoryBase> buffers)90   void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) {
91     CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_));
92     buffers_ = std::move(buffers);
93     buffers_.replace_shape_ptr(&on_device_shape_);
94   }
95 
96   // Returns the underlying ShapeTree containing all the device addresses in the
97   // ShapedBuffer.
buffers()98   const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; }
buffers()99   ShapeTree<se::DeviceMemoryBase>& buffers() { return buffers_; }
100 
101   StatusOr<ShapedBuffer> SubShapedBuffer(const ShapeIndex& index) const;
102 
103   // Set all device memory pointers in the object to null.
104   void clear();
105 
106   string ToString() const;
107 
108  protected:
109   // The shape of the data when represented on the host.
110   Shape on_host_shape_;
111 
112   // The shape of the data on the device.
113   Shape on_device_shape_;
114 
115   // The platform the memory is allocated on.
116   const se::Platform* platform_;
117 
118   // The device the memory is allocated on.
119   int device_ordinal_;
120 
121   // The tree of device buffers. Its shape is on_device_shape().
122   ShapeTree<se::DeviceMemoryBase> buffers_;
123 };
124 
125 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
126 
127 // ShapedBuffer derived class which allocates all internal buffers on
128 // construction and deallocates the memory when the object is
129 // destructed.
130 //
131 // TODO(timshen): Remove inheritance between ScopedShapedBuffer and
132 // ShapedBuffer.  There should never be a need to consider a ScopedShapedBuffer
133 // as a ShapedBuffer, because in that case we should just be able to pass around
134 // our ShapeTree<DeviceMemoryBase>.  Inheritance only adds complexity.  See
135 // discussion in cl/192849370.
136 class ScopedShapedBuffer : public ShapedBuffer {
137  public:
138   // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index.
139   explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape,
140                               se::DeviceMemoryAllocator* allocator,
141                               int device_ordinal);
142 
143   // Create a ScopedShapedBuffer by taking over the memory from the incoming
144   // ShapedBuffer.
145   explicit ScopedShapedBuffer(ShapedBuffer shaped_buffer,
146                               se::DeviceMemoryAllocator* allocator);
147 
148   // Movable, but not copyable.
149   ScopedShapedBuffer(ScopedShapedBuffer&& s);
150   ScopedShapedBuffer& operator=(ScopedShapedBuffer&&);
151   ScopedShapedBuffer(const ScopedShapedBuffer&) = delete;
152   ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete;
153 
154   // All buffers in the shape are deallocated on destruction.
155   ~ScopedShapedBuffer() override;
156 
157   // Return the allocator used to allocate the device memory held in this
158   // ScopedShapedBuffer.
memory_allocator()159   se::DeviceMemoryAllocator* memory_allocator() const { return allocator_; }
160 
161   // Sets the device memory buffer at the given index.
162   //
163   // If the given buffer's device memory is non-null, its device_ordinal and
164   // allocator must match those in `this`.
set_buffer(se::OwningDeviceMemory buffer,const ShapeIndex & index)165   void set_buffer(se::OwningDeviceMemory buffer, const ShapeIndex& index) {
166     if (!buffer.is_null()) {
167       CHECK_EQ(buffer.device_ordinal(), device_ordinal());
168       CHECK_EQ(buffer.allocator(), allocator_);
169       *buffers_.mutable_element(index) = buffer.Release();
170     } else {
171       *buffers_.mutable_element(index) = se::DeviceMemoryBase();
172     }
173   }
174 
175   // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from
176   // this ScopedShapedBuffer, without freeing any of the associated memory.
177   //
178   // It's the caller's job to ensure that the memory contained therein is freed.
179   TF_MUST_USE_RESULT ShapedBuffer release();
180 
181   // Extracts the sub-tree rooted at 'index' and returns a ScopedShapedBuffer
182   // that holds ownership of the subtree. Sets the buffers corresponding to the
183   // subtree to null in 'this'.
184   ScopedShapedBuffer TakeSubTree(ShapeIndexView index);
185 
186  protected:
187   void Deallocate();
188 
189   se::DeviceMemoryAllocator* allocator_;
190 };
191 
192 }  // namespace xla
193 
194 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
195