1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 #include "tensorflow/core/platform/logging.h" 25 #include "tensorflow/core/platform/mutex.h" 26 #include "tensorflow/core/platform/thread_annotations.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 31 // ResourceOpKernel<T> is a virtual base class for resource op implementing 32 // interface type T. The inherited op looks up the resource name (determined by 33 // ContainerInfo), and creates a new resource if necessary. 34 // 35 // Requirements: 36 // - Op must be marked as stateful. 37 // - Op must have `container` and `shared_name` attributes. Empty `container` 38 // means using the default container. Empty `shared_name` means private 39 // resource. 40 // - Subclass must override CreateResource(). 41 // - Subclass is encouraged to override VerifyResource(). 42 template <typename T> 43 class ResourceOpKernel : public OpKernel { 44 public: ResourceOpKernel(OpKernelConstruction * context)45 explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) { 46 has_resource_type_ = (context->output_type(0) == DT_RESOURCE); 47 if (!has_resource_type_) { 48 // The resource variant of the op may be placed on non-CPU devices, but 49 // this allocation is always on the host. Fortunately we don't need it in 50 // the resource case. 51 OP_REQUIRES_OK(context, 52 context->allocate_persistent(DT_STRING, TensorShape({2}), 53 &handle_, nullptr)); 54 } 55 } 56 57 // The resource is deleted from the resource manager only when it is private 58 // to kernel. Ideally the resource should be deleted when it is no longer held 59 // by anyone, but it would break backward compatibility. ~ResourceOpKernel()60 ~ResourceOpKernel() override { 61 if (resource_ != nullptr) { 62 resource_->Unref(); 63 if (cinfo_.resource_is_private_to_kernel()) { 64 if (!cinfo_.resource_manager() 65 ->template Delete<T>(cinfo_.container(), cinfo_.name()) 66 .ok()) { 67 // Do nothing; the resource can have been deleted by session resets. 68 } 69 } 70 } 71 } 72 Compute(OpKernelContext * context)73 void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { 74 mutex_lock l(mu_); 75 if (resource_ == nullptr) { 76 ResourceMgr* mgr = context->resource_manager(); 77 OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); 78 79 T* resource; 80 OP_REQUIRES_OK( 81 context, 82 mgr->LookupOrCreate<T>(cinfo_.container(), cinfo_.name(), &resource, 83 [this](T** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 84 Status s = CreateResource(ret); 85 if (!s.ok() && *ret != nullptr) { 86 CHECK((*ret)->Unref()); 87 } 88 return s; 89 })); 90 91 Status s = VerifyResource(resource); 92 if (TF_PREDICT_FALSE(!s.ok())) { 93 resource->Unref(); 94 context->SetStatus(s); 95 return; 96 } 97 98 if (!has_resource_type_) { 99 auto h = handle_.AccessTensor(context)->template flat<string>(); 100 h(0) = cinfo_.container(); 101 h(1) = cinfo_.name(); 102 } 103 resource_ = resource; 104 } 105 if (has_resource_type_) { 106 OP_REQUIRES_OK(context, MakeResourceHandleToOutput( 107 context, 0, cinfo_.container(), cinfo_.name(), 108 MakeTypeIndex<T>())); 109 } else { 110 context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); 111 } 112 } 113 114 protected: 115 // Variables accessible from subclasses. 116 mutex mu_; 117 ContainerInfo cinfo_ GUARDED_BY(mu_); 118 T* resource_ GUARDED_BY(mu_) = nullptr; 119 120 private: 121 // Must return a T descendant allocated with new that ResourceOpKernel will 122 // take ownership of. 123 virtual Status CreateResource(T** resource) EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; 124 125 // During the first Compute(), resource is either created or looked up using 126 // shared_name. In the latter case, the resource found should be verified if 127 // it is compatible with this op's configuration. The verification may fail in 128 // cases such as two graphs asking queues of the same shared name to have 129 // inconsistent capacities. VerifyResource(T * resource)130 virtual Status VerifyResource(T* resource) { return Status::OK(); } 131 132 PersistentTensor handle_ GUARDED_BY(mu_); 133 134 // Is the output of the operator of type DT_RESOURCE? 135 bool has_resource_type_; 136 }; 137 } // namespace tensorflow 138 139 #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ 140