• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <deque>
17 #include <vector>
18 
19 #include "tensorflow/core/kernels/function_ops.h"
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/memory_types.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/graph/gradients.h"
29 #include "tensorflow/core/graph/graph_constructor.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/tracing.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 
34 namespace tensorflow {
35 
36 static const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp;
37 
ArgOp(OpKernelConstruction * ctx)38 ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
39   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
40   OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
41 }
42 
Compute(OpKernelContext * ctx)43 void ArgOp::Compute(OpKernelContext* ctx) {
44   auto frame = ctx->call_frame();
45   OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
46   Tensor val;
47   OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
48   OP_REQUIRES(ctx, val.dtype() == dtype_,
49               errors::InvalidArgument("Type mismatch: actual ",
50                                       DataTypeString(val.dtype()),
51                                       " vs. expect ", DataTypeString(dtype_)));
52   ctx->set_output(0, val);
53 }
54 
RetvalOp(OpKernelConstruction * ctx)55 RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
56   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
57   OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
58 }
59 
Compute(OpKernelContext * ctx)60 void RetvalOp::Compute(OpKernelContext* ctx) {
61   const Tensor& val = ctx->input(0);
62   OP_REQUIRES(ctx, val.dtype() == dtype_,
63               errors::InvalidArgument("Type mismatch: actual ",
64                                       DataTypeString(val.dtype()),
65                                       " vs. expect ", DataTypeString(dtype_)));
66   auto frame = ctx->call_frame();
67   OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
68   OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
69 }
70 
71 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
72 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
73 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
74 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
75 
76 #if TENSORFLOW_USE_SYCL
77 #define REGISTER(type)     \
78   REGISTER_KERNEL_BUILDER( \
79       Name(kArgOp).Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
80 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
81 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
82                                                    .Device(DEVICE_SYCL)
83                                                    .HostMemory("output")
84                                                    .TypeConstraint<int32>("T"),
85                                                ArgOp);
86 #undef REGISTER
87 #define REGISTER(type)     \
88   REGISTER_KERNEL_BUILDER( \
89       Name(kRetOp).Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
90 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
91 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
92                                                    .Device(DEVICE_SYCL)
93                                                    .HostMemory("input")
94                                                    .TypeConstraint<int32>("T"),
95                                                RetvalOp);
96 #undef REGISTER
97 #endif
98 
99 #define REGISTER(type)     \
100   REGISTER_KERNEL_BUILDER( \
101       Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
102 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
103 TF_CALL_QUANTIZED_TYPES(REGISTER)
104 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
105                                                    .Device(DEVICE_GPU)
106                                                    .HostMemory("output")
107                                                    .TypeConstraint<int32>("T"),
108                                                ArgOp);
109 REGISTER_KERNEL_BUILDER(
110     Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
111 #undef REGISTER
112 
113 REGISTER_KERNEL_BUILDER(Name(kArgOp)
114                             .Device(DEVICE_GPU)
115                             .HostMemory("output")
116                             .TypeConstraint<ResourceHandle>("T"),
117                         ArgOp);
118 
119 REGISTER_KERNEL_BUILDER(Name(kArgOp)
120                             .Device(DEVICE_GPU)
121                             .HostMemory("output")
122                             .TypeConstraint<string>("T"),
123                         ArgOp);
124 
125 REGISTER_KERNEL_BUILDER(
126     Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<Variant>("T"), ArgOp);
127 
128 #define REGISTER(type)     \
129   REGISTER_KERNEL_BUILDER( \
130       Name(kRetOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
131 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
132 TF_CALL_QUANTIZED_TYPES(REGISTER)
133 REGISTER(Variant)
134 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
135                                                    .Device(DEVICE_GPU)
136                                                    .HostMemory("input")
137                                                    .TypeConstraint<int32>("T"),
138                                                RetvalOp);
139 REGISTER_KERNEL_BUILDER(
140     Name(kDeviceRetOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), RetvalOp);
141 
142 REGISTER_KERNEL_BUILDER(Name(kRetOp)
143                             .Device(DEVICE_GPU)
144                             .TypeConstraint<ResourceHandle>("T")
145                             .HostMemory("input"),
146                         RetvalOp);
147 
148 REGISTER_KERNEL_BUILDER(Name(kRetOp)
149                             .Device(DEVICE_GPU)
150                             .TypeConstraint<string>("T")
151                             .HostMemory("input"),
152                         RetvalOp);
153 #undef REGISTER
154 
155 class PassOn : public OpKernel {
156  public:
PassOn(OpKernelConstruction * ctx)157   explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
158     OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
159                 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
160                                  " vs. ", ctx->num_outputs()));
161     for (int i = 0; i < ctx->num_inputs(); ++i) {
162       OP_REQUIRES(
163           ctx, input_type(i) == output_type(i),
164           errors::Internal("Input and output types for position ", i,
165                            " do not match: ", DataTypeString(input_type(i)),
166                            " vs. ", DataTypeString(output_type(i))));
167     }
168   }
169 
Compute(OpKernelContext * ctx)170   void Compute(OpKernelContext* ctx) override {
171     for (int i = 0; i < ctx->num_inputs(); ++i) {
172       ctx->set_output(i, ctx->input(i));
173     }
174   }
175 };
176 
177 REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
178 REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
179 
180 #define REGISTER_GPU_KERNELS(type)                                       \
181   REGISTER_KERNEL_BUILDER(                                               \
182       Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
183       PassOn);                                                           \
184   REGISTER_KERNEL_BUILDER(                                               \
185       Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
186       PassOn);
187 
188 REGISTER_GPU_KERNELS(Eigen::half);
189 REGISTER_GPU_KERNELS(float);
190 REGISTER_GPU_KERNELS(double);
191 
192 #undef REGISTER_GPU_KERNELS
193 
194 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
195                             .Device(DEVICE_GPU)
196                             .HostMemory("input")
197                             .HostMemory("output")
198                             .TypeConstraint<int32>("T"),
199                         PassOn);
200 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
201                             .Device(DEVICE_GPU)
202                             .HostMemory("input")
203                             .HostMemory("output")
204                             .TypeConstraint<int32>("T"),
205                         PassOn);
206 
207 #ifdef TENSORFLOW_USE_SYCL
208 #define REGISTER_SYCL_KERNELS(type)                                       \
209   REGISTER_KERNEL_BUILDER(                                                \
210       Name("_ListToArray").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
211       PassOn);                                                            \
212   REGISTER_KERNEL_BUILDER(                                                \
213       Name("_ArrayToList").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
214       PassOn);
215 
216 REGISTER_SYCL_KERNELS(float);
217 REGISTER_SYCL_KERNELS(double);
218 
219 #undef REGISTER_SYCL_KERNELS
220 
221 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
222                             .Device(DEVICE_SYCL)
223                             .HostMemory("input")
224                             .HostMemory("output")
225                             .TypeConstraint<int32>("T"),
226                         PassOn);
227 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
228                             .Device(DEVICE_SYCL)
229                             .HostMemory("input")
230                             .HostMemory("output")
231                             .TypeConstraint<int32>("T"),
232                         PassOn);
233 #endif  // TENSORFLOW_USE_SYCL
234 
235 class SymbolicGradientOp : public AsyncOpKernel {
236  public:
SymbolicGradientOp(OpKernelConstruction * ctx)237   explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
238 
~SymbolicGradientOp()239   ~SymbolicGradientOp() override {}
240 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)241   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
242     FunctionLibraryRuntime* lib = ctx->function_library();
243     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
244                       errors::Internal("No function library is provided."),
245                       done);
246 
247     FunctionLibraryRuntime::Handle handle;
248     OP_REQUIRES_OK_ASYNC(
249         ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
250 
251     FunctionLibraryRuntime::Options opts;
252     opts.step_id = ctx->step_id();
253     opts.rendezvous = ctx->rendezvous();
254     opts.cancellation_manager = ctx->cancellation_manager();
255     opts.runner = ctx->runner();
256     opts.stats_collector = ctx->stats_collector();
257     opts.step_container = ctx->step_container();
258     opts.collective_executor = ctx->collective_executor();
259     std::vector<Tensor> args;
260     args.reserve(ctx->num_inputs());
261     for (int i = 0; i < ctx->num_inputs(); ++i) {
262       args.push_back(ctx->input(i));
263     }
264     std::vector<Tensor>* rets = new std::vector<Tensor>;
265     lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
266       if (!status.ok()) {
267         ctx->SetStatus(status);
268       } else if (rets->size() != ctx->num_outputs()) {
269         ctx->SetStatus(errors::InvalidArgument(
270             "SymGrad expects to return ", ctx->num_outputs(),
271             " tensor(s), but get ", rets->size(), " tensor(s) instead."));
272       } else {
273         for (size_t i = 0; i < rets->size(); ++i) {
274           ctx->set_output(i, (*rets)[i]);
275         }
276       }
277       delete rets;
278       done();
279     });
280   }
281 
282  private:
283   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
284 };
285 
286 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
287                         SymbolicGradientOp);
288 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
289                         SymbolicGradientOp);
290 #if TENSORFLOW_USE_SYCL
291 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL),
292                         SymbolicGradientOp);
293 
294 #endif  // TENSORFLOW_USE_SYCL
295 
RemoteCallOp(OpKernelConstruction * ctx)296 RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
297   OP_REQUIRES_OK(ctx,
298                  ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
299   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
300   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
301 }
302 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)303 void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
304   FunctionLibraryRuntime* lib = ctx->function_library();
305   OP_REQUIRES_ASYNC(ctx, lib != nullptr,
306                     errors::Internal("No function library is provided."), done);
307 
308   const string& source_device = lib->device()->name();
309   const Tensor* target;
310   OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
311   string target_device;
312   OP_REQUIRES_OK_ASYNC(
313       ctx,
314       DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(),
315                                               source_device, &target_device),
316       done);
317 
318   AttrValueMap attr_values = func_.attr();
319   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
320   instantiate_opts.target = target_device;
321 
322   FunctionTarget function_target = {target_device, lib};
323 
324   FunctionLibraryRuntime::Handle handle;
325   {
326     mutex_lock l(mu_);
327     auto cached_entry = handle_cache_.find(function_target);
328     if (cached_entry != handle_cache_.end()) {
329       handle = cached_entry->second;
330     } else {
331       VLOG(1) << "Instantiating " << func_.name() << " on " << target_device;
332       tracing::ScopedActivity activity(strings::StrCat(
333           "RemoteCall: Instantiate: ", func_.name(), " on ", target_device));
334       OP_REQUIRES_OK_ASYNC(
335           ctx,
336           lib->Instantiate(func_.name(), AttrSlice(&attr_values),
337                            instantiate_opts, &handle),
338           done);
339       auto insert_result = handle_cache_.insert({function_target, handle});
340       CHECK(insert_result.second) << "Insert unsuccessful.";
341       VLOG(1) << "Instantiated " << func_.name() << " on " << target_device
342               << ", resulting in handle: " << handle << " flr: " << lib;
343     }
344   }
345 
346   OpInputList arguments;
347   OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
348 
349   FunctionLibraryRuntime::Options opts;
350   opts.step_id = ctx->step_id();
351   opts.runner = ctx->runner();
352   opts.source_device = source_device;
353   if (opts.source_device != target_device) {
354     opts.remote_execution = true;
355   }
356   opts.create_rendezvous = true;
357   std::vector<Tensor> args;
358   args.reserve(arguments.size());
359   for (const Tensor& argument : arguments) {
360     args.push_back(argument);
361   }
362   for (const auto& dtype : input_dtypes_) {
363     AllocatorAttributes arg_alloc_attrs;
364     if (DataTypeAlwaysOnHost(dtype)) {
365       arg_alloc_attrs.set_on_host(true);
366     }
367     opts.args_alloc_attrs.push_back(arg_alloc_attrs);
368   }
369   for (const auto& dtype : output_dtypes_) {
370     AllocatorAttributes ret_alloc_attrs;
371     if (DataTypeAlwaysOnHost(dtype)) {
372       ret_alloc_attrs.set_on_host(true);
373     }
374     opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
375   }
376   auto* rets = new std::vector<Tensor>;
377   auto* activity = new tracing::ScopedActivity(strings::StrCat(
378       "RemoteCall: Run: ", func_.name(), " on ", target_device));
379   VLOG(1) << "Running " << func_.name() << " on " << target_device
380           << " with handle: " << handle;
381   lib->Run(opts, handle, args, rets,
382            [rets, activity, done, ctx](const Status& status) {
383              if (!status.ok()) {
384                ctx->SetStatus(status);
385              } else {
386                for (size_t i = 0; i < rets->size(); ++i) {
387                  ctx->set_output(i, (*rets)[i]);
388                }
389              }
390              delete rets;
391              delete activity;
392              done();
393            });
394 }
395 
396 REGISTER_KERNEL_BUILDER(
397     Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
398 REGISTER_KERNEL_BUILDER(
399     Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
400 #if TENSORFLOW_USE_SYCL
401 REGISTER_KERNEL_BUILDER(
402     Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
403 
404 #endif  // TENSORFLOW_USE_SYCL
405 }  // namespace tensorflow
406