• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
18 
19 #include <memory>
20 #include <string>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24 
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_handle.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/tensor_types.h"
32 #include "tensorflow/core/framework/type_index.h"
33 #include "tensorflow/core/framework/variant_tensor_data.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/refcount.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/hash/hash.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/macros.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 
43 namespace tensorflow {
44 
45 // A ResourceMgr instance keeps track of named and typed resources
46 // grouped into containers.
47 //
48 // Each resource must be represented as a sub-class of ResourceBase,
49 // which is reference counted explicitly.  Each named resource is
50 // registered with ResourceMgr under a named "container" name. At any
51 // time, there is at most one instance of a resource given the container
52 // name, the resource type and the resource name.
53 //
54 // All resources for a given container can be dropped by one call of
55 // Cleanup().
56 //
57 // E.g.,
58 //   struct MyVar : public ResourceBase {
59 //     mutex mu;
60 //     Tensor val;
61 //   }
62 //
63 //   ResourceMgr rm;
64 //
65 //   // Create a var.
66 //   MyVar* my_var = new MyVar;
67 //   my_var->val = Tensor(DT_FLOAT, my_shape);
68 //   my_var->val.flat<float>().setZeros();   // 0 initialized.
69 //   ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
70 //
71 //   // += a variable.
72 //   MyVar* my_var = nullptr;
73 //   Status s = rm.Lookup("my_container", "my_name", &my_var);
74 //   if (s.ok()) {
75 //     my_var->val.flat<float>() += grad;
76 //   }
77 //   my_var->Unref();   // Or use ScopedUnref().
78 //   ctx->SetStatus(s);
79 class ResourceBase : public core::RefCounted {
80  public:
81   // Returns a debug string for *this.
82   virtual std::string DebugString() const = 0;
83 
84   // Returns memory used by this resource.
MemoryUsed()85   virtual int64 MemoryUsed() const { return 0; }
86 };
87 
88 // Container used for per-step resources.
89 class ScopedStepContainer {
90  public:
91   // step_id: the unique ID of this step. Doesn't have to be sequential, just
92   // has to be unique.
93   // cleanup: callback to delete a container of this name.
94   // prefix: optional string prefix to disambiguate step containers.
ScopedStepContainer(const int64 step_id,std::function<void (const string &)> cleanup)95   ScopedStepContainer(const int64 step_id,
96                       std::function<void(const string&)> cleanup)
97       : step_id_(step_id),
98         container_(strings::StrCat("__per_step_", step_id)),
99         cleanup_(cleanup),
100         dirty_(false) {}
101 
ScopedStepContainer(const int64 step_id,std::function<void (const string &)> cleanup,const std::string & prefix)102   ScopedStepContainer(const int64 step_id,
103                       std::function<void(const string&)> cleanup,
104                       const std::string& prefix)
105       : step_id_(step_id),
106         container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
107         cleanup_(cleanup),
108         dirty_(false) {}
109 
~ScopedStepContainer()110   ~ScopedStepContainer() { CleanUp(); }
111 
CleanUp()112   void CleanUp() TF_NO_THREAD_SAFETY_ANALYSIS {
113     // NOTE(mrry): Avoid acquiring the mutex in the case that the container is
114     // clean.
115     if (dirty_) {
116       mutex_lock ml(mu_);
117       cleanup_(container_);
118       dirty_ = false;
119     }
120   }
121 
122   // Pass through functions for resource lookup and creation. We do this to
123   // ensure that we can appropriately set the dirty_ bit in the
124   // ScopedStepContainer if the name of the container is used to create
125   // resources.
126 
127   // Pass through to MakeResourceHandle with the container name
128   template <typename T>
129   ResourceHandle MakeResourceHandle(
130       const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
131   // Pass through to ResourceMgr::Create with the container name
132   template <typename T>
133   Status Create(ResourceMgr* rm, const std::string& name,
134                 T* resource) TF_MUST_USE_RESULT;
135   // Pass through to ResourceMgr::Delete with the container name
136   template <typename T>
137   Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT;
138   // Pass through to ResourceMgr::Lookup with the container name
139   template <typename T>
140   Status Lookup(ResourceMgr* rm, const std::string& name,
141                 T** resource) const TF_MUST_USE_RESULT;
142   // Pass through to ResourceMgr::LookupOrCreate with the container name
143   template <typename T>
144   Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource,
145                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
StepId()146   int64 StepId() const { return step_id_; }
147 
148  private:
149   const int64 step_id_;
150   const std::string container_;
151   const std::function<void(const string&)> cleanup_;
152   mutex mu_;
153   mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_);
154 };
155 
156 class ResourceMgr {
157  public:
158   ResourceMgr();
159   explicit ResourceMgr(const std::string& default_container);
160   ~ResourceMgr();
161 
162   // Returns the default container name for *this.
default_container()163   const std::string& default_container() const { return default_container_; }
164 
165   // Creates a resource "name" in the "container".  The caller transfers
166   // the ownership of one ref on "resource" to *this, regardless of whether this
167   // operation succeeds or fails.
168   //
169   // REQUIRES: std::is_base_of<ResourceBase, T>
170   // REQUIRES: resource != nullptr.
171   template <typename T>
172   Status Create(const std::string& container, const std::string& name,
173                 T* resource) TF_MUST_USE_RESULT;
174 
175   // If "container" has a resource "name", returns it in "*resource" and
176   // the caller takes the ownership of one ref on "*resource".
177   //
178   // REQUIRES: std::is_base_of<ResourceBase, T>
179   // REQUIRES: resource != nullptr
180   template <typename T, bool use_dynamic_cast = false>
181   Status Lookup(const std::string& container, const std::string& name,
182                 T** resource) const TF_MUST_USE_RESULT;
183 
184   // Similar to Lookup, but looks up multiple resources at once, with only a
185   // single lock acquisition.  If containers_and_names[i] is uninitialized
186   // then this function does not modify resources[i].
187   template <typename T, bool use_dynamic_cast = false>
188   Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
189                         containers_and_names,
190                     std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
191                         resources) const TF_MUST_USE_RESULT;
192 
193   // If "container" has a resource "name", returns it in
194   // "*resource". Otherwise, invokes creator() to create the resource.
195   // The caller takes the ownership of one ref on "*resource".
196   //
197   // WARNING: creator() must not call any methods on ResourceMgr during its
198   // execution, because a non-reentrant lock is held during the creator() call
199   // in order to guarantee atomicity of LookupOrCreate().
200   //
201   // REQUIRES: std::is_base_of<ResourceBase, T>
202   // REQUIRES: resource != nullptr
203   template <typename T, bool use_dynamic_cast = false>
204   Status LookupOrCreate(const std::string& container, const std::string& name,
205                         T** resource,
206                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
207 
208   // Deletes the resource "name" from the "container".
209   //
210   // REQUIRES: std::is_base_of<ResourceBase, T>
211   template <typename T>
212   Status Delete(const std::string& container,
213                 const std::string& name) TF_MUST_USE_RESULT;
214 
215   // Deletes the resource pointed by "handle".
216   Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
217 
218   // Deletes all resources from the "container" and removes the container.
219   Status Cleanup(const std::string& container) TF_MUST_USE_RESULT;
220 
221   // Deletes all resources in all containers.
222   void Clear();
223 
224   // Returns a text description for all resources.
225   std::string DebugString() const;
226 
227  private:
228   typedef std::pair<uint64, StringPiece> Key;
229   struct KeyHash {
operatorKeyHash230     std::size_t operator()(const Key& k) const {
231       return Hash64(k.second.data(), k.second.size(), k.first);
232     }
233   };
234   struct KeyEqual {
operatorKeyEqual235     bool operator()(const Key& x, const Key& y) const {
236       return (x.second == y.second) && (x.first == y.first);
237     }
238   };
239   struct ResourceAndName {
240     core::RefCountPtr<ResourceBase> resource;
241     std::unique_ptr<string> name;
242 
243     ResourceAndName();
244     ResourceAndName(ResourceBase* resource, std::string name);
245     ResourceAndName(ResourceAndName&& other) noexcept;
246     ~ResourceAndName();
247 
248     ResourceAndName& operator=(ResourceAndName&&) noexcept;
249 
250    private:
251     TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
252   };
253   typedef std::unordered_map<Key, ResourceAndName, KeyHash, KeyEqual> Container;
254 
255   const std::string default_container_;
256   mutable mutex mu_;
257   std::unordered_map<string, Container*> containers_ TF_GUARDED_BY(mu_);
258 
259   template <typename T, bool use_dynamic_cast = false>
260   Status LookupInternal(const std::string& container, const std::string& name,
261                         T** resource) const
262       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
263 
264   Status DoCreate(const std::string& container, TypeIndex type,
265                   const std::string& name, ResourceBase* resource)
266       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
267 
268   Status DoLookup(const std::string& container, TypeIndex type,
269                   const std::string& name, ResourceBase** resource) const
270       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
271 
272   Status DoDelete(const std::string& container, uint64 type_hash_code,
273                   const std::string& resource_name,
274                   const std::string& type_name) TF_MUST_USE_RESULT;
275   Status DoDelete(const std::string& container, TypeIndex type,
276                   const std::string& resource_name) TF_MUST_USE_RESULT;
277 
278   // Inserts the type name for 'hash_code' into the hash_code to type name map.
279   Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name)
280       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
281 
282   // Returns the type name for the 'hash_code'.
283   // Returns "<unknown>" if a resource with such a type was never inserted into
284   // the container.
285   const char* DebugTypeName(uint64 hash_code) const
286       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
287 
288   // Map from type hash_code to type name.
289   std::unordered_map<uint64, string> debug_type_names_ TF_GUARDED_BY(mu_);
290 
291   TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
292 };
293 
294 // Makes a resource handle with the specified type for a given container /
295 // name.
296 ResourceHandle MakeResourceHandle(
297     const std::string& container, const std::string& name,
298     const DeviceBase& device, const TypeIndex& type_index,
299     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
300     const absl::optional<ManagedStackTrace>& definition_stack_trace = {})
301     TF_MUST_USE_RESULT;
302 
303 template <typename T>
304 ResourceHandle MakeResourceHandle(
305     OpKernelContext* ctx, const std::string& container, const std::string& name,
306     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
307     const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
308   return MakeResourceHandle(container.empty()
309                                 ? ctx->resource_manager()->default_container()
310                                 : container,
311                             name, *ctx->device(), TypeIndex::Make<T>(),
312                             dtypes_and_shapes, definition_stack_trace);
313 }
314 
315 template <typename T>
316 ResourceHandle MakeResourceHandle(
317     OpKernelConstruction* ctx, const std::string& container,
318     const std::string& name,
319     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
320     const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
321   return MakeResourceHandle(container.empty()
322                                 ? ctx->resource_manager()->default_container()
323                                 : container,
324                             name, *ctx->device(), TypeIndex::Make<T>(),
325                             dtypes_and_shapes, definition_stack_trace);
326 }
327 
328 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
329                                   const std::string& container,
330                                   const std::string& name,
331                                   const TypeIndex& type_index);
332 
333 // Returns a resource handle from a numbered op input.
334 const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
335 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
336                        ResourceHandle* handle);
337 
338 // Create a resource pointed by a given resource handle.
339 //
340 // If successful, the caller transfers the ownership of one ref on `resource` to
341 // `ctx->resource_mgr()`.
342 template <typename T>
343 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
344 
345 // Looks up a resource pointed by a given resource handle.
346 //
347 // If the lookup is successful, the caller takes the ownership of one ref on
348 // `*value`, and must call its `Unref()` method when it has finished using it.
349 template <typename T, bool use_dynamic_cast = false>
350 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
351 
352 // Looks up a resource pointed by a given resource handle.
353 //
354 // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
355 // requiring the caller to explicitly call `Unref()`.
356 template <typename T>
357 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
358                       core::RefCountPtr<T>* value);
359 
360 // Looks up multiple resources pointed by a sequence of resource handles.  If
361 // p[i] is uninitialized then values[i] is unmodified.
362 template <typename T>
363 Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
364                        std::vector<core::RefCountPtr<T>>* values);
365 
366 // Looks up or creates a resource.
367 //
368 // If successful, the caller takes the ownership of one ref on `*value`, and
369 // must call its `Unref()` method when it has finished using it. If the
370 // `creator` is invoked, its reference on the created resource is transferred
371 // to `ctx->resource_mgr()`.
372 //
373 // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
374 // requiring the caller to explicitly call `Unref()`.
375 template <typename T>
376 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
377                               T** value, std::function<Status(T**)> creator);
378 
379 // Looks up or creates a resource.
380 template <typename T>
381 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
382                               core::RefCountPtr<T>* value,
383                               std::function<Status(T**)> creator);
384 
385 // Destroys a resource pointed by a given resource handle.
386 template <typename T>
387 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
388 
389 // Same as above, but uses the hash code of the type directly.
390 // The type name information will be missing in the debug output when the
391 // resource is not present in the container.
392 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
393 
394 // Policy helper to decide which container/shared_name to use for a
395 // stateful kernel that accesses shared resource.
396 class ContainerInfo {
397  public:
398   // Analyze the node attribute of 'ndef' and decides the container and
399   // resource name the kernel should use for accessing the shared
400   // resource.
401   //
402   // 'ndef' is expected to have node attribute "container" and
403   // "shared_name". Returns non-OK if they are not provided or they are
404   // invalid.
405   //
406   // The policy is as following:
407   // * If the attribute "container" is non-empty, it is used as is.
408   //   Otherwise, uses the resource manager's default container.
409   // * If the attribute "shared_name" is non-empty, it is used as is.
410   //   Otherwise, if "use_node_name_as_default" is true, the kernel's
411   //   node name is used as the resource name. Otherwise, a string
412   //   unique to this process is used.
413   Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
414               bool use_node_name_as_default);
Init(ResourceMgr * rmgr,const NodeDef & ndef)415   Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
416     return Init(rmgr, ndef, false);
417   }
418 
419   // The policy decides that the kernel should access the resource in
420   // resource_manager(), the resource is in the container() and its
421   // name is name().  If resource_is_private_to_kernel() is true, the
422   // kernel should delete the resource when the kernel is deleted.
resource_manager()423   ResourceMgr* resource_manager() const { return rmgr_; }
container()424   const std::string& container() const { return container_; }
name()425   const std::string& name() const { return name_; }
resource_is_private_to_kernel()426   bool resource_is_private_to_kernel() const {
427     return resource_is_private_to_kernel_;
428   }
429 
430   // Returns a readable string for *this.
431   std::string DebugString() const;
432 
433  private:
434   ResourceMgr* rmgr_ = nullptr;
435   std::string container_;
436   std::string name_;
437   bool resource_is_private_to_kernel_ = false;
438 };
439 
440 // Helper for kernels to obtain 'resource' from the
441 // ctx->resource_manager().
442 //
443 // "input_name" specifies the kernel's ref input which gives a string
444 // tensor with two elements, which specifies the container and
445 // resource name.
446 //
447 // Returns OK if the resource is found and transfers one ref of
448 // *resource to the caller. Otherwise, returns an error.
449 template <typename T>
450 Status GetResourceFromContext(OpKernelContext* ctx,
451                               const std::string& input_name, T** resource);
452 
453 // Utility op kernel to check if a handle to resource type T is initialized.
454 template <typename T>
455 class IsResourceInitialized : public OpKernel {
456  public:
IsResourceInitialized(OpKernelConstruction * c)457   explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
458 
459   void Compute(OpKernelContext* ctx) override;
460 };
461 
462 // Registers an op which produces just a resource handle to a resource of the
463 // specified type. The type will be a part of the generated op name.
464 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
465 // through constant folding so this doesn't have to be marked as stateful.
466 #define REGISTER_RESOURCE_HANDLE_OP(Type) \
467   REGISTER_OP(#Type "HandleOp")           \
468       .Attr("container: string = ''")     \
469       .Attr("shared_name: string = ''")   \
470       .Output("resource: resource")       \
471       .SetIsStateful()                    \
472       .SetShapeFn(tensorflow::shape_inference::ScalarShape)
473 
474 // Utility op kernel to produce a handle to a resource of type T.
475 template <typename T>
476 class ResourceHandleOp : public OpKernel {
477  public:
478   explicit ResourceHandleOp(OpKernelConstruction* context);
479 
480   void Compute(OpKernelContext* ctx) override;
481 
IsExpensive()482   bool IsExpensive() override { return false; }
483 
484  private:
485   std::string container_;
486   std::string name_;
487   mutex mutex_;
488   Tensor resource_;
489   std::atomic<bool> initialized_{false};
490 };
491 
492 // Utility op kernel to produce a handle to a resource of type T.
493 template <typename T>
494 class ResourceHandlesOp : public OpKernel {
495  public:
496   explicit ResourceHandlesOp(OpKernelConstruction* context);
497 
498   void Compute(OpKernelContext* ctx) override;
499 
IsExpensive()500   bool IsExpensive() override { return false; }
501 
502  private:
503   std::vector<string> containers_;
504   std::vector<string> names_;
505   mutex mutex_;
506   std::vector<Tensor> resources_;
507   std::atomic<bool> initialized_{false};
508 };
509 
510 // Registers a kernel for an op which produces a handle to a resource of the
511 // specified type.
512 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type)                        \
513   REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
514                           ResourceHandleOp<Type>)
515 
516 // This class is used to guarantee that an anonymous resource is deleted
517 // (irrespective of whether a resource deleter op is called explicitly or
518 // the execution encounters an error before the op runs).
519 //
520 // This is achieved by wrapping an instance of this class into a variant
521 // tensor which is passed as an input to a resource deleter op. If the
522 // execution encounters an error before the op runs, the tensor will be
523 // destroyed, essentially triggering the iterator deletion.
524 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
525 // specification. In particular, we cannot serialize the `ResourceMgr`
526 // object, so the `Encode()` and `Decode()` methods are not implemented.
527 class ResourceDeleter {
528  public:
ResourceDeleter()529   ResourceDeleter() : deleter_() {}
530 
ResourceDeleter(ResourceHandle handle,ResourceMgr * resource_manager)531   ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager)
532       : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
533 
ResourceDeleter(ResourceDeleter && rhs)534   ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) {
535     VLOG(3) << "ResourceDeleter move constructor called.";
536   }
537 
ResourceDeleter(const ResourceDeleter & rhs)538   ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) {
539     VLOG(3) << "ResourceDeleter copy constructor called.";
540   }
541 
542   ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete;
543 
544   ResourceDeleter& operator=(ResourceDeleter&& rhs) = default;
545 
~ResourceDeleter()546   virtual ~ResourceDeleter() {
547     VLOG(3) << "ResourceDeleter destructor called.";
548   }
549 
Encode(VariantTensorData *)550   void Encode(VariantTensorData*) const {
551     LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter "
552                   "objects.";
553   }
554 
Decode(const VariantTensorData &)555   bool Decode(const VariantTensorData&) {
556     LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter "
557                   "objects";
558     return false;  // Not supported.
559   }
560 
561  private:
562   // Helper that performs reference counting for the parent class and deletes
563   // the iterator resource when the refcount goes to zero.
564   //
565   // NOTE: The object is borrowing a pointer to the resource manager.
566   // Consequently, the tensor containing this object should not escape the
567   // function in which was created (so that it is guaranteed that the resource
568   // manager will outlive it).
569   struct Helper {
HelperHelper570     Helper(ResourceHandle handle, ResourceMgr* resource_manager)
571         : handle(handle), resource_manager(resource_manager) {}
572 
573     Helper(const Helper& rhs) = delete;
574     Helper(Helper&& rhs) = delete;
575 
~HelperHelper576     ~Helper() {
577       VLOG(3) << "Deleting Resource: " << handle.DebugString();
578       resource_manager->Delete(handle).IgnoreError();
579     }
580 
581     ResourceHandle handle;
582     ResourceMgr* resource_manager;  // not owned
583   };
584 
585   std::shared_ptr<Helper> deleter_;
586 };
587 
588 // Implementation details below.
589 
590 template <typename T>
CheckDeriveFromResourceBase()591 void CheckDeriveFromResourceBase() {
592   static_assert(std::is_base_of<ResourceBase, T>::value,
593                 "T must derive from ResourceBase");
594 }
595 
596 template <typename T>
Create(const std::string & container,const std::string & name,T * resource)597 Status ResourceMgr::Create(const std::string& container,
598                            const std::string& name, T* resource) {
599   CheckDeriveFromResourceBase<T>();
600   CHECK(resource != nullptr);
601   mutex_lock l(mu_);
602   return DoCreate(container, TypeIndex::Make<T>(), name, resource);
603 }
604 
605 template <typename T, bool use_dynamic_cast>
Lookup(const std::string & container,const std::string & name,T ** resource)606 Status ResourceMgr::Lookup(const std::string& container,
607                            const std::string& name, T** resource) const {
608   CheckDeriveFromResourceBase<T>();
609   tf_shared_lock l(mu_);
610   return LookupInternal<T, use_dynamic_cast>(container, name, resource);
611 }
612 
613 template <typename T, bool use_dynamic_cast>
LookupMany(absl::Span<std::pair<const string *,const string * > const> containers_and_names,std::vector<std::unique_ptr<T,core::RefCountDeleter>> * resources)614 Status ResourceMgr::LookupMany(
615     absl::Span<std::pair<const string*, const string*> const>
616         containers_and_names,
617     std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
618   CheckDeriveFromResourceBase<T>();
619   tf_shared_lock l(mu_);
620   resources->resize(containers_and_names.size());
621   for (size_t i = 0; i < containers_and_names.size(); ++i) {
622     T* resource;
623     Status s = LookupInternal<T, use_dynamic_cast>(
624         *containers_and_names[i].first, *containers_and_names[i].second,
625         &resource);
626     if (s.ok()) {
627       (*resources)[i].reset(resource);
628     }
629   }
630   return Status::OK();
631 }
632 
633 // Simple wrapper to allow conditional dynamic / static casts.
634 template <typename T, bool use_dynamic_cast>
635 struct TypeCastFunctor {
CastTypeCastFunctor636   static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
637 };
638 
639 template <typename T>
640 struct TypeCastFunctor<T, true> {
641   static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
642 };
643 
644 template <typename T, bool use_dynamic_cast>
645 Status ResourceMgr::LookupInternal(const std::string& container,
646                                    const std::string& name,
647                                    T** resource) const {
648   ResourceBase* found = nullptr;
649   Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
650   if (s.ok()) {
651     // It's safe to down cast 'found' to T* since
652     // typeid(T).hash_code() is part of the map key.
653     *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
654   }
655   return s;
656 }
657 
658 template <typename T, bool use_dynamic_cast>
659 Status ResourceMgr::LookupOrCreate(const std::string& container,
660                                    const std::string& name, T** resource,
661                                    std::function<Status(T**)> creator) {
662   CheckDeriveFromResourceBase<T>();
663   *resource = nullptr;
664   Status s;
665   {
666     tf_shared_lock l(mu_);
667     s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
668     if (s.ok()) return s;
669   }
670   mutex_lock l(mu_);
671   s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
672   if (s.ok()) return s;
673   TF_RETURN_IF_ERROR(creator(resource));
674   s = DoCreate(container, TypeIndex::Make<T>(), name, *resource);
675   if (!s.ok()) {
676     return errors::Internal("LookupOrCreate failed unexpectedly");
677   }
678   (*resource)->Ref();
679   return s;
680 }
681 
682 template <typename T>
683 Status ResourceMgr::Delete(const std::string& container,
684                            const std::string& name) {
685   CheckDeriveFromResourceBase<T>();
686   return DoDelete(container, TypeIndex::Make<T>(), name);
687 }
688 
689 template <typename T>
690 Status GetResourceFromContext(OpKernelContext* ctx,
691                               const std::string& input_name, T** resource) {
692   DataType dtype;
693   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
694   if (dtype == DT_RESOURCE) {
695     const Tensor* handle;
696     TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
697     return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
698   }
699   std::string container;
700   std::string shared_name;
701   {
702     mutex* mu;
703     TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
704     mutex_lock l(*mu);
705     Tensor tensor;
706     TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
707     if (tensor.NumElements() != 2) {
708       return errors::InvalidArgument(
709           "Resource handle must have 2 elements, but had shape: ",
710           tensor.shape().DebugString());
711     }
712     container = tensor.flat<tstring>()(0);
713     shared_name = tensor.flat<tstring>()(1);
714   }
715   return ctx->resource_manager()->Lookup(container, shared_name, resource);
716 }
717 
718 namespace internal {
719 
720 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
721 
722 template <typename T>
723 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
724   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
725   auto type_index = TypeIndex::Make<T>();
726   if (type_index.hash_code() != p.hash_code()) {
727     return errors::InvalidArgument(
728         "Trying to access resource using the wrong type. Expected ",
729         p.maybe_type_name(), " got ", type_index.name());
730   }
731   return Status::OK();
732 }
733 
734 }  // namespace internal
735 
736 template <typename T>
737 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
738   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
739   return ctx->resource_manager()->Create(p.container(), p.name(), value);
740 }
741 
742 template <typename T, bool use_dynamic_cast>
743 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
744                       T** value) {
745   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
746   return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
747                                                               p.name(), value);
748 }
749 
750 template <typename T>
751 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
752                       core::RefCountPtr<T>* value) {
753   T* raw_ptr = nullptr;
754   TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
755   value->reset(raw_ptr);
756 
757   return Status::OK();
758 }
759 
760 template <typename T>
761 Status LookupResources(OpKernelContext* ctx,
762                        absl::Span<ResourceHandle const* const> p,
763                        std::vector<core::RefCountPtr<T>>* values) {
764   std::vector<std::pair<const string*, const string*>> containers_and_names(
765       p.size());
766   for (size_t i = 0; i < p.size(); ++i) {
767     TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
768     containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
769   }
770   return ctx->resource_manager()->LookupMany(containers_and_names, values);
771 }
772 
773 template <typename T>
774 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
775                               T** value, std::function<Status(T**)> creator) {
776   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
777   return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
778                                                  creator);
779 }
780 
781 template <typename T>
782 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
783                               core::RefCountPtr<T>* value,
784                               std::function<Status(T**)> creator) {
785   T* raw_ptr = nullptr;
786   TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
787   value->reset(raw_ptr);
788 
789   return Status::OK();
790 }
791 
792 template <typename T>
793 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
794   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
795   return ctx->resource_manager()->Delete<T>(p.container(), p.name());
796 }
797 
798 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
799 
800 template <typename T>
801 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
802   Tensor* output;
803   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
804   T* object;
805   bool found;
806   if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
807     found = true;
808     object->Unref();
809   } else {
810     found = false;
811   }
812 
813   output->flat<bool>()(0) = found;
814 }
815 
816 template <typename T>
817 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
818     : OpKernel(context) {
819   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
820   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
821 }
822 
823 template <typename T>
824 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
825   if (name_ == ResourceHandle::ANONYMOUS_NAME) {
826     AllocatorAttributes attr;
827     attr.set_on_host(true);
828     Tensor handle;
829     OP_REQUIRES_OK(
830         ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
831     handle.scalar<ResourceHandle>()() =
832         MakeResourceHandle<T>(ctx, container_, name_);
833     ctx->set_output(0, handle);
834   } else {
835     if (!initialized_.load()) {
836       mutex_lock ml(mutex_);
837       // Checking again to see if another thread has initialized the resource.
838       if (!initialized_.load()) {
839         AllocatorAttributes attr;
840         attr.set_on_host(true);
841         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
842                                                &resource_, attr));
843         resource_.scalar<ResourceHandle>()() =
844             MakeResourceHandle<T>(ctx, container_, name_);
845         initialized_.store(true);
846       }
847     }
848     ctx->set_output(0, resource_);
849   }
850 }
851 
852 template <typename T>
853 ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
854     : OpKernel(context) {
855   int n;
856   OP_REQUIRES_OK(context, context->GetAttr("N", &n));
857   OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
858   OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
859   OP_REQUIRES(
860       context, containers_.size() == n,
861       errors::InvalidArgument("Number of containers (", containers_.size(),
862                               ") must be equal to N (", n, ")"));
863   OP_REQUIRES(context, names_.size() == n,
864               errors::InvalidArgument("Number of names (", containers_.size(),
865                                       ") must be equal to N (", n, ")"));
866   resources_.resize(n);
867 }
868 
869 template <typename T>
870 void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
871   if (!initialized_.load()) {
872     mutex_lock ml(mutex_);
873     // Checking again to see if another thread has initialized the resource.
874     if (!initialized_.load()) {
875       AllocatorAttributes attr;
876       attr.set_on_host(true);
877       for (size_t i = 0; i < resources_.size(); ++i) {
878         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
879                                                &resources_[i], attr));
880         ResourceHandle h =
881             MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
882         resources_[i].template scalar<ResourceHandle>()() = h;
883       }
884       initialized_.store(true);
885     }
886   }
887   for (size_t i = 0; i < resources_.size(); ++i) {
888     ctx->set_output(i, resources_[i]);
889   }
890 }
891 
892 template <typename T>
893 ResourceHandle ScopedStepContainer::MakeResourceHandle(
894     const std::string& name, const DeviceBase& device) {
895   mutex_lock ml(mu_);
896   dirty_ = true;
897   return tensorflow::MakeResourceHandle(container_, name, device,
898                                         TypeIndex::Make<T>(), {});
899 }
900 
901 template <typename T>
902 Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name,
903                                    T** resource) const {
904   return rm->Lookup<T>(container_, name, resource);
905 }
906 
907 template <typename T>
908 Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm,
909                                            const std::string& name,
910                                            T** resource,
911                                            std::function<Status(T**)> creator) {
912   mutex_lock ml(mu_);
913   dirty_ = true;
914   return rm->LookupOrCreate<T>(container_, name, resource, creator);
915 }
916 
917 template <typename T>
918 Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name,
919                                    T* resource) {
920   mutex_lock ml(mu_);
921   dirty_ = true;
922   return rm->Create<T>(container_, name, resource);
923 }
924 
925 template <typename T>
926 Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) {
927   return rm->Delete<T>(container_, name);
928 }
929 
930 }  //  end namespace tensorflow
931 
932 #endif  // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
933