• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
18 
19 #include <memory>
20 #include <string>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24 
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/resource_handle.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/type_index.h"
32 #include "tensorflow/core/framework/variant_tensor_data.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/refcount.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/thread_annotations.h"
41 
42 namespace tensorflow {
43 
44 // A ResourceMgr instance keeps track of named and typed resources
45 // grouped into containers.
46 //
47 // Each resource must be represented as a sub-class of ResourceBase,
48 // which is reference counted explicitly.  Each named resource is
49 // registered with ResourceMgr under a named "container" name. At any
50 // time, there is at most one instance of a resource given the container
51 // name, the resource type and the resource name.
52 //
53 // All resources for a given container can be dropped by one call of
54 // Cleanup().
55 //
56 // E.g.,
57 //   struct MyVar : public ResourceBase {
58 //     mutex mu;
59 //     Tensor val;
60 //   }
61 //
62 //   ResourceMgr rm;
63 //
64 //   // Create a var.
65 //   MyVar* my_var = new MyVar;
66 //   my_var->val = Tensor(DT_FLOAT, my_shape);
67 //   my_var->val.flat<float>().setZeros();   // 0 initialized.
68 //   ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
69 //
70 //   // += a variable.
71 //   MyVar* my_var = nullptr;
72 //   Status s = rm.Lookup("my_container", "my_name", &my_var);
73 //   if (s.ok()) {
74 //     my_var->val.flat<float>() += grad;
75 //   }
76 //   my_var->Unref();   // Or use ScopedUnref().
77 //   ctx->SetStatus(s);
78 class ResourceBase : public core::RefCounted {
79  public:
80   // Returns a debug string for *this.
81   virtual string DebugString() const = 0;
82 
83   // Returns memory used by this resource.
MemoryUsed()84   virtual int64 MemoryUsed() const { return 0; }
85 };
86 
87 // Container used for per-step resources.
88 class ScopedStepContainer {
89  public:
90   // step_id: the unique ID of this step. Doesn't have to be sequential, just
91   // has to be unique.
92   // cleanup: callback to delete a container of this name.
93   // prefix: optional string prefix to disambiguate step containers.
ScopedStepContainer(const int64 step_id,std::function<void (const string &)> cleanup)94   ScopedStepContainer(const int64 step_id,
95                       std::function<void(const string&)> cleanup)
96       : container_(strings::StrCat("__per_step_", step_id)),
97         step_id_(step_id),
98         cleanup_(cleanup),
99         dirty_(false) {}
100 
ScopedStepContainer(const int64 step_id,std::function<void (const string &)> cleanup,const string & prefix)101   ScopedStepContainer(const int64 step_id,
102                       std::function<void(const string&)> cleanup,
103                       const string& prefix)
104       : container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
105         step_id_(step_id),
106         cleanup_(cleanup),
107         dirty_(false) {}
108 
~ScopedStepContainer()109   ~ScopedStepContainer() { CleanUp(); }
110 
CleanUp()111   void CleanUp() NO_THREAD_SAFETY_ANALYSIS {
112     // NOTE(mrry): Avoid acquiring the mutex in the case that the container is
113     // clean.
114     if (dirty_) {
115       mutex_lock ml(mu_);
116       cleanup_(container_);
117       dirty_ = false;
118     }
119   }
120 
121   // Pass through functions for resource lookup and creation. We do this to
122   // ensure that we can appropriately set the dirty_ bit in the
123   // ScopedStepContainer if the name of the container is used to create
124   // resources.
125 
126   // Pass through to MakeResourceHandle with the container name
127   template <typename T>
128   ResourceHandle MakeResourceHandle(
129       const string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
130   // Pass through to ResourceMgr::Create with the container name
131   template <typename T>
132   Status Create(ResourceMgr* rm, const string& name,
133                 T* resource) TF_MUST_USE_RESULT;
134   // Pass through to ResourceMgr::Delete with the container name
135   template <typename T>
136   Status Delete(ResourceMgr* rm, const string& name) TF_MUST_USE_RESULT;
137   // Pass through to ResourceMgr::Lookup with the container name
138   template <typename T>
139   Status Lookup(ResourceMgr* rm, const string& name,
140                 T** resource) const TF_MUST_USE_RESULT;
141   // Pass through to ResourceMgr::LookupOrCreate with the container name
142   template <typename T>
143   Status LookupOrCreate(ResourceMgr* rm, const string& name, T** resource,
144                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
145 
step_id()146   const int64 step_id() const { return step_id_; }
147 
148  private:
149   const string container_;
150   const int64 step_id_;
151   const std::function<void(const string&)> cleanup_;
152   mutex mu_;
153   mutable std::atomic<bool> dirty_ GUARDED_BY(mu_);
154 };
155 
156 class ResourceMgr {
157  public:
158   ResourceMgr();
159   explicit ResourceMgr(const string& default_container);
160   ~ResourceMgr();
161 
162   // Returns the default container name for *this.
default_container()163   const string& default_container() const { return default_container_; }
164 
165   // Creates a resource "name" in the "container".  The caller transfers
166   // the ownership of one ref on "resource" to *this, regardless of whether this
167   // operation succeeds or fails.
168   //
169   // REQUIRES: std::is_base_of<ResourceBase, T>
170   // REQUIRES: resource != nullptr.
171   template <typename T>
172   Status Create(const string& container, const string& name,
173                 T* resource) TF_MUST_USE_RESULT;
174 
175   // If "container" has a resource "name", returns it in "*resource" and
176   // the caller takes the ownership of one ref on "*resource".
177   //
178   // REQUIRES: std::is_base_of<ResourceBase, T>
179   // REQUIRES: resource != nullptr
180   template <typename T, bool use_dynamic_cast = false>
181   Status Lookup(const string& container, const string& name,
182                 T** resource) const TF_MUST_USE_RESULT;
183 
184   // Similar to Lookup, but looks up multiple resources at once, with only a
185   // single lock acquisition.  If containers_and_names[i] is uninitialized
186   // then this function does not modify resources[i].
187   template <typename T, bool use_dynamic_cast = false>
188   Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
189                         containers_and_names,
190                     std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
191                         resources) const TF_MUST_USE_RESULT;
192 
193   // If "container" has a resource "name", returns it in
194   // "*resource". Otherwise, invokes creator() to create the resource.
195   // The caller takes the ownership of one ref on "*resource".
196   //
197   // WARNING: creator() must not call any methods on ResourceMgr during its
198   // execution, because a non-reentrant lock is held during the creator() call
199   // in order to guarantee atomicity of LookupOrCreate().
200   //
201   // REQUIRES: std::is_base_of<ResourceBase, T>
202   // REQUIRES: resource != nullptr
203   template <typename T, bool use_dynamic_cast = false>
204   Status LookupOrCreate(const string& container, const string& name,
205                         T** resource,
206                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
207 
208   // Deletes the resource "name" from the "container".
209   //
210   // REQUIRES: std::is_base_of<ResourceBase, T>
211   template <typename T>
212   Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT;
213 
214   // Deletes the resource pointed by "handle".
215   Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
216 
217   // Deletes all resources from the "container" and removes the container.
218   Status Cleanup(const string& container) TF_MUST_USE_RESULT;
219 
220   // Deletes all resources in all containers.
221   void Clear();
222 
223   // Returns a text description for all resources.
224   string DebugString() const;
225 
226  private:
227   typedef std::pair<uint64, StringPiece> Key;
228   struct KeyHash {
operatorKeyHash229     std::size_t operator()(const Key& k) const {
230       return Hash64(k.second.data(), k.second.size(), k.first);
231     }
232   };
233   struct KeyEqual {
operatorKeyEqual234     bool operator()(const Key& x, const Key& y) const {
235       return (x.second == y.second) && (x.first == y.first);
236     }
237   };
238   struct ResourceAndName {
239     core::RefCountPtr<ResourceBase> resource;
240     std::unique_ptr<string> name;
241 
242     ResourceAndName();
243     ResourceAndName(ResourceBase* resource, string name);
244     ResourceAndName(ResourceAndName&& other) noexcept;
245     ~ResourceAndName();
246 
247     ResourceAndName& operator=(ResourceAndName&&) noexcept;
248 
249    private:
250     TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
251   };
252   typedef std::unordered_map<Key, ResourceAndName, KeyHash, KeyEqual> Container;
253 
254   const string default_container_;
255   mutable mutex mu_;
256   std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
257 
258   template <typename T, bool use_dynamic_cast = false>
259   Status LookupInternal(const string& container, const string& name,
260                         T** resource) const
261       SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
262 
263   Status DoCreate(const string& container, TypeIndex type, const string& name,
264                   ResourceBase* resource)
265       EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
266 
267   Status DoLookup(const string& container, TypeIndex type, const string& name,
268                   ResourceBase** resource) const
269       SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
270 
271   Status DoDelete(const string& container, uint64 type_hash_code,
272                   const string& resource_name,
273                   const string& type_name) TF_MUST_USE_RESULT;
274   Status DoDelete(const string& container, TypeIndex type,
275                   const string& resource_name) TF_MUST_USE_RESULT;
276 
277   // Inserts the type name for 'hash_code' into the hash_code to type name map.
278   Status InsertDebugTypeName(uint64 hash_code, const string& type_name)
279       EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
280 
281   // Returns the type name for the 'hash_code'.
282   // Returns "<unknown>" if a resource with such a type was never inserted into
283   // the container.
284   const char* DebugTypeName(uint64 hash_code) const
285       EXCLUSIVE_LOCKS_REQUIRED(mu_);
286 
287   // Map from type hash_code to type name.
288   std::unordered_map<uint64, string> debug_type_names_ GUARDED_BY(mu_);
289 
290   TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
291 };
292 
293 // Makes a resource handle with the specified type for a given container /
294 // name.
295 ResourceHandle MakeResourceHandle(
296     const string& container, const string& name, const DeviceBase& device,
297     const TypeIndex& type_index,
298     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
299     TF_MUST_USE_RESULT;
300 
301 template <typename T>
302 ResourceHandle MakeResourceHandle(
303     OpKernelContext* ctx, const string& container, const string& name,
304     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
305   return MakeResourceHandle(
306       container.empty() ? ctx->resource_manager()->default_container()
307                         : container,
308       name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
309 }
310 
311 template <typename T>
312 ResourceHandle MakeResourceHandle(
313     OpKernelConstruction* ctx, const string& container, const string& name,
314     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
315   return MakeResourceHandle(
316       container.empty() ? ctx->resource_manager()->default_container()
317                         : container,
318       name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
319 }
320 
321 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
322                                   const string& container, const string& name,
323                                   const TypeIndex& type_index);
324 
325 // Returns a resource handle from a numbered op input.
326 const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
327 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
328                        ResourceHandle* handle);
329 
330 // Create a resource pointed by a given resource handle.
331 //
332 // If successful, the caller transfers the ownership of one ref on `resource` to
333 // `ctx->resource_mgr()`.
334 template <typename T>
335 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
336 
337 // Looks up a resource pointed by a given resource handle.
338 //
339 // If the lookup is successful, the caller takes the ownership of one ref on
340 // `*value`, and must call its `Unref()` method when it has finished using it.
341 template <typename T, bool use_dynamic_cast = false>
342 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
343 
344 // Looks up a resource pointed by a given resource handle.
345 //
346 // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
347 // requiring the caller to explicitly call `Unref()`.
348 template <typename T>
349 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
350                       core::RefCountPtr<T>* value);
351 
352 // Looks up multiple resources pointed by a sequence of resource handles.  If
353 // p[i] is uninitialized then values[i] is unmodified.
354 template <typename T>
355 Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
356                        std::vector<core::RefCountPtr<T>>* values);
357 
358 // Looks up or creates a resource.
359 //
360 // If successful, the caller takes the ownership of one ref on `*value`, and
361 // must call its `Unref()` method when it has finished using it. If the
362 // `creator` is invoked, its reference on the created resource is transferred
363 // to `ctx->resource_mgr()`.
364 //
365 // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
366 // requiring the caller to explicitly call `Unref()`.
367 template <typename T>
368 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
369                               T** value, std::function<Status(T**)> creator);
370 
371 // Looks up or creates a resource.
372 template <typename T>
373 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
374                               core::RefCountPtr<T>* value,
375                               std::function<Status(T**)> creator);
376 
377 // Destroys a resource pointed by a given resource handle.
378 template <typename T>
379 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
380 
381 // Same as above, but uses the hash code of the type directly.
382 // The type name information will be missing in the debug output when the
383 // resource is not present in the container.
384 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
385 
386 // Policy helper to decide which container/shared_name to use for a
387 // stateful kernel that accesses shared resource.
388 class ContainerInfo {
389  public:
390   // Analyze the node attribute of 'ndef' and decides the container and
391   // resource name the kernel should use for accessing the shared
392   // resource.
393   //
394   // 'ndef' is expected to have node attribute "container" and
395   // "shared_name". Returns non-OK if they are not provided or they are
396   // invalid.
397   //
398   // The policy is as following:
399   // * If the attribute "container" is non-empty, it is used as is.
400   //   Otherwise, uses the resource manager's default container.
401   // * If the attribute "shared_name" is non-empty, it is used as is.
402   //   Otherwise, if "use_node_name_as_default" is true, the kernel's
403   //   node name is used as the resource name. Otherwise, a string
404   //   unique to this process is used.
405   Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
406               bool use_node_name_as_default);
Init(ResourceMgr * rmgr,const NodeDef & ndef)407   Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
408     return Init(rmgr, ndef, false);
409   }
410 
411   // The policy decides that the kernel should access the resource in
412   // resource_manager(), the resource is in the container() and its
413   // name is name().  If resource_is_private_to_kernel() is true, the
414   // kernel should delete the resource when the kernel is deleted.
resource_manager()415   ResourceMgr* resource_manager() const { return rmgr_; }
container()416   const string& container() const { return container_; }
name()417   const string& name() const { return name_; }
resource_is_private_to_kernel()418   bool resource_is_private_to_kernel() const {
419     return resource_is_private_to_kernel_;
420   }
421 
422   // Returns a readable string for *this.
423   string DebugString() const;
424 
425  private:
426   ResourceMgr* rmgr_ = nullptr;
427   string container_;
428   string name_;
429   bool resource_is_private_to_kernel_ = false;
430 };
431 
432 // Helper for kernels to obtain 'resource' from the
433 // ctx->resource_manager().
434 //
435 // "input_name" specifies the kernel's ref input which gives a string
436 // tensor with two elements, which specifies the container and
437 // resource name.
438 //
439 // Returns OK if the resource is found and transfers one ref of
440 // *resource to the caller. Otherwise, returns an error.
441 template <typename T>
442 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
443                               T** resource);
444 
445 // Utility op kernel to check if a handle to resource type T is initialized.
446 template <typename T>
447 class IsResourceInitialized : public OpKernel {
448  public:
IsResourceInitialized(OpKernelConstruction * c)449   explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
450 
451   void Compute(OpKernelContext* ctx) override;
452 };
453 
454 // Registers an op which produces just a resource handle to a resource of the
455 // specified type. The type will be a part of the generated op name.
456 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
457 // through constant folding so this doesn't have to be marked as stateful.
458 #define REGISTER_RESOURCE_HANDLE_OP(Type) \
459   REGISTER_OP(#Type "HandleOp")           \
460       .Attr("container: string = ''")     \
461       .Attr("shared_name: string = ''")   \
462       .Output("resource: resource")       \
463       .SetIsStateful()                    \
464       .SetShapeFn(tensorflow::shape_inference::ScalarShape)
465 
466 // Utility op kernel to produce a handle to a resource of type T.
467 template <typename T>
468 class ResourceHandleOp : public OpKernel {
469  public:
470   explicit ResourceHandleOp(OpKernelConstruction* context);
471 
472   void Compute(OpKernelContext* ctx) override;
473 
IsExpensive()474   bool IsExpensive() override { return false; }
475 
476  private:
477   string container_;
478   string name_;
479   mutex mutex_;
480   Tensor resource_;
481   std::atomic<bool> initialized_{false};
482 };
483 
484 // Utility op kernel to produce a handle to a resource of type T.
485 template <typename T>
486 class ResourceHandlesOp : public OpKernel {
487  public:
488   explicit ResourceHandlesOp(OpKernelConstruction* context);
489 
490   void Compute(OpKernelContext* ctx) override;
491 
IsExpensive()492   bool IsExpensive() override { return false; }
493 
494  private:
495   std::vector<string> containers_;
496   std::vector<string> names_;
497   mutex mutex_;
498   std::vector<Tensor> resources_;
499   std::atomic<bool> initialized_{false};
500 };
501 
502 Status ResourceHandlesShape(shape_inference::InferenceContext* c);
503 
504 // Registers a kernel for an op which produces a handle to a resource of the
505 // specified type.
506 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type)                        \
507   REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
508                           ResourceHandleOp<Type>)
509 
510 // This class is used to guarantee that an anonymous resource is deleted
511 // (irrespective of whether a resource deleter op is called explicitly or
512 // the execution encounters an error before the op runs).
513 //
514 // This is achieved by wrapping an instance of this class into a variant
515 // tensor which is passed as an input to a resource deleter op. If the
516 // execution encounters an error before the op runs, the tensor will be
517 // destroyed, essentially triggering the iterator deletion.
518 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
519 // specification. In particular, we cannot serialize the `ResourceMgr`
520 // object, so the `Encode()` and `Decode()` methods are not implemented.
521 class ResourceDeleter {
522  public:
ResourceDeleter()523   ResourceDeleter() : deleter_() {}
524 
ResourceDeleter(ResourceHandle handle,ResourceMgr * resource_manager)525   ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager)
526       : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
527 
ResourceDeleter(ResourceDeleter && rhs)528   ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) {
529     VLOG(3) << "ResourceDeleter move constructor called.";
530   }
531 
ResourceDeleter(const ResourceDeleter & rhs)532   ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) {
533     VLOG(3) << "ResourceDeleter copy constructor called.";
534   }
535 
536   ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete;
537 
538   ResourceDeleter& operator=(ResourceDeleter&& rhs) = default;
539 
~ResourceDeleter()540   virtual ~ResourceDeleter() {
541     VLOG(3) << "ResourceDeleter destructor called.";
542   }
543 
Encode(VariantTensorData *)544   void Encode(VariantTensorData*) const {
545     LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter "
546                   "objects.";
547   }
548 
Decode(const VariantTensorData &)549   bool Decode(const VariantTensorData&) {
550     LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter "
551                   "objects";
552     return false;  // Not supported.
553   }
554 
555  private:
556   // Helper that performs reference counting for the parent class and deletes
557   // the iterator resource when the refcount goes to zero.
558   //
559   // NOTE: The object is borrowing a pointer to the resource manager.
560   // Consequently, the tensor containing this object should not escape the
561   // function in which was created (so that it is guaranteed that the resource
562   // manager will outlive it).
563   struct Helper {
HelperHelper564     Helper(ResourceHandle handle, ResourceMgr* resource_manager)
565         : handle(handle), resource_manager(resource_manager) {}
566 
567     Helper(const Helper& rhs) = delete;
568     Helper(Helper&& rhs) = delete;
569 
~HelperHelper570     ~Helper() {
571       VLOG(3) << "Deleting Resource: " << handle.DebugString();
572       resource_manager->Delete(handle).IgnoreError();
573     }
574 
575     ResourceHandle handle;
576     ResourceMgr* resource_manager;  // not owned
577   };
578 
579   std::shared_ptr<Helper> deleter_;
580 };
581 
582 // Implementation details below.
583 
584 template <typename T>
CheckDeriveFromResourceBase()585 void CheckDeriveFromResourceBase() {
586   static_assert(std::is_base_of<ResourceBase, T>::value,
587                 "T must derive from ResourceBase");
588 }
589 
590 template <typename T>
Create(const string & container,const string & name,T * resource)591 Status ResourceMgr::Create(const string& container, const string& name,
592                            T* resource) {
593   CheckDeriveFromResourceBase<T>();
594   CHECK(resource != nullptr);
595   mutex_lock l(mu_);
596   return DoCreate(container, MakeTypeIndex<T>(), name, resource);
597 }
598 
599 template <typename T, bool use_dynamic_cast>
Lookup(const string & container,const string & name,T ** resource)600 Status ResourceMgr::Lookup(const string& container, const string& name,
601                            T** resource) const {
602   CheckDeriveFromResourceBase<T>();
603   tf_shared_lock l(mu_);
604   return LookupInternal<T, use_dynamic_cast>(container, name, resource);
605 }
606 
607 template <typename T, bool use_dynamic_cast>
LookupMany(absl::Span<std::pair<const string *,const string * > const> containers_and_names,std::vector<std::unique_ptr<T,core::RefCountDeleter>> * resources)608 Status ResourceMgr::LookupMany(
609     absl::Span<std::pair<const string*, const string*> const>
610         containers_and_names,
611     std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
612   CheckDeriveFromResourceBase<T>();
613   tf_shared_lock l(mu_);
614   resources->resize(containers_and_names.size());
615   for (size_t i = 0; i < containers_and_names.size(); ++i) {
616     T* resource;
617     Status s = LookupInternal<T, use_dynamic_cast>(
618         *containers_and_names[i].first, *containers_and_names[i].second,
619         &resource);
620     if (s.ok()) {
621       (*resources)[i].reset(resource);
622     }
623   }
624   return Status::OK();
625 }
626 
627 // Simple wrapper to allow conditional dynamic / static casts.
628 template <typename T, bool use_dynamic_cast>
629 struct TypeCastFunctor {
CastTypeCastFunctor630   static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
631 };
632 
633 template <typename T>
634 struct TypeCastFunctor<T, true> {
635   static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
636 };
637 
638 template <typename T, bool use_dynamic_cast>
639 Status ResourceMgr::LookupInternal(const string& container, const string& name,
640                                    T** resource) const {
641   ResourceBase* found = nullptr;
642   Status s = DoLookup(container, MakeTypeIndex<T>(), name, &found);
643   if (s.ok()) {
644     // It's safe to down cast 'found' to T* since
645     // typeid(T).hash_code() is part of the map key.
646     *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
647   }
648   return s;
649 }
650 
651 template <typename T, bool use_dynamic_cast>
652 Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
653                                    T** resource,
654                                    std::function<Status(T**)> creator) {
655   CheckDeriveFromResourceBase<T>();
656   *resource = nullptr;
657   Status s;
658   {
659     tf_shared_lock l(mu_);
660     s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
661     if (s.ok()) return s;
662   }
663   mutex_lock l(mu_);
664   s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
665   if (s.ok()) return s;
666   TF_RETURN_IF_ERROR(creator(resource));
667   s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
668   if (!s.ok()) {
669     return errors::Internal("LookupOrCreate failed unexpectedly");
670   }
671   (*resource)->Ref();
672   return s;
673 }
674 
675 template <typename T>
676 Status ResourceMgr::Delete(const string& container, const string& name) {
677   CheckDeriveFromResourceBase<T>();
678   return DoDelete(container, MakeTypeIndex<T>(), name);
679 }
680 
681 template <typename T>
682 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
683                               T** resource) {
684   DataType dtype;
685   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
686   if (dtype == DT_RESOURCE) {
687     const Tensor* handle;
688     TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
689     return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
690   }
691   string container;
692   string shared_name;
693   {
694     mutex* mu;
695     TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
696     mutex_lock l(*mu);
697     Tensor tensor;
698     TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
699     if (tensor.NumElements() != 2) {
700       return errors::InvalidArgument(
701           "Resource handle must have 2 elements, but had shape: ",
702           tensor.shape().DebugString());
703     }
704     container = tensor.flat<tstring>()(0);
705     shared_name = tensor.flat<tstring>()(1);
706   }
707   return ctx->resource_manager()->Lookup(container, shared_name, resource);
708 }
709 
710 namespace internal {
711 
712 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
713 
714 template <typename T>
715 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
716   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
717   auto type_index = MakeTypeIndex<T>();
718   if (type_index.hash_code() != p.hash_code()) {
719     return errors::InvalidArgument(
720         "Trying to access resource using the wrong type. Expected ",
721         p.maybe_type_name(), " got ", type_index.name());
722   }
723   return Status::OK();
724 }
725 
726 }  // namespace internal
727 
728 template <typename T>
729 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
730   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
731   return ctx->resource_manager()->Create(p.container(), p.name(), value);
732 }
733 
734 template <typename T, bool use_dynamic_cast>
735 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
736                       T** value) {
737   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
738   return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
739                                                               p.name(), value);
740 }
741 
742 template <typename T>
743 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
744                       core::RefCountPtr<T>* value) {
745   T* raw_ptr = nullptr;
746   TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
747   value->reset(raw_ptr);
748 
749   return Status::OK();
750 }
751 
752 template <typename T>
753 Status LookupResources(OpKernelContext* ctx,
754                        absl::Span<ResourceHandle const* const> p,
755                        std::vector<core::RefCountPtr<T>>* values) {
756   std::vector<std::pair<const string*, const string*>> containers_and_names(
757       p.size());
758   for (size_t i = 0; i < p.size(); ++i) {
759     TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
760     containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
761   }
762   return ctx->resource_manager()->LookupMany(containers_and_names, values);
763 }
764 
765 template <typename T>
766 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
767                               T** value, std::function<Status(T**)> creator) {
768   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
769   return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
770                                                  creator);
771 }
772 
773 template <typename T>
774 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
775                               core::RefCountPtr<T>* value,
776                               std::function<Status(T**)> creator) {
777   T* raw_ptr = nullptr;
778   TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
779   value->reset(raw_ptr);
780 
781   return Status::OK();
782 }
783 
784 template <typename T>
785 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
786   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
787   return ctx->resource_manager()->Delete<T>(p.container(), p.name());
788 }
789 
790 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
791 
792 template <typename T>
793 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
794   Tensor* output;
795   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
796   T* object;
797   bool found;
798   if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
799     found = true;
800     object->Unref();
801   } else {
802     found = false;
803   }
804 
805   output->flat<bool>()(0) = found;
806 }
807 
808 template <typename T>
809 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
810     : OpKernel(context) {
811   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
812   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
813 }
814 
815 template <typename T>
816 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
817   if (name_ == ResourceHandle::ANONYMOUS_NAME) {
818     AllocatorAttributes attr;
819     attr.set_on_host(true);
820     Tensor handle;
821     OP_REQUIRES_OK(
822         ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
823     handle.scalar<ResourceHandle>()() =
824         MakeResourceHandle<T>(ctx, container_, name_);
825     ctx->set_output(0, handle);
826   } else {
827     if (!initialized_.load()) {
828       mutex_lock ml(mutex_);
829       // Checking again to see if another thread has initialized the resource.
830       if (!initialized_.load()) {
831         AllocatorAttributes attr;
832         attr.set_on_host(true);
833         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
834                                                &resource_, attr));
835         resource_.scalar<ResourceHandle>()() =
836             MakeResourceHandle<T>(ctx, container_, name_);
837         initialized_.store(true);
838       }
839     }
840     ctx->set_output(0, resource_);
841   }
842 }
843 
844 template <typename T>
845 ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
846     : OpKernel(context) {
847   int n;
848   OP_REQUIRES_OK(context, context->GetAttr("N", &n));
849   OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
850   OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
851   OP_REQUIRES(
852       context, containers_.size() == n,
853       errors::InvalidArgument("Number of containers (", containers_.size(),
854                               ") must be equal to N (", n, ")"));
855   OP_REQUIRES(context, names_.size() == n,
856               errors::InvalidArgument("Number of names (", containers_.size(),
857                                       ") must be equal to N (", n, ")"));
858   resources_.resize(n);
859 }
860 
861 template <typename T>
862 void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
863   if (!initialized_.load()) {
864     mutex_lock ml(mutex_);
865     // Checking again to see if another thread has initialized the resource.
866     if (!initialized_.load()) {
867       AllocatorAttributes attr;
868       attr.set_on_host(true);
869       for (size_t i = 0; i < resources_.size(); ++i) {
870         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
871                                                &resources_[i], attr));
872         ResourceHandle h =
873             MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
874         resources_[i].template scalar<ResourceHandle>()() = h;
875       }
876       initialized_.store(true);
877     }
878   }
879   for (size_t i = 0; i < resources_.size(); ++i) {
880     ctx->set_output(i, resources_[i]);
881   }
882 }
883 
884 template <typename T>
885 ResourceHandle ScopedStepContainer::MakeResourceHandle(
886     const string& name, const DeviceBase& device) {
887   mutex_lock ml(mu_);
888   dirty_ = true;
889   return tensorflow::MakeResourceHandle(container_, name, device,
890                                         MakeTypeIndex<T>(), {});
891 }
892 
893 template <typename T>
894 Status ScopedStepContainer::Lookup(ResourceMgr* rm, const string& name,
895                                    T** resource) const {
896   return rm->Lookup<T>(container_, name, resource);
897 }
898 
899 template <typename T>
900 Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, const string& name,
901                                            T** resource,
902                                            std::function<Status(T**)> creator) {
903   mutex_lock ml(mu_);
904   dirty_ = true;
905   return rm->LookupOrCreate<T>(container_, name, resource, creator);
906 }
907 
908 template <typename T>
909 Status ScopedStepContainer::Create(ResourceMgr* rm, const string& name,
910                                    T* resource) {
911   mutex_lock ml(mu_);
912   dirty_ = true;
913   return rm->Create<T>(container_, name, resource);
914 }
915 
916 template <typename T>
917 Status ScopedStepContainer::Delete(ResourceMgr* rm, const string& name) {
918   return rm->Delete<T>(container_, name);
919 }
920 
921 }  //  end namespace tensorflow
922 
923 #endif  // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
924