• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/compiler/xla/service/backend.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 
37 // Tracks allocations for the XLA service; allocations can be registered
38 // with shape/device/tag and resolved from a handle for later use.
39 class AllocationTracker {
40  public:
41   // The allocator is used for deallocating memory when allocations are
42   // deregistered. All registered allocations must have the same platform as the
43   // allocator.
AllocationTracker(Backend * backend)44   AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {}
45 
46   // Registers a shaped buffer of device memory, and returns a corresponding
47   // handle that can be used for talking to XLA clients. The given shaped buffer
48   // will be treated as the buffer corresponding to the only replica.
49   StatusOr<GlobalDataHandle> Register(ScopedShapedBuffer shaped_buffer,
50                                       const string& tag);
51 
52   // Registers a vector of shaped buffers of device memory, one per replica, and
53   // returns a corresponding handle that can be used for talking to XLA clients.
54   StatusOr<GlobalDataHandle> RegisterReplicatedBuffers(
55       std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag);
56 
57   // Unregister the allocation for the given data handle.
58   Status Unregister(const GlobalDataHandle& data);
59 
60   // Returns a vector of global data handles that point to the tuple elements.
61   StatusOr<std::vector<GlobalDataHandle>> DeconstructTuple(
62       const GlobalDataHandle& Data);
63 
64   // Resolve a handle from an XLA client to a vector of shaped buffers, one per
65   // replica, or provide an error status to say whether any of those buffers
66   // were not found (or found, but found deallocated).
67   StatusOr<std::vector<const ShapedBuffer*>> Resolve(
68       const GlobalDataHandle& data) const;
69 
70   // Resolves a handle from an XLA client and replica id to a shaped buffer, or
71   // provide an error status to say whether it was not found (or found, but
72   // found deallocated).
73   StatusOr<const ShapedBuffer*> ResolveForReplica(const GlobalDataHandle& data,
74                                                   int replica_id) const;
75 
76  private:
77   // Data structure encapsulating single memory allocation on the device.
78   struct Allocation {
79     // The pointer to this allocation.
80     se::OwningDeviceMemory device_memory;
81 
82     // This is the number of times this memory allocation is referred to by
83     // registered data handles.
84     int ref_count;
85   };
86 
87   // Internal helper which resolves the given GlobalDataHandle to a
88   // list of ScopedShapedBuffers.
89   StatusOr<std::vector<const ShapedBuffer*>> ResolveInternal(
90       const GlobalDataHandle& data) const TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
91 
92   // Internal helper which registers a vector of shaped buffers, one per
93   // replica.  ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer.  If
94   // it's ShapedBuffer, all of the given buffers must already be tracked by this
95   // object -- presumably this is a call from DeconstructTuple.
96   template <typename ShapedBufferTy>
97   StatusOr<GlobalDataHandle> RegisterInternal(
98       std::vector<ShapedBufferTy> replicated_buffers, const string& tag)
99       TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
100 
101   // Adds the given device address to the allocation tracker, or if it already
102   // exists, then increment its reference count.
103   void AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,
104                                         int device_ordinal)
105       TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
106 
107   // Decrements the reference count of the given device memory. Then, if it is
108   // zero, deallocate the memory.
109   Status DecrementRefCount(se::DeviceMemoryBase device_memory,
110                            int device_ordinal)
111       TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
112 
113   // A map from device memory opaque value to allocation. One such map is
114   // maintained per device ordinal.
115   using AllocationMap = absl::flat_hash_map<const void*, Allocation>;
116 
117   mutable tensorflow::mutex mutex_;
118 
119   // Backend to use with this tracker. The backend supplies the memory allocator
120   // to use when deallocating memory.
121   Backend* backend_;
122 
123   // The next handle to assign to an allocation, guarded by the same mutex as
124   // the mapping as they'll be mutated at the same time.
125   int64 next_handle_ TF_GUARDED_BY(mutex_);
126 
127   // A map from device ordinal to AllocationMap.
128   absl::flat_hash_map<int, AllocationMap> opaque_to_allocation_map_
129       TF_GUARDED_BY(mutex_);
130 
131   // A map from data handle to a vector of shaped buffers that represent the
132   // buffers for different replicas.
133   //
134   // The ShapedBuffers in this map's vectors need to be unique_ptrs, because our
135   // public API returns pointers to them.  We expect the concrete class to be
136   // ShapedBuffer and never ScopedShapedBuffer; deallocation of buffers is
137   // handled by opaque_to_allocation_map_.
138   //
139   // The elements of the vectors need to be unique_ptrs because we return
140   // pointers to them.  (In theory we could use std::list or something instead,
141   // but we also want to be able to null out these elements.)
142   //
143   // The reason that the elements can't be unique_ptr<ScopedShapedBuffer>s is
144   // the existence of DeconstructTuple().  This function allows us to create a
145   // non-owning "view" into a tuple's sub-buffers.  The sub-buffers are then
146   // free'd when both the view *and* the original tuple are Unregistered.  This
147   // refcounting is managed in opaque_to_allocation_map_.
148   absl::flat_hash_map<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
149       handle_to_shaped_buffers_ TF_GUARDED_BY(mutex_);
150 
151   TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
152 };
153 
154 }  // namespace xla
155 
156 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
157