1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 18 19 #include <memory> 20 #include <ostream> 21 #include <string> 22 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/shape_tree.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 29 #include "tensorflow/core/platform/types.h" 30 #include "tensorflow/stream_executor/device_memory_allocator.h" 31 32 namespace xla { 33 34 class ScopedShapedBuffer; 35 36 // Class which encapsulates a buffer or set of buffers containing data of a 37 // particular XLA shape. 38 class ShapedBuffer { 39 public: 40 // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The 41 // shape of the data on the host and the device may differ because the device 42 // may have a different representation for different data types. Therefore, 43 // both the on-host and on-device shape are required. The on-device shape 44 // determines the number of device allocations (DeviceMemoryBase) held by the 45 // ShapedBuffer. 46 ShapedBuffer(Shape on_device_shape, int device_ordinal); 47 48 // TODO(b/170310047): remove this overload. 49 ShapedBuffer(Shape on_host_shape, Shape on_device_shape, int device_ordinal); 50 51 // Movable, but not copyable. 52 ShapedBuffer(ShapedBuffer&& s); 53 ShapedBuffer& operator=(ShapedBuffer&&); 54 ShapedBuffer(const ShapedBuffer&) = delete; 55 ShapedBuffer& operator=(const ShapedBuffer&) = delete; 56 57 // Prevent (some forms of) accidental object slicing. 58 ShapedBuffer(const ScopedShapedBuffer&) = delete; 59 ShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; 60 61 virtual ~ShapedBuffer(); 62 63 // Returns the shape of the on-host representation of the data held by this 64 // ShapedBuffer. on_host_shape()65 const Shape& on_host_shape() const { return on_host_shape_; } 66 67 // Returns the shape of the on-device representation of the data held by this 68 // ShapedBuffer. on_device_shape()69 const Shape& on_device_shape() const { return on_device_shape_; } 70 device_ordinal()71 int device_ordinal() const { return device_ordinal_; } 72 73 // Return the root buffer of the shape (shape index {}). root_buffer()74 const se::DeviceMemoryBase& root_buffer() const { 75 return buffer(/*index=*/{}); 76 } 77 78 // Returns the buffer at the given shape index where index is defined as in 79 // ShapeUtil::GetSubshape. buffer(const ShapeIndex & index)80 const se::DeviceMemoryBase& buffer(const ShapeIndex& index) const { 81 return buffers_.element(index); 82 } 83 84 // Sets the device memory buffer at the given index. set_buffer(const se::DeviceMemoryBase & buffer,const ShapeIndex & index)85 void set_buffer(const se::DeviceMemoryBase& buffer, const ShapeIndex& index) { 86 *buffers_.mutable_element(index) = buffer; 87 } 88 89 // Sets all buffers. 90 // 91 // Precondition: buffers.shape == on_device_shape_ set_buffers(ShapeTree<se::DeviceMemoryBase> buffers)92 void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) { 93 CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_)); 94 buffers_ = std::move(buffers); 95 buffers_.replace_shape_ptr(&on_device_shape_); 96 } 97 98 // Reset the shape of this shaped buffer and underlying buffer structure. 99 // 100 // Precondition: EqualStructure(this->on_device_shape_, on_device_shape). set_shapes(const Shape & on_device_shape)101 void set_shapes(const Shape& on_device_shape) { 102 CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) 103 << "Structures are not the same. new: " << on_device_shape 104 << ", old: " << on_device_shape_; 105 on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape); 106 on_device_shape_ = on_device_shape; 107 buffers_.replace_shape_ptr(&on_device_shape_); 108 } 109 // TODO(b/170310047): remove this overload. set_shapes(const Shape & on_host_shape,const Shape & on_device_shape)110 void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { 111 set_shapes(on_device_shape); 112 } 113 114 // Returns the underlying ShapeTree containing all the device addresses in the 115 // ShapedBuffer. buffers()116 const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; } buffers()117 ShapeTree<se::DeviceMemoryBase>& buffers() { return buffers_; } 118 119 StatusOr<ShapedBuffer> SubShapedBuffer(const ShapeIndex& index) const; 120 121 // Set all device memory pointers in the object to null. 122 void clear(); 123 124 string ToString() const; 125 126 protected: 127 Shape on_host_shape_; 128 129 // The shape of the data on the device. 130 Shape on_device_shape_; 131 132 // The device the memory is allocated on. 133 int device_ordinal_; 134 135 // The tree of device buffers. Its shape is on_device_shape(). 136 ShapeTree<se::DeviceMemoryBase> buffers_; 137 }; 138 139 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); 140 141 // ScopedShapedBuffer takes allocated buffers as inputs, and deallocates on 142 // destruction. This class represents an owning wrapper around `ShapedBuffer`. 143 // 144 // TODO(timshen): Remove inheritance between ScopedShapedBuffer and 145 // ShapedBuffer. There should never be a need to consider a ScopedShapedBuffer 146 // as a ShapedBuffer, because in that case we should just be able to pass around 147 // our ShapeTree<DeviceMemoryBase>. Inheritance only adds complexity. See 148 // discussion in cl/192849370. 149 class ScopedShapedBuffer : public ShapedBuffer { 150 public: 151 // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. 152 explicit ScopedShapedBuffer(Shape on_device_shape, 153 se::DeviceMemoryAllocator* allocator, 154 int device_ordinal); 155 // TODO(b/170310047): remove this overload. 156 explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape, 157 se::DeviceMemoryAllocator* allocator, 158 int device_ordinal); 159 160 // Create a ScopedShapedBuffer by taking over the memory from the incoming 161 // ShapedBuffer. 162 explicit ScopedShapedBuffer(ShapedBuffer shaped_buffer, 163 se::DeviceMemoryAllocator* allocator); 164 165 // Movable, but not copyable. 166 ScopedShapedBuffer(ScopedShapedBuffer&& s); 167 ScopedShapedBuffer& operator=(ScopedShapedBuffer&&); 168 ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; 169 ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; 170 171 // All buffers in the shape are deallocated on destruction. 172 ~ScopedShapedBuffer() override; 173 174 // Return the allocator used to allocate the device memory held in this 175 // ScopedShapedBuffer. memory_allocator()176 se::DeviceMemoryAllocator* memory_allocator() const { return allocator_; } 177 178 // Sets the device memory buffer at the given index. 179 // 180 // If the given buffer's device memory is non-null, its device_ordinal and 181 // allocator must match those in `this`. set_buffer(se::OwningDeviceMemory buffer,const ShapeIndex & index)182 void set_buffer(se::OwningDeviceMemory buffer, const ShapeIndex& index) { 183 if (!buffer.is_null()) { 184 CHECK_EQ(buffer.device_ordinal(), device_ordinal()); 185 CHECK_EQ(buffer.allocator(), allocator_); 186 *buffers_.mutable_element(index) = buffer.Release(); 187 } else { 188 *buffers_.mutable_element(index) = se::DeviceMemoryBase(); 189 } 190 } 191 192 // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from 193 // this ScopedShapedBuffer, without freeing any of the associated memory. 194 // 195 // It's the caller's job to ensure that the memory contained therein is freed. 196 TF_MUST_USE_RESULT ShapedBuffer release(); 197 198 // Extracts the sub-tree rooted at 'index' and returns a ScopedShapedBuffer 199 // that holds ownership of the subtree. Sets the buffers corresponding to the 200 // subtree to null in 'this'. 201 ScopedShapedBuffer TakeSubTree(ShapeIndexView index); 202 203 protected: 204 void Deallocate(); 205 206 se::DeviceMemoryAllocator* allocator_; 207 }; 208 209 } // namespace xla 210 211 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 212