1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 18 19 #include <memory> 20 #include <ostream> 21 #include <string> 22 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 25 #include "tensorflow/compiler/xla/shape_tree.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace xla { 32 33 class ScopedShapedBuffer; 34 35 // Class which encapsulates a buffer or set of buffers containing data of a 36 // particular XLA shape. 37 class ShapedBuffer { 38 public: 39 // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The 40 // shape of the data on the host and the device may differ because the device 41 // may have a different representation for different data types. Therefore, 42 // both the on-host and on-device shape are required. The on-device shape 43 // determines the number of device allocations (DeviceMemoryBase) held by the 44 // ShapedBuffer. 45 ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, 46 const se::Platform* platform, int device_ordinal); 47 48 // Movable, but not copyable. 49 ShapedBuffer(ShapedBuffer&& s); 50 ShapedBuffer& operator=(ShapedBuffer&&); 51 ShapedBuffer(const ShapedBuffer&) = delete; 52 ShapedBuffer& operator=(const ShapedBuffer&) = delete; 53 54 // Prevent (some forms of) accidental object slicing. 55 ShapedBuffer(const ScopedShapedBuffer&) = delete; 56 ShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; 57 58 virtual ~ShapedBuffer(); 59 60 // Returns the shape of the on-host representation of the data held by this 61 // ShapedBuffer. on_host_shape()62 const Shape& on_host_shape() const { return on_host_shape_; } 63 64 // Returns the shape of the on-device representation of the data held by this 65 // ShapedBuffer. on_device_shape()66 const Shape& on_device_shape() const { return on_device_shape_; } 67 platform()68 const se::Platform* platform() const { return platform_; } device_ordinal()69 int device_ordinal() const { return device_ordinal_; } 70 71 // Return the root buffer of the shape (shape index {}). root_buffer()72 const se::DeviceMemoryBase& root_buffer() const { 73 return buffer(/*index=*/{}); 74 } 75 76 // Returns the buffer at the given shape index where index is defined as in 77 // ShapeUtil::GetSubshape. buffer(const ShapeIndex & index)78 const se::DeviceMemoryBase& buffer(const ShapeIndex& index) const { 79 return buffers_.element(index); 80 } 81 82 // Sets the device memory buffer at the given index. set_buffer(const se::DeviceMemoryBase & buffer,const ShapeIndex & index)83 void set_buffer(const se::DeviceMemoryBase& buffer, const ShapeIndex& index) { 84 *buffers_.mutable_element(index) = buffer; 85 } 86 87 // Sets all buffers. 88 // 89 // Precondition: buffers.shape == on_device_shape_ set_buffers(ShapeTree<se::DeviceMemoryBase> buffers)90 void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) { 91 CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_)); 92 buffers_ = std::move(buffers); 93 } 94 95 // Returns the underlying ShapeTree containing all the device addresses in the 96 // ShapedBuffer. buffers()97 const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; } buffers()98 ShapeTree<se::DeviceMemoryBase>& buffers() { return buffers_; } 99 100 // Set all device memory pointers in the object to null. 101 void clear(); 102 103 string ToString() const; 104 105 protected: 106 // The shape of the data when represented on the host. 107 Shape on_host_shape_; 108 109 // The shape of the data on the device. 110 Shape on_device_shape_; 111 112 // The platform the memory is allocated on. 113 const se::Platform* platform_; 114 115 // The device the memory is allocated on. 116 int device_ordinal_; 117 118 // The tree of device buffers. Its shape is on_device_shape(). 119 ShapeTree<se::DeviceMemoryBase> buffers_; 120 }; 121 122 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); 123 124 // ShapedBuffer derived class which allocates all internal buffers on 125 // construction and deallocates the memory when the object is 126 // destructed. 127 // 128 // TODO(timshen): Remove inheritance between ScopedShapedBuffer and 129 // ShapedBuffer. There should never be a need to consider a ScopedShapedBuffer 130 // as a ShapedBuffer, because in that case we should just be able to pass around 131 // our ShapeTree<DeviceMemoryBase>. Inheritance only adds complexity. See 132 // discussion in cl/192849370. 133 class ScopedShapedBuffer : public ShapedBuffer { 134 public: 135 // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. 136 explicit ScopedShapedBuffer(const Shape& on_host_shape, 137 const Shape& on_device_shape, 138 DeviceMemoryAllocator* allocator, 139 int device_ordinal); 140 141 // Create a ScopedShapedBuffer by taking over the memory from the incoming 142 // ShapedBuffer. 143 explicit ScopedShapedBuffer(ShapedBuffer shaped_buffer, 144 DeviceMemoryAllocator* allocator); 145 146 // Movable, but not copyable. 147 ScopedShapedBuffer(ScopedShapedBuffer&& s); 148 ScopedShapedBuffer& operator=(ScopedShapedBuffer&&); 149 ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; 150 ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; 151 152 // All buffers in the shape are deallocated on destruction. 153 ~ScopedShapedBuffer() override; 154 155 // Return the allocator used to allocate the device memory held in this 156 // ScopedShapedBuffer. memory_allocator()157 DeviceMemoryAllocator* memory_allocator() const { return allocator_; } 158 159 // Sets the device memory buffer at the given index. 160 // 161 // If the given buffer's device memory is non-null, its device_ordinal and 162 // allocator must match those in `this`. set_buffer(OwningDeviceMemory buffer,const ShapeIndex & index)163 void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) { 164 if (!buffer.is_null()) { 165 CHECK_EQ(buffer.device_ordinal(), device_ordinal()); 166 CHECK_EQ(buffer.allocator(), allocator_); 167 *buffers_.mutable_element(index) = buffer.Forget(); 168 } else { 169 *buffers_.mutable_element(index) = se::DeviceMemoryBase(); 170 } 171 } 172 173 // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from 174 // this ScopedShapedBuffer, without freeing any of the associated memory. 175 // 176 // It's the caller's job to ensure that the memory contained therein is freed. 177 TF_MUST_USE_RESULT ShapedBuffer release(); 178 179 // Extracts the sub-tree rooted at 'index' and returns a ScopedShapedBuffer 180 // that holds ownership of the subtree. Sets the buffers corresponding to the 181 // subtree to null in 'this'. 182 ScopedShapedBuffer TakeSubTree(ShapeIndexView index); 183 184 protected: 185 void Deallocate(); 186 187 DeviceMemoryAllocator* allocator_; 188 }; 189 190 } // namespace xla 191 192 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 193