• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api_remote_test_util.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/c/eager/c_api_internal.h"
20 #include "tensorflow/c/eager/c_api_test_util.h"
21 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
26 
27 using ::tensorflow::string;
28 
MatMulFunction(const string & matmul_device)29 string MatMulFunction(const string& matmul_device) {
30   tensorflow::FunctionDef def;
31   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
32       absl::StrCat("    signature {"
33                    "      name: 'MatMulFunction'"
34                    "      input_arg {"
35                    "        name: 'a'"
36                    "        type: DT_FLOAT"
37                    "      }"
38                    "      input_arg {"
39                    "        name: 'b'"
40                    "        type: DT_FLOAT"
41                    "      }"
42                    "      output_arg {"
43                    "        name: 'm'"
44                    "        type: DT_FLOAT"
45                    "      }"
46                    "    }"
47                    "    node_def {"
48                    "      name: 'matmul'"
49                    "      op: 'MatMul'"
50                    "      input: 'a'"
51                    "      input: 'b'"
52                    "      device: '",
53                    matmul_device, "'",
54                    "      attr {"
55                    "        key: 'T'"
56                    "        value {"
57                    "          type: DT_FLOAT"
58                    "        }"
59                    "      }"
60                    "    }"
61                    "    ret {"
62                    "      key: 'm'"
63                    "      value: 'matmul:product'"
64                    "    }"),
65       &def));
66   return def.SerializeAsString();
67 }
68 
TestRemoteExecuteSilentCopies(bool async,bool remote,bool func,bool heavy_load_on_streaming_rpc,bool remote_func_outputs,bool has_packed_input)69 void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
70                                    bool heavy_load_on_streaming_rpc,
71                                    bool remote_func_outputs,
72                                    bool has_packed_input) {
73   CHECK(!has_packed_input || func);
74   tensorflow::ServerDef server_def = GetServerDef(3);
75 
76   // This server def has the task index set to 0.
77   string serialized = server_def.SerializeAsString();
78 
79   server_def.set_task_index(1);
80   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
81   ASSERT_TRUE(tensorflow::GrpcServer::Create(
82                   server_def, tensorflow::Env::Default(), &worker_server1)
83                   .ok());
84   ASSERT_TRUE(worker_server1->Start().ok());
85 
86   server_def.set_task_index(2);
87   std::unique_ptr<tensorflow::GrpcServer> worker_server2;
88   ASSERT_TRUE(tensorflow::GrpcServer::Create(
89                   server_def, tensorflow::Env::Default(), &worker_server2)
90                   .ok());
91   ASSERT_TRUE(worker_server2->Start().ok());
92 
93   TF_Status* status = TF_NewStatus();
94   TFE_ContextOptions* opts = TFE_NewContextOptions();
95   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
96   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
97   TFE_Context* ctx = TFE_NewContext(opts, status);
98   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
99   TFE_DeleteContextOptions(opts);
100 
101   TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
102   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
103 
104   TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
105   TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
106   std::vector<TFE_TensorHandle*> handles_task0;
107   if (heavy_load_on_streaming_rpc) {
108     // Send 50 tensor copy requests to simulate that there have been some RPC
109     // requests been enqueued.
110     for (int i = 0; i < 50; ++i) {
111       handles_task0.push_back(TestMatrixTensorHandle(ctx));
112     }
113   }
114   const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
115   const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
116 
117   std::vector<TFE_TensorHandle*> handles_task2;
118   for (auto* h_task0 : handles_task0) {
119     handles_task2.push_back(
120         TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
121     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
122   }
123 
124   auto* h1_task2 =
125       TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
126   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
127 
128   TFE_TensorHandle* packed_handle = nullptr;
129   if (has_packed_input) {
130     int num_replicas = 1;
131     std::vector<TFE_TensorHandle*> packed_handles = {h1_task2};
132     packed_handle = TFE_CreatePackedTensorHandle(ctx, packed_handles.data(),
133                                                  &num_replicas, status);
134     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
135   }
136 
137   TFE_Op* matmul = nullptr;
138   if (func) {
139     const string matmul_device = remote_func_outputs ? task2_name : "";
140     string function_def = MatMulFunction(matmul_device);
141     TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
142                               status);
143     CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
144 
145     matmul = TFE_NewOp(ctx, "MatMulFunction", status);
146     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
147     TFE_OpAddInput(matmul, h0_task0, status);
148     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
149     TFE_OpAddInput(matmul, has_packed_input ? packed_handle : h1_task2, status);
150     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
151   } else {
152     // Handles are on task0 (local), and task2, but op is on task1.
153     matmul = MatMulOp(ctx, h0_task0, h1_task2);
154   }
155   if (remote) {
156     TFE_OpSetDevice(matmul, task1_name, status);
157     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
158   } else if (!async) {
159     // Set the local device to CPU to easily validate mirroring
160     string cpu_device_name;
161     ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
162     TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
163     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
164     auto remote_arg =
165         tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
166     // The input handles should never change since they have been mirrored.
167     ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
168   }
169 
170   TFE_TensorHandle* retvals[1];
171   int num_retvals = 1;
172   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
173   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
174 
175   // TODO(gjn): Add support for waiting on async local mirrors
176   if (!remote && !async && !remote_func_outputs) {
177     auto remote_arg =
178         tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
179     // The input handles should never change since they have been mirrored.
180     ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
181   }
182 
183   if (remote_func_outputs) {
184     const string backing_device =
185         TFE_TensorHandleBackingDeviceName(retvals[0], status);
186     ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
187     EXPECT_EQ(backing_device, task2_name);
188   }
189 
190   auto* retval_task0 = TFE_TensorHandleCopyToDevice(
191       retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
192   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
193 
194   TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
195   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
196   TFE_DeleteTensorHandle(retval_task0);
197   float product[4] = {0};
198   EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
199   memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
200   TF_DeleteTensor(t);
201   EXPECT_EQ(7, product[0]);
202   EXPECT_EQ(10, product[1]);
203   EXPECT_EQ(15, product[2]);
204   EXPECT_EQ(22, product[3]);
205 
206   TFE_DeleteTensorHandle(h0_task0);
207   TFE_DeleteTensorHandle(h1_task0);
208   if (packed_handle) {
209     TFE_DeleteTensorHandle(packed_handle);
210   }
211   TFE_DeleteTensorHandle(h1_task2);
212   TFE_DeleteTensorHandle(retvals[0]);
213   for (auto* h : handles_task0) {
214     TFE_DeleteTensorHandle(h);
215   }
216   for (auto* h : handles_task2) {
217     TFE_DeleteTensorHandle(h);
218   }
219 
220   TFE_DeleteOp(matmul);
221 
222   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
223   TFE_ExecutorWaitForAllPendingNodes(executor, status);
224   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
225   TFE_DeleteExecutor(executor);
226   if (func) {
227     TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
228   }
229   TFE_DeleteContext(ctx);
230 
231   TF_DeleteStatus(status);
232 
233   // TODO(b/136478427): Figure out how to correctly shut the server down.
234   worker_server1.release();
235   worker_server2.release();
236 }
237