• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/data_flow_ops.cc.
17 
18 #include <limits.h>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace tensorflow {
37 
38 class GetSessionHandleOp : public OpKernel {
39  public:
GetSessionHandleOp(OpKernelConstruction * context)40   explicit GetSessionHandleOp(OpKernelConstruction* context)
41       : OpKernel(context) {}
42 
Compute(OpKernelContext * ctx)43   void Compute(OpKernelContext* ctx) override {
44     const Tensor& val = ctx->input(0);
45     int64 id = ctx->session_state()->GetNewId();
46     TensorStore::TensorAndKey tk{val, id, requested_device()};
47     OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
48 
49     Tensor* handle = nullptr;
50     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
51     if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
52       ResourceHandle resource_handle = MakeResourceHandle<Tensor>(
53           ctx, SessionState::kTensorHandleResourceTypeName,
54           tk.GetHandle(name()));
55       resource_handle.set_maybe_type_name(
56           SessionState::kTensorHandleResourceTypeName);
57       handle->scalar<ResourceHandle>()() = resource_handle;
58     } else {
59       // Legacy behavior in V1.
60       handle->flat<string>().setConstant(tk.GetHandle(name()));
61     }
62   }
63 
64   TF_DISALLOW_COPY_AND_ASSIGN(GetSessionHandleOp);
65 };
66 
67 REGISTER_KERNEL_BUILDER(Name("GetSessionHandle").Device(DEVICE_CPU),
68                         GetSessionHandleOp);
69 REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2").Device(DEVICE_CPU),
70                         GetSessionHandleOp);
71 
72 #define REGISTER_GPU_KERNEL(type)                         \
73   REGISTER_KERNEL_BUILDER(Name("GetSessionHandle")        \
74                               .Device(DEVICE_GPU)         \
75                               .HostMemory("handle")       \
76                               .TypeConstraint<type>("T"), \
77                           GetSessionHandleOp)             \
78   REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2")      \
79                               .Device(DEVICE_GPU)         \
80                               .HostMemory("handle")       \
81                               .TypeConstraint<type>("T"), \
82                           GetSessionHandleOp)
83 
84 TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
85 REGISTER_GPU_KERNEL(bool);
86 #undef REGISTER_GPU_KERNEL
87 
88 #ifdef TENSORFLOW_USE_SYCL
89 #define REGISTER_SYCL_KERNEL(type)                        \
90   REGISTER_KERNEL_BUILDER(Name("GetSessionHandle")        \
91                               .Device(DEVICE_SYCL)        \
92                               .HostMemory("handle")       \
93                               .TypeConstraint<type>("T"), \
94                           GetSessionHandleOp)             \
95   REGISTER_KERNEL_BUILDER(Name("GetSessionHandleV2")      \
96                               .Device(DEVICE_SYCL)        \
97                               .HostMemory("handle")       \
98                               .TypeConstraint<type>("T"), \
99                           GetSessionHandleOp)
100 
101 TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
102 REGISTER_SYCL_KERNEL(bool);
103 #undef REGISTER_SYCL_KERNEL
104 #endif  // TENSORFLOW_USE_SYCL
105 
106 class GetSessionTensorOp : public OpKernel {
107  public:
GetSessionTensorOp(OpKernelConstruction * context)108   explicit GetSessionTensorOp(OpKernelConstruction* context)
109       : OpKernel(context) {}
110 
Compute(OpKernelContext * ctx)111   void Compute(OpKernelContext* ctx) override {
112     const Tensor& handle = ctx->input(0);
113     const string& name = handle.scalar<string>()();
114     Tensor val;
115     OP_REQUIRES_OK(ctx, ctx->session_state()->GetTensor(name, &val));
116     ctx->set_output(0, val);
117   }
118 
119   TF_DISALLOW_COPY_AND_ASSIGN(GetSessionTensorOp);
120 };
121 
122 REGISTER_KERNEL_BUILDER(Name("GetSessionTensor").Device(DEVICE_CPU),
123                         GetSessionTensorOp);
124 
125 #define REGISTER_GPU_KERNEL(type)                             \
126   REGISTER_KERNEL_BUILDER(Name("GetSessionTensor")            \
127                               .Device(DEVICE_GPU)             \
128                               .HostMemory("handle")           \
129                               .TypeConstraint<type>("dtype"), \
130                           GetSessionTensorOp)
131 
132 TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
133 REGISTER_GPU_KERNEL(bool);
134 #undef REGISTER_GPU_KERNEL
135 
136 #ifdef TENSORFLOW_USE_SYCL
137 #define REGISTER_SYCL_KERNEL(type)                            \
138   REGISTER_KERNEL_BUILDER(Name("GetSessionTensor")            \
139                               .Device(DEVICE_SYCL)            \
140                               .HostMemory("handle")           \
141                               .TypeConstraint<type>("dtype"), \
142                           GetSessionTensorOp)
143 
144 TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
145 REGISTER_SYCL_KERNEL(bool);
146 #undef REGISTER_SYCL_KERNEL
147 #endif  // TENSORFLOW_USE_SYCL
148 
149 class DeleteSessionTensorOp : public OpKernel {
150  public:
DeleteSessionTensorOp(OpKernelConstruction * context)151   explicit DeleteSessionTensorOp(OpKernelConstruction* context)
152       : OpKernel(context) {}
153 
Compute(OpKernelContext * ctx)154   void Compute(OpKernelContext* ctx) override {
155     const Tensor& handle = ctx->input(0);
156     const string& name = handle.scalar<string>()();
157     OP_REQUIRES_OK(ctx, ctx->session_state()->DeleteTensor(name));
158   }
159 
160   TF_DISALLOW_COPY_AND_ASSIGN(DeleteSessionTensorOp);
161 };
162 
163 REGISTER_KERNEL_BUILDER(Name("DeleteSessionTensor").Device(DEVICE_CPU),
164                         DeleteSessionTensorOp);
165 REGISTER_KERNEL_BUILDER(
166     Name("DeleteSessionTensor").Device(DEVICE_GPU).HostMemory("handle"),
167     DeleteSessionTensorOp);
168 
169 #ifdef TENSORFLOW_USE_SYCL
170 REGISTER_KERNEL_BUILDER(
171     Name("DeleteSessionTensor").Device(DEVICE_SYCL).HostMemory("handle"),
172     DeleteSessionTensorOp);
173 #endif  // TENSORFLOW_USE_SYCL
174 }  // namespace tensorflow
175