1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <deque>
17 #include <vector>
18
19 #include "tensorflow/core/kernels/function_ops.h"
20
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/memory_types.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/graph/gradients.h"
29 #include "tensorflow/core/graph/graph_constructor.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/tracing.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33
34 namespace tensorflow {
35
36 static const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp;
37
ArgOp(OpKernelConstruction * ctx)38 ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
39 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
40 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
41 }
42
Compute(OpKernelContext * ctx)43 void ArgOp::Compute(OpKernelContext* ctx) {
44 auto frame = ctx->call_frame();
45 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
46 Tensor val;
47 OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
48 OP_REQUIRES(ctx, val.dtype() == dtype_,
49 errors::InvalidArgument("Type mismatch: actual ",
50 DataTypeString(val.dtype()),
51 " vs. expect ", DataTypeString(dtype_)));
52 ctx->set_output(0, val);
53 }
54
RetvalOp(OpKernelConstruction * ctx)55 RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
56 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
57 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
58 }
59
Compute(OpKernelContext * ctx)60 void RetvalOp::Compute(OpKernelContext* ctx) {
61 const Tensor& val = ctx->input(0);
62 OP_REQUIRES(ctx, val.dtype() == dtype_,
63 errors::InvalidArgument("Type mismatch: actual ",
64 DataTypeString(val.dtype()),
65 " vs. expect ", DataTypeString(dtype_)));
66 auto frame = ctx->call_frame();
67 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
68 OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
69 }
70
71 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
72 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
73 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
74 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
75
76 #if TENSORFLOW_USE_SYCL
77 #define REGISTER(type) \
78 REGISTER_KERNEL_BUILDER( \
79 Name(kArgOp).Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
80 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
81 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
82 .Device(DEVICE_SYCL)
83 .HostMemory("output")
84 .TypeConstraint<int32>("T"),
85 ArgOp);
86 #undef REGISTER
87 #define REGISTER(type) \
88 REGISTER_KERNEL_BUILDER( \
89 Name(kRetOp).Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
90 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
91 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
92 .Device(DEVICE_SYCL)
93 .HostMemory("input")
94 .TypeConstraint<int32>("T"),
95 RetvalOp);
96 #undef REGISTER
97 #endif
98
99 #define REGISTER(type) \
100 REGISTER_KERNEL_BUILDER( \
101 Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
102 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
103 TF_CALL_QUANTIZED_TYPES(REGISTER)
104 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
105 .Device(DEVICE_GPU)
106 .HostMemory("output")
107 .TypeConstraint<int32>("T"),
108 ArgOp);
109 REGISTER_KERNEL_BUILDER(
110 Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
111 #undef REGISTER
112
113 REGISTER_KERNEL_BUILDER(Name(kArgOp)
114 .Device(DEVICE_GPU)
115 .HostMemory("output")
116 .TypeConstraint<ResourceHandle>("T"),
117 ArgOp);
118
119 REGISTER_KERNEL_BUILDER(Name(kArgOp)
120 .Device(DEVICE_GPU)
121 .HostMemory("output")
122 .TypeConstraint<string>("T"),
123 ArgOp);
124
125 REGISTER_KERNEL_BUILDER(
126 Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<Variant>("T"), ArgOp);
127
128 #define REGISTER(type) \
129 REGISTER_KERNEL_BUILDER( \
130 Name(kRetOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
131 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
132 TF_CALL_QUANTIZED_TYPES(REGISTER)
133 REGISTER(Variant)
134 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
135 .Device(DEVICE_GPU)
136 .HostMemory("input")
137 .TypeConstraint<int32>("T"),
138 RetvalOp);
139 REGISTER_KERNEL_BUILDER(
140 Name(kDeviceRetOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), RetvalOp);
141
142 REGISTER_KERNEL_BUILDER(Name(kRetOp)
143 .Device(DEVICE_GPU)
144 .TypeConstraint<ResourceHandle>("T")
145 .HostMemory("input"),
146 RetvalOp);
147
148 REGISTER_KERNEL_BUILDER(Name(kRetOp)
149 .Device(DEVICE_GPU)
150 .TypeConstraint<string>("T")
151 .HostMemory("input"),
152 RetvalOp);
153 #undef REGISTER
154
155 class PassOn : public OpKernel {
156 public:
PassOn(OpKernelConstruction * ctx)157 explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
158 OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
159 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
160 " vs. ", ctx->num_outputs()));
161 for (int i = 0; i < ctx->num_inputs(); ++i) {
162 OP_REQUIRES(
163 ctx, input_type(i) == output_type(i),
164 errors::Internal("Input and output types for position ", i,
165 " do not match: ", DataTypeString(input_type(i)),
166 " vs. ", DataTypeString(output_type(i))));
167 }
168 }
169
Compute(OpKernelContext * ctx)170 void Compute(OpKernelContext* ctx) override {
171 for (int i = 0; i < ctx->num_inputs(); ++i) {
172 ctx->set_output(i, ctx->input(i));
173 }
174 }
175 };
176
177 REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
178 REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
179
180 #define REGISTER_GPU_KERNELS(type) \
181 REGISTER_KERNEL_BUILDER( \
182 Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
183 PassOn); \
184 REGISTER_KERNEL_BUILDER( \
185 Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
186 PassOn);
187
188 REGISTER_GPU_KERNELS(Eigen::half);
189 REGISTER_GPU_KERNELS(float);
190 REGISTER_GPU_KERNELS(double);
191
192 #undef REGISTER_GPU_KERNELS
193
194 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
195 .Device(DEVICE_GPU)
196 .HostMemory("input")
197 .HostMemory("output")
198 .TypeConstraint<int32>("T"),
199 PassOn);
200 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
201 .Device(DEVICE_GPU)
202 .HostMemory("input")
203 .HostMemory("output")
204 .TypeConstraint<int32>("T"),
205 PassOn);
206
207 #ifdef TENSORFLOW_USE_SYCL
208 #define REGISTER_SYCL_KERNELS(type) \
209 REGISTER_KERNEL_BUILDER( \
210 Name("_ListToArray").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
211 PassOn); \
212 REGISTER_KERNEL_BUILDER( \
213 Name("_ArrayToList").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
214 PassOn);
215
216 REGISTER_SYCL_KERNELS(float);
217 REGISTER_SYCL_KERNELS(double);
218
219 #undef REGISTER_SYCL_KERNELS
220
221 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
222 .Device(DEVICE_SYCL)
223 .HostMemory("input")
224 .HostMemory("output")
225 .TypeConstraint<int32>("T"),
226 PassOn);
227 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
228 .Device(DEVICE_SYCL)
229 .HostMemory("input")
230 .HostMemory("output")
231 .TypeConstraint<int32>("T"),
232 PassOn);
233 #endif // TENSORFLOW_USE_SYCL
234
235 class SymbolicGradientOp : public AsyncOpKernel {
236 public:
SymbolicGradientOp(OpKernelConstruction * ctx)237 explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
238
~SymbolicGradientOp()239 ~SymbolicGradientOp() override {}
240
ComputeAsync(OpKernelContext * ctx,DoneCallback done)241 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
242 FunctionLibraryRuntime* lib = ctx->function_library();
243 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
244 errors::Internal("No function library is provided."),
245 done);
246
247 FunctionLibraryRuntime::Handle handle;
248 OP_REQUIRES_OK_ASYNC(
249 ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
250
251 FunctionLibraryRuntime::Options opts;
252 opts.step_id = ctx->step_id();
253 opts.rendezvous = ctx->rendezvous();
254 opts.cancellation_manager = ctx->cancellation_manager();
255 opts.runner = ctx->runner();
256 opts.stats_collector = ctx->stats_collector();
257 opts.step_container = ctx->step_container();
258 opts.collective_executor = ctx->collective_executor();
259 std::vector<Tensor> args;
260 args.reserve(ctx->num_inputs());
261 for (int i = 0; i < ctx->num_inputs(); ++i) {
262 args.push_back(ctx->input(i));
263 }
264 std::vector<Tensor>* rets = new std::vector<Tensor>;
265 lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
266 if (!status.ok()) {
267 ctx->SetStatus(status);
268 } else if (rets->size() != ctx->num_outputs()) {
269 ctx->SetStatus(errors::InvalidArgument(
270 "SymGrad expects to return ", ctx->num_outputs(),
271 " tensor(s), but get ", rets->size(), " tensor(s) instead."));
272 } else {
273 for (size_t i = 0; i < rets->size(); ++i) {
274 ctx->set_output(i, (*rets)[i]);
275 }
276 }
277 delete rets;
278 done();
279 });
280 }
281
282 private:
283 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
284 };
285
286 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
287 SymbolicGradientOp);
288 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
289 SymbolicGradientOp);
290 #if TENSORFLOW_USE_SYCL
291 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL),
292 SymbolicGradientOp);
293
294 #endif // TENSORFLOW_USE_SYCL
295
RemoteCallOp(OpKernelConstruction * ctx)296 RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
297 OP_REQUIRES_OK(ctx,
298 ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
299 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
300 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
301 }
302
ComputeAsync(OpKernelContext * ctx,DoneCallback done)303 void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
304 FunctionLibraryRuntime* lib = ctx->function_library();
305 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
306 errors::Internal("No function library is provided."), done);
307
308 const string& source_device = lib->device()->name();
309 const Tensor* target;
310 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
311 string target_device;
312 OP_REQUIRES_OK_ASYNC(
313 ctx,
314 DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(),
315 source_device, &target_device),
316 done);
317
318 AttrValueMap attr_values = func_.attr();
319 FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
320 instantiate_opts.target = target_device;
321
322 FunctionTarget function_target = {target_device, lib};
323
324 FunctionLibraryRuntime::Handle handle;
325 {
326 mutex_lock l(mu_);
327 auto cached_entry = handle_cache_.find(function_target);
328 if (cached_entry != handle_cache_.end()) {
329 handle = cached_entry->second;
330 } else {
331 VLOG(1) << "Instantiating " << func_.name() << " on " << target_device;
332 tracing::ScopedActivity activity(strings::StrCat(
333 "RemoteCall: Instantiate: ", func_.name(), " on ", target_device));
334 OP_REQUIRES_OK_ASYNC(
335 ctx,
336 lib->Instantiate(func_.name(), AttrSlice(&attr_values),
337 instantiate_opts, &handle),
338 done);
339 auto insert_result = handle_cache_.insert({function_target, handle});
340 CHECK(insert_result.second) << "Insert unsuccessful.";
341 VLOG(1) << "Instantiated " << func_.name() << " on " << target_device
342 << ", resulting in handle: " << handle << " flr: " << lib;
343 }
344 }
345
346 OpInputList arguments;
347 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
348
349 FunctionLibraryRuntime::Options opts;
350 opts.step_id = ctx->step_id();
351 opts.runner = ctx->runner();
352 opts.source_device = source_device;
353 if (opts.source_device != target_device) {
354 opts.remote_execution = true;
355 }
356 opts.create_rendezvous = true;
357 std::vector<Tensor> args;
358 args.reserve(arguments.size());
359 for (const Tensor& argument : arguments) {
360 args.push_back(argument);
361 }
362 for (const auto& dtype : input_dtypes_) {
363 AllocatorAttributes arg_alloc_attrs;
364 if (DataTypeAlwaysOnHost(dtype)) {
365 arg_alloc_attrs.set_on_host(true);
366 }
367 opts.args_alloc_attrs.push_back(arg_alloc_attrs);
368 }
369 for (const auto& dtype : output_dtypes_) {
370 AllocatorAttributes ret_alloc_attrs;
371 if (DataTypeAlwaysOnHost(dtype)) {
372 ret_alloc_attrs.set_on_host(true);
373 }
374 opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
375 }
376 auto* rets = new std::vector<Tensor>;
377 auto* activity = new tracing::ScopedActivity(strings::StrCat(
378 "RemoteCall: Run: ", func_.name(), " on ", target_device));
379 VLOG(1) << "Running " << func_.name() << " on " << target_device
380 << " with handle: " << handle;
381 lib->Run(opts, handle, args, rets,
382 [rets, activity, done, ctx](const Status& status) {
383 if (!status.ok()) {
384 ctx->SetStatus(status);
385 } else {
386 for (size_t i = 0; i < rets->size(); ++i) {
387 ctx->set_output(i, (*rets)[i]);
388 }
389 }
390 delete rets;
391 delete activity;
392 done();
393 });
394 }
395
396 REGISTER_KERNEL_BUILDER(
397 Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
398 REGISTER_KERNEL_BUILDER(
399 Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
400 #if TENSORFLOW_USE_SYCL
401 REGISTER_KERNEL_BUILDER(
402 Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
403
404 #endif // TENSORFLOW_USE_SYCL
405 } // namespace tensorflow
406