• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0(the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/grappler/grappler.h"
16 
17 #include "tensorflow/c/c_api_internal.h"
18 #include "tensorflow/c/experimental/grappler/grappler_internal.h"
19 #include "tensorflow/c/tf_buffer_internal.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/grappler/clusters/single_machine.h"
22 #include "tensorflow/core/grappler/costs/graph_properties.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 namespace {
31 
optimize_func(void * optimizer,const TF_Buffer * graph_buf,const TF_GrapplerItem * item,TF_Buffer * optimized_graph_buf,TF_Status * tf_status)32 void optimize_func(void* optimizer, const TF_Buffer* graph_buf,
33                    const TF_GrapplerItem* item, TF_Buffer* optimized_graph_buf,
34                    TF_Status* tf_status) {}
35 
PopulateDefaultParam(TP_OptimizerRegistrationParams * params)36 void PopulateDefaultParam(TP_OptimizerRegistrationParams* params) {
37   params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE;
38   params->optimizer_configs->struct_size = TP_OPTIMIZER_CONFIGS_STRUCT_SIZE;
39   params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE;
40   params->optimizer->create_func = nullptr;
41   params->optimizer->optimize_func = optimize_func;
42   params->optimizer->destroy_func = nullptr;
43 }
44 
TEST(Grappler,SuccessfulRegistration)45 TEST(Grappler, SuccessfulRegistration) {
46   auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
47                         TF_Status* const status) -> void {
48     TF_SetStatus(status, TF_OK, "");
49     PopulateDefaultParam(params);
50     params->device_type = "Success";
51     params->optimizer_configs->remapping = TF_TriState_Off;
52   };
53 
54   TF_ASSERT_OK(InitGraphPlugin(plugin_init));
55   ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers(
56                 std::set<string>{"Success"})
57                 .size(),
58             1);
59   ConfigList config = PluginGraphOptimizerRegistry::GetPluginConfigs(
60       true, std::set<string>{"Success"});
61   ASSERT_EQ(config.toggle_config["remapping"], RewriterConfig::OFF);
62 }
63 
TEST(Grappler,MultiplePluginRegistration)64 TEST(Grappler, MultiplePluginRegistration) {
65   auto plugin_init_0 = [](TP_OptimizerRegistrationParams* const params,
66                           TF_Status* const status) -> void {
67     TF_SetStatus(status, TF_OK, "");
68     PopulateDefaultParam(params);
69     params->device_type = "Device0";
70   };
71   auto plugin_init_1 = [](TP_OptimizerRegistrationParams* const params,
72                           TF_Status* const status) -> void {
73     TF_SetStatus(status, TF_OK, "");
74     PopulateDefaultParam(params);
75     params->device_type = "Device1";
76   };
77 
78   TF_ASSERT_OK(InitGraphPlugin(plugin_init_0));
79   TF_ASSERT_OK(InitGraphPlugin(plugin_init_1));
80   ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers(
81                 std::set<string>{"Device0", "Device1"})
82                 .size(),
83             2);
84 }
85 
TEST(Grappler,DeviceTypeNotSet)86 TEST(Grappler, DeviceTypeNotSet) {
87   auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
88                         TF_Status* const status) -> void {
89     TF_SetStatus(status, TF_OK, "");
90     PopulateDefaultParam(params);
91     params->device_type = nullptr;
92   };
93 
94   tensorflow::Status status = InitGraphPlugin(plugin_init);
95   ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
96   ASSERT_EQ(
97       status.error_message(),
98       "'device_type' field in TP_OptimizerRegistrationParams must be set.");
99 }
100 
TEST(Grappler,OptimizeFuncNotSet)101 TEST(Grappler, OptimizeFuncNotSet) {
102   auto plugin_init = [](TP_OptimizerRegistrationParams* const params,
103                         TF_Status* const status) -> void {
104     TF_SetStatus(status, TF_OK, "");
105     PopulateDefaultParam(params);
106     params->device_type = "FuncNotSet";
107     params->optimizer->optimize_func = nullptr;
108   };
109 
110   tensorflow::Status status = InitGraphPlugin(plugin_init);
111   ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
112   ASSERT_EQ(status.error_message(),
113             "'optimize_func' field in TP_Optimizer must be set.");
114 }
115 
TEST(TF_GrapplerItem,NodesToPreserve)116 TEST(TF_GrapplerItem, NodesToPreserve) {
117   GrapplerItem item;
118   item.fetch = std::vector<string>{"Conv", "BiasAdd"};
119   std::unordered_set<string> nodes_preserved = item.NodesToPreserve();
120   TF_GrapplerItem* c_item = reinterpret_cast<TF_GrapplerItem*>(&item);
121 
122   int list_total_size = 0;
123   for (const string& s : nodes_preserved) {
124     list_total_size += s.size();
125   }
126 
127   size_t storage_size = 0;
128   int num_values = 0;
129   TF_Status* status = TF_NewStatus();
130   TF_GetNodesToPreserveListSize(c_item, &num_values, &storage_size, status);
131   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
132   EXPECT_EQ(nodes_preserved.size(), num_values);
133   EXPECT_EQ(list_total_size, storage_size);
134 
135   std::unique_ptr<char*[]> values(new char*[nodes_preserved.size()]);
136   std::unique_ptr<size_t[]> lens(new size_t[nodes_preserved.size()]);
137   std::unique_ptr<char[]> storage(new char[storage_size]);
138   TF_GetNodesToPreserveList(c_item, values.get(), lens.get(),
139                             nodes_preserved.size(), storage.get(), storage_size,
140                             status);
141   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
142 
143   for (size_t i = 0; i < nodes_preserved.size(); ++i) {
144     EXPECT_EQ(nodes_preserved.find(string(static_cast<const char*>(values[i]),
145                                           lens[i])) != nodes_preserved.end(),
146               true);
147   }
148   TF_DeleteStatus(status);
149 }
150 
TEST(TF_GrapplerItem,FetchNodes)151 TEST(TF_GrapplerItem, FetchNodes) {
152   GrapplerItem item;
153   item.fetch = std::vector<string>{"Conv", "BiasAdd"};
154   TF_GrapplerItem* c_item = reinterpret_cast<TF_GrapplerItem*>(&item);
155 
156   int list_total_size = 0;
157   for (const string& s : item.fetch) {
158     list_total_size += s.size();
159   }
160 
161   size_t storage_size = 0;
162   int num_values = 0;
163   TF_Status* status = TF_NewStatus();
164   TF_GetFetchNodesListSize(c_item, &num_values, &storage_size, status);
165   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
166   EXPECT_EQ(item.fetch.size(), num_values);
167   EXPECT_EQ(list_total_size, storage_size);
168 
169   std::unique_ptr<char*[]> values(new char*[item.fetch.size()]);
170   std::unique_ptr<size_t[]> lens(new size_t[item.fetch.size()]);
171   std::unique_ptr<char[]> storage(new char[storage_size]);
172   TF_GetFetchNodesList(c_item, values.get(), lens.get(), item.fetch.size(),
173                        storage.get(), storage_size, status);
174   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
175 
176   for (size_t i = 0; i < item.fetch.size(); ++i) {
177     EXPECT_EQ(item.fetch[i].size(), lens[i]) << i;
178     EXPECT_EQ(item.fetch[i],
179               string(static_cast<const char*>(values[i]), lens[i]))
180         << i;
181   }
182   TF_DeleteStatus(status);
183 }
184 
TEST(TF_GraphProperties,InputProperties)185 TEST(TF_GraphProperties, InputProperties) {
186   std::unique_ptr<SingleMachine> cluster(new SingleMachine(5 * 60, 3, 0));
187   TF_ASSERT_OK(cluster->Provision());
188 
189   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
190                                           cluster->GetDeviceNames());
191   GrapplerItem item;
192   CHECK(fake_input.NextItem(&item));
193 
194   TF_Status* status = TF_NewStatus();
195   TF_GraphProperties* graph_properties =
196       TF_NewGraphProperties(reinterpret_cast<TF_GrapplerItem*>(&item));
197   TF_InferStatically(graph_properties, true, false, false, false, status);
198   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
199 
200   for (const NodeDef& node : item.graph.node()) {
201     if (node.op() == "AddN") {
202       int num_values = 0;
203       TF_GetInputPropertiesListSize(graph_properties, node.name().c_str(),
204                                     &num_values, status);
205       EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
206       EXPECT_EQ(num_values, 1);
207 
208       std::vector<TF_Buffer*> in_props_buf(num_values, TF_NewBuffer());
209 
210       TF_GetInputPropertiesList(graph_properties, node.name().c_str(),
211                                 in_props_buf.data(), num_values, status);
212       EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
213 
214       tensorflow::OpInfo::TensorProperties in_props;
215       Status s = tensorflow::BufferToMessage(in_props_buf[0], &in_props);
216       TF_ASSERT_OK(s);
217 
218       EXPECT_EQ(DT_FLOAT, in_props.dtype());
219       EXPECT_FALSE(in_props.shape().unknown_rank());
220       EXPECT_EQ(2, in_props.shape().dim_size());
221       EXPECT_EQ(10, in_props.shape().dim(0).size());
222       EXPECT_EQ(1, in_props.shape().dim(1).size());
223 
224       for (int i = 0; i < in_props_buf.size(); i++)
225         TF_DeleteBuffer(in_props_buf[i]);
226     }
227   }
228   TF_DeleteGraphProperties(graph_properties);
229   TF_DeleteStatus(status);
230   TF_ASSERT_OK(cluster->Shutdown());
231 }
232 
TEST(TF_GraphProperties,OutputProperties)233 TEST(TF_GraphProperties, OutputProperties) {
234   std::unique_ptr<SingleMachine> cluster(new SingleMachine(5 * 60, 3, 0));
235   TF_ASSERT_OK(cluster->Provision());
236 
237   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
238                                           cluster->GetDeviceNames());
239   GrapplerItem item;
240   CHECK(fake_input.NextItem(&item));
241 
242   TF_Status* status = TF_NewStatus();
243   TF_GraphProperties* graph_properties =
244       TF_NewGraphProperties(reinterpret_cast<TF_GrapplerItem*>(&item));
245   TF_InferStatically(graph_properties, true, false, false, false, status);
246   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
247 
248   for (const NodeDef& node : item.graph.node()) {
249     if (node.op() == "AddN") {
250       int num_values = 0;
251       TF_GetOutputPropertiesListSize(graph_properties, node.name().c_str(),
252                                      &num_values, status);
253       EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
254       EXPECT_EQ(num_values, 1);
255 
256       std::vector<TF_Buffer*> out_props_buf(num_values, TF_NewBuffer());
257 
258       TF_GetOutputPropertiesList(graph_properties, node.name().c_str(),
259                                  out_props_buf.data(), num_values, status);
260       EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
261 
262       tensorflow::OpInfo::TensorProperties out_props;
263       Status s = tensorflow::BufferToMessage(out_props_buf[0], &out_props);
264       TF_ASSERT_OK(s);
265 
266       EXPECT_EQ(DT_FLOAT, out_props.dtype());
267       EXPECT_FALSE(out_props.shape().unknown_rank());
268       EXPECT_EQ(2, out_props.shape().dim_size());
269       EXPECT_EQ(10, out_props.shape().dim(0).size());
270       EXPECT_EQ(1, out_props.shape().dim(1).size());
271 
272       for (int i = 0; i < out_props_buf.size(); i++)
273         TF_DeleteBuffer(out_props_buf[i]);
274     }
275   }
276   TF_DeleteStatus(status);
277   TF_DeleteGraphProperties(graph_properties);
278   TF_ASSERT_OK(cluster->Shutdown());
279 }
280 
TEST(TF_FunctionLibraryDefinition,LookUpOpDef)281 TEST(TF_FunctionLibraryDefinition, LookUpOpDef) {
282   TF_Buffer* g_buf = TF_NewBuffer();
283   TF_Buffer* op_buf = TF_NewBuffer();
284   TF_Status* status = TF_NewStatus();
285   GraphDef g_def;
286   Status s = MessageToBuffer(g_def, g_buf);
287   TF_ASSERT_OK(s);
288   TF_FunctionLibraryDefinition* func =
289       TF_NewFunctionLibraryDefinition(g_buf, status);
290 
291   TF_LookUpOpDef(func, "Add", op_buf, status);
292   string actual_string(reinterpret_cast<const char*>(op_buf->data),
293                        op_buf->length);
294   ASSERT_EQ(TF_OK, TF_GetCode(status));
295 
296   const OpDef* expected_op_def;
297   TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def));
298   string expected_serialized;
299   expected_op_def->SerializeToString(&expected_serialized);
300   EXPECT_EQ(expected_serialized, actual_string);
301   TF_DeleteBuffer(g_buf);
302   TF_DeleteBuffer(op_buf);
303   TF_DeleteStatus(status);
304   TF_DeleteFunctionLibraryDefinition(func);
305 }
306 
307 }  // namespace
308 }  // namespace grappler
309 }  // namespace tensorflow
310