• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/function_ops.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/gradients.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/memory_types.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/tracing.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 
36 namespace tensorflow {
37 
38 static constexpr const char* const kGradientOp =
39     FunctionLibraryDefinition::kGradientOp;
40 
ArgOp(OpKernelConstruction * ctx)41 ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
42   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
43   OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
44 }
45 
Compute(OpKernelContext * ctx)46 void ArgOp::Compute(OpKernelContext* ctx) {
47   auto frame = ctx->call_frame();
48   OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
49   const Tensor* val;
50 
51   auto validate_type = [this](const Tensor& val) {
52     if (val.dtype() == dtype_) {
53       return Status::OK();
54     } else {
55       return errors::InvalidArgument("Type mismatch: actual ",
56                                      DataTypeString(val.dtype()),
57                                      " vs. expect ", DataTypeString(dtype_));
58     }
59   };
60 
61   if (frame->CanConsumeArg(index_)) {
62     Tensor val;
63     frame->ConsumeArg(index_, &val);
64     OP_REQUIRES_OK(ctx, validate_type(val));
65     ctx->set_output(0, std::move(val));
66   } else {
67     OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
68     OP_REQUIRES_OK(ctx, validate_type(*val));
69     ctx->set_output(0, *val);
70   }
71 }
72 
RetvalOp(OpKernelConstruction * ctx)73 RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
74   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
75   OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
76 }
77 
Compute(OpKernelContext * ctx)78 void RetvalOp::Compute(OpKernelContext* ctx) {
79   const Tensor& val = ctx->input(0);
80   OP_REQUIRES(ctx, val.dtype() == dtype_,
81               errors::InvalidArgument("Type mismatch: actual ",
82                                       DataTypeString(val.dtype()),
83                                       " vs. expect ", DataTypeString(dtype_)));
84   auto frame = ctx->call_frame();
85   OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
86   OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
87 }
88 
89 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
90 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
91 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
92 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
93 
94 // TPU ops are only registered when they are required as part of the larger
95 // TPU runtime, and does not need to be registered when selective registration
96 // is turned on.
97 REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp);
98 
99 #define REGISTER(type)     \
100   REGISTER_KERNEL_BUILDER( \
101       Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
102 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
103 TF_CALL_QUANTIZED_TYPES(REGISTER)
104 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
105                                                    .Device(DEVICE_GPU)
106                                                    .HostMemory("output")
107                                                    .TypeConstraint<int32>("T"),
108                                                ArgOp);
109 REGISTER_KERNEL_BUILDER(
110     Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
111 #undef REGISTER
112 
113 REGISTER_KERNEL_BUILDER(Name(kArgOp)
114                             .Device(DEVICE_GPU)
115                             .HostMemory("output")
116                             .TypeConstraint<ResourceHandle>("T"),
117                         ArgOp);
118 
119 REGISTER_KERNEL_BUILDER(Name(kArgOp)
120                             .Device(DEVICE_GPU)
121                             .HostMemory("output")
122                             .TypeConstraint<tstring>("T"),
123                         ArgOp);
124 
125 REGISTER_KERNEL_BUILDER(
126     Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<Variant>("T"), ArgOp);
127 
128 #define REGISTER(type)     \
129   REGISTER_KERNEL_BUILDER( \
130       Name(kRetOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
131 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
132 TF_CALL_QUANTIZED_TYPES(REGISTER)
133 REGISTER(Variant)
134 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
135                                                    .Device(DEVICE_GPU)
136                                                    .HostMemory("input")
137                                                    .TypeConstraint<int32>("T"),
138                                                RetvalOp);
139 REGISTER_KERNEL_BUILDER(
140     Name(kDeviceRetOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), RetvalOp);
141 
142 REGISTER_KERNEL_BUILDER(Name(kRetOp)
143                             .Device(DEVICE_GPU)
144                             .TypeConstraint<ResourceHandle>("T")
145                             .HostMemory("input"),
146                         RetvalOp);
147 
148 REGISTER_KERNEL_BUILDER(Name(kRetOp)
149                             .Device(DEVICE_GPU)
150                             .TypeConstraint<tstring>("T")
151                             .HostMemory("input"),
152                         RetvalOp);
153 #undef REGISTER
154 
155 class PassOn : public OpKernel {
156  public:
PassOn(OpKernelConstruction * ctx)157   explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
158     OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
159                 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
160                                  " vs. ", ctx->num_outputs()));
161     for (int i = 0; i < ctx->num_inputs(); ++i) {
162       OP_REQUIRES(
163           ctx, input_type(i) == output_type(i),
164           errors::Internal("Input and output types for position ", i,
165                            " do not match: ", DataTypeString(input_type(i)),
166                            " vs. ", DataTypeString(output_type(i))));
167     }
168   }
169 
Compute(OpKernelContext * ctx)170   void Compute(OpKernelContext* ctx) override {
171     for (int i = 0; i < ctx->num_inputs(); ++i) {
172       ctx->set_output(i, ctx->input(i));
173     }
174   }
175 };
176 
177 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
178 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
179 
180 #define REGISTER_GPU_KERNELS(type)                                       \
181   REGISTER_KERNEL_BUILDER(                                               \
182       Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
183       PassOn);                                                           \
184   REGISTER_KERNEL_BUILDER(                                               \
185       Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
186       PassOn);
187 
188 REGISTER_GPU_KERNELS(Eigen::half);
189 REGISTER_GPU_KERNELS(float);
190 REGISTER_GPU_KERNELS(double);
191 
192 #undef REGISTER_GPU_KERNELS
193 
194 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
195                             .Device(DEVICE_GPU)
196                             .HostMemory("input")
197                             .HostMemory("output")
198                             .TypeConstraint<int32>("T"),
199                         PassOn);
200 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
201                             .Device(DEVICE_GPU)
202                             .HostMemory("input")
203                             .HostMemory("output")
204                             .TypeConstraint<int32>("T"),
205                         PassOn);
206 
207 class SymbolicGradientOp : public AsyncOpKernel {
208  public:
SymbolicGradientOp(OpKernelConstruction * ctx)209   explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
210 
~SymbolicGradientOp()211   ~SymbolicGradientOp() override {}
212 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)213   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
214     FunctionLibraryRuntime* lib = ctx->function_library();
215     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
216                       errors::Internal("No function library is provided."),
217                       done);
218 
219     FunctionLibraryRuntime::Handle handle;
220     OP_REQUIRES_OK_ASYNC(
221         ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
222 
223     FunctionLibraryRuntime::Options opts;
224     opts.rendezvous = ctx->rendezvous();
225     opts.cancellation_manager = ctx->cancellation_manager();
226     opts.collective_executor = ctx->collective_executor();
227     opts.runner = ctx->runner();
228     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
229     opts.stats_collector = ctx->stats_collector();
230     opts.step_container = ctx->step_container();
231     std::vector<Tensor> args;
232     args.reserve(ctx->num_inputs());
233     for (int i = 0; i < ctx->num_inputs(); ++i) {
234       args.push_back(ctx->input(i));
235     }
236     std::vector<Tensor>* rets = new std::vector<Tensor>;
237     profiler::TraceMe trace_me("SymbolicGradientOp");
238     lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
239       if (!status.ok()) {
240         ctx->SetStatus(status);
241       } else if (rets->size() != ctx->num_outputs()) {
242         ctx->SetStatus(errors::InvalidArgument(
243             "SymGrad expects to return ", ctx->num_outputs(),
244             " tensor(s), but get ", rets->size(), " tensor(s) instead."));
245       } else {
246         for (size_t i = 0; i < rets->size(); ++i) {
247           ctx->set_output(i, std::move((*rets)[i]));
248         }
249       }
250       delete rets;
251       done();
252     });
253   }
254 
255  private:
256   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
257 };
258 
259 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
260                         SymbolicGradientOp);
261 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
262                         SymbolicGradientOp);
263 
RemoteCallOp(OpKernelConstruction * ctx)264 RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
265   OP_REQUIRES_OK(ctx,
266                  ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
267   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
268   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
269 }
270 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)271 void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
272   FunctionLibraryRuntime* lib = ctx->function_library();
273   OP_REQUIRES_ASYNC(ctx, lib != nullptr,
274                     errors::Internal("No function library is provided."), done);
275 
276   const string& source_device = lib->device()->name();
277   const Tensor* target;
278   OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
279 
280   FunctionTarget function_target;
281   OP_REQUIRES_OK_ASYNC(
282       ctx,
283       DeviceNameUtils::CanonicalizeDeviceName(
284           target->scalar<tstring>()(), source_device, &function_target.first),
285       done);
286   function_target.second = lib;
287 
288   const string& target_device = function_target.first;
289   const string& func_name = func_.name();
290 
291   FunctionLibraryRuntime::Handle handle;
292   {
293     mutex_lock l(mu_);
294     auto cached_entry = handle_cache_.find(function_target);
295     if (cached_entry != handle_cache_.end()) {
296       handle = cached_entry->second;
297     } else {
298       VLOG(1) << "Instantiating " << func_name << " on " << target_device;
299       profiler::TraceMe activity(
300           [&] {
301             return strings::StrCat("RemoteCall: Instantiate: ", func_name,
302                                    " on ", target_device);
303           },
304           profiler::TraceMeLevel::kInfo);
305       FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
306       const auto* config = (ctx->function_library())
307                                ? ctx->function_library()->config_proto()
308                                : nullptr;
309       if (config) {
310         instantiate_opts.config_proto = *config;
311       }
312       instantiate_opts.target = target_device;
313       OP_REQUIRES_OK_ASYNC(ctx,
314                            lib->Instantiate(func_name, AttrSlice(&func_.attr()),
315                                             instantiate_opts, &handle),
316                            done);
317       auto insert_result = handle_cache_.insert({function_target, handle});
318       CHECK(insert_result.second) << "Insert unsuccessful.";
319       VLOG(1) << "Instantiated " << func_name << " on " << target_device
320               << ", resulting in handle: " << handle << " flr: " << lib;
321     }
322   }
323 
324   OpInputList arguments;
325   OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
326 
327   FunctionLibraryRuntime::Options opts;
328   opts.runner = nullptr;  // Use default runner at remote device.
329   opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
330   opts.source_device = source_device;
331   if (opts.source_device != target_device) {
332     opts.remote_execution = true;
333   }
334   opts.create_rendezvous = true;
335   CancellationManager* cancel_mgr = nullptr;
336   if (ctx->cancellation_manager() != nullptr) {
337     cancel_mgr = new CancellationManager(ctx->cancellation_manager());
338   }
339   opts.cancellation_manager = cancel_mgr;
340   opts.collective_executor = ctx->collective_executor();
341   std::vector<Tensor> args(arguments.begin(), arguments.end());
342   opts.args_alloc_attrs.reserve(input_dtypes_.size());
343   for (const auto& dtype : input_dtypes_) {
344     AllocatorAttributes arg_alloc_attrs;
345     arg_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
346     opts.args_alloc_attrs.push_back(arg_alloc_attrs);
347   }
348   opts.rets_alloc_attrs.reserve(output_dtypes_.size());
349   for (const auto& dtype : output_dtypes_) {
350     AllocatorAttributes ret_alloc_attrs;
351     ret_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
352     opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
353   }
354   auto* rets = new std::vector<Tensor>;
355   VLOG(1) << "Running " << func_name << " on " << target_device
356           << " with handle: " << handle;
357   profiler::TraceMe trace_me(
358       [&] {
359         return profiler::TraceMeEncode(
360             "RemoteCallOp",
361             {{"func_name", func_name}, {"device", target_device}});
362       },
363       profiler::TraceMeLevel::kInfo);
364   lib->Run(
365       opts, handle, args, rets,
366       [rets, done = std::move(done), func_name, ctx, cancel_mgr,
367        target_device = std::move(function_target.first)](const Status& status) {
368         profiler::TraceMe activity(
369             [&] {
370               return profiler::TraceMeEncode(
371                   "RemoteCallOpDone",
372                   {{"func_name", func_name}, {"device", target_device}});
373             },
374             profiler::TraceMeLevel::kInfo);
375         if (!status.ok()) {
376           ctx->SetStatus(status);
377         } else {
378           for (size_t i = 0; i < rets->size(); ++i) {
379             ctx->set_output(i, std::move((*rets)[i]));
380           }
381         }
382         delete cancel_mgr;
383         delete rets;
384         done();
385       });
386 }
387 
TraceString(const OpKernelContext & ctx,bool verbose) const388 string RemoteCallOp::TraceString(const OpKernelContext& ctx,
389                                  bool verbose) const {
390   string trace_string = profiler::TraceMeOp(
391       strings::StrCat(name_view(), "__", func_.name()), type_string_view());
392   if (verbose) {
393     string shape = ShapeTraceString(ctx);
394     if (!shape.empty()) {
395       trace_string =
396           profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
397     }
398   }
399   return trace_string;
400 }
401 
402 REGISTER_KERNEL_BUILDER(
403     Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
404 REGISTER_KERNEL_BUILDER(
405     Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
406 }  // namespace tensorflow
407