• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/permuter.h"
16 
17 #include <algorithm>
18 
19 #include "absl/memory/memory.h"
20 #include "tensorflow/core/common_runtime/base_collective_executor.h"
21 #include "tensorflow/core/common_runtime/collective_rma_local.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/device_resolver_local.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
26 #include "tensorflow/core/common_runtime/threadpool_device.h"
27 #include "tensorflow/core/framework/collective.h"
28 #include "tensorflow/core/framework/fake_input.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/node_def_builder.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/lib/core/notification.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/unbounded_work_queue.h"
37 #include "tensorflow/core/public/session_options.h"
38 #include "tensorflow/core/public/version.h"
39 
40 namespace tensorflow {
41 namespace {
42 
43 static int64 kStepId = 123;
44 
45 // Wraps CollectiveRemoteAccessLocal with the ability to return an
46 // error status to the N'th action.
47 // TODO(b/113171733): factor out of this file and ring_reducer_test.cc
48 // into a single common source.
49 class FailTestRMA : public CollectiveRemoteAccessLocal {
50  public:
FailTestRMA(const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,int64 step_id,int fail_after)51   FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
52               int64 step_id, int fail_after)
53       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
54         fail_after_(fail_after) {}
55 
MaybeFail(const StatusCallback & done)56   bool MaybeFail(const StatusCallback& done) {
57     bool fail_now = false;
58     {
59       mutex_lock l(mu_);
60       if (fail_after_ > 0) {
61         fail_now = (--fail_after_ == 0);
62       }
63     }
64     if (fail_now) {
65       auto error = errors::Internal("Deliberate failure");
66       LOG(INFO) << "triggering failure " << error;
67       SchedNonBlockingClosureAfter(
68           1000, [this, error] { buf_rendezvous()->StartAbort(error); });
69       done(error);
70       return true;
71     }
72     return false;
73   }
74 
RecvFromPeer(const string & peer_device,const string & peer_task,bool peer_is_local,const string & key,Device * to_device,DeviceContext * to_device_ctx,const AllocatorAttributes & to_alloc_attr,Tensor * to_tensor,const DeviceLocality & client_locality,int stream_index,CancellationManager * cancellation_manager,const StatusCallback & done)75   void RecvFromPeer(const string& peer_device, const string& peer_task,
76                     bool peer_is_local, const string& key, Device* to_device,
77                     DeviceContext* to_device_ctx,
78                     const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
79                     const DeviceLocality& client_locality, int stream_index,
80                     CancellationManager* cancellation_manager,
81                     const StatusCallback& done) override {
82     if (MaybeFail(done)) return;
83     CollectiveRemoteAccessLocal::RecvFromPeer(
84         peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
85         to_alloc_attr, to_tensor, client_locality, stream_index,
86         cancellation_manager, done);
87   }
88 
PostToPeer(const string & peer_device,const string & peer_task,const string & key,Device * from_device,DeviceContext * from_device_ctx,const AllocatorAttributes & from_alloc_attr,const Tensor * from_tensor,const DeviceLocality & client_locality,CancellationManager * cancellation_manager,const StatusCallback & done)89   void PostToPeer(const string& peer_device, const string& peer_task,
90                   const string& key, Device* from_device,
91                   DeviceContext* from_device_ctx,
92                   const AllocatorAttributes& from_alloc_attr,
93                   const Tensor* from_tensor,
94                   const DeviceLocality& client_locality,
95                   CancellationManager* cancellation_manager,
96                   const StatusCallback& done) override {
97     if (MaybeFail(done)) return;
98     CollectiveRemoteAccessLocal::PostToPeer(
99         peer_device, peer_task, key, from_device, from_device_ctx,
100         from_alloc_attr, from_tensor, client_locality, cancellation_manager,
101         done);
102   }
103 
104   mutex mu_;
105   int fail_after_ TF_GUARDED_BY(mu_);
106 };
107 
108 class PermuterTest : public ::testing::Test {
109  protected:
PermuterTest()110   PermuterTest()
111       : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
112 
~PermuterTest()113   ~PermuterTest() override {
114     stop_ = true;
115     for (auto i : instances_) delete i;
116     if (col_exec_) col_exec_->Unref();
117     if (col_params_) col_params_->Unref();
118   }
119 
120 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
InitGPUDevices()121   void InitGPUDevices() {
122     auto device_factory = DeviceFactory::GetFactory("GPU");
123     CHECK(device_factory);
124     SessionOptions options;
125     Status s = device_factory->CreateDevices(
126         options, "/job:worker/replica:0/task:0", &gpu_devices_);
127     CHECK(s.ok());
128   }
129 #endif
130 
Init(int num_workers,int num_devices_per_worker,DataType dtype,const DeviceType & device_type,int fail_after)131   void Init(int num_workers, int num_devices_per_worker, DataType dtype,
132             const DeviceType& device_type, int fail_after) {
133 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
134     InitGPUDevices();
135 #endif
136     device_type_ = device_type;
137     std::vector<std::unique_ptr<Device>> local_devices;
138     SessionOptions sess_opts;
139     sess_opts.env = Env::Default();
140     Bytes mem_limit(4 << 20);
141     DeviceLocality dev_locality;
142     for (int wi = 0; wi < num_workers; ++wi) {
143       for (int di = 0; di < num_devices_per_worker; ++di) {
144         if (device_type == DEVICE_CPU) {
145           string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi,
146                                             "/device:CPU:", di);
147           local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
148               sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
149         } else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
150           int dev_idx = (wi * num_devices_per_worker) + di;
151           if (dev_idx >= static_cast<int>(gpu_devices_.size())) {
152             LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
153                          "than one ring node.";
154           } else {
155             local_devices.push_back(std::move(gpu_devices_[dev_idx]));
156           }
157         } else {
158           LOG(FATAL) << "Unsupported device_type " << device_type;
159         }
160       }
161     }
162     if (!dev_mgr_ || device_type == DEVICE_CPU) {
163       dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
164     }
165     if (!gpu_ring_order_) {
166       gpu_ring_order_ = absl::make_unique<string>();
167     }
168     dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
169     work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
170     rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId,
171                            fail_after);
172     col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
173                                            dev_mgr_.get(),
174                                            gpu_ring_order_.get(), work_queue_);
175     col_params_ = new CollectiveParams();
176     col_params_->name = "test_collective";
177     col_params_->instance.data_type = dtype;
178     static const int kInstanceKey = 18;
179     col_params_->instance.instance_key = kInstanceKey;
180     col_params_->group.device_type = device_type;
181     col_params_->instance.type = PERMUTE_COLLECTIVE;
182 
183     // Set up all the fake device contexts.
184     for (int wi = 0; wi < num_workers; wi++) {
185       for (int di = 0; di < num_devices_per_worker; di++) {
186         string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
187         string dev_name;
188         if (device_type == DEVICE_GPU) {
189           dev_name = strings::StrCat(task_name, "/device:GPU:0");
190         } else {
191           dev_name = strings::StrCat(task_name, "/device:CPU:", di);
192         }
193         col_params_->group.device_names.push_back(dev_name);
194         col_params_->instance.devices.push_back(dev_name);
195         int default_rank = wi * num_devices_per_worker + di;
196         permutation_.push_back(default_rank);
197         col_params_->group.task_names.push_back(task_name);
198         col_params_->task.is_local.push_back(true);
199       }
200     }
201 
202     // Generate a permutation by permuting every two instances.
203     // E.g. [0,1] becomes [1,0]
204     //      [0,1,2,3] becomes [1,0,3,2]
205     for (int i = 0; i < permutation_.size(); i += 2) {
206       // If the total number of instances is odd,
207       // swap the last instance with the first.
208       // E.g. [0,1,2] becomes [2,0,1]
209       if (permutation_.size() == i + 1) {
210         std::swap(permutation_[i], permutation_[0]);
211         continue;
212       }
213       std::next_permutation(permutation_.begin() + i,
214                             permutation_.begin() + i + 2);
215     }
216     col_params_->instance.permutation = permutation_;
217 
218     for (int wi = 0; wi < num_workers; wi++) {
219       for (int di = 0; di < num_devices_per_worker; di++) {
220         int default_rank = wi * num_devices_per_worker + di;
221         instances_.push_back(new DeviceInstance(
222             default_rank, col_params_->group.device_names[default_rank],
223             device_type, this));
224       }
225     }
226   }
227 
228   typedef std::function<void(Tensor*)> InitFunc;
229 
Permute(int fail_after)230   void Permute(int fail_after) {
231     std::atomic<int> done(0);
232     for (auto di : instances_) {
233       SchedClosure([di, &done] {
234         di->DoPermute();
235         ++done;
236       });
237       if (fail_after > 0) {
238         // Stagger the op execution starts.
239         Env::Default()->SleepForMicroseconds(100);
240       }
241     }
242     while (done < instances_.size()) {
243       if (stop_) break;
244       Env::Default()->SleepForMicroseconds(1000);
245     }
246   }
247 
248   template <typename T>
RunTest(DataType dtype,const DeviceType & device_type,int num_workers,int num_devices,int tensor_len,int fail_after)249   void RunTest(DataType dtype, const DeviceType& device_type, int num_workers,
250                int num_devices, int tensor_len, int fail_after) {
251     Init(num_workers, num_devices, dtype, device_type, fail_after);
252     std::vector<T> expected(tensor_len * num_devices * num_workers, 0.0);
253     // Initialize each instance tensor with distinct values.
254     for (int di = 0; di < instances_.size(); ++di) {
255       DeviceInstance* instance = instances_[di];
256       instance->InitTensor(
257           dtype, TensorShape({tensor_len}),
258           [this, &expected, di, tensor_len](Tensor* t) {
259             for (size_t i = 0; i < t->NumElements(); ++i) {
260               // The cast is necessary to prevent clang-tidy from insisting
261               // that a faster non-open source function be substituted.
262               float value = pow(10, static_cast<double>(di)) * i;
263               t->flat<T>()(i) = value;
264               expected[permutation_[di] * tensor_len + i] = value;
265             }
266           });
267     }
268 
269     Permute(fail_after);
270 
271     // At this point all of the ops have terminated.
272     for (int di = 0; di < instances_.size(); ++di) {
273       if (!instances_[di]->status_.ok()) {
274         ASSERT_GT(fail_after, 0);
275         ASSERT_NE(
276             instances_[di]->status_.error_message().find("Deliberate failure"),
277             string::npos);
278         continue;
279       }
280       TF_EXPECT_OK(instances_[di]->status_);
281       Tensor* inst = &instances_[di]->tensor_output_;
282       Tensor actual(dtype, TensorShape({tensor_len}));
283       if (device_type_ == DEVICE_CPU) {
284         CHECK(actual.CopyFrom(*inst, inst->shape()));
285       } else if (device_type_ == DEVICE_GPU) {
286         Device* dev = instances_[di]->device_;
287         auto* dev_info = dev->tensorflow_gpu_device_info();
288         CHECK(dev_info);
289         TF_CHECK_OK(dev_info->default_context->CopyDeviceTensorToCPUSync(
290             inst, "" /*tensor_name*/, dev, &actual));
291       }
292       for (int i = 0; i < tensor_len; ++i) {
293         switch (dtype) {
294           case DT_FLOAT:
295             EXPECT_FLOAT_EQ(expected[(di * tensor_len) + i],
296                             actual.template flat<T>()(i))
297                 << "Mismatch at device " << di << " index " << i;
298             break;
299           case DT_DOUBLE:
300             EXPECT_DOUBLE_EQ(expected[(di * tensor_len) + i],
301                              actual.template flat<T>()(i))
302                 << "Mismatch at device " << di << " index " << i;
303             break;
304           case DT_BOOL:
305           case DT_INT32:
306           case DT_INT64:
307             EXPECT_EQ(expected[(di * tensor_len) + i],
308                       actual.template flat<T>()(i))
309                 << "Mismatch at device " << di << " index " << i;
310             break;
311           default:
312             LOG(FATAL) << "unimplemented";
313         }
314       }
315       //  }
316     }
317   }
318 
319   class DeviceInstance {
320    public:
DeviceInstance(int rank,const string & dev_name,const DeviceType & device_type,PermuterTest * parent)321     DeviceInstance(int rank, const string& dev_name,
322                    const DeviceType& device_type, PermuterTest* parent)
323         : parent_(parent),
324           dev_name_(dev_name),
325           device_type_(device_type),
326           rank_(rank),
327           col_params_(new CollectiveParams()) {
328       TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
329       col_params_->name = parent_->col_params_->name;
330       col_params_->instance.data_type =
331           parent_->col_params_->instance.data_type;
332       col_params_->instance.instance_key =
333           parent_->col_params_->instance.instance_key;
334       col_params_->group.device_type = parent_->col_params_->group.device_type;
335       col_params_->group.device_names =
336           parent_->col_params_->group.device_names;
337       col_params_->instance.devices = parent_->col_params_->instance.devices;
338       col_params_->instance.permutation =
339           parent->col_params_->instance.permutation;
340       col_params_->group.task_names = parent_->col_params_->group.task_names;
341       col_params_->task.is_local = parent_->col_params_->task.is_local;
342       CHECK_EQ(col_params_->instance.devices.size(),
343                col_params_->group.device_names.size());
344       // Default rank is order in device_names.
345       col_params_->default_rank = rank;
346     }
347 
~DeviceInstance()348     ~DeviceInstance() { col_params_->Unref(); }
349 
InitTensor(DataType dtype,const TensorShape & shape,const InitFunc & f)350     void InitTensor(DataType dtype, const TensorShape& shape,
351                     const InitFunc& f) {
352       tensor_input_ =
353           Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
354       tensor_output_ =
355           Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
356       if (device_type_ == DEVICE_CPU) {
357         f(&tensor_input_);
358       } else if (device_type_ == DEVICE_GPU) {
359         Tensor cpu_tensor(dtype, shape);
360         f(&cpu_tensor);
361         // Notification notification;
362         auto* dev_info = device_->tensorflow_gpu_device_info();
363         CHECK(dev_info);
364         TF_CHECK_OK(dev_info->default_context->CopyCPUTensorToDeviceSync(
365             &cpu_tensor, device_, &tensor_input_));
366       } else {
367         LOG(FATAL) << "Unsupported device_type " << device_type_;
368       }
369     }
370 
DoPermute()371     void DoPermute() {
372       // Prepare an OpKernelContext.
373       OpKernelContext::Params op_params;
374       op_params.step_id = parent_->step_id_;
375       op_params.device = device_;
376       op_params.cancellation_manager = &parent_->cancellation_manager_;
377       gtl::InlinedVector<TensorValue, 4> inputs;
378       inputs.push_back(TensorValue(&tensor_input_));
379       op_params.inputs = &inputs;
380       gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
381           {AllocatorAttributes()});
382       op_params.input_alloc_attrs = &input_aa;
383       DeviceContext* dev_ctx = nullptr;
384       auto* dev_info = device_->tensorflow_gpu_device_info();
385       if (dev_info) {
386         dev_ctx = dev_info->default_context;
387         dev_ctx->Ref();
388       } else {
389         dev_ctx = new DeviceContext;
390       }
391       op_params.op_device_context = dev_ctx;
392       AllocatorAttributes generic_alloc_attr;
393       op_params.output_attr_array = &generic_alloc_attr;
394       OpKernelContext ctx(&op_params, 1);
395 
396       // Prepare a Permuter instance.
397       string exec_key =
398           strings::StrCat(col_params_->instance.instance_key, ":0:0");
399       Permuter* permuter = new Permuter;
400       core::ScopedUnref unref(permuter);
401       auto col_ctx = std::make_shared<CollectiveContext>(
402           parent_->col_exec_, /*nccl_communicator*/ nullptr,
403           parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
404           kStepId, &tensor_input_, &tensor_output_);
405       TF_CHECK_OK(permuter->InitializeCollectiveContext(col_ctx));
406       Notification note;
407       // Run the permute.
408       permuter->Run([this, &note](Status s) {
409         status_ = s;
410         note.Notify();
411       });
412       note.WaitForNotification();
413       dev_ctx->Unref();
414     }
415 
416     PermuterTest* parent_;
417     string dev_name_;
418     DeviceType device_type_ = DEVICE_CPU;
419     int rank_;
420     Tensor tensor_input_;
421     Tensor tensor_output_;
422     Device* device_;
423     CollectiveParams* col_params_;
424     Status status_;
425   };  // class DeviceInstance
426 
427   bool stop_ = false;
428   int64 step_id_ = kStepId;
429   DeviceType device_type_;
430   TestCollectiveExecutorMgr col_exec_mgr_;
431   CollectiveExecutor* col_exec_ = nullptr;
432   CollectiveRemoteAccessLocal* rma_;
433   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
434   std::shared_ptr<UnboundedWorkQueue> work_queue_;
435   std::vector<DeviceInstance*> instances_;
436   CollectiveParams* col_params_;
437   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
438   std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
439   std::unique_ptr<string> gpu_ring_order_;
440   mutex mu_;
441   int permute_counter_ TF_GUARDED_BY(mu_) = 0;
442   std::vector<int> permutation_;
443   CancellationManager cancellation_manager_;
444 };
445 
446 // TODO(b/113171733): change to use TEST_P.
447 // Tests of full permute algorithm, with different device and
448 // data types.
449 // B = data element type
450 // T = device type
451 // W = number of workers
452 // D = number of devices per worker
453 // L = tensor length
454 // A = abort after count
455 #define DEF_TEST(B, T, W, D, L, A)                                            \
456   TEST_F(PermuterTest,                                                        \
457          DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Sdiv##S##_Len##L##_Abrt##A) { \
458     DataType dtype = DT_##B;                                                  \
459     switch (dtype) {                                                          \
460       case DT_BOOL: {                                                         \
461         RunTest<bool>(dtype, DEVICE_##T, W, D, L, A);                         \
462       } break;                                                                \
463       case DT_FLOAT: {                                                        \
464         RunTest<float>(dtype, DEVICE_##T, W, D, L, A);                        \
465       } break;                                                                \
466       case DT_DOUBLE: {                                                       \
467         RunTest<double>(dtype, DEVICE_##T, W, D, L, A);                       \
468       } break;                                                                \
469       case DT_INT32: {                                                        \
470         RunTest<int32>(dtype, DEVICE_##T, W, D, L, A);                        \
471       } break;                                                                \
472       case DT_INT64: {                                                        \
473         RunTest<int64>(dtype, DEVICE_##T, W, D, L, A);                        \
474       } break;                                                                \
475       default:                                                                \
476         LOG(FATAL) << "Unimplemented";                                        \
477     }                                                                         \
478   }
479 
480 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
481 //       B      T    W  D  L  A
482 DEF_TEST(FLOAT, CPU, 1, 2, 1, 0)
483 DEF_TEST(FLOAT, CPU, 1, 3, 3, 0)
484 DEF_TEST(FLOAT, CPU, 1, 7, 3, 0)
485 DEF_TEST(FLOAT, CPU, 1, 2, 1001, 0)
486 DEF_TEST(FLOAT, CPU, 2, 2, 3, 0)
487 DEF_TEST(FLOAT, CPU, 2, 1, 128, 0)
488 DEF_TEST(FLOAT, CPU, 2, 4, 128, 0)
489 DEF_TEST(FLOAT, CPU, 2, 8, 4095, 0)
490 DEF_TEST(FLOAT, CPU, 4, 4, 1045991, 0)
491 
492 DEF_TEST(BOOL, CPU, 1, 4, 1, 0)
493 DEF_TEST(BOOL, CPU, 2, 4, 1, 0)
494 DEF_TEST(BOOL, CPU, 2, 4, 1001, 0)
495 
496 DEF_TEST(DOUBLE, CPU, 2, 4, 128, 0)
497 DEF_TEST(INT32, CPU, 2, 4, 128, 0)
498 DEF_TEST(INT64, CPU, 2, 4, 128, 0)
499 
500 // Failure cases
501 DEF_TEST(FLOAT, CPU, 1, 2, 1, 1)
502 DEF_TEST(FLOAT, CPU, 2, 4, 128, 1)
503 DEF_TEST(FLOAT, CPU, 2, 4, 128, 5)
504 #endif
505 
506 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
507 // Can only set W=1 for GPU tests.
508 //       B      T    W  D  L  A
509 DEF_TEST(FLOAT, GPU, 1, 2, 1, 0)
510 DEF_TEST(FLOAT, GPU, 1, 7, 3, 0)
511 DEF_TEST(FLOAT, GPU, 1, 2, 33, 0)
512 DEF_TEST(FLOAT, GPU, 1, 3, 64, 0)
513 DEF_TEST(FLOAT, GPU, 1, 8, 1001, 0)
514 DEF_TEST(FLOAT, GPU, 1, 8, 4095, 0)
515 DEF_TEST(FLOAT, GPU, 1, 8, 1045991, 0)
516 
517 DEF_TEST(BOOL, GPU, 1, 4, 1, 0)
518 DEF_TEST(BOOL, GPU, 1, 4, 1001, 0)
519 
520 DEF_TEST(DOUBLE, GPU, 1, 8, 1001, 0)
521 DEF_TEST(INT64, GPU, 1, 8, 1001, 0)
522 
523 // Failure cases
524 DEF_TEST(FLOAT, GPU, 1, 8, 128, 6)
525 #endif
526 
527 }  // namespace
528 }  // namespace tensorflow
529