• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16 
17 #include <memory>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/composite_device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/function_testlib.h"
24 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/function_testlib.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_var.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/framework/type_index.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/lib/core/threadpool.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/protobuf/config.pb.h"
38 #include "tensorflow/core/public/session_options.h"
39 #include "tensorflow/core/public/version.h"
40 
41 #if GOOGLE_CUDA
42 #include "third_party/gpus/cuda/include/cuda.h"
43 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
44 #elif TENSORFLOW_USE_ROCM
45 #include "rocm/include/hip/hip_runtime.h"
46 #endif  // GOOGLE_CUDA
47 
48 namespace tensorflow {
49 namespace {
50 
51 class TestClusterFLR : public DistributedFunctionLibraryRuntime {
52  public:
TestClusterFLR(DeviceMgr * device_mgr)53   explicit TestClusterFLR(DeviceMgr* device_mgr) : device_mgr_(device_mgr) {}
54 
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)55   void Instantiate(const string& function_name,
56                    const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
57                    const FunctionLibraryRuntime::InstantiateOptions& options,
58                    FunctionLibraryRuntime::LocalHandle* handle,
59                    FunctionLibraryRuntime::DoneCallback done) override {
60     {
61       mutex_lock l(mu_);
62       *handle = next_handle_;
63       next_handle_++;
64     }
65     done(Status::OK());
66   }
67 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)68   void Run(const FunctionLibraryRuntime::Options& opts,
69            FunctionLibraryRuntime::LocalHandle handle,
70            gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
71            FunctionLibraryRuntime::DoneCallback done) override {}
72 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)73   void Run(const FunctionLibraryRuntime::Options& opts,
74            FunctionLibraryRuntime::LocalHandle handle,
75            gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
76            FunctionLibraryRuntime::DoneCallback done) override {}
77 
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)78   void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
79                FunctionLibraryRuntime::DoneCallback done) override {}
80 
remote_device_mgr() const81   DeviceMgr* remote_device_mgr() const override { return device_mgr_; }
82 
83  private:
84   mutex mu_;
85   int next_handle_ TF_GUARDED_BY(mu_) = 0;
86   DeviceMgr* device_mgr_;
87 };
88 
GenerateSessionMetadata()89 SessionMetadata GenerateSessionMetadata() {
90   SessionMetadata session_metadata;
91   session_metadata.set_name("name");
92   session_metadata.set_version(42);
93   return session_metadata;
94 }
95 
96 // TODO(b/128707168): Tests requiring a GPU device are currently always skipped
97 // because the check for whether a GPU device is present happens before the GPU
98 // device is set up.
99 class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
100  public:
ProcessFunctionLibraryRuntimeTest()101   ProcessFunctionLibraryRuntimeTest() {
102     SessionOptions options;
103     auto* device_count = options.config.mutable_device_count();
104     device_count->insert({"CPU", 3});
105     std::vector<std::unique_ptr<Device>> created_devices;
106     TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
107                                           &created_devices));
108     // Do not add CPU:2 to device manager. Used for removed device testing.
109     device2_ = std::move(created_devices[2]);
110     created_devices.erase(created_devices.begin() + 2);
111 
112     device_mgr_ = std::make_unique<DynamicDeviceMgr>();
113     TF_CHECK_OK(device_mgr_->AddDevices(std::move(created_devices)));
114     TF_CHECK_OK(device_mgr_->LookupDevice(
115         "/job:a/replica:0/task:0/device:CPU:0", &device0_));
116     TF_CHECK_OK(device_mgr_->LookupDevice(
117         "/job:a/replica:0/task:0/device:CPU:1", &device1_));
118     Device* device2_ptr = nullptr;
119     EXPECT_NE(
120         error::OK,
121         device_mgr_
122             ->LookupDevice("/job:a/replica:0/task:0/device:CPU:2", &device2_ptr)
123             .code());
124     // If no GPU is available, gpu_device_ will remain nullptr.
125     Status status = device_mgr_->LookupDevice(
126         "/job:a/replica:0/task:0/device:GPU:0", &gpu_device_);
127     if (!status.ok()) {
128       CHECK_EQ(nullptr, gpu_device_);
129     }
130   }
131 
Init(const std::vector<FunctionDef> & flib,const SessionMetadata * session_metadata=nullptr)132   void Init(const std::vector<FunctionDef>& flib,
133             const SessionMetadata* session_metadata = nullptr) {
134     FunctionDefLibrary proto;
135     for (const auto& fdef : flib) *(proto.add_function()) = fdef;
136     lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
137     OptimizerOptions opts;
138     cluster_flr_.reset(new TestClusterFLR(device_mgr_.get()));
139     proc_flr_.reset(new ProcessFunctionLibraryRuntime(
140         device_mgr_.get(), Env::Default(), /*config=*/nullptr,
141         TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
142         /*thread_pool=*/nullptr, cluster_flr_.get(), session_metadata,
143         Rendezvous::Factory{
144             [this](const int64_t step_id, const DeviceMgr* device_mgr,
145                    Rendezvous** r) {
146               *r = new IntraProcessRendezvous(device_mgr);
147               if (rendezvous_ref_counts_.find(step_id) !=
148                   rendezvous_ref_counts_.end()) {
149                 rendezvous_ref_counts_[step_id]++;
150               } else {
151                 rendezvous_ref_counts_[step_id] = 1;
152               }
153               return Status::OK();
154             },
155             [this](const int64_t step_id) {
156               CHECK(rendezvous_ref_counts_.find(step_id) !=
157                     rendezvous_ref_counts_.end());
158               rendezvous_ref_counts_[step_id]--;
159               return Status::OK();
160             }}));
161   }
162 
AddCompositeDevice(CompositeDevice * d)163   void AddCompositeDevice(CompositeDevice* d) {
164     proc_flr_->AddCompositeDevice(d);
165   }
166 
Instantiate(const string & name,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & instantiate_opts,FunctionLibraryRuntime::Handle * handle)167   Status Instantiate(
168       const string& name, test::function::Attrs attrs,
169       const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
170       FunctionLibraryRuntime::Handle* handle) {
171     return proc_flr_->Instantiate(name, attrs, instantiate_opts, handle);
172   }
173 
GPUToCPU(const Tensor & device_tensor)174   Tensor GPUToCPU(const Tensor& device_tensor) {
175 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
176     CHECK(gpu_device_);
177     CHECK(gpu_device_->tensorflow_gpu_device_info() != nullptr);
178     DeviceContext* device_context =
179         gpu_device_->tensorflow_gpu_device_info()->default_context;
180 
181     Tensor cpu_tensor(device_tensor.dtype(), device_tensor.shape());
182     CHECK(device_context
183               ->CopyDeviceTensorToCPUSync(&device_tensor, "", gpu_device_,
184                                           &cpu_tensor)
185               .ok());
186     return cpu_tensor;
187 #else
188     CHECK(false);
189 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
190   }
191 
CPUToGPU(const Tensor & cpu_tensor)192   Tensor CPUToGPU(const Tensor& cpu_tensor) {
193 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
194     CHECK(gpu_device_);
195     CHECK(gpu_device_->tensorflow_gpu_device_info() != nullptr);
196     DeviceContext* device_context =
197         gpu_device_->tensorflow_gpu_device_info()->default_context;
198 
199     Tensor device_tensor(gpu_device_->GetAllocator({}), cpu_tensor.dtype(),
200                          cpu_tensor.shape(), {});
201     CHECK(device_context
202               ->CopyCPUTensorToDeviceSync(&cpu_tensor, gpu_device_,
203                                           &device_tensor)
204               .ok());
205     return device_tensor;
206 #else
207     CHECK(false);
208 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
209   }
210 
211   template <typename T, typename K>
RunWithRuntime(const string & name,FunctionLibraryRuntime::Options opts,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & instantiate_opts,const T & args,std::vector<K * > rets,ProcessFunctionLibraryRuntime * pflr)212   Status RunWithRuntime(
213       const string& name, FunctionLibraryRuntime::Options opts,
214       test::function::Attrs attrs,
215       const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
216       const T& args, std::vector<K*> rets,
217       ProcessFunctionLibraryRuntime* pflr) {
218     FunctionLibraryRuntime::Handle handle;
219     Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle);
220     if (!status.ok()) {
221       return status;
222     }
223     bool is_cross_process = false;
224     TF_CHECK_OK(pflr->IsCrossProcess(handle, &is_cross_process));
225     EXPECT_FALSE(is_cross_process);
226 
227     std::function<void(std::function<void()>)> runner =
228         [](std::function<void()> fn) {
229           test::function::FunctionTestSchedClosure(fn);
230         };
231     Notification done;
232     opts.runner = &runner;
233     std::vector<K> out;
234     pflr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
235       status = s;
236       done.Notify();
237     });
238     done.WaitForNotification();
239     if (!status.ok()) {
240       return status;
241     }
242     CHECK_EQ(rets.size(), out.size());
243     for (size_t i = 0; i < rets.size(); ++i) {
244       *rets[i] = out[i];
245     }
246 
247     // Release the handle and then try running the function. It shouldn't
248     // succeed.
249     status = pflr->ReleaseHandle(handle);
250     if (!status.ok()) {
251       return status;
252     }
253     Notification done2;
254     pflr->Run(opts, handle, args, &out, [&status, &done2](const Status& s) {
255       status = s;
256       done2.Notify();
257     });
258     done2.WaitForNotification();
259     EXPECT_TRUE(errors::IsNotFound(status)) << "Actual status: " << status;
260     EXPECT_TRUE(absl::StrContains(status.error_message(), "not found."));
261 
262     return Status::OK();
263   }
264 
Run(const string & name,FunctionLibraryRuntime::Options opts,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & instantiate_opts,const std::vector<Tensor> & args,std::vector<Tensor * > rets,ProcessFunctionLibraryRuntime * pflr=nullptr)265   Status Run(const string& name, FunctionLibraryRuntime::Options opts,
266              test::function::Attrs attrs,
267              const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
268              const std::vector<Tensor>& args, std::vector<Tensor*> rets,
269              ProcessFunctionLibraryRuntime* pflr = nullptr) {
270     return RunWithRuntime<std::vector<Tensor>, Tensor>(
271         name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
272   }
273 
RunWithPackedArgs(const string & name,FunctionLibraryRuntime::Options opts,test::function::Attrs attrs,const FunctionLibraryRuntime::InstantiateOptions & instantiate_opts,const FunctionArgsInterface & args,std::vector<FunctionRet * > rets,ProcessFunctionLibraryRuntime * pflr=nullptr)274   Status RunWithPackedArgs(
275       const string& name, FunctionLibraryRuntime::Options opts,
276       test::function::Attrs attrs,
277       const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
278       const FunctionArgsInterface& args, std::vector<FunctionRet*> rets,
279       ProcessFunctionLibraryRuntime* pflr = nullptr) {
280     return RunWithRuntime<FunctionArgsInterface, FunctionRet>(
281         name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get());
282   }
283 
RunInstantiated(FunctionLibraryRuntime::Handle handle,FunctionLibraryRuntime::Options opts,const std::vector<Tensor> & args,std::vector<Tensor * > rets)284   Status RunInstantiated(FunctionLibraryRuntime::Handle handle,
285                          FunctionLibraryRuntime::Options opts,
286                          const std::vector<Tensor>& args,
287                          std::vector<Tensor*> rets) {
288     std::function<void(std::function<void()>)> runner =
289         [](std::function<void()> fn) {
290           test::function::FunctionTestSchedClosure(fn);
291         };
292 
293     opts.runner = &runner;
294     Status status;
295     Notification done;
296     std::vector<Tensor> out;
297     proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
298       status = s;
299       done.Notify();
300     });
301     done.WaitForNotification();
302     if (!status.ok()) {
303       return status;
304     }
305     CHECK_EQ(rets.size(), out.size());
306     for (size_t i = 0; i < rets.size(); ++i) {
307       *rets[i] = out[i];
308     }
309     return Status::OK();
310   }
311 
312   std::unique_ptr<DynamicDeviceMgr> device_mgr_;
313   Device* device0_ = nullptr;  // Not owned. (Owned by device_mgr_.)
314   Device* device1_ = nullptr;  // Not owned. (Owned by device_mgr_.)
315   std::unique_ptr<Device> device2_;
316   // Remains as nullptr if no GPU is available.
317   Device* gpu_device_ = nullptr;  // Not owned. (Owned by device_mgr_.)
318   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
319   std::unique_ptr<TestClusterFLR> cluster_flr_;
320   std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
321 
322   // To ensure that we are cleaning up the rendezvous properly.
323   std::unordered_map<int64, int> rendezvous_ref_counts_;
324 };
325 
TEST_F(ProcessFunctionLibraryRuntimeTest,GetFLRNull)326 TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
327   FunctionDefLibrary proto;
328   std::unique_ptr<FunctionLibraryDefinition> lib_def(
329       new FunctionLibraryDefinition(OpRegistry::Global(), proto));
330   OptimizerOptions opts;
331   std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
332       new ProcessFunctionLibraryRuntime(
333           nullptr /* device_mgr */, Env::Default(), /*config=*/nullptr,
334           TF_GRAPH_DEF_VERSION, lib_def.get(), opts));
335   FunctionLibraryRuntime* flr =
336       proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
337   EXPECT_NE(flr, nullptr);
338 }
339 
TEST_F(ProcessFunctionLibraryRuntimeTest,Basic)340 TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
341   Init({});
342   FunctionLibraryRuntime* flr =
343       proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
344   EXPECT_NE(flr, nullptr);
345   EXPECT_EQ(flr->device(), device0_);
346   flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
347   EXPECT_NE(flr, nullptr);
348   EXPECT_EQ(flr->device(), device0_);
349   flr = proc_flr_->GetFLR("/device:CPU:0");
350   EXPECT_NE(flr, nullptr);
351   EXPECT_EQ(flr->device(), device0_);
352   flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1");
353   EXPECT_NE(flr, nullptr);
354   EXPECT_EQ(flr->device(), device1_);
355   flr = proc_flr_->GetFLR("abc");
356   EXPECT_EQ(flr, nullptr);
357 }
358 
TEST_F(ProcessFunctionLibraryRuntimeTest,GetDeviceIncarnation)359 TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) {
360   Init({});
361   int64_t incarnation;
362   TF_EXPECT_OK(proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:1",
363                                                &incarnation));
364   // Incarnation is a random number other than 0.
365   EXPECT_NE(incarnation, 0);
366   Status s = proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:2",
367                                              &incarnation);
368   EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
369 }
370 
TEST_F(ProcessFunctionLibraryRuntimeTest,SingleCall)371 TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
372   Init({test::function::XTimesTwo()});
373   FunctionLibraryRuntime::Options opts;
374   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
375   opts.remote_execution = true;
376   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
377   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
378   auto x = test::AsTensor<float>({1, 2, 3, 4});
379   Tensor y;
380   TF_CHECK_OK(
381       Run("XTimesTwo", opts, {{"T", DT_FLOAT}}, instantiate_opts, {x}, {&y}));
382   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
383 }
384 
TEST_F(ProcessFunctionLibraryRuntimeTest,SingleCallFindDevice)385 TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
386   Init({test::function::FindDevice()});
387   FunctionLibraryRuntime::Options opts;
388   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
389   opts.remote_execution = true;
390   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
391   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
392   Tensor y;
393   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts, {}, {&y}));
394   test::ExpectTensorEqual<tstring>(
395       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:0"},
396                                  TensorShape({})));
397   EXPECT_EQ(1, rendezvous_ref_counts_.size());
398   EXPECT_EQ(opts.step_id, rendezvous_ref_counts_.begin()->first);
399   EXPECT_EQ(0, rendezvous_ref_counts_.begin()->second);
400 }
401 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultipleCallsSameDeviceXTimes)402 TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
403   Init({test::function::XTimesTwo(), test::function::XTimesFour()});
404   auto x = test::AsTensor<float>({1, 2, 3, 4});
405   FunctionLibraryRuntime::Options opts;
406   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
407   opts.remote_execution = true;
408   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
409   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
410   Tensor y;
411   TF_CHECK_OK(
412       Run("XTimesTwo", opts, {{"T", DT_FLOAT}}, instantiate_opts, {x}, {&y}));
413   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
414   TF_CHECK_OK(
415       Run("XTimesFour", opts, {{"T", DT_FLOAT}}, instantiate_opts, {x}, {&y}));
416   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
417 }
418 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultipleCallsSameDeviceFindDevice)419 TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
420   Init({test::function::FindDevice()});
421   FunctionLibraryRuntime::Options opts;
422   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
423   opts.remote_execution = true;
424   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
425   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:1";
426   Tensor y;
427   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts, {}, {&y}));
428   test::ExpectTensorEqual<tstring>(
429       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:1"},
430                                  TensorShape({})));
431   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts, {}, {&y}));
432   test::ExpectTensorEqual<tstring>(
433       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:1"},
434                                  TensorShape({})));
435 }
436 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultipleCallsDiffDeviceFindDevice)437 TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
438   Init({test::function::FindDevice()});
439   FunctionLibraryRuntime::Options opts;
440   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
441   opts.remote_execution = true;
442   Tensor y;
443   FunctionLibraryRuntime::InstantiateOptions instantiate_opts_0;
444   instantiate_opts_0.target = "/job:a/replica:0/task:0/device:CPU:0";
445   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts_0, {}, {&y}));
446   test::ExpectTensorEqual<tstring>(
447       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:0"},
448                                  TensorShape({})));
449   FunctionLibraryRuntime::InstantiateOptions instantiate_opts_1;
450   instantiate_opts_1.target = "/job:a/replica:0/task:0/device:CPU:1";
451   TF_CHECK_OK(Run("FindDevice", opts, {}, instantiate_opts_1, {}, {&y}));
452   test::ExpectTensorEqual<tstring>(
453       y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:1"},
454                                  TensorShape({})));
455 }
456 
TEST_F(ProcessFunctionLibraryRuntimeTest,InstantiateFunctionOnRemovedDevice)457 TEST_F(ProcessFunctionLibraryRuntimeTest, InstantiateFunctionOnRemovedDevice) {
458   std::vector<std::unique_ptr<Device>> devices;
459   Device* device2_ptr = device2_.get();
460   devices.emplace_back(std::move(device2_));
461   TF_CHECK_OK(device_mgr_->AddDevices(std::move(devices)));
462 
463   Init({test::function::FindDevice()});
464   std::vector<Device*> remove_devices{device2_ptr};
465   TF_CHECK_OK(device_mgr_->RemoveDevices(std::move(remove_devices)));
466 
467   // Since the process FLR device set is not updated yet, it still holds the
468   // raw pointer to device2. Make sure that function instantion with device2
469   // will not lead to segfault.
470   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
471   FunctionLibraryRuntime::Handle h;
472   instantiate_opts.target = "/job:a/replica:0/task:0/device:CPU:1";
473   instantiate_opts.is_multi_device_function = true;
474   TF_CHECK_OK(Instantiate("FindDevice",
475                           {{"_target", "/job:b/replica:0/task:0/device:CPU:2"}},
476                           instantiate_opts, &h));
477 }
478 
TEST_F(ProcessFunctionLibraryRuntimeTest,ClusterFLRSerialTest)479 TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
480   Init({test::function::FindDevice()});
481   FunctionLibraryRuntime::Options opts;
482   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
483   opts.remote_execution = true;
484   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
485   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
486   FunctionLibraryRuntime::Handle h;
487   TF_CHECK_OK(Instantiate("FindDevice",
488                           {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
489                           instantiate_opts, &h));
490   bool is_cross_process = false;
491   TF_CHECK_OK(proc_flr_->IsCrossProcess(h, &is_cross_process));
492   EXPECT_TRUE(is_cross_process);
493   EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
494                    "/job:b/replica:0/task:0/device:CPU:0", h));
495   TF_CHECK_OK(Instantiate("FindDevice",
496                           {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
497                           instantiate_opts, &h));
498   EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
499                    "/job:b/replica:0/task:0/device:CPU:0", h));
500   instantiate_opts.target = "/job:c/replica:0/task:0/device:CPU:0";
501   TF_CHECK_OK(Instantiate("FindDevice",
502                           {{"_target", "/job:c/replica:0/task:0/device:CPU:0"}},
503                           instantiate_opts, &h));
504   EXPECT_EQ(1, proc_flr_->GetHandleOnDevice(
505                    "/job:c/replica:0/task:0/device:CPU:0", h));
506 }
507 
TEST_F(ProcessFunctionLibraryRuntimeTest,ClusterFLRParallelTest)508 TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
509   Init({test::function::FindDevice()});
510   FunctionLibraryRuntime::Options opts;
511   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
512   opts.remote_execution = true;
513   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
514   instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
515 
516   thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
517   auto fn = [this, &instantiate_opts]() {
518     FunctionLibraryRuntime::Handle h;
519     TF_CHECK_OK(Instantiate(
520         "FindDevice", {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
521         instantiate_opts, &h));
522     EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
523                      "/job:b/replica:0/task:0/device:CPU:0", h));
524   };
525 
526   for (int i = 0; i < 100; ++i) {
527     tp->Schedule(fn);
528   }
529   delete tp;
530 }
531 
IsCUDATensor(const Tensor & t)532 bool IsCUDATensor(const Tensor& t) {
533 #if GOOGLE_CUDA
534   cudaPointerAttributes attributes;
535   cudaError_t err =
536       cudaPointerGetAttributes(&attributes, t.tensor_data().data());
537   if (err == cudaErrorInvalidValue) return false;
538   CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
539   return (attributes.type == cudaMemoryTypeDevice);
540 #elif TENSORFLOW_USE_ROCM
541   hipPointerAttribute_t attributes;
542   hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data());
543   if (err == hipErrorInvalidValue) return false;
544   CHECK_EQ(hipSuccess, err) << hipGetErrorString(err);
545   return (attributes.memoryType == hipMemoryTypeDevice);
546 #else
547   CHECK(false)
548       << "IsCUDATensor should not be called when CUDA is not available";
549 #endif  // GOOGLE_CUDA
550 }
551 
TestTwoDeviceMult(ProcessFunctionLibraryRuntimeTest * fixture,const FunctionLibraryRuntime::InstantiateOptions & inst_opts,const string & error="")552 void TestTwoDeviceMult(
553     ProcessFunctionLibraryRuntimeTest* fixture,
554     const FunctionLibraryRuntime::InstantiateOptions& inst_opts,
555     const string& error = "") {
556   fixture->Init({test::function::TwoDeviceMult()});
557   FunctionLibraryRuntime::Options opts;
558   auto x = test::AsTensor<float>({1, 2, 3});
559   Tensor y_cpu;
560   Tensor y_gpu;
561   Status status = fixture->Run("TwoDeviceMult", opts, {{"T", DT_FLOAT}},
562                                inst_opts, {x}, {&y_cpu, &y_gpu});
563   if (!error.empty()) {
564     EXPECT_TRUE(errors::IsInvalidArgument(status))
565         << "Actual status: " << status;
566     EXPECT_TRUE(absl::StrContains(status.error_message(), error))
567         << "Actual error message: " << status.error_message();
568     return;
569   }
570 
571   EXPECT_TRUE(status.ok()) << "Actual status: " << status;
572   EXPECT_FALSE(IsCUDATensor(y_cpu));
573   test::ExpectTensorEqual<float>(y_cpu, test::AsTensor<float>({2, 4, 6}));
574 
575   EXPECT_TRUE(IsCUDATensor(y_gpu));
576   Tensor y_gpu_on_cpu = fixture->GPUToCPU(y_gpu);
577   test::ExpectTensorEqual<float>(y_gpu_on_cpu,
578                                  test::AsTensor<float>({3, 6, 9}));
579 }
580 
TestTwoDeviceInputOutput(ProcessFunctionLibraryRuntimeTest * fixture,const FunctionLibraryRuntime::InstantiateOptions & inst_opts)581 void TestTwoDeviceInputOutput(
582     ProcessFunctionLibraryRuntimeTest* fixture,
583     const FunctionLibraryRuntime::InstantiateOptions& inst_opts) {
584   if (fixture->gpu_device_ == nullptr) {
585     GTEST_SKIP() << "No GPUs available";
586   }
587   fixture->Init({test::function::TwoDeviceInputOutput()});
588 
589   FunctionLibraryRuntime::Options opts;
590   Tensor x1 = test::AsTensor<float>({1, 2});
591   if (absl::StrContains(inst_opts.input_devices[0], "GPU")) {
592     x1 = fixture->CPUToGPU(x1);
593   }
594   Tensor x2 = test::AsTensor<float>({10, 20});
595   if (absl::StrContains(inst_opts.input_devices[1], "GPU")) {
596     x2 = fixture->CPUToGPU(x2);
597   }
598   Tensor y1;
599   Tensor y2;
600   TF_CHECK_OK(fixture->Run("TwoDeviceInputOutput", opts, {{"T", DT_FLOAT}},
601                            inst_opts, {x1, x2}, {&y1, &y2}));
602 
603   if (absl::StrContains(inst_opts.output_devices[0], "GPU")) {
604     EXPECT_TRUE(IsCUDATensor(y1));
605     y1 = fixture->GPUToCPU(y1);
606   } else {
607     EXPECT_FALSE(IsCUDATensor(y1));
608   }
609   test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({2, 4}));
610 
611   if (absl::StrContains(inst_opts.output_devices[1], "GPU")) {
612     EXPECT_TRUE(IsCUDATensor(y2));
613     y2 = fixture->GPUToCPU(y2);
614   } else {
615     EXPECT_FALSE(IsCUDATensor(y2));
616   }
617   test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({30, 60}));
618 }
619 
CompleteDevices(const std::vector<string> & v)620 std::vector<string> CompleteDevices(const std::vector<string>& v) {
621   std::vector<string> result;
622   result.reserve(v.size());
623   for (const string& s : v) {
624     result.push_back(strings::StrCat("/job:a/replica:0/task:0/device:", s));
625   }
626   return result;
627 }
628 
MakeOptions(const string & target,const std::vector<string> & input_devices,const std::vector<string> & output_devices)629 FunctionLibraryRuntime::InstantiateOptions MakeOptions(
630     const string& target, const std::vector<string>& input_devices,
631     const std::vector<string>& output_devices) {
632   FunctionLibraryRuntime::InstantiateOptions inst_opts;
633   inst_opts.target = target;
634   inst_opts.input_devices = CompleteDevices(input_devices);
635   inst_opts.output_devices = CompleteDevices(output_devices);
636   inst_opts.is_multi_device_function = true;
637   return inst_opts;
638 }
639 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ExplicitOutputDevice)640 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ExplicitOutputDevice) {
641   if (gpu_device_ == nullptr) {
642     GTEST_SKIP() << "No GPUs available";
643   }
644   TestTwoDeviceMult(this, MakeOptions("CPU:0", {"CPU:0"}, {"CPU:0", "GPU:0"}));
645 }
646 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_InferredOutputDevice)647 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_InferredOutputDevice) {
648   if (gpu_device_ == nullptr) {
649     GTEST_SKIP() << "No GPUs available";
650   }
651   TestTwoDeviceMult(this, MakeOptions("CPU:0", {"CPU:0"}, {}));
652 }
653 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenNoInputDevices)654 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenNoInputDevices) {
655   if (gpu_device_ == nullptr) {
656     GTEST_SKIP() << "No GPUs available";
657   }
658   TestTwoDeviceMult(this, MakeOptions("CPU:0", {}, {}),
659                     "input_devices must have the same length");
660 }
661 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenTooManyInputDevices)662 TEST_F(ProcessFunctionLibraryRuntimeTest,
663        MultiDevice_ErrorWhenTooManyInputDevices) {
664   if (gpu_device_ == nullptr) {
665     GTEST_SKIP() << "No GPUs available";
666   }
667   TestTwoDeviceMult(this, MakeOptions("CPU:0", {"CPU:0", "CPU:1"}, {}),
668                     "input_devices must have the same length");
669 }
670 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenTooManyOutputDevices)671 TEST_F(ProcessFunctionLibraryRuntimeTest,
672        MultiDevice_ErrorWhenTooManyOutputDevices) {
673   TestTwoDeviceMult(
674       this, MakeOptions("CPU:0", {"CPU:0"}, {"CPU:0", "GPU:0", "CPU:1"}),
675       "output_devices must either be empty or have the same length");
676 }
677 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenBadTargetDevice)678 TEST_F(ProcessFunctionLibraryRuntimeTest,
679        MultiDevice_ErrorWhenBadTargetDevice) {
680   TestTwoDeviceMult(
681       this, MakeOptions("GPU:11", {"CPU:0"}, {"CPU:0", "GPU:0"}),
682       "Cannot instantiate multi-device function with target device GPU:11");
683 }
684 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenListInput)685 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenListInput) {
686   const FunctionDef& def = test::function::FuncWithListInput();
687   Init({def});
688   FunctionLibraryRuntime::Handle handle;
689   Status status = proc_flr_->Instantiate(
690       "FuncWithListInput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}),
691       MakeOptions("CPU:0", {"CPU:0"}, {}), &handle);
692   ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
693   ASSERT_TRUE(absl::StrContains(
694       status.error_message(),
695       "FuncWithListInput has an input named \"x1\" that is a list of tensors"))
696       << "Actual error message: " << status.error_message();
697 }
698 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ErrorWhenListOutput)699 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenListOutput) {
700   const FunctionDef& def = test::function::FuncWithListOutput();
701   Init({def});
702   FunctionLibraryRuntime::Handle handle;
703   Status status = proc_flr_->Instantiate(
704       "FuncWithListOutput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}),
705       MakeOptions("CPU:0", {}, {"CPU:0"}), &handle);
706   ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
707   ASSERT_TRUE(absl::StrContains(
708       status.error_message(),
709       "FuncWithListOutput has an output named \"y\" that is a list of tensors"))
710       << "Actual error message: " << status.error_message();
711 }
712 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ExplicitMultiInputOutput)713 TEST_F(ProcessFunctionLibraryRuntimeTest,
714        MultiDevice_ExplicitMultiInputOutput) {
715   TestTwoDeviceInputOutput(
716       this, MakeOptions("CPU:0", {"CPU:0", "GPU:0"}, {"CPU:0", "GPU:0"}));
717 }
718 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_FlipInputs)719 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_FlipInputs) {
720   TestTwoDeviceInputOutput(
721       this, MakeOptions("CPU:0", {"GPU:0", "CPU:0"}, {"CPU:0", "GPU:0"}));
722 }
723 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_FlipOutputs)724 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_FlipOutputs) {
725   TestTwoDeviceInputOutput(
726       this, MakeOptions("CPU:0", {"CPU:0", "GPU:0"}, {"GPU:0", "CPU:0"}));
727 }
728 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_FlipBoth)729 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_FlipBoth) {
730   TestTwoDeviceInputOutput(
731       this, MakeOptions("CPU:0", {"GPU:0", "CPU:0"}, {"GPU:0", "CPU:0"}));
732 }
733 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_EmptyBodySwap)734 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_EmptyBodySwap) {
735   if (gpu_device_ == nullptr) {
736     GTEST_SKIP() << "No GPUs available";
737   }
738   FunctionLibraryRuntime::InstantiateOptions inst_opts =
739       MakeOptions("CPU:0", {"GPU:0", "CPU:0"}, {"CPU:0", "GPU:0"});
740   Init({test::function::EmptyBodySwap()});
741 
742   Tensor x1 = CPUToGPU(test::AsTensor<float>({1, 2}));
743   Tensor x2 = test::AsTensor<float>({10, 20});
744   Tensor y1;
745   Tensor y2;
746   TF_CHECK_OK(Run("EmptyBodySwap", {}, {{"T", DT_FLOAT}}, inst_opts, {x1, x2},
747                   {&y1, &y2}));
748 
749   EXPECT_FALSE(IsCUDATensor(y1));
750   test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({10, 20}));
751 
752   EXPECT_TRUE(IsCUDATensor(y2));
753   y2 = GPUToCPU(y2);
754   test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({1, 2}));
755 }
756 
GetResourceHandle(const string & var_name,const string & container,const string & device_name)757 Tensor GetResourceHandle(const string& var_name, const string& container,
758                          const string& device_name) {
759   ResourceHandle handle;
760   handle.set_device(device_name);
761   handle.set_container(container);
762   handle.set_name(var_name);
763   handle.set_hash_code(TypeIndex::Make<Var>().hash_code());
764   handle.set_maybe_type_name(TypeIndex::Make<Var>().name());
765   Tensor tensor(DT_RESOURCE, TensorShape({}));
766   tensor.scalar<ResourceHandle>()() = handle;
767   return tensor;
768 }
769 
770 // Returns a function which adds two variables on different devices.
AddVarAcrossDevices()771 FunctionDef AddVarAcrossDevices() {
772   return FunctionDefHelper::Create(
773       // Name
774       "AddVarAcrossDevices",
775       // Args
776       {"x: resource"},
777       // Return values
778       {"y: float"},
779       // Attr def
780       {},
781       // Nodes
782       {
783           {{"read0"},
784            "ReadVariableOp",
785            {"x"},
786            {{"dtype", DT_FLOAT}},
787            {},
788            "/device:CPU:0"},
789           {{"read1"},
790            "ReadVariableOp",
791            {"x"},
792            {{"dtype", DT_FLOAT}},
793            {},
794            "/device:CPU:1"},
795           {{"add"},
796            "Add",
797            {"read0:value:0", "read1:value:0"},
798            {{"T", DT_FLOAT}},
799            {},
800            "/device:CPU:0"},
801       },
802       {{"y", "add:z:0"}});
803 }
804 
805 // An implementation of FunctionArgsInterface for packed inputs.
806 class TestFunctionPackedArgs : public FunctionArgsInterface {
807  public:
TestFunctionPackedArgs(const int index,gtl::InlinedVector<TensorValue,4> && tensor_args)808   TestFunctionPackedArgs(const int index,
809                          gtl::InlinedVector<TensorValue, 4>&& tensor_args) {
810     packed_args_.emplace(index, std::move(tensor_args));
811   }
812 
~TestFunctionPackedArgs()813   ~TestFunctionPackedArgs() override{};
814 
HasRemoteOrPackedInputs() const815   bool HasRemoteOrPackedInputs() const override { return true; };
816 
GetLocalArg(const FunctionArgIndex & index,Tensor * val) const817   Status GetLocalArg(const FunctionArgIndex& index,
818                      Tensor* val) const override {
819     *val = *packed_args_.at(index.index).at(index.sub_index).tensor;
820     return Status::OK();
821   };
822 
GetLocalTensors() const823   std::vector<Tensor> GetLocalTensors() const override { return {}; }
824 
825  private:
826   absl::flat_hash_map<int, gtl::InlinedVector<TensorValue, 4>> packed_args_;
827 };
828 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_CompositeDevice)829 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) {
830   Init({AddVarAcrossDevices()});
831   // Create two variables on two devices.
832   const Tensor initial_resource_value0 = test::AsTensor<float>({10, 20});
833   Var* resource0 = new Var(DT_FLOAT);
834   *resource0->tensor() = initial_resource_value0;
835   resource0->is_initialized = true;
836   const Tensor initial_resource_value1 = test::AsTensor<float>({30, 40});
837   Var* resource1 = new Var(DT_FLOAT);
838   *resource1->tensor() = initial_resource_value1;
839   resource1->is_initialized = true;
840   ResourceMgr* mgr0 = device0_->resource_manager();
841   ResourceMgr* mgr1 = device1_->resource_manager();
842   TF_ASSERT_OK(mgr0->Create(mgr0->default_container(), "var", resource0));
843   TF_ASSERT_OK(mgr1->Create(mgr1->default_container(), "var", resource1));
844 
845   Tensor resource_handle0 =
846       GetResourceHandle("var", mgr0->default_container(), device0_->name());
847   Tensor resource_handle1 =
848       GetResourceHandle("var", mgr1->default_container(), device1_->name());
849 
850   // Create a CompositeDevice
851   Status s;
852   std::unique_ptr<CompositeDevice> composite_device =
853       CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
854                                   /*unique_device_id=*/0,
855                                   device_mgr_->HostCPU()->parsed_name(), &s);
856   TF_ASSERT_OK(s);
857   AddCompositeDevice(composite_device.get());
858 
859   FunctionLibraryRuntime::Options opts;
860   FunctionLibraryRuntime::InstantiateOptions inst_opts =
861       MakeOptions("CPU:0", {"COMPOSITE:0"}, {"CPU:0"});
862   inst_opts.composite_devices[composite_device->name()] =
863       composite_device->underlying_devices();
864   inst_opts.input_resource_dtypes_and_shapes[0] = {
865       initial_resource_value0.dtype(), initial_resource_value0.shape()};
866 
867   // Packed TensorHandle
868   {
869     gtl::InlinedVector<TensorValue, 4> handles;
870     handles.push_back(TensorValue(&resource_handle0));
871     handles.push_back(TensorValue(&resource_handle1));
872     TestFunctionPackedArgs args(0, std::move(handles));
873     FunctionRet ret;
874     TF_CHECK_OK(RunWithPackedArgs("AddVarAcrossDevices", opts,
875                                   {{"T", DT_FLOAT}}, inst_opts, args, {&ret}));
876     EXPECT_EQ(ret.index(), 0);
877     test::ExpectTensorEqual<float>(absl::get<Tensor>(ret),
878                                    test::AsTensor<float>({40, 60}));
879   }
880 
881   // Packed Tensor
882   {
883     Tensor arg(DT_RESOURCE, TensorShape({2}));
884     arg.flat<ResourceHandle>()(0) = resource_handle0.scalar<ResourceHandle>()();
885     arg.flat<ResourceHandle>()(1) = resource_handle1.scalar<ResourceHandle>()();
886 
887     Tensor ret;
888     TF_CHECK_OK(Run("AddVarAcrossDevices", opts, {{"T", DT_FLOAT}}, inst_opts,
889                     {arg}, {&ret}));
890     test::ExpectTensorEqual<float>(ret, test::AsTensor<float>({40, 60}));
891   }
892 }
893 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_ResourceOutput_GPU)894 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) {
895   if (gpu_device_ == nullptr) {
896     GTEST_SKIP() << "No GPUs available";
897   }
898   FunctionLibraryRuntime::InstantiateOptions inst_opts =
899       MakeOptions("CPU:0", {"GPU:0", "GPU:0"}, {"GPU:0", "GPU:0"});
900   Init({test::function::ResourceOutput(),
901         test::function::ReadResourceVariable()});
902 
903   // Make resource var
904   Tensor resource_value = CPUToGPU(test::AsTensor<float>({10, 20}));
905   Var* resource = new Var(DT_FLOAT);
906   *resource->tensor() = resource_value;
907   resource->is_initialized = true;
908   ResourceMgr* mgr = gpu_device_->resource_manager();
909   Status status = mgr->Create(mgr->default_container(), "my_gpu_var", resource);
910   ASSERT_TRUE(status.ok()) << status.error_message();
911 
912   // Run the function taking a resource and outputing it
913   FunctionLibraryRuntime::Options opts;
914   Tensor x1 = CPUToGPU(test::AsTensor<float>({1, 2}));
915   Tensor x2 = GetResourceHandle("my_gpu_var", mgr->default_container(),
916                                 "/job:a/replica:0/task:0/device:GPU:0");
917   Tensor returned_handle;
918   Tensor y2;
919   TF_CHECK_OK(Run("ResourceOutput", opts, {{"T", DT_FLOAT}}, inst_opts,
920                   {x1, x2}, {&returned_handle, &y2}));
921 
922   EXPECT_FALSE(IsCUDATensor(returned_handle));
923   EXPECT_TRUE(IsCUDATensor(y2));
924   y2 = GPUToCPU(y2);
925   test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({2, 4}));
926 
927   // Read the variable using the handle returned from previous function to
928   // make sure the handle and read value is on the right device.
929   inst_opts = MakeOptions("GPU:0", {"GPU:0"}, {"GPU:0"});
930   Tensor read_resource;
931   TF_CHECK_OK(Run("ReadResourceVariable", opts, {{"T", DT_FLOAT}}, inst_opts,
932                   {returned_handle}, {&read_resource}));
933   EXPECT_TRUE(IsCUDATensor(read_resource));
934   read_resource = GPUToCPU(read_resource);
935   test::ExpectTensorEqual<float>(read_resource,
936                                  test::AsTensor<float>({10, 20}));
937 }
938 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_PlacerError)939 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_PlacerError) {
940   if (gpu_device_ == nullptr) {
941     GTEST_SKIP() << "No GPUs available";
942   }
943   // ResourceOutput forwards second input to first output. Both are resources.
944   // Placer should not be able to place this graph because we ask it to place
945   // second input on GPU but first output to CPU.
946   FunctionLibraryRuntime::InstantiateOptions inst_opts =
947       MakeOptions("CPU:0", {"GPU:0", "GPU:0"}, {"CPU:0", "GPU:0"});
948   Init({test::function::ResourceOutput(),
949         test::function::ReadResourceVariable()});
950 
951   FunctionLibraryRuntime::Handle handle;
952   Status status = proc_flr_->Instantiate(
953       "ResourceOutput", test::function::Attrs({{"T", DT_FLOAT}}), inst_opts,
954       &handle);
955   ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status;
956   ASSERT_TRUE(absl::StrContains(status.error_message(), "Cannot place"));
957 }
958 
959 REGISTER_OP("BrokenOp")
960     .Input("in: T")
961     .Output("out: T")
962     .Attr("T: type")
963     .SetShapeFn(shape_inference::UnknownShape);
964 class BrokenOp : public OpKernel {
965  public:
BrokenOp(OpKernelConstruction * ctx)966   explicit BrokenOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
967     ctx->SetStatus(errors::Internal("I am broken"));
968   }
969 
Compute(OpKernelContext * ctx)970   void Compute(OpKernelContext* ctx) override {
971     ctx->SetStatus(errors::Internal("I am broken"));
972   }
973 };
974 REGISTER_KERNEL_BUILDER(Name("BrokenOp").Device(DEVICE_CPU), BrokenOp);
975 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_CreateKernelsEagerly)976 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CreateKernelsEagerly) {
977   auto T = DT_INT32;
978   // The expected sequence of outputs from this function is [6, 4, 0, 1, ...].
979   FunctionDef broken_func = FunctionDefHelper::Define(
980       // Name
981       "Broken",
982       // Args
983       {"x: int32"},
984       // Return values
985       {"y: int32"},
986       // Attrs
987       {},
988       // Nodes
989       {{{"y"}, "BrokenOp", {"x"}, {{"T", T}}}});
990   Init({broken_func});
991 
992   FunctionLibraryRuntime::InstantiateOptions inst_opts =
993       MakeOptions("CPU:0", {"CPU:0"}, {"CPU:0"});
994 
995   // Instantiating the broken function should work.
996   FunctionLibraryRuntime::Handle handle;
997   TF_CHECK_OK(Instantiate("Broken", {{"T", DT_INT32}}, inst_opts, &handle));
998   TF_CHECK_OK(proc_flr_->ReleaseHandle(handle));
999 
1000   // Instantiating the broken function while creating kernels eagerly should
1001   // fail.
1002   inst_opts.create_kernels_eagerly = true;
1003   Status status = Instantiate("Broken", {{"T", DT_INT32}}, inst_opts, &handle);
1004   EXPECT_TRUE(errors::IsInternal(status));
1005 }
1006 
TEST_F(ProcessFunctionLibraryRuntimeTest,MultiDevice_StateHandle)1007 TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_StateHandle) {
1008   auto T = DT_INT32;
1009   // The expected sequence of outputs from this function is [6, 4, 0, 1, ...].
1010   FunctionDef stateful_func = FunctionDefHelper::Define(
1011       // Name
1012       "RandomUniformWrapper",
1013       // Args
1014       {"x: resource"},
1015       // Return values
1016       {"y: int32"},
1017       // Attrs
1018       {},
1019       // Nodes
1020       {FunctionDefHelper::Const<int32>("shape", gtl::ArraySlice<int32>({1})),
1021        FunctionDefHelper::Const<int32>("minval", 0),
1022        {{"maxval"}, "ReadVariableOp", {"x"}, {{"dtype", T}}, {}},
1023        // A stateful node.
1024        {{"y"},
1025         "RandomUniformInt",
1026         {"shape", "minval", "maxval"},
1027         {{"seed", 37}, {"seed2", 48}, {"Tout", T}, {"T", T}}}});
1028   Init({stateful_func});
1029   if (gpu_device_ == nullptr) {
1030     GTEST_SKIP() << "No GPUs available";
1031   }
1032 
1033   // Make resource variables.
1034   ResourceMgr* mgr = gpu_device_->resource_manager();
1035   Tensor resource_value = CPUToGPU(test::AsScalar<int>(10));
1036   Var* resource = new Var(T);
1037   *resource->tensor() = resource_value;
1038   resource->is_initialized = true;
1039   Status status = mgr->Create(mgr->default_container(), "my_gpu_var", resource);
1040   ASSERT_TRUE(status.ok()) << status.error_message();
1041 
1042   Tensor x = GetResourceHandle("my_gpu_var", mgr->default_container(),
1043                                "/job:a/replica:0/task:0/device:GPU:0");
1044   Tensor y;
1045 
1046   FunctionLibraryRuntime::InstantiateOptions inst_opts =
1047       MakeOptions("CPU:0", {"GPU:0"}, {"CPU:0"});
1048 
1049   // Instantiate the function with no state handle.
1050   FunctionLibraryRuntime::Handle handle;
1051   TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}}, inst_opts,
1052                           &handle));
1053   for (auto expected : {6, 4}) {
1054     TF_CHECK_OK(RunInstantiated(handle, {}, {x}, {&y}));
1055     test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1056   }
1057 
1058   // Instantiating the function again with no state handle should result in the
1059   // same handle.
1060   FunctionLibraryRuntime::Handle other_handle;
1061   TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}}, inst_opts,
1062                           &other_handle));
1063   EXPECT_EQ(handle, other_handle);
1064   // Running the function should yield continuation of the same sequence.
1065   for (auto expected : {0, 1}) {
1066     TF_CHECK_OK(RunInstantiated(other_handle, {}, {x}, {&y}));
1067     test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1068   }
1069 
1070   // Instantiating the function with a state handle should result in a different
1071   // handle.
1072   inst_opts.state_handle = "handle_1";
1073   TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}}, inst_opts,
1074                           &other_handle));
1075   EXPECT_NE(handle, other_handle);
1076   // Running the function should yield the original sequeunce.
1077   for (auto expected : {6, 4, 0, 1}) {
1078     TF_CHECK_OK(RunInstantiated(other_handle, {}, {x}, {&y}));
1079     test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1080   }
1081 
1082   // Instantiating the function with a different state handle should result in a
1083   // different handle.
1084   inst_opts.state_handle = "handle_2";
1085   TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}}, inst_opts,
1086                           &other_handle));
1087   EXPECT_NE(handle, other_handle);
1088   // Running the function should yield the original sequeunce.
1089   for (auto expected : {6, 4, 0, 1}) {
1090     TF_CHECK_OK(RunInstantiated(other_handle, {}, {x}, {&y}));
1091     test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1092   }
1093 
1094   // Repeatedly instantiating a function and releasing its handle will result in
1095   // repeating the original sequence.
1096   inst_opts.state_handle = "handle_3";
1097   for (int i = 0; i < 2; ++i) {
1098     TF_CHECK_OK(Instantiate("RandomUniformWrapper", {{"T", DT_INT32}},
1099                             inst_opts, &other_handle));
1100     EXPECT_NE(handle, other_handle);
1101     // Running the function should yield the original sequeunce.
1102     for (auto expected : {6, 4, 0, 1}) {
1103       TF_CHECK_OK(RunInstantiated(other_handle, {}, {x}, {&y}));
1104       test::ExpectTensorEqual<int>(y, test::AsTensor<int>({expected}));
1105     }
1106     TF_CHECK_OK(proc_flr_->ReleaseHandle(other_handle));
1107   }
1108 }
1109 
1110 REGISTER_OP("SessionMetadataReader")
1111     .Input("x: int64")
1112     .Output("y: string")
1113     .SetIsStateful()
1114     .Doc(R"doc(SessionMetadataReader returns the session metadata.
1115 
1116 x: int64
1117 y: string
1118 )doc");
1119 
1120 class SessionMetadataReaderOp : public OpKernel {
1121  public:
SessionMetadataReaderOp(OpKernelConstruction * ctx)1122   explicit SessionMetadataReaderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Compute(OpKernelContext * ctx)1123   void Compute(OpKernelContext* ctx) override {
1124     Tensor* out_tensor = nullptr;
1125     OP_REQUIRES_OK(ctx,
1126                    ctx->allocate_output("y", TensorShape({}), &out_tensor));
1127     if (ctx->session_metadata() != nullptr) {
1128       out_tensor->scalar<tstring>()() = ctx->session_metadata()->DebugString();
1129     } else {
1130       out_tensor->scalar<tstring>()() = "";
1131     }
1132   }
1133 };
1134 REGISTER_KERNEL_BUILDER(Name("SessionMetadataReader").Device(DEVICE_CPU),
1135                         SessionMetadataReaderOp);
1136 
SessionMetadataReaderOpFn()1137 FunctionDef SessionMetadataReaderOpFn() {
1138   return FunctionDefHelper::Define(
1139       // Name
1140       "SessionMetadataReaderFn",
1141       // Args
1142       {"x: int64"},
1143       // Return values
1144       {"y: string"},
1145       // Attr def
1146       {},
1147       // Nodes
1148       {{{"y"}, "SessionMetadataReader", {"x"}, {}}});
1149 }
1150 
TEST_F(ProcessFunctionLibraryRuntimeTest,SessionMetadataAbsent)1151 TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataAbsent) {
1152   Init({SessionMetadataReaderOpFn()}, /*session_metadata=*/nullptr);
1153   FunctionLibraryRuntime::Options opts;
1154   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
1155   opts.remote_execution = true;
1156   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
1157   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
1158   const auto x = test::AsTensor<int64>({17});
1159   Tensor y;
1160   TF_CHECK_OK(
1161       Run("SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y}));
1162   EXPECT_EQ("", y.scalar<tstring>()());
1163 }
1164 
TEST_F(ProcessFunctionLibraryRuntimeTest,SessionMetadataPresent)1165 TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) {
1166   const SessionMetadata session_metadata = GenerateSessionMetadata();
1167   Init({SessionMetadataReaderOpFn()}, &session_metadata);
1168   FunctionLibraryRuntime::Options opts;
1169   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
1170   opts.remote_execution = true;
1171   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
1172   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
1173   const auto x = test::AsTensor<int64>({17});
1174   Tensor y;
1175   TF_CHECK_OK(
1176       Run("SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y}));
1177   SessionMetadata read_metadata;
1178   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
1179                                                     &read_metadata));
1180   EXPECT_EQ(session_metadata.name(), read_metadata.name());
1181   EXPECT_EQ(session_metadata.version(), read_metadata.version());
1182 }
1183 
TEST_F(ProcessFunctionLibraryRuntimeTest,CompositeDevicesAfterCloning)1184 TEST_F(ProcessFunctionLibraryRuntimeTest, CompositeDevicesAfterCloning) {
1185   Init({AddVarAcrossDevices()});
1186 
1187   Status s;
1188   std::unique_ptr<CompositeDevice> composite_device =
1189       CompositeDevice::MakeDevice({device0_->name(), device1_->name()},
1190                                   /*unique_device_id=*/0,
1191                                   device_mgr_->HostCPU()->parsed_name(), &s);
1192   TF_ASSERT_OK(s);
1193   AddCompositeDevice(composite_device.get());
1194 
1195   auto* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
1196   ASSERT_NE(nullptr, flr);
1197   std::unique_ptr<FunctionLibraryDefinition> cloned_lib_def;
1198   std::unique_ptr<ProcessFunctionLibraryRuntime> cloned_proc_flr;
1199   FunctionLibraryRuntime* cloned_flr;
1200   TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
1201   EXPECT_EQ(
1202       cloned_proc_flr->device_set()->FindDeviceByName(composite_device->name()),
1203       composite_device.get());
1204 }
1205 
TEST_F(ProcessFunctionLibraryRuntimeTest,SessionMetadataPresentAfterCloning)1206 TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) {
1207   const SessionMetadata session_metadata = GenerateSessionMetadata();
1208   Init({SessionMetadataReaderOpFn()}, &session_metadata);
1209   auto* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
1210   ASSERT_NE(nullptr, flr);
1211   std::unique_ptr<FunctionLibraryDefinition> cloned_lib_def;
1212   std::unique_ptr<ProcessFunctionLibraryRuntime> cloned_proc_flr;
1213   FunctionLibraryRuntime* cloned_flr;
1214   TF_ASSERT_OK(flr->Clone(&cloned_lib_def, &cloned_proc_flr, &cloned_flr));
1215   FunctionLibraryRuntime::Options opts;
1216   opts.source_device = "/job:a/replica:0/task:0/cpu:0";
1217   opts.remote_execution = true;
1218   FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
1219   instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0";
1220   const auto x = test::AsTensor<int64>({17});
1221   Tensor y;
1222   Status s = RunWithRuntime<std::vector<Tensor>, Tensor>(
1223       "SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y},
1224       cloned_proc_flr.get());
1225   TF_CHECK_OK(s);
1226   SessionMetadata read_metadata;
1227   ASSERT_TRUE(protobuf::TextFormat::ParseFromString(y.scalar<tstring>()(),
1228                                                     &read_metadata));
1229   EXPECT_EQ(session_metadata.name(), read_metadata.name());
1230   EXPECT_EQ(session_metadata.version(), read_metadata.version());
1231 }
1232 
1233 }  // anonymous namespace
1234 }  // namespace tensorflow
1235