• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
18 
19 #include <memory>
20 #include <string>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24 
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_base.h"
29 #include "tensorflow/core/framework/resource_handle.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/type_index.h"
34 #include "tensorflow/core/framework/variant_tensor_data.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/thread_annotations.h"
41 
42 namespace tensorflow {
43 
44 // A ResourceMgr instance keeps track of named and typed resources
45 // grouped into containers.
46 //
47 // Each named resource is
48 // registered with ResourceMgr under a named "container" name. At any
49 // time, there is at most one instance of a resource given the container
50 // name, the resource type and the resource name.
51 //
52 // All resources for a given container can be dropped by one call of
53 // Cleanup().
54 //
55 // E.g.,
56 //   struct MyVar : public ResourceBase {
57 //     mutex mu;
58 //     Tensor val;
59 //   }
60 //
61 //   ResourceMgr rm;
62 //
63 //   // Create a var.
64 //   MyVar* my_var = new MyVar;
65 //   my_var->val = Tensor(DT_FLOAT, my_shape);
66 //   my_var->val.flat<float>().setZeros();   // 0 initialized.
67 //   ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
68 //
69 //   // += a variable.
70 //   MyVar* my_var = nullptr;
71 //   Status s = rm.Lookup("my_container", "my_name", &my_var);
72 //   if (s.ok()) {
73 //     my_var->val.flat<float>() += grad;
74 //   }
75 //   my_var->Unref();   // Or use ScopedUnref().
76 //   ctx->SetStatus(s);
77 
78 // Container used for per-step resources.
79 class ScopedStepContainer {
80  public:
81   // step_id: the unique ID of this step. Doesn't have to be sequential, just
82   // has to be unique.
83   // cleanup: callback to delete a container of this name.
84   // prefix: optional string prefix to disambiguate step containers.
ScopedStepContainer(const int64_t step_id,std::function<void (const string &)> cleanup)85   ScopedStepContainer(const int64_t step_id,
86                       std::function<void(const string&)> cleanup)
87       : step_id_(step_id),
88         container_(strings::StrCat("__per_step_", step_id)),
89         cleanup_(cleanup),
90         dirty_(false) {}
91 
ScopedStepContainer(const int64_t step_id,std::function<void (const string &)> cleanup,const std::string & prefix)92   ScopedStepContainer(const int64_t step_id,
93                       std::function<void(const string&)> cleanup,
94                       const std::string& prefix)
95       : step_id_(step_id),
96         container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
97         cleanup_(cleanup),
98         dirty_(false) {}
99 
~ScopedStepContainer()100   ~ScopedStepContainer() { CleanUp(); }
101 
CleanUp()102   void CleanUp() TF_NO_THREAD_SAFETY_ANALYSIS {
103     // NOTE(mrry): Avoid acquiring the mutex in the case that the container is
104     // clean.
105     if (dirty_) {
106       mutex_lock ml(mu_);
107       cleanup_(container_);
108       dirty_ = false;
109     }
110   }
111 
112   // Pass through functions for resource lookup and creation. We do this to
113   // ensure that we can appropriately set the dirty_ bit in the
114   // ScopedStepContainer if the name of the container is used to create
115   // resources.
116 
117   // Pass through to MakeResourceHandle with the container name
118   template <typename T>
119   ResourceHandle MakeResourceHandle(
120       const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
121   // Pass through to ResourceMgr::Create with the container name
122   template <typename T>
123   Status Create(ResourceMgr* rm, const std::string& name,
124                 T* resource) TF_MUST_USE_RESULT;
125   // Pass through to ResourceMgr::Delete with the container name
126   template <typename T>
127   Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT;
128   // Pass through to ResourceMgr::Lookup with the container name
129   template <typename T>
130   Status Lookup(ResourceMgr* rm, const std::string& name,
131                 T** resource) const TF_MUST_USE_RESULT;
132   // Pass through to ResourceMgr::LookupOrCreate with the container name
133   template <typename T>
134   Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource,
135                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
StepId()136   int64 StepId() const { return step_id_; }
137 
138  private:
139   const int64 step_id_;
140   const std::string container_;
141   const std::function<void(const string&)> cleanup_;
142   mutex mu_;
143   mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_);
144 };
145 
146 class ResourceMgr {
147  public:
148   ResourceMgr();
149   explicit ResourceMgr(const std::string& default_container);
150   ~ResourceMgr();
151 
152   // Returns the default container name for *this.
default_container()153   const std::string& default_container() const { return default_container_; }
154 
155   // Creates a resource "name" in the "container".  The caller transfers
156   // the ownership of one ref on "resource" to *this, regardless of whether this
157   // operation succeeds or fails.
158   //
159   // REQUIRES: std::is_base_of<ResourceBase, T>
160   // REQUIRES: resource != nullptr.
161   template <typename T>
162   Status Create(const std::string& container, const std::string& name,
163                 T* resource) TF_MUST_USE_RESULT;
164 
165   // If "container" has a resource "name", returns it in "*resource" and
166   // the caller takes the ownership of one ref on "*resource".
167   //
168   // REQUIRES: std::is_base_of<ResourceBase, T>
169   // REQUIRES: resource != nullptr
170   template <typename T, bool use_dynamic_cast = false>
171   Status Lookup(const std::string& container, const std::string& name,
172                 T** resource) const TF_MUST_USE_RESULT;
173 
174   // If the resource manager has a resource matching "handle", returns it in
175   // "*resource" and the caller takes the ownership of one ref on "*resource".
176   //
177   // REQUIRES: resource != nullptr
178   Status Lookup(const ResourceHandle& handle,
179                 ResourceBase** resource) const TF_MUST_USE_RESULT;
180 
181   // Similar to Lookup, but looks up multiple resources at once, with only a
182   // single lock acquisition.  If containers_and_names[i] is uninitialized
183   // then this function does not modify resources[i].
184   template <typename T, bool use_dynamic_cast = false>
185   Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
186                         containers_and_names,
187                     std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
188                         resources) const TF_MUST_USE_RESULT;
189 
190   // If "container" has a resource "name", returns it in
191   // "*resource". Otherwise, invokes creator() to create the resource.
192   // The caller takes the ownership of one ref on "*resource".
193   //
194   // WARNING: creator() must not call any methods on ResourceMgr during its
195   // execution, because a non-reentrant lock is held during the creator() call
196   // in order to guarantee atomicity of LookupOrCreate().
197   //
198   // REQUIRES: std::is_base_of<ResourceBase, T>
199   // REQUIRES: resource != nullptr
200   template <typename T, bool use_dynamic_cast = false>
201   Status LookupOrCreate(const std::string& container, const std::string& name,
202                         T** resource,
203                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
204 
205   // Deletes the resource "name" from the "container".
206   //
207   // REQUIRES: std::is_base_of<ResourceBase, T>
208   template <typename T>
209   Status Delete(const std::string& container,
210                 const std::string& name) TF_MUST_USE_RESULT;
211 
212   // Deletes the resource pointed by "handle".
213   Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
214 
215   // Deletes all resources from the "container" and removes the container.
216   Status Cleanup(const std::string& container) TF_MUST_USE_RESULT;
217 
218   // Deletes all resources in all containers.
219   void Clear();
220 
221   // Returns a text description for all resources.
222   std::string DebugString() const;
223 
224  private:
225   typedef std::pair<uint64, StringPiece> Key;
226   struct KeyHash {
operatorKeyHash227     std::size_t operator()(const Key& k) const {
228       return Hash64(k.second.data(), k.second.size(), k.first);
229     }
230   };
231   struct KeyEqual {
operatorKeyEqual232     bool operator()(const Key& x, const Key& y) const {
233       return (x.second == y.second) && (x.first == y.first);
234     }
235   };
236   struct ResourceAndName {
237     core::RefCountPtr<ResourceBase> resource;
238     std::unique_ptr<string> name;
239 
240     ResourceAndName();
241     ResourceAndName(ResourceBase* resource, std::string name);
242     ResourceAndName(ResourceAndName&& other) noexcept;
243     ~ResourceAndName();
244 
245     ResourceAndName& operator=(ResourceAndName&&) noexcept;
246 
247    private:
248     TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
249   };
250   typedef std::unordered_map<Key, ResourceAndName, KeyHash, KeyEqual> Container;
251 
252   const std::string default_container_;
253   mutable mutex mu_;
254   std::unordered_map<string, Container*> containers_ TF_GUARDED_BY(mu_);
255 
256   template <typename T, bool use_dynamic_cast = false>
257   Status LookupInternal(const std::string& container, const std::string& name,
258                         T** resource) const
259       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
260   Status LookupInternal(const std::string& container, uint64 type_hash_code,
261                         const std::string& name, ResourceBase** resource) const
262       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
263 
264   Status DoCreate(const std::string& container, TypeIndex type,
265                   const std::string& name, ResourceBase* resource)
266       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
267 
268   Status DoLookup(const std::string& container, TypeIndex type,
269                   const std::string& name, ResourceBase** resource) const
270       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
271   Status DoLookup(const std::string& container, uint64 type_hash_code,
272                   const std::string& type_name,
273                   const std::string& resource_name,
274                   ResourceBase** resource) const
275       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
276 
277   Status DoDelete(const std::string& container, uint64 type_hash_code,
278                   const std::string& resource_name,
279                   const std::string& type_name) TF_MUST_USE_RESULT;
280   Status DoDelete(const std::string& container, TypeIndex type,
281                   const std::string& resource_name) TF_MUST_USE_RESULT;
282 
283   // Inserts the type name for 'hash_code' into the hash_code to type name map.
284   Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name)
285       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
286 
287   // Returns the type name for the 'hash_code'.
288   // Returns "<unknown>" if a resource with such a type was never inserted into
289   // the container.
290   const char* DebugTypeName(uint64 hash_code) const
291       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
292 
293   // Map from type hash_code to type name.
294   std::unordered_map<uint64, string> debug_type_names_ TF_GUARDED_BY(mu_);
295 
296   TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
297 };
298 
299 // Makes a resource handle with the specified type for a given container /
300 // name.
301 ResourceHandle MakeResourceHandle(
302     const std::string& container, const std::string& name,
303     const DeviceBase& device, const TypeIndex& type_index,
304     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
305     const absl::optional<ManagedStackTrace>& definition_stack_trace = {})
306     TF_MUST_USE_RESULT;
307 
308 template <typename T>
309 ResourceHandle MakeResourceHandle(
310     OpKernelContext* ctx, const std::string& container, const std::string& name,
311     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
312     const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
313   return MakeResourceHandle(container.empty()
314                                 ? ctx->resource_manager()->default_container()
315                                 : container,
316                             name, *ctx->device(), TypeIndex::Make<T>(),
317                             dtypes_and_shapes, definition_stack_trace);
318 }
319 
320 template <typename T>
321 ResourceHandle MakeResourceHandle(
322     OpKernelConstruction* ctx, const std::string& container,
323     const std::string& name,
324     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
325     const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
326   return MakeResourceHandle(container.empty()
327                                 ? ctx->resource_manager()->default_container()
328                                 : container,
329                             name, *ctx->device(), TypeIndex::Make<T>(),
330                             dtypes_and_shapes, definition_stack_trace);
331 }
332 
333 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
334                                   const std::string& container,
335                                   const std::string& name,
336                                   const TypeIndex& type_index);
337 
338 // Returns a resource handle from a numbered op input.
339 const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
340 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
341                        ResourceHandle* handle);
342 
343 // Create a resource pointed by a given resource handle.
344 //
345 // If successful, the caller transfers the ownership of one ref on `resource` to
346 // `ctx->resource_mgr()`.
347 template <typename T>
348 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
349 
350 // Looks up a resource pointed by a given resource handle.
351 //
352 // If the lookup is successful, the caller takes the ownership of one ref on
353 // `*value`, and must call its `Unref()` method when it has finished using it.
354 template <typename T, bool use_dynamic_cast = false>
355 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
356 
357 // Looks up a resource pointed by a given resource handle.
358 //
359 // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
360 // requiring the caller to explicitly call `Unref()`.
361 template <typename T>
362 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
363                       core::RefCountPtr<T>* value);
364 
365 // Looks up multiple resources pointed by a sequence of resource handles.  If
366 // p[i] is uninitialized then values[i] is unmodified.
367 template <typename T>
368 Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
369                        std::vector<core::RefCountPtr<T>>* values);
370 
371 // Looks up or creates a resource.
372 //
373 // If successful, the caller takes the ownership of one ref on `*value`, and
374 // must call its `Unref()` method when it has finished using it. If the
375 // `creator` is invoked, its reference on the created resource is transferred
376 // to `ctx->resource_mgr()`.
377 //
378 // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
379 // requiring the caller to explicitly call `Unref()`.
380 template <typename T>
381 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
382                               T** value, std::function<Status(T**)> creator);
383 
384 // Looks up or creates a resource.
385 template <typename T>
386 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
387                               core::RefCountPtr<T>* value,
388                               std::function<Status(T**)> creator);
389 
390 // Destroys a resource pointed by a given resource handle.
391 template <typename T>
392 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
393 
394 // Same as above, but uses the hash code of the type directly.
395 // The type name information will be missing in the debug output when the
396 // resource is not present in the container.
397 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
398 
399 // Policy helper to decide which container/shared_name to use for a
400 // stateful kernel that accesses shared resource.
401 class ContainerInfo {
402  public:
403   // Analyze the node attribute of 'ndef' and decides the container and
404   // resource name the kernel should use for accessing the shared
405   // resource.
406   //
407   // 'ndef' is expected to have node attribute "container" and
408   // "shared_name". Returns non-OK if they are not provided or they are
409   // invalid.
410   //
411   // The policy is as following:
412   // * If the attribute "container" is non-empty, it is used as is.
413   //   Otherwise, uses the resource manager's default container.
414   // * If the attribute "shared_name" is non-empty, it is used as is.
415   //   Otherwise, if "use_node_name_as_default" is true, the kernel's
416   //   node name is used as the resource name. Otherwise, a string
417   //   unique to this process is used.
418   Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
419               bool use_node_name_as_default);
Init(ResourceMgr * rmgr,const NodeDef & ndef)420   Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
421     return Init(rmgr, ndef, false);
422   }
423 
424   // The policy decides that the kernel should access the resource in
425   // resource_manager(), the resource is in the container() and its
426   // name is name().  If resource_is_private_to_kernel() is true, the
427   // kernel should delete the resource when the kernel is deleted.
resource_manager()428   ResourceMgr* resource_manager() const { return rmgr_; }
container()429   const std::string& container() const { return container_; }
name()430   const std::string& name() const { return name_; }
resource_is_private_to_kernel()431   bool resource_is_private_to_kernel() const {
432     return resource_is_private_to_kernel_;
433   }
434 
435   // Returns a readable string for *this.
436   std::string DebugString() const;
437 
438  private:
439   ResourceMgr* rmgr_ = nullptr;
440   std::string container_;
441   std::string name_;
442   bool resource_is_private_to_kernel_ = false;
443 };
444 
445 // Helper for kernels to obtain 'resource' from the
446 // ctx->resource_manager().
447 //
448 // "input_name" specifies the kernel's ref input which gives a string
449 // tensor with two elements, which specifies the container and
450 // resource name.
451 //
452 // Returns OK if the resource is found and transfers one ref of
453 // *resource to the caller. Otherwise, returns an error.
454 template <typename T>
455 Status GetResourceFromContext(OpKernelContext* ctx,
456                               const std::string& input_name, T** resource);
457 
458 // Utility op kernel to check if a handle to resource type T is initialized.
459 template <typename T>
460 class IsResourceInitialized : public OpKernel {
461  public:
IsResourceInitialized(OpKernelConstruction * c)462   explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
463 
464   void Compute(OpKernelContext* ctx) override;
465 };
466 
467 // Registers an op which produces just a resource handle to a resource of the
468 // specified type. The type will be a part of the generated op name.
469 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
470 // through constant folding so this doesn't have to be marked as stateful.
471 #define REGISTER_RESOURCE_HANDLE_OP(Type) \
472   REGISTER_OP(#Type "HandleOp")           \
473       .Attr("container: string = ''")     \
474       .Attr("shared_name: string = ''")   \
475       .Output("resource: resource")       \
476       .SetIsStateful()                    \
477       .SetShapeFn(tensorflow::shape_inference::ScalarShape)
478 
479 // Utility op kernel to produce a handle to a resource of type T.
480 template <typename T>
481 class ResourceHandleOp : public OpKernel {
482  public:
483   explicit ResourceHandleOp(OpKernelConstruction* context);
484 
485   void Compute(OpKernelContext* ctx) override;
486 
IsExpensive()487   bool IsExpensive() override { return false; }
488 
489  private:
490   std::string container_;
491   std::string name_;
492   mutex mutex_;
493   Tensor resource_;
494   std::atomic<bool> initialized_{false};
495 };
496 
497 // Utility op kernel to produce a handle to a resource of type T.
498 template <typename T>
499 class ResourceHandlesOp : public OpKernel {
500  public:
501   explicit ResourceHandlesOp(OpKernelConstruction* context);
502 
503   void Compute(OpKernelContext* ctx) override;
504 
IsExpensive()505   bool IsExpensive() override { return false; }
506 
507  private:
508   std::vector<string> containers_;
509   std::vector<string> names_;
510   mutex mutex_;
511   std::vector<Tensor> resources_;
512   std::atomic<bool> initialized_{false};
513 };
514 
515 // Registers a kernel for an op which produces a handle to a resource of the
516 // specified type.
517 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type)                        \
518   REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
519                           ResourceHandleOp<Type>)
520 
521 // This class is used to guarantee that an anonymous resource is deleted
522 // (irrespective of whether a resource deleter op is called explicitly or
523 // the execution encounters an error before the op runs).
524 //
525 // This is achieved by wrapping an instance of this class into a variant
526 // tensor which is passed as an input to a resource deleter op. If the
527 // execution encounters an error before the op runs, the tensor will be
528 // destroyed, essentially triggering the iterator deletion.
529 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
530 // specification. In particular, we cannot serialize the `ResourceMgr`
531 // object, so the `Encode()` and `Decode()` methods are not implemented.
532 class ResourceDeleter {
533  public:
ResourceDeleter()534   ResourceDeleter() : deleter_() {}
535 
ResourceDeleter(ResourceHandle handle,ResourceMgr * resource_manager)536   ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager)
537       : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
538 
ResourceDeleter(ResourceDeleter && rhs)539   ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) {
540     VLOG(3) << "ResourceDeleter move constructor called.";
541   }
542 
ResourceDeleter(const ResourceDeleter & rhs)543   ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) {
544     VLOG(3) << "ResourceDeleter copy constructor called.";
545   }
546 
547   ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete;
548 
549   ResourceDeleter& operator=(ResourceDeleter&& rhs) = default;
550 
~ResourceDeleter()551   virtual ~ResourceDeleter() {
552     VLOG(3) << "ResourceDeleter destructor called.";
553   }
554 
Encode(VariantTensorData *)555   void Encode(VariantTensorData*) const {
556     LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter "
557                   "objects.";
558   }
559 
Decode(const VariantTensorData &)560   bool Decode(const VariantTensorData&) {
561     LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter "
562                   "objects";
563     return false;  // Not supported.
564   }
565 
566  private:
567   // Helper that performs reference counting for the parent class and deletes
568   // the iterator resource when the refcount goes to zero.
569   //
570   // NOTE: The object is borrowing a pointer to the resource manager.
571   // Consequently, the tensor containing this object should not escape the
572   // function in which was created (so that it is guaranteed that the resource
573   // manager will outlive it).
574   struct Helper {
HelperHelper575     Helper(ResourceHandle handle, ResourceMgr* resource_manager)
576         : handle(handle), resource_manager(resource_manager) {}
577 
578     Helper(const Helper& rhs) = delete;
579     Helper(Helper&& rhs) = delete;
580 
~HelperHelper581     ~Helper() {
582       VLOG(3) << "Deleting Resource: " << handle.DebugString();
583       resource_manager->Delete(handle).IgnoreError();
584     }
585 
586     ResourceHandle handle;
587     ResourceMgr* resource_manager;  // not owned
588   };
589 
590   std::shared_ptr<Helper> deleter_;
591 };
592 
593 // Implementation details below.
594 
595 template <typename T>
CheckDeriveFromResourceBase()596 void CheckDeriveFromResourceBase() {
597   static_assert(std::is_base_of<ResourceBase, T>::value,
598                 "T must derive from ResourceBase");
599 }
600 
601 template <typename T>
Create(const std::string & container,const std::string & name,T * resource)602 Status ResourceMgr::Create(const std::string& container,
603                            const std::string& name, T* resource) {
604   CheckDeriveFromResourceBase<T>();
605   CHECK(resource != nullptr);
606   mutex_lock l(mu_);
607   return DoCreate(container, TypeIndex::Make<T>(), name, resource);
608 }
609 
610 template <typename T, bool use_dynamic_cast>
Lookup(const std::string & container,const std::string & name,T ** resource)611 Status ResourceMgr::Lookup(const std::string& container,
612                            const std::string& name, T** resource) const {
613   CheckDeriveFromResourceBase<T>();
614   tf_shared_lock l(mu_);
615   return LookupInternal<T, use_dynamic_cast>(container, name, resource);
616 }
617 
618 template <typename T, bool use_dynamic_cast>
LookupMany(absl::Span<std::pair<const string *,const string * > const> containers_and_names,std::vector<std::unique_ptr<T,core::RefCountDeleter>> * resources)619 Status ResourceMgr::LookupMany(
620     absl::Span<std::pair<const string*, const string*> const>
621         containers_and_names,
622     std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
623   CheckDeriveFromResourceBase<T>();
624   tf_shared_lock l(mu_);
625   resources->resize(containers_and_names.size());
626   for (size_t i = 0; i < containers_and_names.size(); ++i) {
627     T* resource;
628     Status s = LookupInternal<T, use_dynamic_cast>(
629         *containers_and_names[i].first, *containers_and_names[i].second,
630         &resource);
631     if (s.ok()) {
632       (*resources)[i].reset(resource);
633     }
634   }
635   return Status::OK();
636 }
637 
638 // Simple wrapper to allow conditional dynamic / static casts.
639 template <typename T, bool use_dynamic_cast>
640 struct TypeCastFunctor {
CastTypeCastFunctor641   static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
642 };
643 
644 template <typename T>
645 struct TypeCastFunctor<T, true> {
646   static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
647 };
648 
649 template <typename T, bool use_dynamic_cast>
650 Status ResourceMgr::LookupInternal(const std::string& container,
651                                    const std::string& name,
652                                    T** resource) const {
653   ResourceBase* found = nullptr;
654   Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
655   if (s.ok()) {
656     // It's safe to down cast 'found' to T* since
657     // typeid(T).hash_code() is part of the map key.
658     *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
659   }
660   return s;
661 }
662 
663 template <typename T, bool use_dynamic_cast>
664 Status ResourceMgr::LookupOrCreate(const std::string& container,
665                                    const std::string& name, T** resource,
666                                    std::function<Status(T**)> creator) {
667   CheckDeriveFromResourceBase<T>();
668   *resource = nullptr;
669   Status s;
670   {
671     tf_shared_lock l(mu_);
672     s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
673     if (s.ok()) return s;
674   }
675   mutex_lock l(mu_);
676   s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
677   if (s.ok()) return s;
678   TF_RETURN_IF_ERROR(creator(resource));
679   s = DoCreate(container, TypeIndex::Make<T>(), name, *resource);
680   if (!s.ok()) {
681     return errors::Internal("LookupOrCreate failed unexpectedly");
682   }
683   (*resource)->Ref();
684   return s;
685 }
686 
687 template <typename T>
688 Status ResourceMgr::Delete(const std::string& container,
689                            const std::string& name) {
690   CheckDeriveFromResourceBase<T>();
691   return DoDelete(container, TypeIndex::Make<T>(), name);
692 }
693 
694 template <typename T>
695 Status GetResourceFromContext(OpKernelContext* ctx,
696                               const std::string& input_name, T** resource) {
697   DataType dtype;
698   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
699   if (dtype == DT_RESOURCE) {
700     const Tensor* handle;
701     TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
702     return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
703   }
704   std::string container;
705   std::string shared_name;
706   {
707     mutex* mu;
708     TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
709     mutex_lock l(*mu);
710     Tensor tensor;
711     TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
712     if (tensor.NumElements() != 2) {
713       return errors::InvalidArgument(
714           "Resource handle must have 2 elements, but had shape: ",
715           tensor.shape().DebugString());
716     }
717     container = tensor.flat<tstring>()(0);
718     shared_name = tensor.flat<tstring>()(1);
719   }
720   return ctx->resource_manager()->Lookup(container, shared_name, resource);
721 }
722 
723 namespace internal {
724 
725 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
726 
727 template <typename T>
728 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
729   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
730   TF_RETURN_IF_ERROR(p.ValidateType<T>());
731   return Status::OK();
732 }
733 
734 }  // namespace internal
735 
736 // Creates the resource pointed at by "p". The caller transfers the ownership of
737 // one ref on "*value" to the resource manager in "ctx", regardless of whether
738 // this operation succeeds or fails.
739 template <typename T>
740 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
741   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
742   return ctx->resource_manager()->Create(p.container(), p.name(), value);
743 }
744 
745 // If the resource manager in "ctx" has a resource matching "p", returns it in
746 // "*value" and the caller takes the ownership of one ref on "*value"
747 template <typename T, bool use_dynamic_cast>
748 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
749                       T** value) {
750   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
751   return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
752                                                               p.name(), value);
753 }
754 
755 // If the resource manager in "ctx" has a resource matching "p", returns it in
756 // "*value" and the caller takes the ownership of one ref on "*value"
757 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
758                       ResourceBase** value);
759 
760 // If the resource manager in "ctx" has a resource matching "p", returns it in
761 // "*value".
762 template <typename T>
763 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
764                       core::RefCountPtr<T>* value) {
765   T* raw_ptr = nullptr;
766   TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
767   value->reset(raw_ptr);
768 
769   return Status::OK();
770 }
771 
772 // Similar to Lookup, but looks up multiple resources at once, with only a
773 // single lock acquisition.
774 template <typename T>
775 Status LookupResources(OpKernelContext* ctx,
776                        absl::Span<ResourceHandle const* const> p,
777                        std::vector<core::RefCountPtr<T>>* values) {
778   std::vector<std::pair<const string*, const string*>> containers_and_names(
779       p.size());
780   for (size_t i = 0; i < p.size(); ++i) {
781     TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
782     containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
783   }
784   return ctx->resource_manager()->LookupMany(containers_and_names, values);
785 }
786 
787 // If the resource manager in "ctx" has a resource pointed at by "p", returns
788 // it in "*value". Otherwise, invokes creator() to create the resource.
789 // The caller takes the ownership of one ref on "*value".
790 //
791 // WARNING: creator() must not call any methods on the resource manager during
792 // its execution, because a non-reentrant lock is held during the creator() call
793 // in order to guarantee atomicity of LookupOrCreateResource().
794 template <typename T>
795 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
796                               T** value, std::function<Status(T**)> creator) {
797   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
798   return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
799                                                  creator);
800 }
801 
802 // If the resource manager in "ctx" has a resource pointed at by "p", returns
803 // it in "*value". Otherwise, invokes creator() to create the resource.
804 //
805 // WARNING: creator() must not call any methods on the resource manager during
806 // its execution, because a non-reentrant lock is held during the creator() call
807 // in order to guarantee atomicity of LookupOrCreateResource().
808 template <typename T>
809 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
810                               core::RefCountPtr<T>* value,
811                               std::function<Status(T**)> creator) {
812   T* raw_ptr = nullptr;
813   TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
814   value->reset(raw_ptr);
815 
816   return Status::OK();
817 }
818 
819 // Deletes the resource pointed by "p", using the resource manager in "ctx".
820 template <typename T>
821 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
822   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
823   return ctx->resource_manager()->Delete<T>(p.container(), p.name());
824 }
825 
826 // Deletes the resource pointed by "p", using the resource manager in "ctx".
827 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
828 
829 template <typename T>
830 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
831   Tensor* output;
832   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
833   T* object;
834   bool found;
835   if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
836     found = true;
837     object->Unref();
838   } else {
839     found = false;
840   }
841 
842   output->flat<bool>()(0) = found;
843 }
844 
845 template <typename T>
846 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
847     : OpKernel(context) {
848   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
849   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
850 }
851 
852 template <typename T>
853 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
854   if (name_ == ResourceHandle::ANONYMOUS_NAME) {
855     AllocatorAttributes attr;
856     attr.set_on_host(true);
857     Tensor handle;
858     OP_REQUIRES_OK(
859         ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
860     handle.scalar<ResourceHandle>()() =
861         MakeResourceHandle<T>(ctx, container_, name_);
862     ctx->set_output(0, handle);
863   } else {
864     if (!initialized_.load()) {
865       mutex_lock ml(mutex_);
866       // Checking again to see if another thread has initialized the resource.
867       if (!initialized_.load()) {
868         AllocatorAttributes attr;
869         attr.set_on_host(true);
870         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
871                                                &resource_, attr));
872         resource_.scalar<ResourceHandle>()() =
873             MakeResourceHandle<T>(ctx, container_, name_);
874         initialized_.store(true);
875       }
876     }
877     ctx->set_output(0, resource_);
878   }
879 }
880 
881 template <typename T>
882 ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
883     : OpKernel(context) {
884   int n;
885   OP_REQUIRES_OK(context, context->GetAttr("N", &n));
886   OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
887   OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
888   OP_REQUIRES(
889       context, containers_.size() == n,
890       errors::InvalidArgument("Number of containers (", containers_.size(),
891                               ") must be equal to N (", n, ")"));
892   OP_REQUIRES(context, names_.size() == n,
893               errors::InvalidArgument("Number of names (", containers_.size(),
894                                       ") must be equal to N (", n, ")"));
895   resources_.resize(n);
896 }
897 
898 template <typename T>
899 void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
900   if (!initialized_.load()) {
901     mutex_lock ml(mutex_);
902     // Checking again to see if another thread has initialized the resource.
903     if (!initialized_.load()) {
904       AllocatorAttributes attr;
905       attr.set_on_host(true);
906       for (size_t i = 0; i < resources_.size(); ++i) {
907         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
908                                                &resources_[i], attr));
909         ResourceHandle h =
910             MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
911         resources_[i].template scalar<ResourceHandle>()() = h;
912       }
913       initialized_.store(true);
914     }
915   }
916   for (size_t i = 0; i < resources_.size(); ++i) {
917     ctx->set_output(i, resources_[i]);
918   }
919 }
920 
921 template <typename T>
922 ResourceHandle ScopedStepContainer::MakeResourceHandle(
923     const std::string& name, const DeviceBase& device) {
924   mutex_lock ml(mu_);
925   dirty_ = true;
926   return tensorflow::MakeResourceHandle(container_, name, device,
927                                         TypeIndex::Make<T>(), {});
928 }
929 
930 template <typename T>
931 Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name,
932                                    T** resource) const {
933   return rm->Lookup<T>(container_, name, resource);
934 }
935 
936 template <typename T>
937 Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm,
938                                            const std::string& name,
939                                            T** resource,
940                                            std::function<Status(T**)> creator) {
941   mutex_lock ml(mu_);
942   dirty_ = true;
943   return rm->LookupOrCreate<T>(container_, name, resource, creator);
944 }
945 
946 template <typename T>
947 Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name,
948                                    T* resource) {
949   mutex_lock ml(mu_);
950   dirty_ = true;
951   return rm->Create<T>(container_, name, resource);
952 }
953 
954 template <typename T>
955 Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) {
956   return rm->Delete<T>(container_, name);
957 }
958 
959 }  //  end namespace tensorflow
960 
961 #endif  // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
962