1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
18
19 #include <memory>
20 #include <string>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_base.h"
29 #include "tensorflow/core/framework/resource_handle.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/type_index.h"
34 #include "tensorflow/core/framework/variant_tensor_data.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/thread_annotations.h"
41
42 namespace tensorflow {
43
44 // A ResourceMgr instance keeps track of named and typed resources
45 // grouped into containers.
46 //
47 // Each named resource is
48 // registered with ResourceMgr under a named "container" name. At any
49 // time, there is at most one instance of a resource given the container
50 // name, the resource type and the resource name.
51 //
52 // All resources for a given container can be dropped by one call of
53 // Cleanup().
54 //
55 // E.g.,
56 // struct MyVar : public ResourceBase {
57 // mutex mu;
58 // Tensor val;
59 // }
60 //
61 // ResourceMgr rm;
62 //
63 // // Create a var.
64 // MyVar* my_var = new MyVar;
65 // my_var->val = Tensor(DT_FLOAT, my_shape);
66 // my_var->val.flat<float>().setZeros(); // 0 initialized.
67 // ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
68 //
69 // // += a variable.
70 // MyVar* my_var = nullptr;
71 // Status s = rm.Lookup("my_container", "my_name", &my_var);
72 // if (s.ok()) {
73 // my_var->val.flat<float>() += grad;
74 // }
75 // my_var->Unref(); // Or use ScopedUnref().
76 // ctx->SetStatus(s);
77
78 // Container used for per-step resources.
79 class ScopedStepContainer {
80 public:
81 // step_id: the unique ID of this step. Doesn't have to be sequential, just
82 // has to be unique.
83 // cleanup: callback to delete a container of this name.
84 // prefix: optional string prefix to disambiguate step containers.
ScopedStepContainer(const int64_t step_id,std::function<void (const string &)> cleanup)85 ScopedStepContainer(const int64_t step_id,
86 std::function<void(const string&)> cleanup)
87 : step_id_(step_id),
88 container_(strings::StrCat("__per_step_", step_id)),
89 cleanup_(cleanup),
90 dirty_(false) {}
91
ScopedStepContainer(const int64_t step_id,std::function<void (const string &)> cleanup,const std::string & prefix)92 ScopedStepContainer(const int64_t step_id,
93 std::function<void(const string&)> cleanup,
94 const std::string& prefix)
95 : step_id_(step_id),
96 container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
97 cleanup_(cleanup),
98 dirty_(false) {}
99
~ScopedStepContainer()100 ~ScopedStepContainer() { CleanUp(); }
101
CleanUp()102 void CleanUp() TF_NO_THREAD_SAFETY_ANALYSIS {
103 // NOTE(mrry): Avoid acquiring the mutex in the case that the container is
104 // clean.
105 if (dirty_) {
106 mutex_lock ml(mu_);
107 cleanup_(container_);
108 dirty_ = false;
109 }
110 }
111
112 // Pass through functions for resource lookup and creation. We do this to
113 // ensure that we can appropriately set the dirty_ bit in the
114 // ScopedStepContainer if the name of the container is used to create
115 // resources.
116
117 // Pass through to MakeResourceHandle with the container name
118 template <typename T>
119 ResourceHandle MakeResourceHandle(
120 const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
121 // Pass through to ResourceMgr::Create with the container name
122 template <typename T>
123 Status Create(ResourceMgr* rm, const std::string& name,
124 T* resource) TF_MUST_USE_RESULT;
125 // Pass through to ResourceMgr::Delete with the container name
126 template <typename T>
127 Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT;
128 // Pass through to ResourceMgr::Lookup with the container name
129 template <typename T>
130 Status Lookup(ResourceMgr* rm, const std::string& name,
131 T** resource) const TF_MUST_USE_RESULT;
132 // Pass through to ResourceMgr::LookupOrCreate with the container name
133 template <typename T>
134 Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource,
135 std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
StepId()136 int64 StepId() const { return step_id_; }
137
138 private:
139 const int64 step_id_;
140 const std::string container_;
141 const std::function<void(const string&)> cleanup_;
142 mutex mu_;
143 mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_);
144 };
145
146 class ResourceMgr {
147 public:
148 ResourceMgr();
149 explicit ResourceMgr(const std::string& default_container);
150 ~ResourceMgr();
151
152 // Returns the default container name for *this.
default_container()153 const std::string& default_container() const { return default_container_; }
154
155 // Creates a resource "name" in the "container". The caller transfers
156 // the ownership of one ref on "resource" to *this, regardless of whether this
157 // operation succeeds or fails.
158 //
159 // REQUIRES: std::is_base_of<ResourceBase, T>
160 // REQUIRES: resource != nullptr.
161 template <typename T>
162 Status Create(const std::string& container, const std::string& name,
163 T* resource) TF_MUST_USE_RESULT;
164
165 // If "container" has a resource "name", returns it in "*resource" and
166 // the caller takes the ownership of one ref on "*resource".
167 //
168 // REQUIRES: std::is_base_of<ResourceBase, T>
169 // REQUIRES: resource != nullptr
170 template <typename T, bool use_dynamic_cast = false>
171 Status Lookup(const std::string& container, const std::string& name,
172 T** resource) const TF_MUST_USE_RESULT;
173
174 // If the resource manager has a resource matching "handle", returns it in
175 // "*resource" and the caller takes the ownership of one ref on "*resource".
176 //
177 // REQUIRES: resource != nullptr
178 Status Lookup(const ResourceHandle& handle,
179 ResourceBase** resource) const TF_MUST_USE_RESULT;
180
181 // Similar to Lookup, but looks up multiple resources at once, with only a
182 // single lock acquisition. If containers_and_names[i] is uninitialized
183 // then this function does not modify resources[i].
184 template <typename T, bool use_dynamic_cast = false>
185 Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
186 containers_and_names,
187 std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
188 resources) const TF_MUST_USE_RESULT;
189
190 // If "container" has a resource "name", returns it in
191 // "*resource". Otherwise, invokes creator() to create the resource.
192 // The caller takes the ownership of one ref on "*resource".
193 //
194 // WARNING: creator() must not call any methods on ResourceMgr during its
195 // execution, because a non-reentrant lock is held during the creator() call
196 // in order to guarantee atomicity of LookupOrCreate().
197 //
198 // REQUIRES: std::is_base_of<ResourceBase, T>
199 // REQUIRES: resource != nullptr
200 template <typename T, bool use_dynamic_cast = false>
201 Status LookupOrCreate(const std::string& container, const std::string& name,
202 T** resource,
203 std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
204
205 // Deletes the resource "name" from the "container".
206 //
207 // REQUIRES: std::is_base_of<ResourceBase, T>
208 template <typename T>
209 Status Delete(const std::string& container,
210 const std::string& name) TF_MUST_USE_RESULT;
211
212 // Deletes the resource pointed by "handle".
213 Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
214
215 // Deletes all resources from the "container" and removes the container.
216 Status Cleanup(const std::string& container) TF_MUST_USE_RESULT;
217
218 // Deletes all resources in all containers.
219 void Clear();
220
221 // Returns a text description for all resources.
222 std::string DebugString() const;
223
224 private:
225 typedef std::pair<uint64, StringPiece> Key;
226 struct KeyHash {
operatorKeyHash227 std::size_t operator()(const Key& k) const {
228 return Hash64(k.second.data(), k.second.size(), k.first);
229 }
230 };
231 struct KeyEqual {
operatorKeyEqual232 bool operator()(const Key& x, const Key& y) const {
233 return (x.second == y.second) && (x.first == y.first);
234 }
235 };
236 struct ResourceAndName {
237 core::RefCountPtr<ResourceBase> resource;
238 std::unique_ptr<string> name;
239
240 ResourceAndName();
241 ResourceAndName(ResourceBase* resource, std::string name);
242 ResourceAndName(ResourceAndName&& other) noexcept;
243 ~ResourceAndName();
244
245 ResourceAndName& operator=(ResourceAndName&&) noexcept;
246
247 private:
248 TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
249 };
250 typedef std::unordered_map<Key, ResourceAndName, KeyHash, KeyEqual> Container;
251
252 const std::string default_container_;
253 mutable mutex mu_;
254 std::unordered_map<string, Container*> containers_ TF_GUARDED_BY(mu_);
255
256 template <typename T, bool use_dynamic_cast = false>
257 Status LookupInternal(const std::string& container, const std::string& name,
258 T** resource) const
259 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
260 Status LookupInternal(const std::string& container, uint64 type_hash_code,
261 const std::string& name, ResourceBase** resource) const
262 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
263
264 Status DoCreate(const std::string& container, TypeIndex type,
265 const std::string& name, ResourceBase* resource)
266 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
267
268 Status DoLookup(const std::string& container, TypeIndex type,
269 const std::string& name, ResourceBase** resource) const
270 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
271 Status DoLookup(const std::string& container, uint64 type_hash_code,
272 const std::string& type_name,
273 const std::string& resource_name,
274 ResourceBase** resource) const
275 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
276
277 Status DoDelete(const std::string& container, uint64 type_hash_code,
278 const std::string& resource_name,
279 const std::string& type_name) TF_MUST_USE_RESULT;
280 Status DoDelete(const std::string& container, TypeIndex type,
281 const std::string& resource_name) TF_MUST_USE_RESULT;
282
283 // Inserts the type name for 'hash_code' into the hash_code to type name map.
284 Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name)
285 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
286
287 // Returns the type name for the 'hash_code'.
288 // Returns "<unknown>" if a resource with such a type was never inserted into
289 // the container.
290 const char* DebugTypeName(uint64 hash_code) const
291 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
292
293 // Map from type hash_code to type name.
294 std::unordered_map<uint64, string> debug_type_names_ TF_GUARDED_BY(mu_);
295
296 TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
297 };
298
299 // Makes a resource handle with the specified type for a given container /
300 // name.
301 ResourceHandle MakeResourceHandle(
302 const std::string& container, const std::string& name,
303 const DeviceBase& device, const TypeIndex& type_index,
304 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
305 const absl::optional<ManagedStackTrace>& definition_stack_trace = {})
306 TF_MUST_USE_RESULT;
307
308 template <typename T>
309 ResourceHandle MakeResourceHandle(
310 OpKernelContext* ctx, const std::string& container, const std::string& name,
311 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
312 const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
313 return MakeResourceHandle(container.empty()
314 ? ctx->resource_manager()->default_container()
315 : container,
316 name, *ctx->device(), TypeIndex::Make<T>(),
317 dtypes_and_shapes, definition_stack_trace);
318 }
319
320 template <typename T>
321 ResourceHandle MakeResourceHandle(
322 OpKernelConstruction* ctx, const std::string& container,
323 const std::string& name,
324 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
325 const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
326 return MakeResourceHandle(container.empty()
327 ? ctx->resource_manager()->default_container()
328 : container,
329 name, *ctx->device(), TypeIndex::Make<T>(),
330 dtypes_and_shapes, definition_stack_trace);
331 }
332
333 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
334 const std::string& container,
335 const std::string& name,
336 const TypeIndex& type_index);
337
338 // Returns a resource handle from a numbered op input.
339 const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
340 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
341 ResourceHandle* handle);
342
343 // Create a resource pointed by a given resource handle.
344 //
345 // If successful, the caller transfers the ownership of one ref on `resource` to
346 // `ctx->resource_mgr()`.
347 template <typename T>
348 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
349
350 // Looks up a resource pointed by a given resource handle.
351 //
352 // If the lookup is successful, the caller takes the ownership of one ref on
353 // `*value`, and must call its `Unref()` method when it has finished using it.
354 template <typename T, bool use_dynamic_cast = false>
355 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
356
357 // Looks up a resource pointed by a given resource handle.
358 //
359 // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
360 // requiring the caller to explicitly call `Unref()`.
361 template <typename T>
362 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
363 core::RefCountPtr<T>* value);
364
365 // Looks up multiple resources pointed by a sequence of resource handles. If
366 // p[i] is uninitialized then values[i] is unmodified.
367 template <typename T>
368 Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
369 std::vector<core::RefCountPtr<T>>* values);
370
371 // Looks up or creates a resource.
372 //
373 // If successful, the caller takes the ownership of one ref on `*value`, and
374 // must call its `Unref()` method when it has finished using it. If the
375 // `creator` is invoked, its reference on the created resource is transferred
376 // to `ctx->resource_mgr()`.
377 //
378 // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
379 // requiring the caller to explicitly call `Unref()`.
380 template <typename T>
381 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
382 T** value, std::function<Status(T**)> creator);
383
384 // Looks up or creates a resource.
385 template <typename T>
386 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
387 core::RefCountPtr<T>* value,
388 std::function<Status(T**)> creator);
389
390 // Destroys a resource pointed by a given resource handle.
391 template <typename T>
392 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
393
394 // Same as above, but uses the hash code of the type directly.
395 // The type name information will be missing in the debug output when the
396 // resource is not present in the container.
397 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
398
399 // Policy helper to decide which container/shared_name to use for a
400 // stateful kernel that accesses shared resource.
401 class ContainerInfo {
402 public:
403 // Analyze the node attribute of 'ndef' and decides the container and
404 // resource name the kernel should use for accessing the shared
405 // resource.
406 //
407 // 'ndef' is expected to have node attribute "container" and
408 // "shared_name". Returns non-OK if they are not provided or they are
409 // invalid.
410 //
411 // The policy is as following:
412 // * If the attribute "container" is non-empty, it is used as is.
413 // Otherwise, uses the resource manager's default container.
414 // * If the attribute "shared_name" is non-empty, it is used as is.
415 // Otherwise, if "use_node_name_as_default" is true, the kernel's
416 // node name is used as the resource name. Otherwise, a string
417 // unique to this process is used.
418 Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
419 bool use_node_name_as_default);
Init(ResourceMgr * rmgr,const NodeDef & ndef)420 Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
421 return Init(rmgr, ndef, false);
422 }
423
424 // The policy decides that the kernel should access the resource in
425 // resource_manager(), the resource is in the container() and its
426 // name is name(). If resource_is_private_to_kernel() is true, the
427 // kernel should delete the resource when the kernel is deleted.
resource_manager()428 ResourceMgr* resource_manager() const { return rmgr_; }
container()429 const std::string& container() const { return container_; }
name()430 const std::string& name() const { return name_; }
resource_is_private_to_kernel()431 bool resource_is_private_to_kernel() const {
432 return resource_is_private_to_kernel_;
433 }
434
435 // Returns a readable string for *this.
436 std::string DebugString() const;
437
438 private:
439 ResourceMgr* rmgr_ = nullptr;
440 std::string container_;
441 std::string name_;
442 bool resource_is_private_to_kernel_ = false;
443 };
444
445 // Helper for kernels to obtain 'resource' from the
446 // ctx->resource_manager().
447 //
448 // "input_name" specifies the kernel's ref input which gives a string
449 // tensor with two elements, which specifies the container and
450 // resource name.
451 //
452 // Returns OK if the resource is found and transfers one ref of
453 // *resource to the caller. Otherwise, returns an error.
454 template <typename T>
455 Status GetResourceFromContext(OpKernelContext* ctx,
456 const std::string& input_name, T** resource);
457
458 // Utility op kernel to check if a handle to resource type T is initialized.
459 template <typename T>
460 class IsResourceInitialized : public OpKernel {
461 public:
IsResourceInitialized(OpKernelConstruction * c)462 explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
463
464 void Compute(OpKernelContext* ctx) override;
465 };
466
467 // Registers an op which produces just a resource handle to a resource of the
468 // specified type. The type will be a part of the generated op name.
469 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
470 // through constant folding so this doesn't have to be marked as stateful.
471 #define REGISTER_RESOURCE_HANDLE_OP(Type) \
472 REGISTER_OP(#Type "HandleOp") \
473 .Attr("container: string = ''") \
474 .Attr("shared_name: string = ''") \
475 .Output("resource: resource") \
476 .SetIsStateful() \
477 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
478
479 // Utility op kernel to produce a handle to a resource of type T.
480 template <typename T>
481 class ResourceHandleOp : public OpKernel {
482 public:
483 explicit ResourceHandleOp(OpKernelConstruction* context);
484
485 void Compute(OpKernelContext* ctx) override;
486
IsExpensive()487 bool IsExpensive() override { return false; }
488
489 private:
490 std::string container_;
491 std::string name_;
492 mutex mutex_;
493 Tensor resource_;
494 std::atomic<bool> initialized_{false};
495 };
496
497 // Utility op kernel to produce a handle to a resource of type T.
498 template <typename T>
499 class ResourceHandlesOp : public OpKernel {
500 public:
501 explicit ResourceHandlesOp(OpKernelConstruction* context);
502
503 void Compute(OpKernelContext* ctx) override;
504
IsExpensive()505 bool IsExpensive() override { return false; }
506
507 private:
508 std::vector<string> containers_;
509 std::vector<string> names_;
510 mutex mutex_;
511 std::vector<Tensor> resources_;
512 std::atomic<bool> initialized_{false};
513 };
514
515 // Registers a kernel for an op which produces a handle to a resource of the
516 // specified type.
517 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \
518 REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
519 ResourceHandleOp<Type>)
520
521 // This class is used to guarantee that an anonymous resource is deleted
522 // (irrespective of whether a resource deleter op is called explicitly or
523 // the execution encounters an error before the op runs).
524 //
525 // This is achieved by wrapping an instance of this class into a variant
526 // tensor which is passed as an input to a resource deleter op. If the
527 // execution encounters an error before the op runs, the tensor will be
528 // destroyed, essentially triggering the iterator deletion.
529 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
530 // specification. In particular, we cannot serialize the `ResourceMgr`
531 // object, so the `Encode()` and `Decode()` methods are not implemented.
532 class ResourceDeleter {
533 public:
ResourceDeleter()534 ResourceDeleter() : deleter_() {}
535
ResourceDeleter(ResourceHandle handle,ResourceMgr * resource_manager)536 ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager)
537 : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
538
ResourceDeleter(ResourceDeleter && rhs)539 ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) {
540 VLOG(3) << "ResourceDeleter move constructor called.";
541 }
542
ResourceDeleter(const ResourceDeleter & rhs)543 ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) {
544 VLOG(3) << "ResourceDeleter copy constructor called.";
545 }
546
547 ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete;
548
549 ResourceDeleter& operator=(ResourceDeleter&& rhs) = default;
550
~ResourceDeleter()551 virtual ~ResourceDeleter() {
552 VLOG(3) << "ResourceDeleter destructor called.";
553 }
554
Encode(VariantTensorData *)555 void Encode(VariantTensorData*) const {
556 LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter "
557 "objects.";
558 }
559
Decode(const VariantTensorData &)560 bool Decode(const VariantTensorData&) {
561 LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter "
562 "objects";
563 return false; // Not supported.
564 }
565
566 private:
567 // Helper that performs reference counting for the parent class and deletes
568 // the iterator resource when the refcount goes to zero.
569 //
570 // NOTE: The object is borrowing a pointer to the resource manager.
571 // Consequently, the tensor containing this object should not escape the
572 // function in which was created (so that it is guaranteed that the resource
573 // manager will outlive it).
574 struct Helper {
HelperHelper575 Helper(ResourceHandle handle, ResourceMgr* resource_manager)
576 : handle(handle), resource_manager(resource_manager) {}
577
578 Helper(const Helper& rhs) = delete;
579 Helper(Helper&& rhs) = delete;
580
~HelperHelper581 ~Helper() {
582 VLOG(3) << "Deleting Resource: " << handle.DebugString();
583 resource_manager->Delete(handle).IgnoreError();
584 }
585
586 ResourceHandle handle;
587 ResourceMgr* resource_manager; // not owned
588 };
589
590 std::shared_ptr<Helper> deleter_;
591 };
592
593 // Implementation details below.
594
595 template <typename T>
CheckDeriveFromResourceBase()596 void CheckDeriveFromResourceBase() {
597 static_assert(std::is_base_of<ResourceBase, T>::value,
598 "T must derive from ResourceBase");
599 }
600
601 template <typename T>
Create(const std::string & container,const std::string & name,T * resource)602 Status ResourceMgr::Create(const std::string& container,
603 const std::string& name, T* resource) {
604 CheckDeriveFromResourceBase<T>();
605 CHECK(resource != nullptr);
606 mutex_lock l(mu_);
607 return DoCreate(container, TypeIndex::Make<T>(), name, resource);
608 }
609
610 template <typename T, bool use_dynamic_cast>
Lookup(const std::string & container,const std::string & name,T ** resource)611 Status ResourceMgr::Lookup(const std::string& container,
612 const std::string& name, T** resource) const {
613 CheckDeriveFromResourceBase<T>();
614 tf_shared_lock l(mu_);
615 return LookupInternal<T, use_dynamic_cast>(container, name, resource);
616 }
617
618 template <typename T, bool use_dynamic_cast>
LookupMany(absl::Span<std::pair<const string *,const string * > const> containers_and_names,std::vector<std::unique_ptr<T,core::RefCountDeleter>> * resources)619 Status ResourceMgr::LookupMany(
620 absl::Span<std::pair<const string*, const string*> const>
621 containers_and_names,
622 std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
623 CheckDeriveFromResourceBase<T>();
624 tf_shared_lock l(mu_);
625 resources->resize(containers_and_names.size());
626 for (size_t i = 0; i < containers_and_names.size(); ++i) {
627 T* resource;
628 Status s = LookupInternal<T, use_dynamic_cast>(
629 *containers_and_names[i].first, *containers_and_names[i].second,
630 &resource);
631 if (s.ok()) {
632 (*resources)[i].reset(resource);
633 }
634 }
635 return Status::OK();
636 }
637
638 // Simple wrapper to allow conditional dynamic / static casts.
639 template <typename T, bool use_dynamic_cast>
640 struct TypeCastFunctor {
CastTypeCastFunctor641 static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
642 };
643
644 template <typename T>
645 struct TypeCastFunctor<T, true> {
646 static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
647 };
648
649 template <typename T, bool use_dynamic_cast>
650 Status ResourceMgr::LookupInternal(const std::string& container,
651 const std::string& name,
652 T** resource) const {
653 ResourceBase* found = nullptr;
654 Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
655 if (s.ok()) {
656 // It's safe to down cast 'found' to T* since
657 // typeid(T).hash_code() is part of the map key.
658 *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
659 }
660 return s;
661 }
662
663 template <typename T, bool use_dynamic_cast>
664 Status ResourceMgr::LookupOrCreate(const std::string& container,
665 const std::string& name, T** resource,
666 std::function<Status(T**)> creator) {
667 CheckDeriveFromResourceBase<T>();
668 *resource = nullptr;
669 Status s;
670 {
671 tf_shared_lock l(mu_);
672 s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
673 if (s.ok()) return s;
674 }
675 mutex_lock l(mu_);
676 s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
677 if (s.ok()) return s;
678 TF_RETURN_IF_ERROR(creator(resource));
679 s = DoCreate(container, TypeIndex::Make<T>(), name, *resource);
680 if (!s.ok()) {
681 return errors::Internal("LookupOrCreate failed unexpectedly");
682 }
683 (*resource)->Ref();
684 return s;
685 }
686
687 template <typename T>
688 Status ResourceMgr::Delete(const std::string& container,
689 const std::string& name) {
690 CheckDeriveFromResourceBase<T>();
691 return DoDelete(container, TypeIndex::Make<T>(), name);
692 }
693
694 template <typename T>
695 Status GetResourceFromContext(OpKernelContext* ctx,
696 const std::string& input_name, T** resource) {
697 DataType dtype;
698 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
699 if (dtype == DT_RESOURCE) {
700 const Tensor* handle;
701 TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
702 return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
703 }
704 std::string container;
705 std::string shared_name;
706 {
707 mutex* mu;
708 TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
709 mutex_lock l(*mu);
710 Tensor tensor;
711 TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
712 if (tensor.NumElements() != 2) {
713 return errors::InvalidArgument(
714 "Resource handle must have 2 elements, but had shape: ",
715 tensor.shape().DebugString());
716 }
717 container = tensor.flat<tstring>()(0);
718 shared_name = tensor.flat<tstring>()(1);
719 }
720 return ctx->resource_manager()->Lookup(container, shared_name, resource);
721 }
722
723 namespace internal {
724
725 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
726
727 template <typename T>
728 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
729 TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
730 TF_RETURN_IF_ERROR(p.ValidateType<T>());
731 return Status::OK();
732 }
733
734 } // namespace internal
735
736 // Creates the resource pointed at by "p". The caller transfers the ownership of
737 // one ref on "*value" to the resource manager in "ctx", regardless of whether
738 // this operation succeeds or fails.
739 template <typename T>
740 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
741 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
742 return ctx->resource_manager()->Create(p.container(), p.name(), value);
743 }
744
745 // If the resource manager in "ctx" has a resource matching "p", returns it in
746 // "*value" and the caller takes the ownership of one ref on "*value"
747 template <typename T, bool use_dynamic_cast>
748 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
749 T** value) {
750 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
751 return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
752 p.name(), value);
753 }
754
755 // If the resource manager in "ctx" has a resource matching "p", returns it in
756 // "*value" and the caller takes the ownership of one ref on "*value"
757 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
758 ResourceBase** value);
759
760 // If the resource manager in "ctx" has a resource matching "p", returns it in
761 // "*value".
762 template <typename T>
763 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
764 core::RefCountPtr<T>* value) {
765 T* raw_ptr = nullptr;
766 TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
767 value->reset(raw_ptr);
768
769 return Status::OK();
770 }
771
772 // Similar to Lookup, but looks up multiple resources at once, with only a
773 // single lock acquisition.
774 template <typename T>
775 Status LookupResources(OpKernelContext* ctx,
776 absl::Span<ResourceHandle const* const> p,
777 std::vector<core::RefCountPtr<T>>* values) {
778 std::vector<std::pair<const string*, const string*>> containers_and_names(
779 p.size());
780 for (size_t i = 0; i < p.size(); ++i) {
781 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
782 containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
783 }
784 return ctx->resource_manager()->LookupMany(containers_and_names, values);
785 }
786
787 // If the resource manager in "ctx" has a resource pointed at by "p", returns
788 // it in "*value". Otherwise, invokes creator() to create the resource.
789 // The caller takes the ownership of one ref on "*value".
790 //
791 // WARNING: creator() must not call any methods on the resource manager during
792 // its execution, because a non-reentrant lock is held during the creator() call
793 // in order to guarantee atomicity of LookupOrCreateResource().
794 template <typename T>
795 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
796 T** value, std::function<Status(T**)> creator) {
797 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
798 return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
799 creator);
800 }
801
802 // If the resource manager in "ctx" has a resource pointed at by "p", returns
803 // it in "*value". Otherwise, invokes creator() to create the resource.
804 //
805 // WARNING: creator() must not call any methods on the resource manager during
806 // its execution, because a non-reentrant lock is held during the creator() call
807 // in order to guarantee atomicity of LookupOrCreateResource().
808 template <typename T>
809 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
810 core::RefCountPtr<T>* value,
811 std::function<Status(T**)> creator) {
812 T* raw_ptr = nullptr;
813 TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
814 value->reset(raw_ptr);
815
816 return Status::OK();
817 }
818
819 // Deletes the resource pointed by "p", using the resource manager in "ctx".
820 template <typename T>
821 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
822 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
823 return ctx->resource_manager()->Delete<T>(p.container(), p.name());
824 }
825
826 // Deletes the resource pointed by "p", using the resource manager in "ctx".
827 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
828
829 template <typename T>
830 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
831 Tensor* output;
832 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
833 T* object;
834 bool found;
835 if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
836 found = true;
837 object->Unref();
838 } else {
839 found = false;
840 }
841
842 output->flat<bool>()(0) = found;
843 }
844
845 template <typename T>
846 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
847 : OpKernel(context) {
848 OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
849 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
850 }
851
852 template <typename T>
853 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
854 if (name_ == ResourceHandle::ANONYMOUS_NAME) {
855 AllocatorAttributes attr;
856 attr.set_on_host(true);
857 Tensor handle;
858 OP_REQUIRES_OK(
859 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
860 handle.scalar<ResourceHandle>()() =
861 MakeResourceHandle<T>(ctx, container_, name_);
862 ctx->set_output(0, handle);
863 } else {
864 if (!initialized_.load()) {
865 mutex_lock ml(mutex_);
866 // Checking again to see if another thread has initialized the resource.
867 if (!initialized_.load()) {
868 AllocatorAttributes attr;
869 attr.set_on_host(true);
870 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
871 &resource_, attr));
872 resource_.scalar<ResourceHandle>()() =
873 MakeResourceHandle<T>(ctx, container_, name_);
874 initialized_.store(true);
875 }
876 }
877 ctx->set_output(0, resource_);
878 }
879 }
880
881 template <typename T>
882 ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
883 : OpKernel(context) {
884 int n;
885 OP_REQUIRES_OK(context, context->GetAttr("N", &n));
886 OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
887 OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
888 OP_REQUIRES(
889 context, containers_.size() == n,
890 errors::InvalidArgument("Number of containers (", containers_.size(),
891 ") must be equal to N (", n, ")"));
892 OP_REQUIRES(context, names_.size() == n,
893 errors::InvalidArgument("Number of names (", containers_.size(),
894 ") must be equal to N (", n, ")"));
895 resources_.resize(n);
896 }
897
898 template <typename T>
899 void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
900 if (!initialized_.load()) {
901 mutex_lock ml(mutex_);
902 // Checking again to see if another thread has initialized the resource.
903 if (!initialized_.load()) {
904 AllocatorAttributes attr;
905 attr.set_on_host(true);
906 for (size_t i = 0; i < resources_.size(); ++i) {
907 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
908 &resources_[i], attr));
909 ResourceHandle h =
910 MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
911 resources_[i].template scalar<ResourceHandle>()() = h;
912 }
913 initialized_.store(true);
914 }
915 }
916 for (size_t i = 0; i < resources_.size(); ++i) {
917 ctx->set_output(i, resources_[i]);
918 }
919 }
920
921 template <typename T>
922 ResourceHandle ScopedStepContainer::MakeResourceHandle(
923 const std::string& name, const DeviceBase& device) {
924 mutex_lock ml(mu_);
925 dirty_ = true;
926 return tensorflow::MakeResourceHandle(container_, name, device,
927 TypeIndex::Make<T>(), {});
928 }
929
930 template <typename T>
931 Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name,
932 T** resource) const {
933 return rm->Lookup<T>(container_, name, resource);
934 }
935
936 template <typename T>
937 Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm,
938 const std::string& name,
939 T** resource,
940 std::function<Status(T**)> creator) {
941 mutex_lock ml(mu_);
942 dirty_ = true;
943 return rm->LookupOrCreate<T>(container_, name, resource, creator);
944 }
945
946 template <typename T>
947 Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name,
948 T* resource) {
949 mutex_lock ml(mu_);
950 dirty_ = true;
951 return rm->Create<T>(container_, name, resource);
952 }
953
954 template <typename T>
955 Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) {
956 return rm->Delete<T>(container_, name);
957 }
958
959 } // end namespace tensorflow
960
961 #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
962