1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0(the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/grappler/grappler.h"
16
17 #include "tensorflow/c/c_api_internal.h"
18 #include "tensorflow/c/experimental/grappler/grappler_internal.h"
19 #include "tensorflow/c/tf_buffer_internal.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/grappler/clusters/single_machine.h"
22 #include "tensorflow/core/grappler/costs/graph_properties.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace tensorflow {
29 namespace grappler {
30 namespace {
31
optimize_func(void * optimizer,const TF_Buffer * graph_buf,const TF_GrapplerItem * item,TF_Buffer * optimized_graph_buf,TF_Status * tf_status)32 void optimize_func(void* optimizer, const TF_Buffer* graph_buf,
33 const TF_GrapplerItem* item, TF_Buffer* optimized_graph_buf,
34 TF_Status* tf_status) {}
35
PopulateDefaultParam(TP_OptimizerRegistrationParams * params)36 void PopulateDefaultParam(TP_OptimizerRegistrationParams* params) {
37 params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE;
38 params->optimizer_configs->struct_size = TP_OPTIMIZER_CONFIGS_STRUCT_SIZE;
39 params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE;
40 params->optimizer->create_func = nullptr;
41 params->optimizer->optimize_func = optimize_func;
42 params->optimizer->destroy_func = nullptr;
43 }
44
TEST(Grappler,SuccessfulRegistration)45 TEST(Grappler, SuccessfulRegistration) {
46 auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
47 TF_Status* const status) -> void {
48 TF_SetStatus(status, TF_OK, "");
49 PopulateDefaultParam(params);
50 params->device_type = "Success";
51 params->optimizer_configs->remapping = TF_TriState_Off;
52 };
53
54 TF_ASSERT_OK(InitGraphPlugin(plugin_init));
55 ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers(
56 std::set<string>{"Success"})
57 .size(),
58 1);
59 ConfigList config = PluginGraphOptimizerRegistry::GetPluginConfigs(
60 true, std::set<string>{"Success"});
61 ASSERT_EQ(config.toggle_config["remapping"], RewriterConfig::OFF);
62 }
63
TEST(Grappler,MultiplePluginRegistration)64 TEST(Grappler, MultiplePluginRegistration) {
65 auto plugin_init_0 = [](TP_OptimizerRegistrationParams* const params,
66 TF_Status* const status) -> void {
67 TF_SetStatus(status, TF_OK, "");
68 PopulateDefaultParam(params);
69 params->device_type = "Device0";
70 };
71 auto plugin_init_1 = [](TP_OptimizerRegistrationParams* const params,
72 TF_Status* const status) -> void {
73 TF_SetStatus(status, TF_OK, "");
74 PopulateDefaultParam(params);
75 params->device_type = "Device1";
76 };
77
78 TF_ASSERT_OK(InitGraphPlugin(plugin_init_0));
79 TF_ASSERT_OK(InitGraphPlugin(plugin_init_1));
80 ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers(
81 std::set<string>{"Device0", "Device1"})
82 .size(),
83 2);
84 }
85
TEST(Grappler,DeviceTypeNotSet)86 TEST(Grappler, DeviceTypeNotSet) {
87 auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
88 TF_Status* const status) -> void {
89 TF_SetStatus(status, TF_OK, "");
90 PopulateDefaultParam(params);
91 params->device_type = nullptr;
92 };
93
94 tensorflow::Status status = InitGraphPlugin(plugin_init);
95 ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
96 ASSERT_EQ(
97 status.error_message(),
98 "'device_type' field in TP_OptimizerRegistrationParams must be set.");
99 }
100
TEST(Grappler,OptimizeFuncNotSet)101 TEST(Grappler, OptimizeFuncNotSet) {
102 auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
103 TF_Status* const status) -> void {
104 TF_SetStatus(status, TF_OK, "");
105 PopulateDefaultParam(params);
106 params->device_type = "FuncNotSet";
107 params->optimizer->optimize_func = nullptr;
108 };
109
110 tensorflow::Status status = InitGraphPlugin(plugin_init);
111 ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
112 ASSERT_EQ(status.error_message(),
113 "'optimize_func' field in TP_Optimizer must be set.");
114 }
115
TEST(TF_GrapplerItem,NodesToPreserve)116 TEST(TF_GrapplerItem, NodesToPreserve) {
117 GrapplerItem item;
118 item.fetch = std::vector<string>{"Conv", "BiasAdd"};
119 std::unordered_set<string> nodes_preserved = item.NodesToPreserve();
120 TF_GrapplerItem* c_item = reinterpret_cast<TF_GrapplerItem*>(&item);
121
122 int list_total_size = 0;
123 for (const string& s : nodes_preserved) {
124 list_total_size += s.size();
125 }
126
127 size_t storage_size = 0;
128 int num_values = 0;
129 TF_Status* status = TF_NewStatus();
130 TF_GetNodesToPreserveListSize(c_item, &num_values, &storage_size, status);
131 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
132 EXPECT_EQ(nodes_preserved.size(), num_values);
133 EXPECT_EQ(list_total_size, storage_size);
134
135 std::unique_ptr<char*[]> values(new char*[nodes_preserved.size()]);
136 std::unique_ptr<size_t[]> lens(new size_t[nodes_preserved.size()]);
137 std::unique_ptr<char[]> storage(new char[storage_size]);
138 TF_GetNodesToPreserveList(c_item, values.get(), lens.get(),
139 nodes_preserved.size(), storage.get(), storage_size,
140 status);
141 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
142
143 for (size_t i = 0; i < nodes_preserved.size(); ++i) {
144 EXPECT_EQ(nodes_preserved.find(string(static_cast<const char*>(values[i]),
145 lens[i])) != nodes_preserved.end(),
146 true);
147 }
148 TF_DeleteStatus(status);
149 }
150
TEST(TF_GrapplerItem,FetchNodes)151 TEST(TF_GrapplerItem, FetchNodes) {
152 GrapplerItem item;
153 item.fetch = std::vector<string>{"Conv", "BiasAdd"};
154 TF_GrapplerItem* c_item = reinterpret_cast<TF_GrapplerItem*>(&item);
155
156 int list_total_size = 0;
157 for (const string& s : item.fetch) {
158 list_total_size += s.size();
159 }
160
161 size_t storage_size = 0;
162 int num_values = 0;
163 TF_Status* status = TF_NewStatus();
164 TF_GetFetchNodesListSize(c_item, &num_values, &storage_size, status);
165 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
166 EXPECT_EQ(item.fetch.size(), num_values);
167 EXPECT_EQ(list_total_size, storage_size);
168
169 std::unique_ptr<char*[]> values(new char*[item.fetch.size()]);
170 std::unique_ptr<size_t[]> lens(new size_t[item.fetch.size()]);
171 std::unique_ptr<char[]> storage(new char[storage_size]);
172 TF_GetFetchNodesList(c_item, values.get(), lens.get(), item.fetch.size(),
173 storage.get(), storage_size, status);
174 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
175
176 for (size_t i = 0; i < item.fetch.size(); ++i) {
177 EXPECT_EQ(item.fetch[i].size(), lens[i]) << i;
178 EXPECT_EQ(item.fetch[i],
179 string(static_cast<const char*>(values[i]), lens[i]))
180 << i;
181 }
182 TF_DeleteStatus(status);
183 }
184
TEST(TF_GraphProperties,InputProperties)185 TEST(TF_GraphProperties, InputProperties) {
186 std::unique_ptr<SingleMachine> cluster(new SingleMachine(5 * 60, 3, 0));
187 TF_ASSERT_OK(cluster->Provision());
188
189 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
190 cluster->GetDeviceNames());
191 GrapplerItem item;
192 CHECK(fake_input.NextItem(&item));
193
194 TF_Status* status = TF_NewStatus();
195 TF_GraphProperties* graph_properties =
196 TF_NewGraphProperties(reinterpret_cast<TF_GrapplerItem*>(&item));
197 TF_InferStatically(graph_properties, true, false, false, false, status);
198 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
199
200 for (const NodeDef& node : item.graph.node()) {
201 if (node.op() == "AddN") {
202 int num_values = 0;
203 TF_GetInputPropertiesListSize(graph_properties, node.name().c_str(),
204 &num_values, status);
205 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
206 EXPECT_EQ(num_values, 1);
207
208 std::vector<TF_Buffer*> in_props_buf(num_values, TF_NewBuffer());
209
210 TF_GetInputPropertiesList(graph_properties, node.name().c_str(),
211 in_props_buf.data(), num_values, status);
212 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
213
214 tensorflow::OpInfo::TensorProperties in_props;
215 Status s = tensorflow::BufferToMessage(in_props_buf[0], &in_props);
216 TF_ASSERT_OK(s);
217
218 EXPECT_EQ(DT_FLOAT, in_props.dtype());
219 EXPECT_FALSE(in_props.shape().unknown_rank());
220 EXPECT_EQ(2, in_props.shape().dim_size());
221 EXPECT_EQ(10, in_props.shape().dim(0).size());
222 EXPECT_EQ(1, in_props.shape().dim(1).size());
223
224 for (int i = 0; i < in_props_buf.size(); i++)
225 TF_DeleteBuffer(in_props_buf[i]);
226 }
227 }
228 TF_DeleteGraphProperties(graph_properties);
229 TF_DeleteStatus(status);
230 TF_ASSERT_OK(cluster->Shutdown());
231 }
232
TEST(TF_GraphProperties,OutputProperties)233 TEST(TF_GraphProperties, OutputProperties) {
234 std::unique_ptr<SingleMachine> cluster(new SingleMachine(5 * 60, 3, 0));
235 TF_ASSERT_OK(cluster->Provision());
236
237 TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
238 cluster->GetDeviceNames());
239 GrapplerItem item;
240 CHECK(fake_input.NextItem(&item));
241
242 TF_Status* status = TF_NewStatus();
243 TF_GraphProperties* graph_properties =
244 TF_NewGraphProperties(reinterpret_cast<TF_GrapplerItem*>(&item));
245 TF_InferStatically(graph_properties, true, false, false, false, status);
246 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
247
248 for (const NodeDef& node : item.graph.node()) {
249 if (node.op() == "AddN") {
250 int num_values = 0;
251 TF_GetOutputPropertiesListSize(graph_properties, node.name().c_str(),
252 &num_values, status);
253 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
254 EXPECT_EQ(num_values, 1);
255
256 std::vector<TF_Buffer*> out_props_buf(num_values, TF_NewBuffer());
257
258 TF_GetOutputPropertiesList(graph_properties, node.name().c_str(),
259 out_props_buf.data(), num_values, status);
260 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
261
262 tensorflow::OpInfo::TensorProperties out_props;
263 Status s = tensorflow::BufferToMessage(out_props_buf[0], &out_props);
264 TF_ASSERT_OK(s);
265
266 EXPECT_EQ(DT_FLOAT, out_props.dtype());
267 EXPECT_FALSE(out_props.shape().unknown_rank());
268 EXPECT_EQ(2, out_props.shape().dim_size());
269 EXPECT_EQ(10, out_props.shape().dim(0).size());
270 EXPECT_EQ(1, out_props.shape().dim(1).size());
271
272 for (int i = 0; i < out_props_buf.size(); i++)
273 TF_DeleteBuffer(out_props_buf[i]);
274 }
275 }
276 TF_DeleteStatus(status);
277 TF_DeleteGraphProperties(graph_properties);
278 TF_ASSERT_OK(cluster->Shutdown());
279 }
280
TEST(TF_FunctionLibraryDefinition,LookUpOpDef)281 TEST(TF_FunctionLibraryDefinition, LookUpOpDef) {
282 TF_Buffer* g_buf = TF_NewBuffer();
283 TF_Buffer* op_buf = TF_NewBuffer();
284 TF_Status* status = TF_NewStatus();
285 GraphDef g_def;
286 Status s = MessageToBuffer(g_def, g_buf);
287 TF_ASSERT_OK(s);
288 TF_FunctionLibraryDefinition* func =
289 TF_NewFunctionLibraryDefinition(g_buf, status);
290
291 TF_LookUpOpDef(func, "Add", op_buf, status);
292 string actual_string(reinterpret_cast<const char*>(op_buf->data),
293 op_buf->length);
294 ASSERT_EQ(TF_OK, TF_GetCode(status));
295
296 const OpDef* expected_op_def;
297 TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def));
298 string expected_serialized;
299 expected_op_def->SerializeToString(&expected_serialized);
300 EXPECT_EQ(expected_serialized, actual_string);
301 TF_DeleteBuffer(g_buf);
302 TF_DeleteBuffer(op_buf);
303 TF_DeleteStatus(status);
304 TF_DeleteFunctionLibraryDefinition(func);
305 }
306
307 } // namespace
308 } // namespace grappler
309 } // namespace tensorflow
310