1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/function_ops.h"
17
18 #include <deque>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/gradients.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/memory_types.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/tracing.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35
36 namespace tensorflow {
37
38 static constexpr const char* const kGradientOp =
39 FunctionLibraryDefinition::kGradientOp;
40
ArgOp(OpKernelConstruction * ctx)41 ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
42 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
43 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
44 }
45
Compute(OpKernelContext * ctx)46 void ArgOp::Compute(OpKernelContext* ctx) {
47 auto frame = ctx->call_frame();
48 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
49 const Tensor* val;
50
51 auto validate_type = [this](const Tensor& val) {
52 if (val.dtype() == dtype_) {
53 return Status::OK();
54 } else {
55 return errors::InvalidArgument("Type mismatch: actual ",
56 DataTypeString(val.dtype()),
57 " vs. expect ", DataTypeString(dtype_));
58 }
59 };
60
61 if (frame->CanConsumeArg(index_)) {
62 Tensor val;
63 frame->ConsumeArg(index_, &val);
64 OP_REQUIRES_OK(ctx, validate_type(val));
65 ctx->set_output(0, std::move(val));
66 } else {
67 OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
68 OP_REQUIRES_OK(ctx, validate_type(*val));
69 ctx->set_output(0, *val);
70 }
71 }
72
RetvalOp(OpKernelConstruction * ctx)73 RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
74 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
75 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
76 }
77
Compute(OpKernelContext * ctx)78 void RetvalOp::Compute(OpKernelContext* ctx) {
79 const Tensor& val = ctx->input(0);
80 OP_REQUIRES(ctx, val.dtype() == dtype_,
81 errors::InvalidArgument("Type mismatch: actual ",
82 DataTypeString(val.dtype()),
83 " vs. expect ", DataTypeString(dtype_)));
84 auto frame = ctx->call_frame();
85 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
86 OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
87 }
88
89 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
90 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
91 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
92 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
93
94 // TPU ops are only registered when they are required as part of the larger
95 // TPU runtime, and does not need to be registered when selective registration
96 // is turned on.
97 REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp);
98
99 #define REGISTER(type) \
100 REGISTER_KERNEL_BUILDER( \
101 Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
102 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
103 TF_CALL_QUANTIZED_TYPES(REGISTER)
104 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
105 .Device(DEVICE_GPU)
106 .HostMemory("output")
107 .TypeConstraint<int32>("T"),
108 ArgOp);
109 REGISTER_KERNEL_BUILDER(
110 Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
111 #undef REGISTER
112
113 REGISTER_KERNEL_BUILDER(Name(kArgOp)
114 .Device(DEVICE_GPU)
115 .HostMemory("output")
116 .TypeConstraint<ResourceHandle>("T"),
117 ArgOp);
118
119 REGISTER_KERNEL_BUILDER(Name(kArgOp)
120 .Device(DEVICE_GPU)
121 .HostMemory("output")
122 .TypeConstraint<tstring>("T"),
123 ArgOp);
124
125 REGISTER_KERNEL_BUILDER(
126 Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<Variant>("T"), ArgOp);
127
128 #define REGISTER(type) \
129 REGISTER_KERNEL_BUILDER( \
130 Name(kRetOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
131 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
132 TF_CALL_QUANTIZED_TYPES(REGISTER)
133 REGISTER(Variant)
134 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
135 .Device(DEVICE_GPU)
136 .HostMemory("input")
137 .TypeConstraint<int32>("T"),
138 RetvalOp);
139 REGISTER_KERNEL_BUILDER(
140 Name(kDeviceRetOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), RetvalOp);
141
142 REGISTER_KERNEL_BUILDER(Name(kRetOp)
143 .Device(DEVICE_GPU)
144 .TypeConstraint<ResourceHandle>("T")
145 .HostMemory("input"),
146 RetvalOp);
147
148 REGISTER_KERNEL_BUILDER(Name(kRetOp)
149 .Device(DEVICE_GPU)
150 .TypeConstraint<tstring>("T")
151 .HostMemory("input"),
152 RetvalOp);
153 #undef REGISTER
154
155 class PassOn : public OpKernel {
156 public:
PassOn(OpKernelConstruction * ctx)157 explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
158 OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
159 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
160 " vs. ", ctx->num_outputs()));
161 for (int i = 0; i < ctx->num_inputs(); ++i) {
162 OP_REQUIRES(
163 ctx, input_type(i) == output_type(i),
164 errors::Internal("Input and output types for position ", i,
165 " do not match: ", DataTypeString(input_type(i)),
166 " vs. ", DataTypeString(output_type(i))));
167 }
168 }
169
Compute(OpKernelContext * ctx)170 void Compute(OpKernelContext* ctx) override {
171 for (int i = 0; i < ctx->num_inputs(); ++i) {
172 ctx->set_output(i, ctx->input(i));
173 }
174 }
175 };
176
177 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
178 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
179
180 #define REGISTER_GPU_KERNELS(type) \
181 REGISTER_KERNEL_BUILDER( \
182 Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
183 PassOn); \
184 REGISTER_KERNEL_BUILDER( \
185 Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
186 PassOn);
187
188 REGISTER_GPU_KERNELS(Eigen::half);
189 REGISTER_GPU_KERNELS(float);
190 REGISTER_GPU_KERNELS(double);
191
192 #undef REGISTER_GPU_KERNELS
193
194 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
195 .Device(DEVICE_GPU)
196 .HostMemory("input")
197 .HostMemory("output")
198 .TypeConstraint<int32>("T"),
199 PassOn);
200 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
201 .Device(DEVICE_GPU)
202 .HostMemory("input")
203 .HostMemory("output")
204 .TypeConstraint<int32>("T"),
205 PassOn);
206
207 class SymbolicGradientOp : public AsyncOpKernel {
208 public:
SymbolicGradientOp(OpKernelConstruction * ctx)209 explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
210
~SymbolicGradientOp()211 ~SymbolicGradientOp() override {}
212
ComputeAsync(OpKernelContext * ctx,DoneCallback done)213 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
214 FunctionLibraryRuntime* lib = ctx->function_library();
215 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
216 errors::Internal("No function library is provided."),
217 done);
218
219 FunctionLibraryRuntime::Handle handle;
220 OP_REQUIRES_OK_ASYNC(
221 ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
222
223 FunctionLibraryRuntime::Options opts;
224 opts.rendezvous = ctx->rendezvous();
225 opts.cancellation_manager = ctx->cancellation_manager();
226 opts.collective_executor = ctx->collective_executor();
227 opts.runner = ctx->runner();
228 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
229 opts.stats_collector = ctx->stats_collector();
230 opts.step_container = ctx->step_container();
231 std::vector<Tensor> args;
232 args.reserve(ctx->num_inputs());
233 for (int i = 0; i < ctx->num_inputs(); ++i) {
234 args.push_back(ctx->input(i));
235 }
236 std::vector<Tensor>* rets = new std::vector<Tensor>;
237 profiler::TraceMe trace_me("SymbolicGradientOp");
238 lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
239 if (!status.ok()) {
240 ctx->SetStatus(status);
241 } else if (rets->size() != ctx->num_outputs()) {
242 ctx->SetStatus(errors::InvalidArgument(
243 "SymGrad expects to return ", ctx->num_outputs(),
244 " tensor(s), but get ", rets->size(), " tensor(s) instead."));
245 } else {
246 for (size_t i = 0; i < rets->size(); ++i) {
247 ctx->set_output(i, std::move((*rets)[i]));
248 }
249 }
250 delete rets;
251 done();
252 });
253 }
254
255 private:
256 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
257 };
258
259 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
260 SymbolicGradientOp);
261 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
262 SymbolicGradientOp);
263
RemoteCallOp(OpKernelConstruction * ctx)264 RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
265 OP_REQUIRES_OK(ctx,
266 ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
267 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
268 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
269 }
270
ComputeAsync(OpKernelContext * ctx,DoneCallback done)271 void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
272 FunctionLibraryRuntime* lib = ctx->function_library();
273 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
274 errors::Internal("No function library is provided."), done);
275
276 const string& source_device = lib->device()->name();
277 const Tensor* target;
278 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
279
280 FunctionTarget function_target;
281 OP_REQUIRES_OK_ASYNC(
282 ctx,
283 DeviceNameUtils::CanonicalizeDeviceName(
284 target->scalar<tstring>()(), source_device, &function_target.first),
285 done);
286 function_target.second = lib;
287
288 const string& target_device = function_target.first;
289 const string& func_name = func_.name();
290
291 FunctionLibraryRuntime::Handle handle;
292 {
293 mutex_lock l(mu_);
294 auto cached_entry = handle_cache_.find(function_target);
295 if (cached_entry != handle_cache_.end()) {
296 handle = cached_entry->second;
297 } else {
298 VLOG(1) << "Instantiating " << func_name << " on " << target_device;
299 profiler::TraceMe activity(
300 [&] {
301 return strings::StrCat("RemoteCall: Instantiate: ", func_name,
302 " on ", target_device);
303 },
304 profiler::TraceMeLevel::kInfo);
305 FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
306 const auto* config = (ctx->function_library())
307 ? ctx->function_library()->config_proto()
308 : nullptr;
309 if (config) {
310 instantiate_opts.config_proto = *config;
311 }
312 instantiate_opts.target = target_device;
313 OP_REQUIRES_OK_ASYNC(ctx,
314 lib->Instantiate(func_name, AttrSlice(&func_.attr()),
315 instantiate_opts, &handle),
316 done);
317 auto insert_result = handle_cache_.insert({function_target, handle});
318 CHECK(insert_result.second) << "Insert unsuccessful.";
319 VLOG(1) << "Instantiated " << func_name << " on " << target_device
320 << ", resulting in handle: " << handle << " flr: " << lib;
321 }
322 }
323
324 OpInputList arguments;
325 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
326
327 FunctionLibraryRuntime::Options opts;
328 opts.runner = nullptr; // Use default runner at remote device.
329 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
330 opts.source_device = source_device;
331 if (opts.source_device != target_device) {
332 opts.remote_execution = true;
333 }
334 opts.create_rendezvous = true;
335 CancellationManager* cancel_mgr = nullptr;
336 if (ctx->cancellation_manager() != nullptr) {
337 cancel_mgr = new CancellationManager(ctx->cancellation_manager());
338 }
339 opts.cancellation_manager = cancel_mgr;
340 opts.collective_executor = ctx->collective_executor();
341 std::vector<Tensor> args(arguments.begin(), arguments.end());
342 opts.args_alloc_attrs.reserve(input_dtypes_.size());
343 for (const auto& dtype : input_dtypes_) {
344 AllocatorAttributes arg_alloc_attrs;
345 arg_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
346 opts.args_alloc_attrs.push_back(arg_alloc_attrs);
347 }
348 opts.rets_alloc_attrs.reserve(output_dtypes_.size());
349 for (const auto& dtype : output_dtypes_) {
350 AllocatorAttributes ret_alloc_attrs;
351 ret_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
352 opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
353 }
354 auto* rets = new std::vector<Tensor>;
355 VLOG(1) << "Running " << func_name << " on " << target_device
356 << " with handle: " << handle;
357 profiler::TraceMe trace_me(
358 [&] {
359 return profiler::TraceMeEncode(
360 "RemoteCallOp",
361 {{"func_name", func_name}, {"device", target_device}});
362 },
363 profiler::TraceMeLevel::kInfo);
364 lib->Run(
365 opts, handle, args, rets,
366 [rets, done = std::move(done), func_name, ctx, cancel_mgr,
367 target_device = std::move(function_target.first)](const Status& status) {
368 profiler::TraceMe activity(
369 [&] {
370 return profiler::TraceMeEncode(
371 "RemoteCallOpDone",
372 {{"func_name", func_name}, {"device", target_device}});
373 },
374 profiler::TraceMeLevel::kInfo);
375 if (!status.ok()) {
376 ctx->SetStatus(status);
377 } else {
378 for (size_t i = 0; i < rets->size(); ++i) {
379 ctx->set_output(i, std::move((*rets)[i]));
380 }
381 }
382 delete cancel_mgr;
383 delete rets;
384 done();
385 });
386 }
387
TraceString(const OpKernelContext & ctx,bool verbose) const388 string RemoteCallOp::TraceString(const OpKernelContext& ctx,
389 bool verbose) const {
390 string trace_string = profiler::TraceMeOp(
391 strings::StrCat(name_view(), "__", func_.name()), type_string_view());
392 if (verbose) {
393 string shape = ShapeTraceString(ctx);
394 if (!shape.empty()) {
395 trace_string =
396 profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
397 }
398 }
399 return trace_string;
400 }
401
402 REGISTER_KERNEL_BUILDER(
403 Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
404 REGISTER_KERNEL_BUILDER(
405 Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
406 } // namespace tensorflow
407