1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
18
19 #include <memory>
20 #include <string>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_handle.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/tensor_types.h"
32 #include "tensorflow/core/framework/type_index.h"
33 #include "tensorflow/core/framework/variant_tensor_data.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/refcount.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/hash/hash.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/macros.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42
43 namespace tensorflow {
44
45 // A ResourceMgr instance keeps track of named and typed resources
46 // grouped into containers.
47 //
48 // Each resource must be represented as a sub-class of ResourceBase,
49 // which is reference counted explicitly. Each named resource is
50 // registered with ResourceMgr under a named "container" name. At any
51 // time, there is at most one instance of a resource given the container
52 // name, the resource type and the resource name.
53 //
54 // All resources for a given container can be dropped by one call of
55 // Cleanup().
56 //
57 // E.g.,
58 // struct MyVar : public ResourceBase {
59 // mutex mu;
60 // Tensor val;
61 // }
62 //
63 // ResourceMgr rm;
64 //
65 // // Create a var.
66 // MyVar* my_var = new MyVar;
67 // my_var->val = Tensor(DT_FLOAT, my_shape);
68 // my_var->val.flat<float>().setZeros(); // 0 initialized.
69 // ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
70 //
71 // // += a variable.
72 // MyVar* my_var = nullptr;
73 // Status s = rm.Lookup("my_container", "my_name", &my_var);
74 // if (s.ok()) {
75 // my_var->val.flat<float>() += grad;
76 // }
77 // my_var->Unref(); // Or use ScopedUnref().
78 // ctx->SetStatus(s);
79 class ResourceBase : public core::RefCounted {
80 public:
81 // Returns a debug string for *this.
82 virtual std::string DebugString() const = 0;
83
84 // Returns memory used by this resource.
MemoryUsed()85 virtual int64 MemoryUsed() const { return 0; }
86 };
87
88 // Container used for per-step resources.
89 class ScopedStepContainer {
90 public:
91 // step_id: the unique ID of this step. Doesn't have to be sequential, just
92 // has to be unique.
93 // cleanup: callback to delete a container of this name.
94 // prefix: optional string prefix to disambiguate step containers.
ScopedStepContainer(const int64 step_id,std::function<void (const string &)> cleanup)95 ScopedStepContainer(const int64 step_id,
96 std::function<void(const string&)> cleanup)
97 : step_id_(step_id),
98 container_(strings::StrCat("__per_step_", step_id)),
99 cleanup_(cleanup),
100 dirty_(false) {}
101
ScopedStepContainer(const int64 step_id,std::function<void (const string &)> cleanup,const std::string & prefix)102 ScopedStepContainer(const int64 step_id,
103 std::function<void(const string&)> cleanup,
104 const std::string& prefix)
105 : step_id_(step_id),
106 container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
107 cleanup_(cleanup),
108 dirty_(false) {}
109
~ScopedStepContainer()110 ~ScopedStepContainer() { CleanUp(); }
111
CleanUp()112 void CleanUp() TF_NO_THREAD_SAFETY_ANALYSIS {
113 // NOTE(mrry): Avoid acquiring the mutex in the case that the container is
114 // clean.
115 if (dirty_) {
116 mutex_lock ml(mu_);
117 cleanup_(container_);
118 dirty_ = false;
119 }
120 }
121
122 // Pass through functions for resource lookup and creation. We do this to
123 // ensure that we can appropriately set the dirty_ bit in the
124 // ScopedStepContainer if the name of the container is used to create
125 // resources.
126
127 // Pass through to MakeResourceHandle with the container name
128 template <typename T>
129 ResourceHandle MakeResourceHandle(
130 const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
131 // Pass through to ResourceMgr::Create with the container name
132 template <typename T>
133 Status Create(ResourceMgr* rm, const std::string& name,
134 T* resource) TF_MUST_USE_RESULT;
135 // Pass through to ResourceMgr::Delete with the container name
136 template <typename T>
137 Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT;
138 // Pass through to ResourceMgr::Lookup with the container name
139 template <typename T>
140 Status Lookup(ResourceMgr* rm, const std::string& name,
141 T** resource) const TF_MUST_USE_RESULT;
142 // Pass through to ResourceMgr::LookupOrCreate with the container name
143 template <typename T>
144 Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource,
145 std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
StepId()146 int64 StepId() const { return step_id_; }
147
148 private:
149 const int64 step_id_;
150 const std::string container_;
151 const std::function<void(const string&)> cleanup_;
152 mutex mu_;
153 mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_);
154 };
155
156 class ResourceMgr {
157 public:
158 ResourceMgr();
159 explicit ResourceMgr(const std::string& default_container);
160 ~ResourceMgr();
161
162 // Returns the default container name for *this.
default_container()163 const std::string& default_container() const { return default_container_; }
164
165 // Creates a resource "name" in the "container". The caller transfers
166 // the ownership of one ref on "resource" to *this, regardless of whether this
167 // operation succeeds or fails.
168 //
169 // REQUIRES: std::is_base_of<ResourceBase, T>
170 // REQUIRES: resource != nullptr.
171 template <typename T>
172 Status Create(const std::string& container, const std::string& name,
173 T* resource) TF_MUST_USE_RESULT;
174
175 // If "container" has a resource "name", returns it in "*resource" and
176 // the caller takes the ownership of one ref on "*resource".
177 //
178 // REQUIRES: std::is_base_of<ResourceBase, T>
179 // REQUIRES: resource != nullptr
180 template <typename T, bool use_dynamic_cast = false>
181 Status Lookup(const std::string& container, const std::string& name,
182 T** resource) const TF_MUST_USE_RESULT;
183
184 // Similar to Lookup, but looks up multiple resources at once, with only a
185 // single lock acquisition. If containers_and_names[i] is uninitialized
186 // then this function does not modify resources[i].
187 template <typename T, bool use_dynamic_cast = false>
188 Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
189 containers_and_names,
190 std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
191 resources) const TF_MUST_USE_RESULT;
192
193 // If "container" has a resource "name", returns it in
194 // "*resource". Otherwise, invokes creator() to create the resource.
195 // The caller takes the ownership of one ref on "*resource".
196 //
197 // WARNING: creator() must not call any methods on ResourceMgr during its
198 // execution, because a non-reentrant lock is held during the creator() call
199 // in order to guarantee atomicity of LookupOrCreate().
200 //
201 // REQUIRES: std::is_base_of<ResourceBase, T>
202 // REQUIRES: resource != nullptr
203 template <typename T, bool use_dynamic_cast = false>
204 Status LookupOrCreate(const std::string& container, const std::string& name,
205 T** resource,
206 std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
207
208 // Deletes the resource "name" from the "container".
209 //
210 // REQUIRES: std::is_base_of<ResourceBase, T>
211 template <typename T>
212 Status Delete(const std::string& container,
213 const std::string& name) TF_MUST_USE_RESULT;
214
215 // Deletes the resource pointed by "handle".
216 Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
217
218 // Deletes all resources from the "container" and removes the container.
219 Status Cleanup(const std::string& container) TF_MUST_USE_RESULT;
220
221 // Deletes all resources in all containers.
222 void Clear();
223
224 // Returns a text description for all resources.
225 std::string DebugString() const;
226
227 private:
228 typedef std::pair<uint64, StringPiece> Key;
229 struct KeyHash {
operatorKeyHash230 std::size_t operator()(const Key& k) const {
231 return Hash64(k.second.data(), k.second.size(), k.first);
232 }
233 };
234 struct KeyEqual {
operatorKeyEqual235 bool operator()(const Key& x, const Key& y) const {
236 return (x.second == y.second) && (x.first == y.first);
237 }
238 };
239 struct ResourceAndName {
240 core::RefCountPtr<ResourceBase> resource;
241 std::unique_ptr<string> name;
242
243 ResourceAndName();
244 ResourceAndName(ResourceBase* resource, std::string name);
245 ResourceAndName(ResourceAndName&& other) noexcept;
246 ~ResourceAndName();
247
248 ResourceAndName& operator=(ResourceAndName&&) noexcept;
249
250 private:
251 TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
252 };
253 typedef std::unordered_map<Key, ResourceAndName, KeyHash, KeyEqual> Container;
254
255 const std::string default_container_;
256 mutable mutex mu_;
257 std::unordered_map<string, Container*> containers_ TF_GUARDED_BY(mu_);
258
259 template <typename T, bool use_dynamic_cast = false>
260 Status LookupInternal(const std::string& container, const std::string& name,
261 T** resource) const
262 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
263
264 Status DoCreate(const std::string& container, TypeIndex type,
265 const std::string& name, ResourceBase* resource)
266 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
267
268 Status DoLookup(const std::string& container, TypeIndex type,
269 const std::string& name, ResourceBase** resource) const
270 TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
271
272 Status DoDelete(const std::string& container, uint64 type_hash_code,
273 const std::string& resource_name,
274 const std::string& type_name) TF_MUST_USE_RESULT;
275 Status DoDelete(const std::string& container, TypeIndex type,
276 const std::string& resource_name) TF_MUST_USE_RESULT;
277
278 // Inserts the type name for 'hash_code' into the hash_code to type name map.
279 Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name)
280 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
281
282 // Returns the type name for the 'hash_code'.
283 // Returns "<unknown>" if a resource with such a type was never inserted into
284 // the container.
285 const char* DebugTypeName(uint64 hash_code) const
286 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
287
288 // Map from type hash_code to type name.
289 std::unordered_map<uint64, string> debug_type_names_ TF_GUARDED_BY(mu_);
290
291 TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
292 };
293
294 // Makes a resource handle with the specified type for a given container /
295 // name.
296 ResourceHandle MakeResourceHandle(
297 const std::string& container, const std::string& name,
298 const DeviceBase& device, const TypeIndex& type_index,
299 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
300 const absl::optional<ManagedStackTrace>& definition_stack_trace = {})
301 TF_MUST_USE_RESULT;
302
303 template <typename T>
304 ResourceHandle MakeResourceHandle(
305 OpKernelContext* ctx, const std::string& container, const std::string& name,
306 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
307 const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
308 return MakeResourceHandle(container.empty()
309 ? ctx->resource_manager()->default_container()
310 : container,
311 name, *ctx->device(), TypeIndex::Make<T>(),
312 dtypes_and_shapes, definition_stack_trace);
313 }
314
315 template <typename T>
316 ResourceHandle MakeResourceHandle(
317 OpKernelConstruction* ctx, const std::string& container,
318 const std::string& name,
319 const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
320 const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
321 return MakeResourceHandle(container.empty()
322 ? ctx->resource_manager()->default_container()
323 : container,
324 name, *ctx->device(), TypeIndex::Make<T>(),
325 dtypes_and_shapes, definition_stack_trace);
326 }
327
328 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
329 const std::string& container,
330 const std::string& name,
331 const TypeIndex& type_index);
332
333 // Returns a resource handle from a numbered op input.
334 const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
335 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
336 ResourceHandle* handle);
337
338 // Create a resource pointed by a given resource handle.
339 //
340 // If successful, the caller transfers the ownership of one ref on `resource` to
341 // `ctx->resource_mgr()`.
342 template <typename T>
343 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
344
345 // Looks up a resource pointed by a given resource handle.
346 //
347 // If the lookup is successful, the caller takes the ownership of one ref on
348 // `*value`, and must call its `Unref()` method when it has finished using it.
349 template <typename T, bool use_dynamic_cast = false>
350 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
351
352 // Looks up a resource pointed by a given resource handle.
353 //
354 // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
355 // requiring the caller to explicitly call `Unref()`.
356 template <typename T>
357 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
358 core::RefCountPtr<T>* value);
359
360 // Looks up multiple resources pointed by a sequence of resource handles. If
361 // p[i] is uninitialized then values[i] is unmodified.
362 template <typename T>
363 Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
364 std::vector<core::RefCountPtr<T>>* values);
365
366 // Looks up or creates a resource.
367 //
368 // If successful, the caller takes the ownership of one ref on `*value`, and
369 // must call its `Unref()` method when it has finished using it. If the
370 // `creator` is invoked, its reference on the created resource is transferred
371 // to `ctx->resource_mgr()`.
372 //
373 // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
374 // requiring the caller to explicitly call `Unref()`.
375 template <typename T>
376 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
377 T** value, std::function<Status(T**)> creator);
378
379 // Looks up or creates a resource.
380 template <typename T>
381 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
382 core::RefCountPtr<T>* value,
383 std::function<Status(T**)> creator);
384
385 // Destroys a resource pointed by a given resource handle.
386 template <typename T>
387 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
388
389 // Same as above, but uses the hash code of the type directly.
390 // The type name information will be missing in the debug output when the
391 // resource is not present in the container.
392 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
393
394 // Policy helper to decide which container/shared_name to use for a
395 // stateful kernel that accesses shared resource.
396 class ContainerInfo {
397 public:
398 // Analyze the node attribute of 'ndef' and decides the container and
399 // resource name the kernel should use for accessing the shared
400 // resource.
401 //
402 // 'ndef' is expected to have node attribute "container" and
403 // "shared_name". Returns non-OK if they are not provided or they are
404 // invalid.
405 //
406 // The policy is as following:
407 // * If the attribute "container" is non-empty, it is used as is.
408 // Otherwise, uses the resource manager's default container.
409 // * If the attribute "shared_name" is non-empty, it is used as is.
410 // Otherwise, if "use_node_name_as_default" is true, the kernel's
411 // node name is used as the resource name. Otherwise, a string
412 // unique to this process is used.
413 Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
414 bool use_node_name_as_default);
Init(ResourceMgr * rmgr,const NodeDef & ndef)415 Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
416 return Init(rmgr, ndef, false);
417 }
418
419 // The policy decides that the kernel should access the resource in
420 // resource_manager(), the resource is in the container() and its
421 // name is name(). If resource_is_private_to_kernel() is true, the
422 // kernel should delete the resource when the kernel is deleted.
resource_manager()423 ResourceMgr* resource_manager() const { return rmgr_; }
container()424 const std::string& container() const { return container_; }
name()425 const std::string& name() const { return name_; }
resource_is_private_to_kernel()426 bool resource_is_private_to_kernel() const {
427 return resource_is_private_to_kernel_;
428 }
429
430 // Returns a readable string for *this.
431 std::string DebugString() const;
432
433 private:
434 ResourceMgr* rmgr_ = nullptr;
435 std::string container_;
436 std::string name_;
437 bool resource_is_private_to_kernel_ = false;
438 };
439
440 // Helper for kernels to obtain 'resource' from the
441 // ctx->resource_manager().
442 //
443 // "input_name" specifies the kernel's ref input which gives a string
444 // tensor with two elements, which specifies the container and
445 // resource name.
446 //
447 // Returns OK if the resource is found and transfers one ref of
448 // *resource to the caller. Otherwise, returns an error.
449 template <typename T>
450 Status GetResourceFromContext(OpKernelContext* ctx,
451 const std::string& input_name, T** resource);
452
453 // Utility op kernel to check if a handle to resource type T is initialized.
454 template <typename T>
455 class IsResourceInitialized : public OpKernel {
456 public:
IsResourceInitialized(OpKernelConstruction * c)457 explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
458
459 void Compute(OpKernelContext* ctx) override;
460 };
461
462 // Registers an op which produces just a resource handle to a resource of the
463 // specified type. The type will be a part of the generated op name.
464 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
465 // through constant folding so this doesn't have to be marked as stateful.
466 #define REGISTER_RESOURCE_HANDLE_OP(Type) \
467 REGISTER_OP(#Type "HandleOp") \
468 .Attr("container: string = ''") \
469 .Attr("shared_name: string = ''") \
470 .Output("resource: resource") \
471 .SetIsStateful() \
472 .SetShapeFn(tensorflow::shape_inference::ScalarShape)
473
474 // Utility op kernel to produce a handle to a resource of type T.
475 template <typename T>
476 class ResourceHandleOp : public OpKernel {
477 public:
478 explicit ResourceHandleOp(OpKernelConstruction* context);
479
480 void Compute(OpKernelContext* ctx) override;
481
IsExpensive()482 bool IsExpensive() override { return false; }
483
484 private:
485 std::string container_;
486 std::string name_;
487 mutex mutex_;
488 Tensor resource_;
489 std::atomic<bool> initialized_{false};
490 };
491
492 // Utility op kernel to produce a handle to a resource of type T.
493 template <typename T>
494 class ResourceHandlesOp : public OpKernel {
495 public:
496 explicit ResourceHandlesOp(OpKernelConstruction* context);
497
498 void Compute(OpKernelContext* ctx) override;
499
IsExpensive()500 bool IsExpensive() override { return false; }
501
502 private:
503 std::vector<string> containers_;
504 std::vector<string> names_;
505 mutex mutex_;
506 std::vector<Tensor> resources_;
507 std::atomic<bool> initialized_{false};
508 };
509
510 // Registers a kernel for an op which produces a handle to a resource of the
511 // specified type.
512 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \
513 REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
514 ResourceHandleOp<Type>)
515
516 // This class is used to guarantee that an anonymous resource is deleted
517 // (irrespective of whether a resource deleter op is called explicitly or
518 // the execution encounters an error before the op runs).
519 //
520 // This is achieved by wrapping an instance of this class into a variant
521 // tensor which is passed as an input to a resource deleter op. If the
522 // execution encounters an error before the op runs, the tensor will be
523 // destroyed, essentially triggering the iterator deletion.
524 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
525 // specification. In particular, we cannot serialize the `ResourceMgr`
526 // object, so the `Encode()` and `Decode()` methods are not implemented.
527 class ResourceDeleter {
528 public:
ResourceDeleter()529 ResourceDeleter() : deleter_() {}
530
ResourceDeleter(ResourceHandle handle,ResourceMgr * resource_manager)531 ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager)
532 : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
533
ResourceDeleter(ResourceDeleter && rhs)534 ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) {
535 VLOG(3) << "ResourceDeleter move constructor called.";
536 }
537
ResourceDeleter(const ResourceDeleter & rhs)538 ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) {
539 VLOG(3) << "ResourceDeleter copy constructor called.";
540 }
541
542 ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete;
543
544 ResourceDeleter& operator=(ResourceDeleter&& rhs) = default;
545
~ResourceDeleter()546 virtual ~ResourceDeleter() {
547 VLOG(3) << "ResourceDeleter destructor called.";
548 }
549
Encode(VariantTensorData *)550 void Encode(VariantTensorData*) const {
551 LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter "
552 "objects.";
553 }
554
Decode(const VariantTensorData &)555 bool Decode(const VariantTensorData&) {
556 LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter "
557 "objects";
558 return false; // Not supported.
559 }
560
561 private:
562 // Helper that performs reference counting for the parent class and deletes
563 // the iterator resource when the refcount goes to zero.
564 //
565 // NOTE: The object is borrowing a pointer to the resource manager.
566 // Consequently, the tensor containing this object should not escape the
567 // function in which was created (so that it is guaranteed that the resource
568 // manager will outlive it).
569 struct Helper {
HelperHelper570 Helper(ResourceHandle handle, ResourceMgr* resource_manager)
571 : handle(handle), resource_manager(resource_manager) {}
572
573 Helper(const Helper& rhs) = delete;
574 Helper(Helper&& rhs) = delete;
575
~HelperHelper576 ~Helper() {
577 VLOG(3) << "Deleting Resource: " << handle.DebugString();
578 resource_manager->Delete(handle).IgnoreError();
579 }
580
581 ResourceHandle handle;
582 ResourceMgr* resource_manager; // not owned
583 };
584
585 std::shared_ptr<Helper> deleter_;
586 };
587
588 // Implementation details below.
589
590 template <typename T>
CheckDeriveFromResourceBase()591 void CheckDeriveFromResourceBase() {
592 static_assert(std::is_base_of<ResourceBase, T>::value,
593 "T must derive from ResourceBase");
594 }
595
596 template <typename T>
Create(const std::string & container,const std::string & name,T * resource)597 Status ResourceMgr::Create(const std::string& container,
598 const std::string& name, T* resource) {
599 CheckDeriveFromResourceBase<T>();
600 CHECK(resource != nullptr);
601 mutex_lock l(mu_);
602 return DoCreate(container, TypeIndex::Make<T>(), name, resource);
603 }
604
605 template <typename T, bool use_dynamic_cast>
Lookup(const std::string & container,const std::string & name,T ** resource)606 Status ResourceMgr::Lookup(const std::string& container,
607 const std::string& name, T** resource) const {
608 CheckDeriveFromResourceBase<T>();
609 tf_shared_lock l(mu_);
610 return LookupInternal<T, use_dynamic_cast>(container, name, resource);
611 }
612
613 template <typename T, bool use_dynamic_cast>
LookupMany(absl::Span<std::pair<const string *,const string * > const> containers_and_names,std::vector<std::unique_ptr<T,core::RefCountDeleter>> * resources)614 Status ResourceMgr::LookupMany(
615 absl::Span<std::pair<const string*, const string*> const>
616 containers_and_names,
617 std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
618 CheckDeriveFromResourceBase<T>();
619 tf_shared_lock l(mu_);
620 resources->resize(containers_and_names.size());
621 for (size_t i = 0; i < containers_and_names.size(); ++i) {
622 T* resource;
623 Status s = LookupInternal<T, use_dynamic_cast>(
624 *containers_and_names[i].first, *containers_and_names[i].second,
625 &resource);
626 if (s.ok()) {
627 (*resources)[i].reset(resource);
628 }
629 }
630 return Status::OK();
631 }
632
633 // Simple wrapper to allow conditional dynamic / static casts.
634 template <typename T, bool use_dynamic_cast>
635 struct TypeCastFunctor {
CastTypeCastFunctor636 static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
637 };
638
639 template <typename T>
640 struct TypeCastFunctor<T, true> {
641 static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
642 };
643
644 template <typename T, bool use_dynamic_cast>
645 Status ResourceMgr::LookupInternal(const std::string& container,
646 const std::string& name,
647 T** resource) const {
648 ResourceBase* found = nullptr;
649 Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
650 if (s.ok()) {
651 // It's safe to down cast 'found' to T* since
652 // typeid(T).hash_code() is part of the map key.
653 *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
654 }
655 return s;
656 }
657
658 template <typename T, bool use_dynamic_cast>
659 Status ResourceMgr::LookupOrCreate(const std::string& container,
660 const std::string& name, T** resource,
661 std::function<Status(T**)> creator) {
662 CheckDeriveFromResourceBase<T>();
663 *resource = nullptr;
664 Status s;
665 {
666 tf_shared_lock l(mu_);
667 s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
668 if (s.ok()) return s;
669 }
670 mutex_lock l(mu_);
671 s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
672 if (s.ok()) return s;
673 TF_RETURN_IF_ERROR(creator(resource));
674 s = DoCreate(container, TypeIndex::Make<T>(), name, *resource);
675 if (!s.ok()) {
676 return errors::Internal("LookupOrCreate failed unexpectedly");
677 }
678 (*resource)->Ref();
679 return s;
680 }
681
682 template <typename T>
683 Status ResourceMgr::Delete(const std::string& container,
684 const std::string& name) {
685 CheckDeriveFromResourceBase<T>();
686 return DoDelete(container, TypeIndex::Make<T>(), name);
687 }
688
689 template <typename T>
690 Status GetResourceFromContext(OpKernelContext* ctx,
691 const std::string& input_name, T** resource) {
692 DataType dtype;
693 TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
694 if (dtype == DT_RESOURCE) {
695 const Tensor* handle;
696 TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
697 return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
698 }
699 std::string container;
700 std::string shared_name;
701 {
702 mutex* mu;
703 TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
704 mutex_lock l(*mu);
705 Tensor tensor;
706 TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
707 if (tensor.NumElements() != 2) {
708 return errors::InvalidArgument(
709 "Resource handle must have 2 elements, but had shape: ",
710 tensor.shape().DebugString());
711 }
712 container = tensor.flat<tstring>()(0);
713 shared_name = tensor.flat<tstring>()(1);
714 }
715 return ctx->resource_manager()->Lookup(container, shared_name, resource);
716 }
717
718 namespace internal {
719
720 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
721
722 template <typename T>
723 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
724 TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
725 auto type_index = TypeIndex::Make<T>();
726 if (type_index.hash_code() != p.hash_code()) {
727 return errors::InvalidArgument(
728 "Trying to access resource using the wrong type. Expected ",
729 p.maybe_type_name(), " got ", type_index.name());
730 }
731 return Status::OK();
732 }
733
734 } // namespace internal
735
736 template <typename T>
737 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
738 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
739 return ctx->resource_manager()->Create(p.container(), p.name(), value);
740 }
741
742 template <typename T, bool use_dynamic_cast>
743 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
744 T** value) {
745 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
746 return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
747 p.name(), value);
748 }
749
750 template <typename T>
751 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
752 core::RefCountPtr<T>* value) {
753 T* raw_ptr = nullptr;
754 TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
755 value->reset(raw_ptr);
756
757 return Status::OK();
758 }
759
760 template <typename T>
761 Status LookupResources(OpKernelContext* ctx,
762 absl::Span<ResourceHandle const* const> p,
763 std::vector<core::RefCountPtr<T>>* values) {
764 std::vector<std::pair<const string*, const string*>> containers_and_names(
765 p.size());
766 for (size_t i = 0; i < p.size(); ++i) {
767 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
768 containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
769 }
770 return ctx->resource_manager()->LookupMany(containers_and_names, values);
771 }
772
773 template <typename T>
774 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
775 T** value, std::function<Status(T**)> creator) {
776 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
777 return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
778 creator);
779 }
780
781 template <typename T>
782 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
783 core::RefCountPtr<T>* value,
784 std::function<Status(T**)> creator) {
785 T* raw_ptr = nullptr;
786 TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
787 value->reset(raw_ptr);
788
789 return Status::OK();
790 }
791
792 template <typename T>
793 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
794 TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
795 return ctx->resource_manager()->Delete<T>(p.container(), p.name());
796 }
797
798 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
799
800 template <typename T>
801 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
802 Tensor* output;
803 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
804 T* object;
805 bool found;
806 if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
807 found = true;
808 object->Unref();
809 } else {
810 found = false;
811 }
812
813 output->flat<bool>()(0) = found;
814 }
815
816 template <typename T>
817 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
818 : OpKernel(context) {
819 OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
820 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
821 }
822
823 template <typename T>
824 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
825 if (name_ == ResourceHandle::ANONYMOUS_NAME) {
826 AllocatorAttributes attr;
827 attr.set_on_host(true);
828 Tensor handle;
829 OP_REQUIRES_OK(
830 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
831 handle.scalar<ResourceHandle>()() =
832 MakeResourceHandle<T>(ctx, container_, name_);
833 ctx->set_output(0, handle);
834 } else {
835 if (!initialized_.load()) {
836 mutex_lock ml(mutex_);
837 // Checking again to see if another thread has initialized the resource.
838 if (!initialized_.load()) {
839 AllocatorAttributes attr;
840 attr.set_on_host(true);
841 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
842 &resource_, attr));
843 resource_.scalar<ResourceHandle>()() =
844 MakeResourceHandle<T>(ctx, container_, name_);
845 initialized_.store(true);
846 }
847 }
848 ctx->set_output(0, resource_);
849 }
850 }
851
852 template <typename T>
853 ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
854 : OpKernel(context) {
855 int n;
856 OP_REQUIRES_OK(context, context->GetAttr("N", &n));
857 OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
858 OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
859 OP_REQUIRES(
860 context, containers_.size() == n,
861 errors::InvalidArgument("Number of containers (", containers_.size(),
862 ") must be equal to N (", n, ")"));
863 OP_REQUIRES(context, names_.size() == n,
864 errors::InvalidArgument("Number of names (", containers_.size(),
865 ") must be equal to N (", n, ")"));
866 resources_.resize(n);
867 }
868
869 template <typename T>
870 void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
871 if (!initialized_.load()) {
872 mutex_lock ml(mutex_);
873 // Checking again to see if another thread has initialized the resource.
874 if (!initialized_.load()) {
875 AllocatorAttributes attr;
876 attr.set_on_host(true);
877 for (size_t i = 0; i < resources_.size(); ++i) {
878 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
879 &resources_[i], attr));
880 ResourceHandle h =
881 MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
882 resources_[i].template scalar<ResourceHandle>()() = h;
883 }
884 initialized_.store(true);
885 }
886 }
887 for (size_t i = 0; i < resources_.size(); ++i) {
888 ctx->set_output(i, resources_[i]);
889 }
890 }
891
892 template <typename T>
893 ResourceHandle ScopedStepContainer::MakeResourceHandle(
894 const std::string& name, const DeviceBase& device) {
895 mutex_lock ml(mu_);
896 dirty_ = true;
897 return tensorflow::MakeResourceHandle(container_, name, device,
898 TypeIndex::Make<T>(), {});
899 }
900
901 template <typename T>
902 Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name,
903 T** resource) const {
904 return rm->Lookup<T>(container_, name, resource);
905 }
906
907 template <typename T>
908 Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm,
909 const std::string& name,
910 T** resource,
911 std::function<Status(T**)> creator) {
912 mutex_lock ml(mu_);
913 dirty_ = true;
914 return rm->LookupOrCreate<T>(container_, name, resource, creator);
915 }
916
917 template <typename T>
918 Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name,
919 T* resource) {
920 mutex_lock ml(mu_);
921 dirty_ = true;
922 return rm->Create<T>(container_, name, resource);
923 }
924
925 template <typename T>
926 Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) {
927 return rm->Delete<T>(container_, name);
928 }
929
930 } // end namespace tensorflow
931
932 #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
933