• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16 
17 #include <memory>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/composite_device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/function_testlib.h"
24 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/function_testlib.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_var.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/framework/type_index.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/public/version.h"
40 
41 #if GOOGLE_CUDA
42 #include "third_party/gpus/cuda/include/cuda.h"
43 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
44 #elif TENSORFLOW_USE_ROCM
45 #include "rocm/include/hip/hip_runtime.h"
46 #endif  // GOOGLE_CUDA
47 
48 namespace tensorflow {
49 namespace {
50 
51 class TestClusterFLR : public DistributedFunctionLibraryRuntime {
52  public:
TestClusterFLR(DeviceMgr * device_mgr)53   explicit TestClusterFLR(DeviceMgr* device_mgr) : device_mgr_(device_mgr) {}
54 
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)55   void Instantiate(const string& function_name,
56                    const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
57                    const FunctionLibraryRuntime::InstantiateOptions& options,
58                    FunctionLibraryRuntime::LocalHandle* handle,
59                    FunctionLibraryRuntime::DoneCallback done) override {
60     {
61       mutex_lock l(mu_);
62       *handle = next_handle_;
63       next_handle_++;
64     }
65     done(Status::OK());
66   }
67 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)68   void Run(const FunctionLibraryRuntime::Options& opts,
69            FunctionLibraryRuntime::LocalHandle handle,
70            gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
71            FunctionLibraryRuntime::DoneCallback done) override {}
72 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)73   void Run(const FunctionLibraryRuntime::Options& opts,
74            FunctionLibraryRuntime::LocalHandle handle,
75            gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
76            FunctionLibraryRuntime::DoneCallback done) override {}
77 
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)78   void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
79                FunctionLibraryRuntime::DoneCallback done) override {}
80 
remote_device_mgr() const81   DeviceMgr* remote_device_mgr() const override { return device_mgr_; }
82 
83  private:
84   mutex mu_;
85   int next_handle_ TF_GUARDED_BY(mu_) = 0;
86   DeviceMgr* device_mgr_;
87 };
88 
GenerateSessionMetadata()89 SessionMetadata GenerateSessionMetadata() {
90   SessionMetadata session_metadata;
91   session_metadata.set_name("name");
92   session_metadata.set_version(42);
93   return session_metadata;
94 }
95 
96 // TODO(b/128707168): Tests requiring a GPU device are currently always skipped
97 // because the check for whether a GPU device is present happens before the GPU
98 // device is set up.
99 class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
100  public:
ProcessFunctionLibraryRuntimeTest()101   ProcessFunctionLibraryRuntimeTest() {
102     SessionOptions options;
103     auto* device_count = options.config.mutable_device_count();
104     device_count->insert({"CPU", 3});
105     std::vector<std::unique_ptr<Device>> created_devices;
106     TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
107                                           &created_devices));
108     // Do not add CPU:2 to device manager. Used for removed device testing.
109     device2_ = std::move(created_devices[2]);
110     created_devices.erase(created_devices.begin() + 2);
111 
112     device_mgr_ = std::make_unique<DynamicDeviceMgr>();
113     TF_CHECK_OK(device_mgr_->AddDevices(std::move(created_devices)));
114     TF_CHECK_OK(device_mgr_->LookupDevice(
115         "/job:a/replica:0/task:0/device:CPU:0", &device0_));
116     TF_CHECK_OK(device_mgr_->LookupDevice(
117         "/job:a/replica:0/task:0/device:CPU:1", &device1_));
118     Device* device2_ptr = nullptr;
119     EXPECT_NE(
120         error::OK,
121         device_mgr_
122             ->LookupDevice("/job:a/replica:0/task:0/device:CPU:2", &device2_ptr)
123             .code());
124     // If no GPU is available, gpu_device_ will remain nullptr.
125     Status status = device_mgr_->LookupDevice(
126         "/job:a/replica:0/task:0/device:GPU:0", &gpu_device_);
127     if (!status.ok()) {
128       CHECK_EQ(nullptr, gpu_device_);
129     }
130   }
131 
Init(const std::vector<FunctionDef> & flib,const SessionMetadata * session_metadata=nullptr)132   void Init(const std::vector<FunctionDef>& flib,
133             const SessionMetadata* session_metadata = nullptr) {
134     FunctionDefLibrary proto;
135     for (const auto& fdef : flib) *(proto.add_function()) = fdef;
136     lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
137     OptimizerOptions opts;
138     cluster_flr_.reset(new TestClusterFLR(device_mgr_.get()));
139     proc_flr_.reset(new ProcessFunctionLibraryRuntime(
140         device_mgr_.get(), Env::Default(), /*config=*/nullptr,
141         TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
142         /*thread_pool=*/nullptr, cluster_flr_.get(), session_metadata,
143         Rendezvous::Factory{
144             [this](const int64 step_id, const DeviceMgr* device_mgr,
145                    Rendezvous** r) {
146               *r = new IntraProcessRendezvous(device_mgr);
147               if (rendezvous_ref_counts_.find(step_id) !=
148                   rendezvous_ref_counts_.end()) {
149                 rendezvous_ref_counts_[step_id]++;
150               } else {
151                 rendezvous_ref_counts_[step_id] = 1;
152               }
153               return Status::OK();
154             },
155             [this](const int64 step_id) {
156               CHECK(rendezvous_ref_counts_.find(step_id) !=
157                     rendezvous_ref_counts_.end());
158               rendezvous_ref_counts_[step_id]--;
159               return Status::OK();
160             }}));
161   }
162 
AddCompositeDevice(CompositeDevice * d)163   void AddCompositeDevice(CompositeDevice* d) {
164     proc_flr_->AddCompositeDevice(d);
165   }
166 
Instantiate(const string & name,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & instantiate_opts,FunctionLibraryRuntime::Handle * handle)167   Status Instantiate(
168       const string& name, test::function::Attrs attrs,
169       const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
170       FunctionLibraryRuntime::Handle* handle) {
171     return proc_flr_->Instantiate(name, attrs, instantiate_opts, handle);
172   }
173 
GPUToCPU(const Tensor & device_tensor)174   Tensor GPUToCPU(const Tensor& device_tensor) {
175 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
176     CHECK(gpu_device_);
177     CHECK(gpu_device_->tensorflow_gpu_device_info() != nullptr);
178     DeviceContext* device_context =
179         gpu_device_->tensorflow_gpu_device_info()->default_context;
180 
181     Tensor cpu_tensor(device_tensor.dtype(), device_tensor.shape());
182     CHECK(device_context
183               ->CopyDeviceTensorToCPUSync(&device_tensor, "", gpu_device_,
184                                           &cpu_tensor)
185               .ok());
186     return cpu_tensor;
187 #else
188     CHECK(false);
189 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
190   }
191 
CPUToGPU(const Tensor & cpu_tensor)192   Tensor CPUToGPU(const Tensor& cpu_tensor) {
193 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
194     CHECK(gpu_device_);
195     CHECK(gpu_device_->tensorflow_gpu_device_info() != nullptr);
196     DeviceContext* device_context =
197         gpu_device_->tensorflow_gpu_device_info()->default_context;
198 
199     Tensor device_tensor(gpu_device_->GetAllocator({}), cpu_tensor.dtype(),
200                          cpu_tensor.shape(), {});
201     CHECK(device_context
202               ->CopyCPUTensorToDeviceSync(&cpu_tensor, gpu_device_,
203                                           &device_tensor)
204               .ok());
205     return device_tensor;
206 #else
207     CHECK(false);
208 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
209   }
210 
211   template <typename T, typename K>
RunWithRuntime(const string & name,FunctionLibraryRuntime::Options opts,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & instantiate_opts,const T & args,std::vector<K * > rets,ProcessFunctionLibraryRuntime * pflr)212   Status RunWithRuntime(
213       const string& name, FunctionLibraryRuntime::Options opts,
214       test::function::Attrs attrs,
215       const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
216       const T& args, std::vector<K*> rets,
217       ProcessFunctionLibraryRuntime* pflr) {
218     FunctionLibraryRuntime::Handle handle;
219     Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle);
220     if (!status.ok()) {
221       return status;
222     }
223     bool is_cross_process = false;
224     TF_CHECK_OK(pflr->IsCrossProcess(handle, &is_cross_process));
225     EXPECT_FALSE(is_cross_process);
226 
227     std::atomic<int32> call_count(0);
228     std::function<void(std::function<void()>)> runner =
229         [&call_count](std::function<void()> fn) {
230           ++call_count;
231           test::function::FunctionTestSchedClosure(fn);
232         };
233 
234     Notification done;
235     opts.runner = &runner;
236     std::vector<K> out;
237     pflr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
238       status = s;
239       done.Notify();
240     });
241     done.WaitForNotification();
242     if (!status.ok()) {
243       return status;
244     }
245     CHECK_EQ(rets.size(), out.size());
246     for (size_t i = 0; i < rets.size(); ++i) {
247       *rets[i] = out[i];
248     }
249 
250     EXPECT_GE(call_count, 1);  // Test runner is used.
251 
252     // Release the handle and then try running the function. It shouldn't
253     // succeed.
254     status = pflr->ReleaseHandle(handle);
255     if (!status.ok()) {
256       return status;
257     }
258     Notification done2;
259     pflr->Run(opts, handle, args, &out, [&status, &done2](const Status& s) {
260       status = s;
261       done2.Notify();
262     });
263     done2.WaitForNotification();
264     EXPECT_TRUE(errors::IsNotFound(status)) << "Actual status: " << status;
265     EXPECT_TRUE(absl::StrContains(status.error_message(), "not found."));
266 
267     return Status::OK();
268   }
269 
Run(const string & name,FunctionLibraryRuntime::Options opts,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & instantiate_opts,const std::vector<Tensor> & args,std::vector<Tensor * > rets,ProcessFunctionLibraryRuntime * pflr=nullptr)270   Status Run(const string& name, FunctionLibraryRuntime::Options opts,
271              test::function::Attrs attrs,
272              const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
273              const std::vector<Tensor>& args, std::vector<Tensor*> rets,
274              ProcessFunctionLibraryRuntime* pflr = nullptr) {
275     return RunWithRuntime<std::vector<Tensor>, Tensor>(
276         name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
277   }
278 
RunWithPackedArgs(const string & name,FunctionLibraryRuntime::Options opts,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & instantiate_opts,const FunctionArgsInterface & args,std::vector<FunctionRet * > rets,ProcessFunctionLibraryRuntime * pflr=nullptr)279   Status RunWithPackedArgs(
280       const string& name, FunctionLibraryRuntime::Options opts,
281       test::function::Attrs attrs,
282       const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
283       const FunctionArgsInterface& args, std::vector<FunctionRet*> rets,
284       ProcessFunctionLibraryRuntime* pflr = nullptr) {
285     return RunWithRuntime<FunctionArgsInterface, FunctionRet>(
286         name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
287   }
288 
RunInstantiated(FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime::Options opts,const std::vector<Tensor> & args,std::vector<Tensor * > rets)289   Status RunInstantiated(FunctionLibraryRuntime::Handle handle,
290                          FunctionLibraryRuntime::Options opts,
291                          const std::vector<Tensor>& args,
292                          std::vector<Tensor*> rets) {
293     std::atomic<int32> call_count(0);
294     std::function<void(std::function<void()>)> runner =
295         [&call_count](std::function<void()> fn) {
296           ++call_count;
297           test::function::FunctionTestSchedClosure(fn);
298         };
299 
300     opts.runner = &runner;
301     Status status;
302     Notification done;
303     std::vector<Tensor> out;
304     proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
305       status = s;
306       done.Notify();
307     });
308     done.WaitForNotification();
309     if (!status.ok()) {
310       return status;
311     }
312     CHECK_EQ(rets.size(), out.size());
313     for (size_t i = 0; i < rets.size(); ++i) {
314       *rets[i] = out[i];
315     }
316     EXPECT_GE(call_count, 1);  // Test runner is used.
317     return Status::OK();
318   }
319 
320   std::unique_ptr<DynamicDeviceMgr> device_mgr_;
321   Device* device0_ = nullptr;  // Not owned. (Owned by device_mgr_.)
322   Device* device1_ = nullptr;  // Not owned. (Owned by device_mgr_.)
323   std::unique_ptr<Device> device2_;
324   // Remains as nullptr if no GPU is available.
325   Device* gpu_device_ = nullptr;  // Not owned. (Owned by device_mgr_.)
326   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
327   std::unique_ptr<TestClusterFLR> cluster_flr_;
328   std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
329 
330   // To ensure that we are cleaning up the rendezvous properly.
331   std::unordered_map<int64, int> rendezvous_ref_counts_;
332 };
333 
TEST_F(ProcessFunctionLibraryRuntimeTest,GetFLRNull)334 TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
335   FunctionDefLibrary proto;
336   std::unique_ptr<FunctionLibraryDefinition> lib_def(
337       new FunctionLibraryDefinition(OpRegistry::Global(), proto));
338   OptimizerOptions opts;
339   std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
340       new ProcessFunctionLibraryRuntime(
341           nullptr /* device_mgr */, Env::Default(), /*config=*/nullptr,
342           TF_GRAPH_DEF_VERSION, lib_def.get(), opts));
343   FunctionLibraryRuntime* flr =
344       proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
345   EXPECT_NE(flr, nullptr);
346 }
347 
TEST_F(ProcessFunctionLibraryRuntimeTest,Basic)348 TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
349   Init({});
350   FunctionLibraryRuntime* flr =
351       proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
352   EXPECT_NE(flr, nullptr);
353   EXPECT_EQ(flr->device(), device0_);
354   flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
355   EXPECT_NE(flr, nullptr);
356   EXPECT_EQ(flr->device(), device0_);
357   flr = proc_flr_->GetFLR("/device:CPU:0");
358   EXPECT_NE(flr, nullptr);
359   EXPECT_EQ(flr->device(), device0_);
360   flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1");
361   EXPECT_NE(flr, nullptr);
362   EXPECT_EQ(flr->device(), device1_);
363   flr = proc_flr_->GetFLR("abc");
364   EXPECT_EQ(flr, nullptr);
365 }
366 
TEST_F(ProcessFunctionLibraryRuntimeTest,GetDeviceIncarnation)367 TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) {
368   Init({});
369   int64 incarnation;
370   TF_EXPECT_OK(proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:1",
371                                                &incarnation));
372   // Incarnation is a random number other than 0.
373   EXPECT_NE(incarnation, 0);
374   Status s = proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:2",
375                                              &incarnation);
376   EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
377 }
378 
TEST_F(ProcessFunctionLibraryRuntimeTest,SingleCall)379 TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
380   Init({test::function::XTimesTwo()});
381   FunctionLibraryRuntime::Options opts;
382   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
383   opts.remote_execution = true;
384   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
385   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
386   auto x = test::AsTensor<float>({1, 2, 3, 4});
387   Tensor y;
388   TF_CHECK_OK(
389       Run("XTimesTwo", opts, {{"T", DT_FLOAT}}, instantiate_opts, {x}, {&y}));
390   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
391 }
392 
TEST_F(ProcessFunctionLibraryRuntimeTest,SingleCallFindDevice)393 TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
394   Init({test::function::FindDevice()});
395   FunctionLibraryRuntime::Options opts;
396   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
397   opts.remote_execution = true;
398   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
399   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
400   Tensor y;
401   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts, {}, {&y}));
402   test::ExpectTensorEqual<tstring>(
403       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:0"},
404                                  TensorShape({})));
405   EXPECT_EQ(1, rendezvous_ref_counts_.size());
406   EXPECT_EQ(opts.step_id, rendezvous_ref_counts_.begin()->first);
407   EXPECT_EQ(0, rendezvous_ref_counts_.begin()->second);
408 }
409 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultipleCallsSameDeviceXTimes)410 TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
411   Init({test::function::XTimesTwo(), test::function::XTimesFour()});
412   auto x = test::AsTensor<float>({1, 2, 3, 4});
413   FunctionLibraryRuntime::Options opts;
414   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
415   opts.remote_execution = true;
416   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
417   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
418   Tensor y;
419   TF_CHECK_OK(
420       Run("XTimesTwo", opts, {{"T", DT_FLOAT}}, instantiate_opts, {x}, {&y}));
421   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
422   TF_CHECK_OK(
423       Run("XTimesFour", opts, {{"T", DT_FLOAT}}, instantiate_opts, {x}, {&y}));
424   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
425 }
426 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultipleCallsSameDeviceFindDevice)427 TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
428   Init({test::function::FindDevice()});
429   FunctionLibraryRuntime::Options opts;
430   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
431   opts.remote_execution = true;
432   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
433   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:1";
434   Tensor y;
435   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts, {}, {&y}));
436   test::ExpectTensorEqual<tstring>(
437       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:1"},
438                                  TensorShape({})));
439   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts, {}, {&y}));
440   test::ExpectTensorEqual<tstring>(
441       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:1"},
442                                  TensorShape({})));
443 }
444 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultipleCallsDiffDeviceFindDevice)445 TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
446   Init({test::function::FindDevice()});
447   FunctionLibraryRuntime::Options opts;
448   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
449   opts.remote_execution = true;
450   Tensor y;
451   FunctionLibraryRuntime::InstantiateOptions instantiate_opts_0;
452   instantiate_opts_0.target = "/job:a/replica:0/task:0/device:CPU:0";
453   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts_0, {}, {&y}));
454   test::ExpectTensorEqual<tstring>(
455       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:0"},
456                                  TensorShape({})));
457   FunctionLibraryRuntime::InstantiateOptions instantiate_opts_1;
458   instantiate_opts_1.target = "/job:a/replica:0/task:0/device:CPU:1";
459   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts_1, {}, {&y}));
460   test::ExpectTensorEqual<tstring>(
461       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:1"},
462                                  TensorShape({})));
463 }
464 
TEST_F(ProcessFunctionLibraryRuntimeTest,InstantiateFunctionOnRemovedDevice)465 TEST_F(ProcessFunctionLibraryRuntimeTest, InstantiateFunctionOnRemovedDevice) {
466   std::vector<std::unique_ptr<Device>> devices;
467   Device* device2_ptr = device2_.get();
468   devices.emplace_back(std::move(device2_));
469   TF_CHECK_OK(device_mgr_->AddDevices(std::move(devices)));
470 
471   Init({test::function::FindDevice()});
472   std::vector<Device*> remove_devices{device2_ptr};
473   TF_CHECK_OK(device_mgr_->RemoveDevices(std::move(remove_devices)));
474 
475   // Since the process FLR device set is not updated yet, it still holds the
476   // raw pointer to device2. Make sure that function instantion with device2
477   // will not lead to segfault.
478   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
479   FunctionLibraryRuntime::Handle h;
480   instantiate_opts.target = "/job:a/replica:0/task:0/device:CPU:1";
481   instantiate_opts.is_multi_device_function = true;
482   TF_CHECK_OK(Instantiate("FindDevice",
483                           {{"_target", "/job:b/replica:0/task:0/device:CPU:2"}},
484                           instantiate_opts, &h));
485 }
486 
TEST_F(ProcessFunctionLibraryRuntimeTest,ClusterFLRSerialTest)487 TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
488   Init({test::function::FindDevice()});
489   FunctionLibraryRuntime::Options opts;
490   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
491   opts.remote_execution = true;
492   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
493   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
494   FunctionLibraryRuntime::Handle h;
495   TF_CHECK_OK(Instantiate("FindDevice",
496                           {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
497                           instantiate_opts, &h));
498   bool is_cross_process = false;
499   TF_CHECK_OK(proc_flr_->IsCrossProcess(h, &is_cross_process));
500   EXPECT_TRUE(is_cross_process);
501   EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
502                    "/job:b/replica:0/task:0/device:CPU:0", h));
503   TF_CHECK_OK(Instantiate("FindDevice",
504                           {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
505                           instantiate_opts, &h));
506   EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
507                    "/job:b/replica:0/task:0/device:CPU:0", h));
508   instantiate_opts.target = "/job:c/replica:0/task:0/device:CPU:0";
509   TF_CHECK_OK(Instantiate("FindDevice",
510                           {{"_target", "/job:c/replica:0/task:0/device:CPU:0"}},
511                           instantiate_opts, &h));
512   EXPECT_EQ(1, proc_flr_->GetHandleOnDevice(
513                    "/job:c/replica:0/task:0/device:CPU:0", h));
514 }
515 
TEST_F(ProcessFunctionLibraryRuntimeTest,ClusterFLRParallelTest)516 TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
517   Init({test::function::FindDevice()});
518   FunctionLibraryRuntime::Options opts;
519   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
520   opts.remote_execution = true;
521   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
522   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
523 
524   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
525   auto fn = [this, &instantiate_opts]() {
526     FunctionLibraryRuntime::Handle h;
527     TF_CHECK_OK(Instantiate(
528         "FindDevice", {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
529         instantiate_opts, &h));
530     EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
531                      "/job:b/replica:0/task:0/device:CPU:0", h));
532   };
533 
534   for (int i = 0; i < 100; ++i) {
535     tp->Schedule(fn);
536   }
537   delete tp;
538 }
539 
IsCUDATensor(const Tensor & t)540 bool IsCUDATensor(const Tensor& t) {
541 #if GOOGLE_CUDA
542   cudaPointerAttributes attributes;
543   cudaError_t err =
544       cudaPointerGetAttributes(&attributes, t.tensor_data().data());
545   if (err == cudaErrorInvalidValue) return false;
546   CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
547   return (attributes.type == cudaMemoryTypeDevice);
548 #elif TENSORFLOW_USE_ROCM
549   hipPointerAttribute_t attributes;
550   hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data());
551   if (err == hipErrorInvalidValue) return false;
552   CHECK_EQ(hipSuccess, err) << hipGetErrorString(err);
553   return (attributes.memoryType == hipMemoryTypeDevice);
554 #else
555   CHECK(false)
556       << "IsCUDATensor should not be called when CUDA is not available";
557 #endif  // GOOGLE_CUDA
558 }
559 
TestTwoDeviceMult(ProcessFunctionLibraryRuntimeTest * fixture,const FunctionLibraryRuntime::InstantiateOptions & inst_opts,const string & error="")560 void TestTwoDeviceMult(
561     ProcessFunctionLibraryRuntimeTest* fixture,
562     const FunctionLibraryRuntime::InstantiateOptions& inst_opts,
563     const string& error = "") {
564   fixture->Init({test::function::TwoDeviceMult()});
565   FunctionLibraryRuntime::Options opts;
566   auto x = test::AsTensor<float>({1, 2, 3});
567   Tensor y_cpu;
568   Tensor y_gpu;
569   Status status = fixture->Run("TwoDeviceMult", opts, {{"T", DT_FLOAT}},
570                                inst_opts, {x}, {&y_cpu, &y_gpu});
571   if (!error.empty()) {
572     EXPECT_TRUE(errors::IsInvalidArgument(status))
573         << "Actual status: " << status;
574     EXPECT_TRUE(absl::StrContains(status.error_message(), error))
575         << "Actual error message: " << status.error_message();
576     return;
577   }
578 
579   EXPECT_TRUE(status.ok()) << "Actual status: " << status;
580   EXPECT_FALSE(IsCUDATensor(y_cpu));
581   test::ExpectTensorEqual<float>(y_cpu, test::AsTensor<float>({2, 4, 6}));
582 
583   EXPECT_TRUE(IsCUDATensor(y_gpu));
584   Tensor y_gpu_on_cpu = fixture->GPUToCPU(y_gpu);
585   test::ExpectTensorEqual<float>(y_gpu_on_cpu,
586                                  test::AsTensor<float>({3, 6, 9}));
587 }
588 
TestTwoDeviceInputOutput(ProcessFunctionLibraryRuntimeTest * fixture,const FunctionLibraryRuntime::InstantiateOptions & inst_opts)589 void TestTwoDeviceInputOutput(
590     ProcessFunctionLibraryRuntimeTest* fixture,
591     const FunctionLibraryRuntime::InstantiateOptions& inst_opts) {
592   if (fixture->gpu_device_ == nullptr) {
593     GTEST_SKIP() << "No GPUs available";
594   }
595   fixture->Init({test::function::TwoDeviceInputOutput()});
596 
597   FunctionLibraryRuntime::Options opts;
598   Tensor x1 = test::AsTensor<float>({1, 2});
599   if (absl::StrContains(inst_opts.input_devices[0], "GPU")) {
600     x1 = fixture->CPUToGPU(x1);
601   }
602   Tensor x2 = test::AsTensor<float>({10, 20});
603   if (absl::StrContains(inst_opts.input_devices[1], "GPU")) {
604     x2 = fixture->CPUToGPU(x2);
605   }
606   Tensor y1;
607   Tensor y2;
608   TF_CHECK_OK(fixture->Run("TwoDeviceInputOutput", opts, {{"T", DT_FLOAT}},
609                            inst_opts, {x1, x2}, {&y1, &y2}));
610 
611   if (absl::StrContains(inst_opts.output_devices[0], "GPU")) {
612     EXPECT_TRUE(IsCUDATensor(y1));
613     y1 = fixture->GPUToCPU(y1);
614   } else {
615     EXPECT_FALSE(IsCUDATensor(y1));
616   }
617   test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({2, 4}));
618 
619   if (absl::StrContains(inst_opts.output_devices[1], "GPU")) {
620     EXPECT_TRUE(IsCUDATensor(y2));
621     y2 = fixture->GPUToCPU(y2);
622   } else {
623     EXPECT_FALSE(IsCUDATensor(y2));
624   }
625   test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({30, 60}));
626 }
627 
CompleteDevices(const std::vector<string> & v)628 std::vector<string> CompleteDevices(const std::vector<string>& v) {
629   std::vector<string> result;
630   result.reserve(v.size());
631   for (const string& s : v) {
632     result.push_back(strings::StrCat("/job:a/replica:0/task:0/device:", s));
633   }
634   return result;
635 }
636 
MakeOptions(const string & target,const std::vector<string> & input_devices,const std::vector<string> & output_devices)637 FunctionLibraryRuntime::InstantiateOptions MakeOptions(
638     const string& target, const std::vector<string>& input_devices,
639     const std::vector<string>& output_devices) {
640   FunctionLibraryRuntime::InstantiateOptions inst_opts;
641   inst_opts.target = target;
642   inst_opts.input_devices = CompleteDevices(input_devices);
643   inst_opts.output_devices = CompleteDevices(output_devices);
644   inst_opts.is_multi_device_function = true;
645   return inst_opts;
646 }
647 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ExplicitOutputDevice)648 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ExplicitOutputDevice) {
649   if (gpu_device_ == nullptr) {
650     GTEST_SKIP() << "No GPUs available";
651   }
652   TestTwoDeviceMult(this, MakeOptions("CPU:0", {"CPU:0"}, {"CPU:0", "GPU:0"}));
653 }
654 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_InferredOutputDevice)655 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_InferredOutputDevice) {
656   if (gpu_device_ == nullptr) {
657     GTEST_SKIP() << "No GPUs available";
658   }
659   TestTwoDeviceMult(this, MakeOptions("CPU:0", {"CPU:0"}, {}));
660 }
661 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenNoInputDevices)662 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenNoInputDevices) {
663   if (gpu_device_ == nullptr) {
664     GTEST_SKIP() << "No GPUs available";
665   }
666   TestTwoDeviceMult(this, MakeOptions("CPU:0", {}, {}),
667                     "input_devices must have the same length");
668 }
669 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenTooManyInputDevices)670 TEST_F(ProcessFunctionLibraryRuntimeTest,
671        MultiDevice_ErrorWhenTooManyInputDevices) {
672   if (gpu_device_ == nullptr) {
673     GTEST_SKIP() << "No GPUs available";
674   }
675   TestTwoDeviceMult(this, MakeOptions("CPU:0", {"CPU:0", "CPU:1"}, {}),
676                     "input_devices must have the same length");
677 }
678 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenTooManyOutputDevices)679 TEST_F(ProcessFunctionLibraryRuntimeTest,
680        MultiDevice_ErrorWhenTooManyOutputDevices) {
681   TestTwoDeviceMult(
682       this, MakeOptions("CPU:0", {"CPU:0"}, {"CPU:0", "GPU:0", "CPU:1"}),
683       "output_devices must either be empty or have the same length");
684 }
685 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenBadTargetDevice)686 TEST_F(ProcessFunctionLibraryRuntimeTest,
687        MultiDevice_ErrorWhenBadTargetDevice) {
688   TestTwoDeviceMult(
689       this, MakeOptions("GPU:11", {"CPU:0"}, {"CPU:0", "GPU:0"}),
690       "Cannot instantiate multi-device function with target device GPU:11");
691 }
692 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenListInput)693 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenListInput) {
694   const FunctionDef& def = test::function::FuncWithListInput();
695   Init({def});
696   FunctionLibraryRuntime::Handle handle;
697   Status status = proc_flr_->Instantiate(
698       "FuncWithListInput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}),
699       MakeOptions("CPU:0", {"CPU:0"}, {}), &handle);
700   ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
701   ASSERT_TRUE(absl::StrContains(
702       status.error_message(),
703       "FuncWithListInput has an input named \"x1\" that is a list of tensors"))
704       << "Actual error message: " << status.error_message();
705 }
706 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenListOutput)707 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenListOutput) {
708   const FunctionDef& def = test::function::FuncWithListOutput();
709   Init({def});
710   FunctionLibraryRuntime::Handle handle;
711   Status status = proc_flr_->Instantiate(
712       "FuncWithListOutput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}),
713       MakeOptions("CPU:0", {}, {"CPU:0"}), &handle);
714   ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
715   ASSERT_TRUE(absl::StrContains(
716       status.error_message(),
717       "FuncWithListOutput has an output named \"y\" that is a list of tensors"))
718       << "Actual error message: " << status.error_message();
719 }
720 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ExplicitMultiInputOutput)721 TEST_F(ProcessFunctionLibraryRuntimeTest,
722        MultiDevice_ExplicitMultiInputOutput) {
723   TestTwoDeviceInputOutput(
724       this, MakeOptions("CPU:0", {"CPU:0", "GPU:0"}, {"CPU:0", "GPU:0"}));
725 }
726 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_FlipInputs)727 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_FlipInputs) {
728   TestTwoDeviceInputOutput(
729       this, MakeOptions("CPU:0", {"GPU:0", "CPU:0"}, {"CPU:0", "GPU:0"}));
730 }
731 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_FlipOutputs)732 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_FlipOutputs) {
733   TestTwoDeviceInputOutput(
734       this, MakeOptions("CPU:0", {"CPU:0", "GPU:0"}, {"GPU:0", "CPU:0"}));
735 }
736 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_FlipBoth)737 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_FlipBoth) {
738   TestTwoDeviceInputOutput(
739       this, MakeOptions("CPU:0", {"GPU:0", "CPU:0"}, {"GPU:0", "CPU:0"}));
740 }
741 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_EmptyBodySwap)742 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_EmptyBodySwap) {
743   if (gpu_device_ == nullptr) {
744     GTEST_SKIP() << "No GPUs available";
745   }
746   FunctionLibraryRuntime::InstantiateOptions inst_opts =
747       MakeOptions("CPU:0", {"GPU:0", "CPU:0"}, {"CPU:0", "GPU:0"});
748   Init({test::function::EmptyBodySwap()});
749 
750   Tensor x1 = CPUToGPU(test::AsTensor<float>({1, 2}));
751   Tensor x2 = test::AsTensor<float>({10, 20});
752   Tensor y1;
753   Tensor y2;
754   TF_CHECK_OK(Run("EmptyBodySwap", {}, {{"T", DT_FLOAT}}, inst_opts, {x1, x2},
755                   {&y1, &y2}));
756 
757   EXPECT_FALSE(IsCUDATensor(y1));
758   test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({10, 20}));
759 
760   EXPECT_TRUE(IsCUDATensor(y2));
761   y2 = GPUToCPU(y2);
762   test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({1, 2}));
763 }
764 
GetResourceHandle(const string & var_name,const string & container,const string & device_name)765 Tensor GetResourceHandle(const string& var_name, const string& container,
766                          const string& device_name) {
767   ResourceHandle handle;
768   handle.set_device(device_name);
769   handle.set_container(container);
770   handle.set_name(var_name);
771   handle.set_hash_code(TypeIndex::Make<Var>().hash_code());
772   handle.set_maybe_type_name(TypeIndex::Make<Var>().name());
773   Tensor tensor(DT_RESOURCE, TensorShape({}));
774   tensor.scalar<ResourceHandle>()() = handle;
775   return tensor;
776 }
777 
778 // Returns a function which adds two variables on different devices.
AddVarAcrossDevices()779 FunctionDef AddVarAcrossDevices() {
780   return FunctionDefHelper::Create(
781       // Name
782       "AddVarAcrossDevices",
783       // Args
784       {"x: resource"},
785       // Return values
786       {"y: float"},
787       // Attr def
788       {},
789       // Nodes
790       {
791           {{"read0"},
792            "ReadVariableOp",
793            {"x"},
794            {{"dtype", DT_FLOAT}},
795            {},
796            "/device:CPU:0"},
797           {{"read1"},
798            "ReadVariableOp",
799            {"x"},
800            {{"dtype", DT_FLOAT}},
801            {},
802            "/device:CPU:1"},
803           {{"add"},
804            "Add",
805            {"read0:value:0", "read1:value:0"},
806            {{"T", DT_FLOAT}},
807            {},
808            "/device:CPU:0"},
809       },
810       {{"y", "add:z:0"}});
811 }
812 
813 // An implementation of FunctionArgsInterface for packed inputs.
814 class TestFunctionPackedArgs : public FunctionArgsInterface {
815  public:
TestFunctionPackedArgs(const int index,gtl::InlinedVector<TensorValue,4> && tensor_args)816   TestFunctionPackedArgs(const int index,
817                          gtl::InlinedVector<TensorValue, 4>&& tensor_args) {
818     packed_args_.emplace(index, std::move(tensor_args));
819   }
820 
~TestFunctionPackedArgs()821   ~TestFunctionPackedArgs() override{};
822 
HasRemoteOrPackedInputs() const823   bool HasRemoteOrPackedInputs() const override { return true; };
824 
GetLocalArg(const FunctionArgIndex & index,Tensor * val) const825   Status GetLocalArg(const FunctionArgIndex& index,
826                      Tensor* val) const override {
827     *val = *packed_args_.at(index.index).at(index.sub_index).tensor;
828     return Status::OK();
829   };
830 
GetLocalTensors() const831   std::vector<Tensor> GetLocalTensors() const override { return {}; }
832 
833  private:
834   absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
835 };
836 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_CompositeDevice)837 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
838   Init({AddVarAcrossDevices()});
839   // Create two variables on two devices.
840   const Tensor initial_resource_value0 = test::AsTensor<float>({10, 20});
841   Var* resource0 = new Var(DT_FLOAT);
842   *resource0->tensor() = initial_resource_value0;
843   resource0->is_initialized = true;
844   const Tensor initial_resource_value1 = test::AsTensor<float>({30, 40});
845   Var* resource1 = new Var(DT_FLOAT);
846   *resource1->tensor() = initial_resource_value1;
847   resource1->is_initialized = true;
848   ResourceMgr* mgr0 = device0_->resource_manager();
849   ResourceMgr* mgr1 = device1_->resource_manager();
850   TF_ASSERT_OK(mgr0->Create(mgr0->default_container(), "var", resource0));
851   TF_ASSERT_OK(mgr1->Create(mgr1->default_container(), "var", resource1));
852 
853   Tensor resource_handle0 =
854       GetResourceHandle("var", mgr0->default_container(), device0_->name());
855   Tensor resource_handle1 =
856       GetResourceHandle("var", mgr1->default_container(), device1_->name());
857 
858   // Create a CompositeDevice
859   Status s;
860   std::unique_ptr<CompositeDevice> composite_device =
861       CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
862                                   /*unique_device_id=*/0,
863                                   device_mgr_->HostCPU()->parsed_name(), &s);
864   TF_ASSERT_OK(s);
865   AddCompositeDevice(composite_device.get());
866 
867   FunctionLibraryRuntime::Options opts;
868   FunctionLibraryRuntime::InstantiateOptions inst_opts =
869       MakeOptions("CPU:0", {"COMPOSITE:0"}, {"CPU:0"});
870   inst_opts.composite_devices[composite_device->name()] =
871       composite_device->underlying_devices();
872   inst_opts.input_resource_dtypes_and_shapes[0] = {
873       initial_resource_value0.dtype(), initial_resource_value0.shape()};
874 
875   // Packed TensorHandle
876   {
877     gtl::InlinedVector<TensorValue, 4> handles;
878     handles.push_back(TensorValue(&resource_handle0));
879     handles.push_back(TensorValue(&resource_handle1));
880     TestFunctionPackedArgs args(0, std::move(handles));
881     FunctionRet ret;
882     TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts,
883                                   {{"T", DT_FLOAT}}, inst_opts, args, {&ret}));
884     EXPECT_EQ(ret.index(), 0);
885     test::ExpectTensorEqual<float>(absl::get<Tensor>(ret),
886                                    test::AsTensor<float>({40, 60}));
887   }
888 
889   // Packed Tensor
890   {
891     Tensor arg(DT_RESOURCE, TensorShape({2}));
892     arg.flat<ResourceHandle>()(0) = resource_handle0.scalar<ResourceHandle>()();
893     arg.flat<ResourceHandle>()(1) = resource_handle1.scalar<ResourceHandle>()();
894 
895     Tensor ret;
896     TF_CHECK_OK(Run("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, inst_opts,
897                     {arg}, {&ret}));
898     test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
899   }
900 }
901 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ResourceOutput_GPU)902 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
903   if (gpu_device_ == nullptr) {
904     GTEST_SKIP() << "No GPUs available";
905   }
906   FunctionLibraryRuntime::InstantiateOptions inst_opts =
907       MakeOptions("CPU:0", {"GPU:0", "GPU:0"}, {"GPU:0", "GPU:0"});
908   Init({test::function::ResourceOutput(),
909         test::function::ReadResourceVariable()});
910 
911   // Make resource var
912   Tensor resource_value = CPUToGPU(test::AsTensor<float>({10, 20}));
913   Var* resource = new Var(DT_FLOAT);
914   *resource->tensor() = resource_value;
915   resource->is_initialized = true;
916   ResourceMgr* mgr = gpu_device_->resource_manager();
917   Status status = mgr->Create(mgr->default_container(), "my_gpu_var", resource);
918   ASSERT_TRUE(status.ok()) << status.error_message();
919 
920   // Run the function taking a resource and outputing it
921   FunctionLibraryRuntime::Options opts;
922   Tensor x1 = CPUToGPU(test::AsTensor<float>({1, 2}));
923   Tensor x2 = GetResourceHandle("my_gpu_var", mgr->default_container(),
924                                 "/job:a/replica:0/task:0/device:GPU:0");
925   Tensor returned_handle;
926   Tensor y2;
927   TF_CHECK_OK(Run("ResourceOutput", opts, {{"T", DT_FLOAT}}, inst_opts,
928                   {x1, x2}, {&returned_handle, &y2}));
929 
930   EXPECT_FALSE(IsCUDATensor(returned_handle));
931   EXPECT_TRUE(IsCUDATensor(y2));
932   y2 = GPUToCPU(y2);
933   test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({2, 4}));
934 
935   // Read the variable using the handle returned from previous function to
936   // make sure the handle and read value is on the right device.
937   inst_opts = MakeOptions("GPU:0", {"GPU:0"}, {"GPU:0"});
938   Tensor read_resource;
939   TF_CHECK_OK(Run("ReadResourceVariable", opts, {{"T", DT_FLOAT}}, inst_opts,
940                   {returned_handle}, {&read_resource}));
941   EXPECT_TRUE(IsCUDATensor(read_resource));
942   read_resource = GPUToCPU(read_resource);
943   test::ExpectTensorEqual<float>(read_resource,
944                                  test::AsTensor<float>({10, 20}));
945 }
946 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_PlacerError)947 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_PlacerError) {
948   if (gpu_device_ == nullptr) {
949     GTEST_SKIP() << "No GPUs available";
950   }
951   // ResourceOutput forwards second input to first output. Both are resources.
952   // Placer should not be able to place this graph because we ask it to place
953   // second input on GPU but first output to CPU.
954   FunctionLibraryRuntime::InstantiateOptions inst_opts =
955       MakeOptions("CPU:0", {"GPU:0", "GPU:0"}, {"CPU:0", "GPU:0"});
956   Init({test::function::ResourceOutput(),
957         test::function::ReadResourceVariable()});
958 
959   FunctionLibraryRuntime::Handle handle;
960   Status status = proc_flr_->Instantiate(
961       "ResourceOutput", test::function::Attrs({{"T", DT_FLOAT}}), inst_opts,
962       &handle);
963   ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
964   ASSERT_TRUE(absl::StrContains(status.error_message(), "Cannot place"));
965 }
966 
967 REGISTER_OP("BrokenOp")
968     .Input("in: T")
969     .Output("out: T")
970     .Attr("T: type")
971     .SetShapeFn(shape_inference::UnknownShape);
972 class BrokenOp : public OpKernel {
973  public:
BrokenOp(OpKernelConstruction * ctx)974   explicit BrokenOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
975     ctx->SetStatus(errors::Internal("I am broken"));
976   }
977 
Compute(OpKernelContext * ctx)978   void Compute(OpKernelContext* ctx) override {
979     ctx->SetStatus(errors::Internal("I am broken"));
980   }
981 };
982 REGISTER_KERNEL_BUILDER(Name("BrokenOp").Device(DEVICE_CPU), BrokenOp);
983 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_CreateKernelsEagerly)984 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CreateKernelsEagerly) {
985   auto T = DT_INT32;
986   // The expected sequence of outputs from this function is [6, 4, 0, 1, ...].
987   FunctionDef broken_func = FunctionDefHelper::Define(
988       // Name
989       "Broken",
990       // Args
991       {"x: int32"},
992       // Return values
993       {"y: int32"},
994       // Attrs
995       {},
996       // Nodes
997       {{{"y"}, "BrokenOp", {"x"}, {{"T", T}}}});
998   Init({broken_func});
999 
1000   FunctionLibraryRuntime::InstantiateOptions inst_opts =
1001       MakeOptions("CPU:0", {"CPU:0"}, {"CPU:0"});
1002 
1003   // Instantiating the broken function should work.
1004   FunctionLibraryRuntime::Handle handle;
1005   TF_CHECK_OK(Instantiate("Broken", {{"T", DT_INT32}}, inst_opts, &handle));
1006   TF_CHECK_OK(proc_flr_->ReleaseHandle(handle));
1007 
1008   // Instantiating the broken function while creating kernels eagerly should
1009   // fail.
1010   inst_opts.create_kernels_eagerly = true;
1011   Status status = Instantiate("Broken", {{"T", DT_INT32}}, inst_opts, &handle);
1012   EXPECT_TRUE(errors::IsInternal(status));
1013 }
1014 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_StateHandle)1015 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_StateHandle) {
1016   auto T = DT_INT32;
1017   // The expected sequence of outputs from this function is [6, 4, 0, 1, ...].
1018   FunctionDef stateful_func = FunctionDefHelper::Define(
1019       // Name
1020       "RandomUniformWrapper",
1021       // Args
1022       {"x: resource"},
1023       // Return values
1024       {"y: int32"},
1025       // Attrs
1026       {},
1027       // Nodes
1028       {FunctionDefHelper::Const<int32>("shape", gtl::ArraySlice<int32>({1})),
1029        FunctionDefHelper::Const<int32>("minval", 0),
1030        {{"maxval"}, "ReadVariableOp", {"x"}, {{"dtype", T}}, {}},
1031        // A stateful node.
1032        {{"y"},
1033         "RandomUniformInt",
1034         {"shape", "minval", "maxval"},
1035         {{"seed", 37}, {"seed2", 48}, {"Tout", T}, {"T", T}}}});
1036   Init({stateful_func});
1037   if (gpu_device_ == nullptr) {
1038     GTEST_SKIP() << "No GPUs available";
1039   }
1040 
1041   // Make resource variables.
1042   ResourceMgr* mgr = gpu_device_->resource_manager();
1043   Tensor resource_value = CPUToGPU(test::AsScalar<int>(10));
1044   Var* resource = new Var(T);
1045   *resource->tensor() = resource_value;
1046   resource->is_initialized = true;
1047   Status status = mgr->Create(mgr->default_container(), "my_gpu_var", resource);
1048   ASSERT_TRUE(status.ok()) << status.error_message();
1049 
1050   Tensor x = GetResourceHandle("my_gpu_var", mgr->default_container(),
1051                                "/job:a/replica:0/task:0/device:GPU:0");
1052   Tensor y;
1053 
1054   FunctionLibraryRuntime::InstantiateOptions inst_opts =
1055       MakeOptions("CPU:0", {"GPU:0"}, {"CPU:0"});
1056 
1057   // Instantiate the function with no state handle.
1058   FunctionLibraryRuntime::Handle handle;
1059   TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}}, inst_opts,
1060                           &handle));
1061   for (auto expected : {6, 4}) {
1062     TF_CHECK_OK(RunInstantiated(handle, {}, {x}, {&y}));
1063     test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1064   }
1065 
1066   // Instantiating the function again with no state handle should result in the
1067   // same handle.
1068   FunctionLibraryRuntime::Handle other_handle;
1069   TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}}, inst_opts,
1070                           &other_handle));
1071   EXPECT_EQ(handle, other_handle);
1072   // Running the function should yield continuation of the same sequence.
1073   for (auto expected : {0, 1}) {
1074     TF_CHECK_OK(RunInstantiated(other_handle, {}, {x}, {&y}));
1075     test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1076   }
1077 
1078   // Instantiating the function with a state handle should result in a different
1079   // handle.
1080   inst_opts.state_handle = "handle_1";
1081   TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}}, inst_opts,
1082                           &other_handle));
1083   EXPECT_NE(handle, other_handle);
1084   // Running the function should yield the original sequeunce.
1085   for (auto expected : {6, 4, 0, 1}) {
1086     TF_CHECK_OK(RunInstantiated(other_handle, {}, {x}, {&y}));
1087     test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1088   }
1089 
1090   // Instantiating the function with a different state handle should result in a
1091   // different handle.
1092   inst_opts.state_handle = "handle_2";
1093   TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}}, inst_opts,
1094                           &other_handle));
1095   EXPECT_NE(handle, other_handle);
1096   // Running the function should yield the original sequeunce.
1097   for (auto expected : {6, 4, 0, 1}) {
1098     TF_CHECK_OK(RunInstantiated(other_handle, {}, {x}, {&y}));
1099     test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1100   }
1101 
1102   // Repeatedly instantiating a function and releasing its handle will result in
1103   // repeating the original sequence.
1104   inst_opts.state_handle = "handle_3";
1105   for (int i = 0; i < 2; ++i) {
1106     TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}},
1107                             inst_opts, &other_handle));
1108     EXPECT_NE(handle, other_handle);
1109     // Running the function should yield the original sequeunce.
1110     for (auto expected : {6, 4, 0, 1}) {
1111       TF_CHECK_OK(RunInstantiated(other_handle, {}, {x}, {&y}));
1112       test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1113     }
1114     TF_CHECK_OK(proc_flr_->ReleaseHandle(other_handle));
1115   }
1116 }
1117 
1118 REGISTER_OP("SessionMetadataReader")
1119     .Input("x: int64")
1120     .Output("y: string")
1121     .SetIsStateful()
1122     .Doc(R"doc(SessionMetadataReader returns the session metadata.
1123 
1124 x: int64
1125 y: string
1126 )doc");
1127 
1128 class SessionMetadataReaderOp : public OpKernel {
1129  public:
SessionMetadataReaderOp(OpKernelConstruction * ctx)1130   explicit SessionMetadataReaderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1131   void Compute(OpKernelContext* ctx) override {
1132     Tensor* out_tensor = nullptr;
1133     OP_REQUIRES_OK(ctx,
1134                    ctx->allocate_output("y", TensorShape({}), &out_tensor));
1135     if (ctx->session_metadata() != nullptr) {
1136       out_tensor->scalar<tstring>()() = ctx->session_metadata()->DebugString();
1137     } else {
1138       out_tensor->scalar<tstring>()() = "";
1139     }
1140   }
1141 };
1142 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_CPU),
1143                         SessionMetadataReaderOp);
1144 
SessionMetadataReaderOpFn()1145 FunctionDef SessionMetadataReaderOpFn() {
1146   return FunctionDefHelper::Define(
1147       // Name
1148       "SessionMetadataReaderFn",
1149       // Args
1150       {"x: int64"},
1151       // Return values
1152       {"y: string"},
1153       // Attr def
1154       {},
1155       // Nodes
1156       {{{"y"}, "SessionMetadataReader", {"x"}, {}}});
1157 }
1158 
TEST_F(ProcessFunctionLibraryRuntimeTest,SessionMetadataAbsent)1159 TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataAbsent) {
1160   Init({SessionMetadataReaderOpFn()}, /*session_metadata=*/nullptr);
1161   FunctionLibraryRuntime::Options opts;
1162   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
1163   opts.remote_execution = true;
1164   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
1165   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
1166   const auto x = test::AsTensor<int64>({17});
1167   Tensor y;
1168   TF_CHECK_OK(
1169       Run("SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y}));
1170   EXPECT_EQ("", y.scalar<tstring>()());
1171 }
1172 
TEST_F(ProcessFunctionLibraryRuntimeTest,SessionMetadataPresent)1173 TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
1174   const SessionMetadata session_metadata = GenerateSessionMetadata();
1175   Init({SessionMetadataReaderOpFn()}, &session_metadata);
1176   FunctionLibraryRuntime::Options opts;
1177   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
1178   opts.remote_execution = true;
1179   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
1180   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
1181   const auto x = test::AsTensor<int64>({17});
1182   Tensor y;
1183   TF_CHECK_OK(
1184       Run("SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y}));
1185   SessionMetadata read_metadata;
1186   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
1187                                                     &read_metadata));
1188   EXPECT_EQ(session_metadata.name(), read_metadata.name());
1189   EXPECT_EQ(session_metadata.version(), read_metadata.version());
1190 }
1191 
TEST_F(ProcessFunctionLibraryRuntimeTest,CompositeDevicesAfterCloning)1192 TEST_F(ProcessFunctionLibraryRuntimeTest, CompositeDevicesAfterCloning) {
1193   Init({AddVarAcrossDevices()});
1194 
1195   Status s;
1196   std::unique_ptr<CompositeDevice> composite_device =
1197       CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
1198                                   /*unique_device_id=*/0,
1199                                   device_mgr_->HostCPU()->parsed_name(), &s);
1200   TF_ASSERT_OK(s);
1201   AddCompositeDevice(composite_device.get());
1202 
1203   auto* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
1204   ASSERT_NE(nullptr, flr);
1205   std::unique_ptr<FunctionLibraryDefinition> cloned_lib_def;
1206   std::unique_ptr<ProcessFunctionLibraryRuntime> cloned_proc_flr;
1207   FunctionLibraryRuntime* cloned_flr;
1208   TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
1209   EXPECT_EQ(
1210       cloned_proc_flr->device_set()->FindDeviceByName(composite_device->name()),
1211       composite_device.get());
1212 }
1213 
TEST_F(ProcessFunctionLibraryRuntimeTest,SessionMetadataPresentAfterCloning)1214 TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
1215   const SessionMetadata session_metadata = GenerateSessionMetadata();
1216   Init({SessionMetadataReaderOpFn()}, &session_metadata);
1217   auto* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
1218   ASSERT_NE(nullptr, flr);
1219   std::unique_ptr<FunctionLibraryDefinition> cloned_lib_def;
1220   std::unique_ptr<ProcessFunctionLibraryRuntime> cloned_proc_flr;
1221   FunctionLibraryRuntime* cloned_flr;
1222   TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
1223   FunctionLibraryRuntime::Options opts;
1224   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
1225   opts.remote_execution = true;
1226   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
1227   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
1228   const auto x = test::AsTensor<int64>({17});
1229   Tensor y;
1230   Status s = RunWithRuntime<std::vector<Tensor>, Tensor>(
1231       "SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y},
1232       cloned_proc_flr.get());
1233   TF_CHECK_OK(s);
1234   SessionMetadata read_metadata;
1235   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
1236                                                     &read_metadata));
1237   EXPECT_EQ(session_metadata.name(), read_metadata.name());
1238   EXPECT_EQ(session_metadata.version(), read_metadata.version());
1239 }
1240 
1241 }  // anonymous namespace
1242 }  // namespace tensorflow
1243