1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/kernels/fill_functor.h"
23 #include "tensorflow/core/kernels/inplace_ops_functor.h"
24 #include "tensorflow/core/lib/core/status.h"
25
26 namespace tensorflow {
27 typedef Eigen::ThreadPoolDevice CPUDevice;
28 #ifdef TENSORFLOW_USE_SYCL
29 typedef Eigen::SyclDevice SyclDevice;
30 #endif // TENSORFLOW_USE_SYCL
31
32 namespace functor {
33
34 template <typename Device, typename T>
DoParallelConcatUpdate(const Device & d,const Tensor & value,int32 loc,Tensor * output)35 Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32 loc,
36 Tensor* output) {
37 auto Tvalue = value.shaped<T, 2>({1, value.NumElements()});
38 auto Toutput = output->flat_outer_dims<T>();
39 auto nrows = Toutput.dimension(0);
40 auto r = (loc % nrows + nrows) % nrows; // Guard index range.
41 Toutput.template chip<0>(r).device(d) = Tvalue.template chip<0>(0);
42 return Status::OK();
43 }
44
45 template <>
DoParallelConcat(const CPUDevice & d,const Tensor & value,int32 loc,Tensor * output)46 Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc,
47 Tensor* output) {
48 CHECK_EQ(value.dtype(), output->dtype());
49 switch (value.dtype()) {
50 #define CASE(type) \
51 case DataTypeToEnum<type>::value: \
52 return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output);
53 TF_CALL_POD_TYPES(CASE);
54 TF_CALL_string(CASE);
55 TF_CALL_variant(CASE);
56 #undef CASE
57 default:
58 return errors::InvalidArgument("Unsupported data type: ",
59 DataTypeString(value.dtype()));
60 }
61 }
62
63 #ifdef TENSORFLOW_USE_SYCL
64 template <>
DoParallelConcat(const SyclDevice & d,const Tensor & value,int32 loc,Tensor * output)65 Status DoParallelConcat(const SyclDevice& d, const Tensor& value, int32 loc,
66 Tensor* output) {
67 CHECK_EQ(value.dtype(), output->dtype());
68 switch (value.dtype()) {
69 #define CASE(type) \
70 case DataTypeToEnum<type>::value: \
71 return DoParallelConcatUpdate<SyclDevice, type>(d, value, loc, output);
72 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(CASE);
73 #undef CASE
74 default:
75 return errors::InvalidArgument("Unsupported data type: ",
76 DataTypeString(value.dtype()));
77 }
78 }
79 #endif // TENSORFLOW_USE_SYCL
80
81 } // end namespace functor
82
83 namespace {
84
85 template <typename Device>
86 class ParallelConcatUpdate : public OpKernel {
87 public:
ParallelConcatUpdate(OpKernelConstruction * ctx)88 explicit ParallelConcatUpdate(OpKernelConstruction* ctx) : OpKernel(ctx) {
89 OP_REQUIRES_OK(ctx, ctx->GetAttr("loc", &loc_));
90 }
91
Compute(OpKernelContext * ctx)92 void Compute(OpKernelContext* ctx) override {
93 auto value = ctx->input(0);
94 auto update = ctx->input(1);
95
96 OP_REQUIRES(
97 ctx, value.dims() == update.dims(),
98 errors::InvalidArgument("value and update shape doesn't match: ",
99 value.shape().DebugString(), " vs. ",
100 update.shape().DebugString()));
101 for (int i = 1; i < value.dims(); ++i) {
102 OP_REQUIRES(
103 ctx, value.dim_size(i) == update.dim_size(i),
104 errors::InvalidArgument("value and update shape doesn't match ",
105 value.shape().DebugString(), " vs. ",
106 update.shape().DebugString()));
107 }
108 OP_REQUIRES(ctx, 1 == update.dim_size(0),
109 errors::InvalidArgument("update shape doesn't match: ",
110 update.shape().DebugString()));
111
112 Tensor output = value; // This creates an alias intentionally.
113 const auto& d = ctx->eigen_device<Device>();
114 OP_REQUIRES_OK(
115 ctx, ::tensorflow::functor::DoParallelConcat(d, update, loc_, &output));
116 ctx->set_output(0, output);
117 }
118
119 private:
120 int32 loc_;
121 };
122
123 template <typename Device, typename T>
124 class ParallelConcatStart : public OpKernel {
125 public:
ParallelConcatStart(OpKernelConstruction * ctx)126 explicit ParallelConcatStart(OpKernelConstruction* ctx) : OpKernel(ctx) {
127 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
128 }
129
Compute(OpKernelContext * ctx)130 void Compute(OpKernelContext* ctx) override {
131 Tensor* out = nullptr;
132 // We do not know whether the output will be used on GPU. Setting it to be
133 // gpu-compatible for now.
134 AllocatorAttributes attr;
135 attr.set_gpu_compatible(true);
136 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape_, &out, attr));
137 }
138
139 private:
140 TensorShape shape_;
141 };
142
143 class FailureKernel : public OpKernel {
144 public:
FailureKernel(OpKernelConstruction * ctx)145 explicit FailureKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {
146 OP_REQUIRES_OK(ctx,
147 errors::Internal("Found instance of parallel_stack which "
148 "could not be properly replaced."));
149 }
150
Compute(OpKernelContext *)151 void Compute(OpKernelContext*) override {}
152 };
153
154 #define REGISTER(type) \
155 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \
156 .Device(DEVICE_CPU) \
157 .TypeConstraint<type>("T"), \
158 ParallelConcatUpdate<CPUDevice>);
159 TF_CALL_POD_STRING_TYPES(REGISTER)
160 #undef REGISTER
161
162 #define REGISTER_EMPTY(type) \
163 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
164 .Device(DEVICE_CPU) \
165 .TypeConstraint<type>("dtype"), \
166 ParallelConcatStart<CPUDevice, type>)
167
168 TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY)
169 #undef REGISTER_EMPTY
170
171 #define REGISTER_PARALLEL_CONCAT(type) \
172 REGISTER_KERNEL_BUILDER( \
173 Name("ParallelConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
174 FailureKernel);
175 TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT);
176 #undef REGISTER_PARALLEL_CONCAT
177
178 #ifdef TENSORFLOW_USE_SYCL
179 #define REGISTER_EMPTY(type) \
180 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
181 .Device(DEVICE_SYCL) \
182 .TypeConstraint<type>("dtype"), \
183 ParallelConcatStart<SyclDevice, type>);
184 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_EMPTY)
185 #undef REGISTER_EMPTY
186
187 #define REGISTER_PARALLEL_CONCAT(type) \
188 REGISTER_KERNEL_BUILDER( \
189 Name("ParallelConcat").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
190 FailureKernel);
191 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_PARALLEL_CONCAT);
192 #undef REGISTER_PARALLEL_CONCAT
193
194 #define REGISTER(type) \
195 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \
196 .Device(DEVICE_SYCL) \
197 .TypeConstraint<type>("T"), \
198 ParallelConcatUpdate<SyclDevice>);
199 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER)
200 #undef REGISTER
201
202 // Register versions that operate on int32 data on the CPU even though the op
203 // has been placed on the SYCL
204
205 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
206 .Device(DEVICE_SYCL)
207 .HostMemory("value")
208 .HostMemory("update")
209 .HostMemory("output")
210 .TypeConstraint<int32>("T"),
211 ParallelConcatUpdate<CPUDevice>);
212 #endif // TENSORFLOW_USE_SYCL
213
214 #if GOOGLE_CUDA
215
216 typedef Eigen::GpuDevice GPUDevice;
217
218 #define REGISTER_PARALLEL_CONCAT_START(type) \
219 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
220 .Device(DEVICE_GPU) \
221 .TypeConstraint<type>("dtype"), \
222 ParallelConcatStart<GPUDevice, type>);
223 TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT_START)
224 #undef REGISTER_PARALLEL_CONCAT_START
225
226 #define REGISTER_PARALLEL_CONCAT(type) \
227 REGISTER_KERNEL_BUILDER( \
228 Name("ParallelConcat").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
229 FailureKernel);
230 TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT);
231 #undef REGISTER_PARALLEL_CONCAT
232
233 #define REGISTER(type) \
234 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \
235 .Device(DEVICE_GPU) \
236 .TypeConstraint<type>("T"), \
237 ParallelConcatUpdate<GPUDevice>);
238 TF_CALL_GPU_NUMBER_TYPES(REGISTER)
239 #undef REGISTER
240
241 // Register versions that operate on int32 data on the CPU even though the op
242 // has been placed on the GPU
243
244 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
245 .Device(DEVICE_GPU)
246 .HostMemory("value")
247 .HostMemory("update")
248 .HostMemory("output")
249 .TypeConstraint<int32>("T"),
250 ParallelConcatUpdate<CPUDevice>);
251 #endif
252
253 class InplaceOpBase : public OpKernel {
254 public:
InplaceOpBase(OpKernelConstruction * ctx)255 explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
256
Compute(OpKernelContext * ctx)257 void Compute(OpKernelContext* ctx) override {
258 auto x = ctx->input(0);
259 auto i = ctx->input(1);
260 auto v = ctx->input(2);
261
262 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(i.shape()),
263 errors::InvalidArgument("i must be a vector. ",
264 i.shape().DebugString()));
265 OP_REQUIRES(ctx, x.dims() == v.dims(),
266 errors::InvalidArgument(
267 "x and v shape doesn't match (ranks differ): ",
268 x.shape().DebugString(), " vs. ", v.shape().DebugString()));
269 for (int i = 1; i < x.dims(); ++i) {
270 OP_REQUIRES(
271 ctx, x.dim_size(i) == v.dim_size(i),
272 errors::InvalidArgument("x and v shape doesn't match at index ", i,
273 " : ", x.shape().DebugString(), " vs. ",
274 v.shape().DebugString()));
275 }
276 OP_REQUIRES(ctx, i.dim_size(0) == v.dim_size(0),
277 errors::InvalidArgument(
278 "i and x shape doesn't match at index 0: ",
279 i.shape().DebugString(), " vs. ", v.shape().DebugString()));
280
281 Tensor y = x; // This creates an alias intentionally.
282 OP_REQUIRES_OK(ctx, DoCompute(ctx, i, v, &y));
283 ctx->set_output(0, y);
284 }
285
286 protected:
287 virtual Status DoCompute(OpKernelContext* ctx, const Tensor& i,
288 const Tensor& v, Tensor* y) = 0;
289 };
290
291 } // end namespace
292
293 namespace functor {
294
295 template <typename T>
DoInplaceOp(const CPUDevice & d,InplaceOpType op,const Tensor & i,const Tensor & v,Tensor * y)296 void DoInplaceOp(const CPUDevice& d, InplaceOpType op, const Tensor& i,
297 const Tensor& v, Tensor* y) {
298 auto Ti = i.flat<int32>();
299 auto Tv = v.flat_outer_dims<T>();
300 auto Ty = y->flat_outer_dims<T>();
301 auto nrows = Ty.dimension(0);
302 for (int64 j = 0; j < Ti.size(); ++j) {
303 auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range.
304 switch (op) {
305 case I_UPDATE:
306 Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j);
307 break;
308 case I_ADD:
309 Ty.template chip<0>(r).device(d) += Tv.template chip<0>(j);
310 break;
311 case I_SUB:
312 Ty.template chip<0>(r).device(d) -= Tv.template chip<0>(j);
313 break;
314 }
315 }
316 }
317
318 // String type only supports inplace update.
DoInplaceStringUpdateOp(const CPUDevice & d,const Tensor & i,const Tensor & v,Tensor * y)319 void DoInplaceStringUpdateOp(const CPUDevice& d, const Tensor& i,
320 const Tensor& v, Tensor* y) {
321 auto Ti = i.flat<int32>();
322 auto Tv = v.flat_outer_dims<string>();
323 auto Ty = y->flat_outer_dims<string>();
324 auto nrows = Ty.dimension(0);
325 for (int64 j = 0; j < Ti.size(); ++j) {
326 auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range.
327 Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j);
328 }
329 }
330
331 template <>
DoInplace(const CPUDevice & device,InplaceOpType op,const Tensor & i,const Tensor & v,Tensor * y)332 Status DoInplace(const CPUDevice& device, InplaceOpType op, const Tensor& i,
333 const Tensor& v, Tensor* y) {
334 CHECK_EQ(v.dtype(), y->dtype());
335 if (op == I_UPDATE) {
336 if (v.dtype() == DT_STRING) {
337 DoInplaceStringUpdateOp(device, i, v, y);
338 return Status::OK();
339 } else if (v.dtype() == DT_BOOL) {
340 DoInplaceOp<bool>(device, op, i, v, y);
341 return Status::OK();
342 }
343 }
344 switch (v.dtype()) {
345 #define CASE(type) \
346 case DataTypeToEnum<type>::value: \
347 DoInplaceOp<type>(device, op, i, v, y); \
348 break;
349 TF_CALL_NUMBER_TYPES(CASE);
350 #undef CASE
351 default:
352 return errors::InvalidArgument("Unsupported data type: ",
353 DataTypeString(v.dtype()));
354 }
355 return Status::OK();
356 }
357
358 } // end namespace functor
359
360 namespace {
361 template <typename Device, functor::InplaceOpType op>
362 class InplaceOp : public InplaceOpBase {
363 public:
InplaceOp(OpKernelConstruction * ctx)364 explicit InplaceOp(OpKernelConstruction* ctx) : InplaceOpBase(ctx) {}
365
366 protected:
DoCompute(OpKernelContext * ctx,const Tensor & i,const Tensor & v,Tensor * y)367 Status DoCompute(OpKernelContext* ctx, const Tensor& i, const Tensor& v,
368 Tensor* y) override {
369 const auto& d = ctx->eigen_device<Device>();
370 return ::tensorflow::functor::DoInplace(d, op, i, v, y);
371 }
372 };
373
374 class CopyOpBase : public OpKernel {
375 public:
CopyOpBase(OpKernelConstruction * ctx)376 explicit CopyOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
377
Compute(OpKernelContext * ctx)378 void Compute(OpKernelContext* ctx) override {
379 auto x = ctx->input(0);
380 Tensor* y;
381 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
382 OP_REQUIRES_OK(ctx, DoCompute(ctx, x, y));
383 }
384
385 protected:
386 virtual Status DoCompute(OpKernelContext* ctx, const Tensor& x,
387 Tensor* y) = 0;
388 };
389
390 template <typename Device>
391 class CopyOp : public CopyOpBase {
392 public:
CopyOp(OpKernelConstruction * ctx)393 explicit CopyOp(OpKernelConstruction* ctx) : CopyOpBase(ctx) {}
394
395 protected:
DoCompute(OpKernelContext * ctx,const Tensor & x,Tensor * y)396 Status DoCompute(OpKernelContext* ctx, const Tensor& x, Tensor* y) override {
397 const auto& d = ctx->eigen_device<Device>();
398 return ::tensorflow::functor::DoCopy(d, x, y);
399 }
400 };
401
402 } // end namespace
403
404 namespace functor {
405
406 typedef Eigen::ThreadPoolDevice CPUDevice;
407
408 template <>
DoCopy(const CPUDevice & device,const Tensor & x,Tensor * y)409 Status DoCopy(const CPUDevice& device, const Tensor& x, Tensor* y) {
410 CHECK_EQ(x.dtype(), y->dtype());
411 switch (x.dtype()) {
412 #define CASE(type) \
413 case DataTypeToEnum<type>::value: \
414 y->flat<type>().device(device) = x.flat<type>(); \
415 break;
416
417 TF_CALL_NUMBER_TYPES(CASE);
418 TF_CALL_bool(CASE);
419 #undef CASE
420 default:
421 return errors::InvalidArgument("Unsupported data type: ",
422 DataTypeString(x.dtype()));
423 }
424 return Status::OK();
425 }
426
427 } // end namespace functor
428
429 namespace {
430 template <typename Device, typename T>
431 class EmptyOp : public OpKernel {
432 public:
EmptyOp(OpKernelConstruction * ctx)433 explicit EmptyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
434 OP_REQUIRES_OK(ctx, ctx->GetAttr("init", &init_));
435 }
436
Compute(OpKernelContext * ctx)437 void Compute(OpKernelContext* ctx) override {
438 const Tensor& shape = ctx->input(0);
439 OP_REQUIRES(
440 ctx, TensorShapeUtils::IsVector(shape.shape()),
441 errors::InvalidArgument("shape must be a vector of int32, got shape ",
442 shape.shape().DebugString()));
443 auto dims = shape.flat<int32>();
444 TensorShape out_shape;
445 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
446 reinterpret_cast<const int32*>(dims.data()),
447 dims.size(), &out_shape));
448 Tensor* out = nullptr;
449 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
450
451 if (init_) {
452 functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
453 out->flat<T>());
454 }
455 }
456
457 private:
458 bool init_;
459 };
460
461 REGISTER_KERNEL_BUILDER(Name("InplaceUpdate").Device(DEVICE_CPU),
462 InplaceOp<CPUDevice, functor::I_UPDATE>);
463 REGISTER_KERNEL_BUILDER(Name("InplaceAdd").Device(DEVICE_CPU),
464 InplaceOp<CPUDevice, functor::I_ADD>);
465 REGISTER_KERNEL_BUILDER(Name("InplaceSub").Device(DEVICE_CPU),
466 InplaceOp<CPUDevice, functor::I_SUB>);
467 REGISTER_KERNEL_BUILDER(Name("DeepCopy").Device(DEVICE_CPU), CopyOp<CPUDevice>);
468
469 #define REGISTER_EMPTY(type, dev) \
470 REGISTER_KERNEL_BUILDER(Name("Empty") \
471 .Device(DEVICE_##dev) \
472 .HostMemory("shape") \
473 .TypeConstraint<type>("dtype"), \
474 EmptyOp<dev##Device, type>)
475
476 REGISTER_EMPTY(float, CPU)
477 REGISTER_EMPTY(double, CPU)
478 REGISTER_EMPTY(Eigen::half, CPU)
479 REGISTER_EMPTY(string, CPU)
480 REGISTER_EMPTY(int32, CPU)
481 REGISTER_EMPTY(int64, CPU)
482 REGISTER_EMPTY(bool, CPU)
483 REGISTER_EMPTY(uint8, CPU)
484
485 #if GOOGLE_CUDA
486
487 typedef Eigen::GpuDevice GPUDevice;
488
489 #define REGISTER(TYPE) \
490 REGISTER_KERNEL_BUILDER( \
491 Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
492 InplaceOp<GPUDevice, functor::I_UPDATE>); \
493 REGISTER_KERNEL_BUILDER( \
494 Name("InplaceAdd").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
495 InplaceOp<GPUDevice, functor::I_ADD>); \
496 REGISTER_KERNEL_BUILDER( \
497 Name("InplaceSub").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
498 InplaceOp<GPUDevice, functor::I_SUB>); \
499 REGISTER_KERNEL_BUILDER( \
500 Name("DeepCopy").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
501 CopyOp<GPUDevice>);
502
503 REGISTER_KERNEL_BUILDER(
504 Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<bool>("T"),
505 InplaceOp<GPUDevice, functor::I_UPDATE>);
506 REGISTER(float);
507 REGISTER(double);
508 REGISTER(Eigen::half);
509 REGISTER(int64);
510
511 REGISTER_KERNEL_BUILDER(Name("InplaceUpdate")
512 .Device(DEVICE_GPU)
513 .HostMemory("x")
514 .HostMemory("i")
515 .HostMemory("v")
516 .HostMemory("y")
517 .TypeConstraint<int32>("T"),
518 InplaceOp<CPUDevice, functor::I_UPDATE>);
519 REGISTER_KERNEL_BUILDER(Name("InplaceAdd")
520 .Device(DEVICE_GPU)
521 .HostMemory("x")
522 .HostMemory("i")
523 .HostMemory("v")
524 .HostMemory("y")
525 .TypeConstraint<int32>("T"),
526 InplaceOp<CPUDevice, functor::I_ADD>);
527 REGISTER_KERNEL_BUILDER(Name("InplaceSub")
528 .Device(DEVICE_GPU)
529 .HostMemory("x")
530 .HostMemory("i")
531 .HostMemory("v")
532 .HostMemory("y")
533 .TypeConstraint<int32>("T"),
534 InplaceOp<CPUDevice, functor::I_SUB>);
535
536 REGISTER_KERNEL_BUILDER(Name("DeepCopy")
537 .Device(DEVICE_GPU)
538 .HostMemory("x")
539 .HostMemory("y")
540 .TypeConstraint<int32>("T"),
541 CopyOp<CPUDevice>);
542 REGISTER_EMPTY(float, GPU);
543 REGISTER_EMPTY(double, GPU);
544 REGISTER_EMPTY(Eigen::half, GPU);
545 REGISTER_EMPTY(int64, GPU);
546 REGISTER_EMPTY(int32, GPU);
547
548 #endif // GOOGLE_CUDA
549
550 } // end namespace
551 } // end namespace tensorflow
552