1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
18
19 #include <memory>
20 #include <string>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/resource_handle.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/type_index.h"
32 #include "tensorflow/core/framework/variant_tensor_data.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/refcount.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/thread_annotations.h"
41
42 namespace tensorflow {
43
44 // A ResourceMgr instance keeps track of named and typed resources
45 // grouped into containers.
46 //
47 // Each resource must be represented as a sub-class of ResourceBase,
48 // which is reference counted explicitly. Each named resource is
49 // registered with ResourceMgr under a named "container" name. At any
50 // time, there is at most one instance of a resource given the container
51 // name, the resource type and the resource name.
52 //
53 // All resources for a given container can be dropped by one call of
54 // Cleanup().
55 //
56 // E.g.,
57 // struct MyVar : public ResourceBase {
58 // mutex mu;
59 // Tensor val;
60 // }
61 //
62 // ResourceMgr rm;
63 //
64 // // Create a var.
65 // MyVar* my_var = new MyVar;
66 // my_var->val = Tensor(DT_FLOAT, my_shape);
67 // my_var->val.flat<float>().setZeros(); // 0 initialized.
68 // ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
69 //
70 // // += a variable.
71 // MyVar* my_var = nullptr;
72 // Status s = rm.Lookup("my_container", "my_name", &my_var);
73 // if (s.ok()) {
74 // my_var->val.flat<float>() += grad;
75 // }
76 // my_var->Unref(); // Or use ScopedUnref().
77 // ctx->SetStatus(s);
78 class ResourceBase : public core::RefCounted {
79 public:
80 // Returns a debug string for *this.
81 virtual string DebugString() const = 0;
82
83 // Returns memory used by this resource.
MemoryUsed()84 virtual int64 MemoryUsed() const { return 0; }
85 };
86
87 // Container used for per-step resources.
88 class ScopedStepContainer {
89 public:
90 // step_id: the unique ID of this step. Doesn't have to be sequential, just
91 // has to be unique.
92 // cleanup: callback to delete a container of this name.
93 // prefix: optional string prefix to disambiguate step containers.
ScopedStepContainer(const int64 step_id,std::function<void (const string &)> cleanup)94 ScopedStepContainer(const int64 step_id,
95 std::function<void(const string&)> cleanup)
96 : container_(strings::StrCat("__per_step_", step_id)),
97 step_id_(step_id),
98 cleanup_(cleanup),
99 dirty_(false) {}
100
ScopedStepContainer(const int64 step_id,std::function<void (const string &)> cleanup,const string & prefix)101 ScopedStepContainer(const int64 step_id,
102 std::function<void(const string&)> cleanup,
103 const string& prefix)
104 : container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
105 step_id_(step_id),
106 cleanup_(cleanup),
107 dirty_(false) {}
108
~ScopedStepContainer()109 ~ScopedStepContainer() { CleanUp(); }
110
CleanUp()111 void CleanUp() NO_THREAD_SAFETY_ANALYSIS {
112 // NOTE(mrry): Avoid acquiring the mutex in the case that the container is
113 // clean.
114 if (dirty_) {
115 mutex_lock ml(mu_);
116 cleanup_(container_);
117 dirty_ = false;
118 }
119 }
120
121 // Pass through functions for resource lookup and creation. We do this to
122 // ensure that we can appropriately set the dirty_ bit in the
123 // ScopedStepContainer if the name of the container is used to create
124 // resources.
125
126 // Pass through to MakeResourceHandle with the container name
127 template <typename T>
128 ResourceHandle MakeResourceHandle(
129 const string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
130 // Pass through to ResourceMgr::Create with the container name
131 template <typename T>
132 Status Create(ResourceMgr* rm, const string& name,
133 T* resource) TF_MUST_USE_RESULT;
134 // Pass through to ResourceMgr::Delete with the container name
135 template <typename T>
136 Status Delete(ResourceMgr* rm, const string& name) TF_MUST_USE_RESULT;
137 // Pass through to ResourceMgr::Lookup with the container name
138 template <typename T>
139 Status Lookup(ResourceMgr* rm, const string& name,
140 T** resource) const TF_MUST_USE_RESULT;
141 // Pass through to ResourceMgr::LookupOrCreate with the container name
142 template <typename T>
143 Status LookupOrCreate(ResourceMgr* rm, const string& name, T** resource,
144 std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
145
step_id()146 const int64 step_id() const { return step_id_; }
147
148 private:
149 const string container_;
150 const int64 step_id_;
151 const std::function<void(const string&)> cleanup_;
152 mutex mu_;
153 mutable std::atomic<bool> dirty_ GUARDED_BY(mu_);
154 };
155
156 class ResourceMgr {
157 public:
158 ResourceMgr();
159 explicit ResourceMgr(const string& default_container);
160 ~ResourceMgr();
161
162 // Returns the default container name for *this.
default_container()163 const string& default_container() const { return default_container_; }
164
165 // Creates a resource "name" in the "container". The caller transfers
166 // the ownership of one ref on "resource" to *this, regardless of whether this
167 // operation succeeds or fails.
168 //
169 // REQUIRES: std::is_base_of<ResourceBase, T>
170 // REQUIRES: resource != nullptr.
171 template <typename T>
172 Status Create(const string& container, const string& name,
173 T* resource) TF_MUST_USE_RESULT;
174
175 // If "container" has a resource "name", returns it in "*resource" and
176 // the caller takes the ownership of one ref on "*resource".
177 //
178 // REQUIRES: std::is_base_of<ResourceBase, T>
179 // REQUIRES: resource != nullptr
180 template <typename T, bool use_dynamic_cast = false>
181 Status Lookup(const string& container, const string& name,
182 T** resource) const TF_MUST_USE_RESULT;
183
184 // Similar to Lookup, but looks up multiple resources at once, with only a
185 // single lock acquisition. If containers_and_names[i] is uninitialized
186 // then this function does not modify resources[i].
187 template <typename T, bool use_dynamic_cast = false>
188 Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
189 containers_and_names,
190 std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
191 resources) const TF_MUST_USE_RESULT;
192
193 // If "container" has a resource "name", returns it in
194 // "*resource". Otherwise, invokes creator() to create the resource.
195 // The caller takes the ownership of one ref on "*resource".
196 //
197 // WARNING: creator() must not call any methods on ResourceMgr during its
198 // execution, because a non-reentrant lock is held during the creator() call
199 // in order to guarantee atomicity of LookupOrCreate().
200 //
201 // REQUIRES: std::is_base_of<ResourceBase, T>
202 // REQUIRES: resource != nullptr
203 template <typename T, bool use_dynamic_cast = false>
204 Status LookupOrCreate(const string& container, const string& name,
205 T** resource,
206 std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
207
208 // Deletes the resource "name" from the "container".
209 //
210 // REQUIRES: std::is_base_of<ResourceBase, T>
211 template <typename T>
212 Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT;
213
214 // Deletes the resource pointed by "handle".
215 Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
216
217 // Deletes all resources from the "container" and removes the container.
218 Status Cleanup(const string& container) TF_MUST_USE_RESULT;
219
220 // Deletes all resources in all containers.
221 void Clear();
222
223 // Returns a text description for all resources.
224 string DebugString() const;
225
226 private:
227 typedef std::pair<uint64, StringPiece> Key;
228 struct KeyHash {
operatorKeyHash229 std::size_t operator()(const Key& k) const {
230 return Hash64(k.second.data(), k.second.size(), k.first);
231 }
232 };
233 struct KeyEqual {
operatorKeyEqual234 bool operator()(const Key& x, const Key& y) const {
235 return (x.second == y.second) && (x.first == y.first);
236 }
237 };
238 struct ResourceAndName {
239 core::RefCountPtr<ResourceBase> resource;
240 std::unique_ptr<string> name;
241
242 ResourceAndName();
243 ResourceAndName(ResourceBase* resource, string name);
244 ResourceAndName(ResourceAndName&& other) noexcept;
245 ~ResourceAndName();
246
247 ResourceAndName& operator=(ResourceAndName&&) noexcept;
248
249 private:
250 TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
251 };
252 typedef std::unordered_map<Key, ResourceAndName, KeyHash, KeyEqual> Container;
253
254 const string default_container_;
255 mutable mutex mu_;
256 std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
257
258 template <typename T, bool use_dynamic_cast = false>
259 Status LookupInternal(const string& container, const string& name,
260 T** resource) const
261 SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
262
263 Status DoCreate(const string& container, TypeIndex type, const string& name,
264 ResourceBase* resource)
265 EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
266
267 Status DoLookup(const string& container, TypeIndex type, const string& name,
268 ResourceBase** resource) const
269 SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
270
271 Status DoDelete(const string& container, uint64 type_hash_code,
272 const string& resource_name,
273 const string& type_name) TF_MUST_USE_RESULT;
274 Status DoDelete(const string& container, TypeIndex type,
275 const string& resource_name) TF_MUST_USE_RESULT;
276
277 // Inserts the type name for 'hash_code' into the hash_code to type name map.
278 Status InsertDebugTypeName(uint64 hash_code, const string& type_name)
279 EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
280
281 // Returns the type name for the 'hash_code'.
282 // Returns "<unknown>" if a resource with such a type was never inserted into
283 // the container.
284 const char* DebugTypeName(uint64 hash_code) const
285 EXCLUSIVE_LOCKS_REQUIRED(mu_);
286
287 // Map from type hash_code to type name.
288 std::unordered_map<uint64, string> debug_type_names_ GUARDED_BY(mu_);
289
290 TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
291 };
292
293 // Makes a resource handle with the specified type for a given container /
294 // name.
295 ResourceHandle MakeResourceHandle(
296 const string& container, const string& name, const DeviceBase& device,
297 const TypeIndex& type_index,
298 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {})
299 TF_MUST_USE_RESULT;
300
301 template <typename T>
302 ResourceHandle MakeResourceHandle(
303 OpKernelContext* ctx, const string& container, const string& name,
304 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
305 return MakeResourceHandle(
306 container.empty() ? ctx->resource_manager()->default_container()
307 : container,
308 name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
309 }
310
311 template <typename T>
312 ResourceHandle MakeResourceHandle(
313 OpKernelConstruction* ctx, const string& container, const string& name,
314 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
315 return MakeResourceHandle(
316 container.empty() ? ctx->resource_manager()->default_container()
317 : container,
318 name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
319 }
320
321 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
322 const string& container, const string& name,
323 const TypeIndex& type_index);
324
325 // Returns a resource handle from a numbered op input.
326 const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
327 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
328 ResourceHandle* handle);
329
330 // Create a resource pointed by a given resource handle.
331 //
332 // If successful, the caller transfers the ownership of one ref on `resource` to
333 // `ctx->resource_mgr()`.
334 template <typename T>
335 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
336
337 // Looks up a resource pointed by a given resource handle.
338 //
339 // If the lookup is successful, the caller takes the ownership of one ref on
340 // `*value`, and must call its `Unref()` method when it has finished using it.
341 template <typename T, bool use_dynamic_cast = false>
342 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
343
344 // Looks up a resource pointed by a given resource handle.
345 //
346 // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
347 // requiring the caller to explicitly call `Unref()`.
348 template <typename T>
349 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
350 core::RefCountPtr<T>* value);
351
352 // Looks up multiple resources pointed by a sequence of resource handles. If
353 // p[i] is uninitialized then values[i] is unmodified.
354 template <typename T>
355 Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
356 std::vector<core::RefCountPtr<T>>* values);
357
358 // Looks up or creates a resource.
359 //
360 // If successful, the caller takes the ownership of one ref on `*value`, and
361 // must call its `Unref()` method when it has finished using it. If the
362 // `creator` is invoked, its reference on the created resource is transferred
363 // to `ctx->resource_mgr()`.
364 //
365 // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
366 // requiring the caller to explicitly call `Unref()`.
367 template <typename T>
368 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
369 T** value, std::function<Status(T**)> creator);
370
371 // Looks up or creates a resource.
372 template <typename T>
373 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
374 core::RefCountPtr<T>* value,
375 std::function<Status(T**)> creator);
376
377 // Destroys a resource pointed by a given resource handle.
378 template <typename T>
379 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
380
381 // Same as above, but uses the hash code of the type directly.
382 // The type name information will be missing in the debug output when the
383 // resource is not present in the container.
384 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
385
386 // Policy helper to decide which container/shared_name to use for a
387 // stateful kernel that accesses shared resource.
388 class ContainerInfo {
389 public:
390 // Analyze the node attribute of 'ndef' and decides the container and
391 // resource name the kernel should use for accessing the shared
392 // resource.
393 //
394 // 'ndef' is expected to have node attribute "container" and
395 // "shared_name". Returns non-OK if they are not provided or they are
396 // invalid.
397 //
398 // The policy is as following:
399 // * If the attribute "container" is non-empty, it is used as is.
400 // Otherwise, uses the resource manager's default container.
401 // * If the attribute "shared_name" is non-empty, it is used as is.
402 // Otherwise, if "use_node_name_as_default" is true, the kernel's
403 // node name is used as the resource name. Otherwise, a string
404 // unique to this process is used.
405 Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
406 bool use_node_name_as_default);
Init(ResourceMgr * rmgr,const NodeDef & ndef)407 Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
408 return Init(rmgr, ndef, false);
409 }
410
411 // The policy decides that the kernel should access the resource in
412 // resource_manager(), the resource is in the container() and its
413 // name is name(). If resource_is_private_to_kernel() is true, the
414 // kernel should delete the resource when the kernel is deleted.
resource_manager()415 ResourceMgr* resource_manager() const { return rmgr_; }
container()416 const string& container() const { return container_; }
name()417 const string& name() const { return name_; }
resource_is_private_to_kernel()418 bool resource_is_private_to_kernel() const {
419 return resource_is_private_to_kernel_;
420 }
421
422 // Returns a readable string for *this.
423 string DebugString() const;
424
425 private:
426 ResourceMgr* rmgr_ = nullptr;
427 string container_;
428 string name_;
429 bool resource_is_private_to_kernel_ = false;
430 };
431
432 // Helper for kernels to obtain 'resource' from the
433 // ctx->resource_manager().
434 //
435 // "input_name" specifies the kernel's ref input which gives a string
436 // tensor with two elements, which specifies the container and
437 // resource name.
438 //
439 // Returns OK if the resource is found and transfers one ref of
440 // *resource to the caller. Otherwise, returns an error.
441 template <typename T>
442 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
443 T** resource);
444
445 // Utility op kernel to check if a handle to resource type T is initialized.
446 template <typename T>
447 class IsResourceInitialized : public OpKernel {
448 public:
IsResourceInitialized(OpKernelConstruction * c)449 explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
450
451 void Compute(OpKernelContext* ctx) override;
452 };
453
454 // Registers an op which produces just a resource handle to a resource of the
455 // specified type. The type will be a part of the generated op name.
456 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
457 // through constant folding so this doesn't have to be marked as stateful.
458 #define REGISTER_RESOURCE_HANDLE_OP(Type) \
459 REGISTER_OP(#Type "HandleOp") \
460 .Attr("container: string = ''") \
461 .Attr("shared_name: string = ''") \
462 .Output("resource: resource") \
463 .SetIsStateful() \
464 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
465
466 // Utility op kernel to produce a handle to a resource of type T.
467 template <typename T>
468 class ResourceHandleOp : public OpKernel {
469 public:
470 explicit ResourceHandleOp(OpKernelConstruction* context);
471
472 void Compute(OpKernelContext* ctx) override;
473
IsExpensive()474 bool IsExpensive() override { return false; }
475
476 private:
477 string container_;
478 string name_;
479 mutex mutex_;
480 Tensor resource_;
481 std::atomic<bool> initialized_{false};
482 };
483
484 // Utility op kernel to produce a handle to a resource of type T.
485 template <typename T>
486 class ResourceHandlesOp : public OpKernel {
487 public:
488 explicit ResourceHandlesOp(OpKernelConstruction* context);
489
490 void Compute(OpKernelContext* ctx) override;
491
IsExpensive()492 bool IsExpensive() override { return false; }
493
494 private:
495 std::vector<string> containers_;
496 std::vector<string> names_;
497 mutex mutex_;
498 std::vector<Tensor> resources_;
499 std::atomic<bool> initialized_{false};
500 };
501
502 Status ResourceHandlesShape(shape_inference::InferenceContext* c);
503
504 // Registers a kernel for an op which produces a handle to a resource of the
505 // specified type.
506 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \
507 REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
508 ResourceHandleOp<Type>)
509
510 // This class is used to guarantee that an anonymous resource is deleted
511 // (irrespective of whether a resource deleter op is called explicitly or
512 // the execution encounters an error before the op runs).
513 //
514 // This is achieved by wrapping an instance of this class into a variant
515 // tensor which is passed as an input to a resource deleter op. If the
516 // execution encounters an error before the op runs, the tensor will be
517 // destroyed, essentially triggering the iterator deletion.
518 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
519 // specification. In particular, we cannot serialize the `ResourceMgr`
520 // object, so the `Encode()` and `Decode()` methods are not implemented.
521 class ResourceDeleter {
522 public:
ResourceDeleter()523 ResourceDeleter() : deleter_() {}
524
ResourceDeleter(ResourceHandle handle,ResourceMgr * resource_manager)525 ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager)
526 : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
527
ResourceDeleter(ResourceDeleter && rhs)528 ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) {
529 VLOG(3) << "ResourceDeleter move constructor called.";
530 }
531
ResourceDeleter(const ResourceDeleter & rhs)532 ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) {
533 VLOG(3) << "ResourceDeleter copy constructor called.";
534 }
535
536 ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete;
537
538 ResourceDeleter& operator=(ResourceDeleter&& rhs) = default;
539
~ResourceDeleter()540 virtual ~ResourceDeleter() {
541 VLOG(3) << "ResourceDeleter destructor called.";
542 }
543
Encode(VariantTensorData *)544 void Encode(VariantTensorData*) const {
545 LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter "
546 "objects.";
547 }
548
Decode(const VariantTensorData &)549 bool Decode(const VariantTensorData&) {
550 LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter "
551 "objects";
552 return false; // Not supported.
553 }
554
555 private:
556 // Helper that performs reference counting for the parent class and deletes
557 // the iterator resource when the refcount goes to zero.
558 //
559 // NOTE: The object is borrowing a pointer to the resource manager.
560 // Consequently, the tensor containing this object should not escape the
561 // function in which was created (so that it is guaranteed that the resource
562 // manager will outlive it).
563 struct Helper {
HelperHelper564 Helper(ResourceHandle handle, ResourceMgr* resource_manager)
565 : handle(handle), resource_manager(resource_manager) {}
566
567 Helper(const Helper& rhs) = delete;
568 Helper(Helper&& rhs) = delete;
569
~HelperHelper570 ~Helper() {
571 VLOG(3) << "Deleting Resource: " << handle.DebugString();
572 resource_manager->Delete(handle).IgnoreError();
573 }
574
575 ResourceHandle handle;
576 ResourceMgr* resource_manager; // not owned
577 };
578
579 std::shared_ptr<Helper> deleter_;
580 };
581
582 // Implementation details below.
583
584 template <typename T>
CheckDeriveFromResourceBase()585 void CheckDeriveFromResourceBase() {
586 static_assert(std::is_base_of<ResourceBase, T>::value,
587 "T must derive from ResourceBase");
588 }
589
590 template <typename T>
Create(const string & container,const string & name,T * resource)591 Status ResourceMgr::Create(const string& container, const string& name,
592 T* resource) {
593 CheckDeriveFromResourceBase<T>();
594 CHECK(resource != nullptr);
595 mutex_lock l(mu_);
596 return DoCreate(container, MakeTypeIndex<T>(), name, resource);
597 }
598
599 template <typename T, bool use_dynamic_cast>
Lookup(const string & container,const string & name,T ** resource)600 Status ResourceMgr::Lookup(const string& container, const string& name,
601 T** resource) const {
602 CheckDeriveFromResourceBase<T>();
603 tf_shared_lock l(mu_);
604 return LookupInternal<T, use_dynamic_cast>(container, name, resource);
605 }
606
607 template <typename T, bool use_dynamic_cast>
LookupMany(absl::Span<std::pair<const string *,const string * > const> containers_and_names,std::vector<std::unique_ptr<T,core::RefCountDeleter>> * resources)608 Status ResourceMgr::LookupMany(
609 absl::Span<std::pair<const string*, const string*> const>
610 containers_and_names,
611 std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
612 CheckDeriveFromResourceBase<T>();
613 tf_shared_lock l(mu_);
614 resources->resize(containers_and_names.size());
615 for (size_t i = 0; i < containers_and_names.size(); ++i) {
616 T* resource;
617 Status s = LookupInternal<T, use_dynamic_cast>(
618 *containers_and_names[i].first, *containers_and_names[i].second,
619 &resource);
620 if (s.ok()) {
621 (*resources)[i].reset(resource);
622 }
623 }
624 return Status::OK();
625 }
626
627 // Simple wrapper to allow conditional dynamic / static casts.
628 template <typename T, bool use_dynamic_cast>
629 struct TypeCastFunctor {
CastTypeCastFunctor630 static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
631 };
632
633 template <typename T>
634 struct TypeCastFunctor<T, true> {
635 static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
636 };
637
638 template <typename T, bool use_dynamic_cast>
639 Status ResourceMgr::LookupInternal(const string& container, const string& name,
640 T** resource) const {
641 ResourceBase* found = nullptr;
642 Status s = DoLookup(container, MakeTypeIndex<T>(), name, &found);
643 if (s.ok()) {
644 // It's safe to down cast 'found' to T* since
645 // typeid(T).hash_code() is part of the map key.
646 *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
647 }
648 return s;
649 }
650
651 template <typename T, bool use_dynamic_cast>
652 Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
653 T** resource,
654 std::function<Status(T**)> creator) {
655 CheckDeriveFromResourceBase<T>();
656 *resource = nullptr;
657 Status s;
658 {
659 tf_shared_lock l(mu_);
660 s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
661 if (s.ok()) return s;
662 }
663 mutex_lock l(mu_);
664 s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
665 if (s.ok()) return s;
666 TF_RETURN_IF_ERROR(creator(resource));
667 s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
668 if (!s.ok()) {
669 return errors::Internal("LookupOrCreate failed unexpectedly");
670 }
671 (*resource)->Ref();
672 return s;
673 }
674
675 template <typename T>
676 Status ResourceMgr::Delete(const string& container, const string& name) {
677 CheckDeriveFromResourceBase<T>();
678 return DoDelete(container, MakeTypeIndex<T>(), name);
679 }
680
681 template <typename T>
682 Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
683 T** resource) {
684 DataType dtype;
685 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
686 if (dtype == DT_RESOURCE) {
687 const Tensor* handle;
688 TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
689 return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
690 }
691 string container;
692 string shared_name;
693 {
694 mutex* mu;
695 TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
696 mutex_lock l(*mu);
697 Tensor tensor;
698 TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
699 if (tensor.NumElements() != 2) {
700 return errors::InvalidArgument(
701 "Resource handle must have 2 elements, but had shape: ",
702 tensor.shape().DebugString());
703 }
704 container = tensor.flat<tstring>()(0);
705 shared_name = tensor.flat<tstring>()(1);
706 }
707 return ctx->resource_manager()->Lookup(container, shared_name, resource);
708 }
709
710 namespace internal {
711
712 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
713
714 template <typename T>
715 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
716 TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
717 auto type_index = MakeTypeIndex<T>();
718 if (type_index.hash_code() != p.hash_code()) {
719 return errors::InvalidArgument(
720 "Trying to access resource using the wrong type. Expected ",
721 p.maybe_type_name(), " got ", type_index.name());
722 }
723 return Status::OK();
724 }
725
726 } // namespace internal
727
728 template <typename T>
729 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
730 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
731 return ctx->resource_manager()->Create(p.container(), p.name(), value);
732 }
733
734 template <typename T, bool use_dynamic_cast>
735 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
736 T** value) {
737 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
738 return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
739 p.name(), value);
740 }
741
742 template <typename T>
743 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
744 core::RefCountPtr<T>* value) {
745 T* raw_ptr = nullptr;
746 TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
747 value->reset(raw_ptr);
748
749 return Status::OK();
750 }
751
752 template <typename T>
753 Status LookupResources(OpKernelContext* ctx,
754 absl::Span<ResourceHandle const* const> p,
755 std::vector<core::RefCountPtr<T>>* values) {
756 std::vector<std::pair<const string*, const string*>> containers_and_names(
757 p.size());
758 for (size_t i = 0; i < p.size(); ++i) {
759 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
760 containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
761 }
762 return ctx->resource_manager()->LookupMany(containers_and_names, values);
763 }
764
765 template <typename T>
766 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
767 T** value, std::function<Status(T**)> creator) {
768 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
769 return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
770 creator);
771 }
772
773 template <typename T>
774 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
775 core::RefCountPtr<T>* value,
776 std::function<Status(T**)> creator) {
777 T* raw_ptr = nullptr;
778 TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
779 value->reset(raw_ptr);
780
781 return Status::OK();
782 }
783
784 template <typename T>
785 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
786 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
787 return ctx->resource_manager()->Delete<T>(p.container(), p.name());
788 }
789
790 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
791
792 template <typename T>
793 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
794 Tensor* output;
795 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
796 T* object;
797 bool found;
798 if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
799 found = true;
800 object->Unref();
801 } else {
802 found = false;
803 }
804
805 output->flat<bool>()(0) = found;
806 }
807
808 template <typename T>
809 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
810 : OpKernel(context) {
811 OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
812 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
813 }
814
815 template <typename T>
816 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
817 if (name_ == ResourceHandle::ANONYMOUS_NAME) {
818 AllocatorAttributes attr;
819 attr.set_on_host(true);
820 Tensor handle;
821 OP_REQUIRES_OK(
822 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
823 handle.scalar<ResourceHandle>()() =
824 MakeResourceHandle<T>(ctx, container_, name_);
825 ctx->set_output(0, handle);
826 } else {
827 if (!initialized_.load()) {
828 mutex_lock ml(mutex_);
829 // Checking again to see if another thread has initialized the resource.
830 if (!initialized_.load()) {
831 AllocatorAttributes attr;
832 attr.set_on_host(true);
833 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
834 &resource_, attr));
835 resource_.scalar<ResourceHandle>()() =
836 MakeResourceHandle<T>(ctx, container_, name_);
837 initialized_.store(true);
838 }
839 }
840 ctx->set_output(0, resource_);
841 }
842 }
843
844 template <typename T>
845 ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
846 : OpKernel(context) {
847 int n;
848 OP_REQUIRES_OK(context, context->GetAttr("N", &n));
849 OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
850 OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
851 OP_REQUIRES(
852 context, containers_.size() == n,
853 errors::InvalidArgument("Number of containers (", containers_.size(),
854 ") must be equal to N (", n, ")"));
855 OP_REQUIRES(context, names_.size() == n,
856 errors::InvalidArgument("Number of names (", containers_.size(),
857 ") must be equal to N (", n, ")"));
858 resources_.resize(n);
859 }
860
861 template <typename T>
862 void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
863 if (!initialized_.load()) {
864 mutex_lock ml(mutex_);
865 // Checking again to see if another thread has initialized the resource.
866 if (!initialized_.load()) {
867 AllocatorAttributes attr;
868 attr.set_on_host(true);
869 for (size_t i = 0; i < resources_.size(); ++i) {
870 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
871 &resources_[i], attr));
872 ResourceHandle h =
873 MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
874 resources_[i].template scalar<ResourceHandle>()() = h;
875 }
876 initialized_.store(true);
877 }
878 }
879 for (size_t i = 0; i < resources_.size(); ++i) {
880 ctx->set_output(i, resources_[i]);
881 }
882 }
883
884 template <typename T>
885 ResourceHandle ScopedStepContainer::MakeResourceHandle(
886 const string& name, const DeviceBase& device) {
887 mutex_lock ml(mu_);
888 dirty_ = true;
889 return tensorflow::MakeResourceHandle(container_, name, device,
890 MakeTypeIndex<T>(), {});
891 }
892
893 template <typename T>
894 Status ScopedStepContainer::Lookup(ResourceMgr* rm, const string& name,
895 T** resource) const {
896 return rm->Lookup<T>(container_, name, resource);
897 }
898
899 template <typename T>
900 Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm, const string& name,
901 T** resource,
902 std::function<Status(T**)> creator) {
903 mutex_lock ml(mu_);
904 dirty_ = true;
905 return rm->LookupOrCreate<T>(container_, name, resource, creator);
906 }
907
908 template <typename T>
909 Status ScopedStepContainer::Create(ResourceMgr* rm, const string& name,
910 T* resource) {
911 mutex_lock ml(mu_);
912 dirty_ = true;
913 return rm->Create<T>(container_, name, resource);
914 }
915
916 template <typename T>
917 Status ScopedStepContainer::Delete(ResourceMgr* rm, const string& name) {
918 return rm->Delete<T>(container_, name);
919 }
920
921 } // end namespace tensorflow
922
923 #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
924