1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file extends/implements core graph optimizer base classes in terms of
16 // the C API defined in grappler.h. A class "CSomething" represents a
17 // "Something" that can be manipulated via calls in the C interface and a C
18 // struct called "TP_Something".
19
20 #include "tensorflow/c/experimental/grappler/grappler.h"
21
22 #include <memory>
23 #include <unordered_map>
24 #include <vector>
25
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/experimental/grappler/grappler_internal.h"
29 #include "tensorflow/c/tf_status_helper.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/grappler/costs/graph_properties.h"
32 #include "tensorflow/core/grappler/grappler_item.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/status.h"
36
37 namespace {
38
39 #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
40 do { \
41 if (STRUCT_OBJ.struct_size == 0) { \
42 return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
43 "struct_size field in " #STRUCT_NAME \
44 " must be set to " #SIZE_VALUE_NAME "."); \
45 } \
46 } while (0)
47
48 #define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \
49 do { \
50 if (STRUCT_OBJ.NAME == 0) { \
51 return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
52 "'" #NAME "' field in " #STRUCT_NAME \
53 " must be set."); \
54 } \
55 } while (0)
56
ValidateTPOptimizerRegistrationParams(const TP_OptimizerRegistrationParams & params)57 tensorflow::Status ValidateTPOptimizerRegistrationParams(
58 const TP_OptimizerRegistrationParams& params) {
59 VALIDATE_STRUCT_SIZE(TP_OptimizerRegistrationParams, params,
60 TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE);
61 VALIDATE_MEMBER(TP_OptimizerRegistrationParams, params, device_type);
62 return tensorflow::Status::OK();
63 }
64
ValidateTPOptimizer(const TP_Optimizer & optimizer)65 tensorflow::Status ValidateTPOptimizer(const TP_Optimizer& optimizer) {
66 VALIDATE_STRUCT_SIZE(TP_Optimizer, optimizer, TP_OPTIMIZER_STRUCT_SIZE);
67 VALIDATE_MEMBER(TP_Optimizer, optimizer, optimize_func);
68 return tensorflow::Status::OK();
69 }
70
ValidateTPOptimizerConfigs(const TP_OptimizerConfigs & configs)71 tensorflow::Status ValidateTPOptimizerConfigs(
72 const TP_OptimizerConfigs& configs) {
73 VALIDATE_STRUCT_SIZE(TP_OptimizerConfigs, configs,
74 TP_OPTIMIZER_CONFIGS_STRUCT_SIZE);
75 return tensorflow::Status::OK();
76 }
77
78 #undef VALIDATE_MEMBER
79 #undef VALIDATE_STRUCT_SIZE
80 } // namespace
81
82 namespace tensorflow {
83 namespace grappler {
84
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph_def)85 Status CGraphOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
86 GraphDef* optimized_graph_def) {
87 OwnedTFStatus c_status(TF_NewStatus());
88 OwnedTFBuffer graph_buf(TF_NewBuffer());
89 OwnedTFBuffer optimized_graph_buf(TF_NewBuffer());
90 TF_RETURN_IF_ERROR(MessageToBuffer(item.graph, graph_buf.get()));
91
92 optimizer_.optimize_func(c_optimizer_, graph_buf.get(),
93 reinterpret_cast<const TF_GrapplerItem*>(&item),
94 optimized_graph_buf.get(), c_status.get());
95 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
96 TF_RETURN_IF_ERROR(
97 BufferToMessage(optimized_graph_buf.get(), optimized_graph_def));
98
99 return Status::OK();
100 }
101
102 #define CONFIG_TOGGLE(optimizer) \
103 if (tp_configs.optimizer == TF_TriState_Off) \
104 configs.toggle_config[#optimizer] = RewriterConfig::OFF; \
105 else \
106 configs.toggle_config[#optimizer] = RewriterConfig::ON;
107
CGraphOptimizerRegister(const PluginGraphOptimizerRegistry::Creator & creator,const TP_OptimizerConfigs tp_configs,const char * device_type)108 void CGraphOptimizerRegister(
109 const PluginGraphOptimizerRegistry::Creator& creator,
110 const TP_OptimizerConfigs tp_configs, const char* device_type) {
111 ConfigList configs;
112 // disable_model_pruning is turned off by default.
113 if (tp_configs.disable_model_pruning == TF_TriState_On)
114 configs.disable_model_pruning = true;
115 else
116 configs.disable_model_pruning = false;
117 // The other configs are turned on by default.
118 CONFIG_TOGGLE(implementation_selector);
119 CONFIG_TOGGLE(function_optimization);
120 CONFIG_TOGGLE(common_subgraph_elimination);
121 CONFIG_TOGGLE(arithmetic_optimization);
122 CONFIG_TOGGLE(debug_stripper);
123 CONFIG_TOGGLE(constant_folding);
124 CONFIG_TOGGLE(shape_optimization);
125 CONFIG_TOGGLE(auto_mixed_precision);
126 CONFIG_TOGGLE(auto_mixed_precision_mkl);
127 CONFIG_TOGGLE(pin_to_host_optimization);
128 CONFIG_TOGGLE(layout_optimizer);
129 CONFIG_TOGGLE(remapping);
130 CONFIG_TOGGLE(loop_optimization);
131 CONFIG_TOGGLE(dependency_optimization);
132 CONFIG_TOGGLE(auto_parallel);
133 CONFIG_TOGGLE(memory_optimization);
134 CONFIG_TOGGLE(scoped_allocator_optimization);
135 PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
136 creator, device_type, configs);
137 }
138
139 #undef CONFIG_TOGGLE
140
InitGraphPlugin(void * dso_handle)141 tensorflow::Status InitGraphPlugin(void* dso_handle) {
142 tensorflow::Env* env = tensorflow::Env::Default();
143
144 // Step 1: Load symbol for `TF_InitPlugin`
145 void* dso_symbol;
146 TF_RETURN_IF_ERROR(
147 env->GetSymbolFromLibrary(dso_handle, "TF_InitGraph", &dso_symbol));
148
149 // Step 2: Call `TF_InitPlugin`
150 auto init_fn = reinterpret_cast<TFInitGraphPluginFn>(dso_symbol);
151 return InitGraphPlugin(init_fn);
152 }
153
InitGraphPlugin(TFInitGraphPluginFn init_fn)154 tensorflow::Status InitGraphPlugin(TFInitGraphPluginFn init_fn) {
155 TP_OptimizerRegistrationParams params{
156 TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE};
157 TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE};
158 TP_OptimizerConfigs optimizer_configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE};
159 params.major_version = GO_MAJOR;
160 params.minor_version = GO_MINOR;
161 params.patch_version = GO_PATCH;
162 params.optimizer = &optimizer;
163 params.optimizer_configs = &optimizer_configs;
164
165 OwnedTFStatus c_status(TF_NewStatus());
166 init_fn(¶ms, c_status.get());
167 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
168 TF_RETURN_IF_ERROR(ValidateTPOptimizerRegistrationParams(params));
169 TF_RETURN_IF_ERROR(ValidateTPOptimizer(optimizer));
170 TF_RETURN_IF_ERROR(ValidateTPOptimizerConfigs(optimizer_configs));
171
172 CGraphOptimizerRegister(
173 [=]() { return new CGraphOptimizer(optimizer, params.device_type); },
174 optimizer_configs, params.device_type);
175
176 return Status::OK();
177 }
178
179 } // namespace grappler
180 } // namespace tensorflow
181
TF_GetNodesToPreserveListSize(const TF_GrapplerItem * item,int * num_values,size_t * storage_size,TF_Status * status)182 void TF_GetNodesToPreserveListSize(const TF_GrapplerItem* item, int* num_values,
183 size_t* storage_size, TF_Status* status) {
184 TF_SetStatus(status, TF_OK, "");
185 const std::unordered_set<std::string>& nodes =
186 reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)
187 ->NodesToPreserve();
188 *num_values = nodes.size();
189 *storage_size = 0;
190 for (const std::string& str : nodes) {
191 *storage_size += str.size();
192 }
193 }
194
TF_GetNodesToPreserveList(const TF_GrapplerItem * item,char ** values,size_t * lengths,int num_values,void * storage,size_t storage_size,TF_Status * status)195 void TF_GetNodesToPreserveList(const TF_GrapplerItem* item, char** values,
196 size_t* lengths, int num_values, void* storage,
197 size_t storage_size, TF_Status* status) {
198 TF_SetStatus(status, TF_OK, "");
199 const std::unordered_set<std::string>& nodes =
200 reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)
201 ->NodesToPreserve();
202 char* p = static_cast<char*>(storage);
203
204 int index = 0;
205 for (const std::string& s : nodes) {
206 if (index >= num_values) break;
207 values[index] = p;
208 lengths[index] = s.size();
209 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
210 status->status = tensorflow::errors::InvalidArgument(
211 "Not enough storage to hold the requested list of nodes");
212 return;
213 }
214 memcpy(values[index], s.data(), s.size());
215 p += s.size();
216 index++;
217 }
218 }
219
TF_GetFetchNodesListSize(const TF_GrapplerItem * item,int * num_values,size_t * storage_size,TF_Status * status)220 void TF_GetFetchNodesListSize(const TF_GrapplerItem* item, int* num_values,
221 size_t* storage_size, TF_Status* status) {
222 TF_SetStatus(status, TF_OK, "");
223 const std::vector<std::string>& nodes =
224 reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)->fetch;
225 *num_values = nodes.size();
226 *storage_size = 0;
227 for (const std::string& str : nodes) {
228 *storage_size += str.size();
229 }
230 }
231
TF_GetFetchNodesList(const TF_GrapplerItem * item,char ** values,size_t * lengths,int num_values,void * storage,size_t storage_size,TF_Status * status)232 void TF_GetFetchNodesList(const TF_GrapplerItem* item, char** values,
233 size_t* lengths, int num_values, void* storage,
234 size_t storage_size, TF_Status* status) {
235 TF_SetStatus(status, TF_OK, "");
236 const std::vector<std::string>& nodes =
237 reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)->fetch;
238
239 const int len = std::min(num_values, static_cast<int>(nodes.size()));
240 char* p = static_cast<char*>(storage);
241 for (int index = 0; index < len; ++index) {
242 const std::string& s = nodes[index];
243 values[index] = p;
244 lengths[index] = s.size();
245 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
246 status->status = tensorflow::errors::InvalidArgument(
247 "Not enough storage to hold the requested list of nodes");
248 return;
249 }
250 memcpy(values[index], s.data(), s.size());
251 p += s.size();
252 }
253 }
254
TF_NewGraphProperties(const TF_GrapplerItem * item)255 TF_GraphProperties* TF_NewGraphProperties(const TF_GrapplerItem* item) {
256 return reinterpret_cast<TF_GraphProperties*>(
257 new tensorflow::grappler::GraphProperties(
258 *reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)));
259 }
260
TF_DeleteGraphProperties(TF_GraphProperties * graph_properties)261 void TF_DeleteGraphProperties(TF_GraphProperties* graph_properties) {
262 if (graph_properties == nullptr) return;
263 delete reinterpret_cast<tensorflow::grappler::GraphProperties*>(
264 graph_properties);
265 }
266
TF_InferStatically(TF_GraphProperties * graph_properties,TF_Bool assume_valid_feeds,TF_Bool aggressive_shape_inference,TF_Bool include_input_tensor_values,TF_Bool include_output_tensor_values,TF_Status * status)267 void TF_InferStatically(TF_GraphProperties* graph_properties,
268 TF_Bool assume_valid_feeds,
269 TF_Bool aggressive_shape_inference,
270 TF_Bool include_input_tensor_values,
271 TF_Bool include_output_tensor_values,
272 TF_Status* status) {
273 TF_SetStatus(status, TF_OK, "");
274 tensorflow::Status s =
275 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
276 ->InferStatically(assume_valid_feeds, aggressive_shape_inference,
277 include_input_tensor_values,
278 include_output_tensor_values);
279 if (!s.ok()) {
280 ::tensorflow::Set_TF_Status_from_Status(status, s);
281 }
282 }
283
TF_GetInputPropertiesListSize(TF_GraphProperties * graph_properties,const char * name,int * num_values,TF_Status * status)284 void TF_GetInputPropertiesListSize(TF_GraphProperties* graph_properties,
285 const char* name, int* num_values,
286 TF_Status* status) {
287 TF_SetStatus(status, TF_OK, "");
288 *num_values =
289 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
290 ->GetInputProperties(name)
291 .size();
292 }
293
TF_GetOutputPropertiesListSize(TF_GraphProperties * graph_properties,const char * name,int * num_values,TF_Status * status)294 void TF_GetOutputPropertiesListSize(TF_GraphProperties* graph_properties,
295 const char* name, int* num_values,
296 TF_Status* status) {
297 TF_SetStatus(status, TF_OK, "");
298 *num_values =
299 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
300 ->GetOutputProperties(name)
301 .size();
302 }
303
TF_GetInputPropertiesList(TF_GraphProperties * graph_properties,const char * name,TF_Buffer ** properties,int num_values,TF_Status * status)304 void TF_GetInputPropertiesList(TF_GraphProperties* graph_properties,
305 const char* name, TF_Buffer** properties,
306 int num_values, TF_Status* status) {
307 TF_SetStatus(status, TF_OK, "");
308 const std::vector<tensorflow::OpInfo::TensorProperties>& tensor_properties =
309 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
310 ->GetInputProperties(name);
311 const int len =
312 std::min(num_values, static_cast<int>(tensor_properties.size()));
313 for (int i = 0; i < len; ++i) {
314 tensorflow::Status s =
315 tensorflow::MessageToBuffer(tensor_properties[i], properties[i]);
316 if (!s.ok()) {
317 ::tensorflow::Set_TF_Status_from_Status(status, s);
318 return;
319 }
320 }
321 }
322
TF_GetOutputPropertiesList(TF_GraphProperties * graph_properties,const char * name,TF_Buffer ** properties,int num_values,TF_Status * status)323 void TF_GetOutputPropertiesList(TF_GraphProperties* graph_properties,
324 const char* name, TF_Buffer** properties,
325 int num_values, TF_Status* status) {
326 TF_SetStatus(status, TF_OK, "");
327 const std::vector<tensorflow::OpInfo::TensorProperties>& tensor_properties =
328 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
329 ->GetOutputProperties(name);
330 const int len =
331 std::min(num_values, static_cast<int>(tensor_properties.size()));
332 for (int i = 0; i < len; ++i) {
333 tensorflow::Status s =
334 tensorflow::MessageToBuffer(tensor_properties[i], properties[i]);
335 if (!s.ok()) {
336 ::tensorflow::Set_TF_Status_from_Status(status, s);
337 return;
338 }
339 }
340 }
341
TF_NewFunctionLibraryDefinition(TF_Buffer * graph_buf,TF_Status * status)342 TF_FunctionLibraryDefinition* TF_NewFunctionLibraryDefinition(
343 TF_Buffer* graph_buf, TF_Status* status) {
344 TF_SetStatus(status, TF_OK, "");
345 tensorflow::GraphDef graph_def;
346 tensorflow::Status s = tensorflow::BufferToMessage(graph_buf, &graph_def);
347 if (!s.ok()) {
348 ::tensorflow::Set_TF_Status_from_Status(status, s);
349 return nullptr;
350 }
351 return reinterpret_cast<TF_FunctionLibraryDefinition*>(
352 new tensorflow::FunctionLibraryDefinition(
353 tensorflow::OpRegistry::Global(), graph_def.library()));
354 }
355
TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition * fn_lib)356 void TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition* fn_lib) {
357 if (fn_lib == nullptr) return;
358 delete reinterpret_cast<tensorflow::FunctionLibraryDefinition*>(fn_lib);
359 }
360
TF_LookUpOpDef(TF_FunctionLibraryDefinition * fn_lib,const char * name,TF_Buffer * buf,TF_Status * status)361 void TF_LookUpOpDef(TF_FunctionLibraryDefinition* fn_lib, const char* name,
362 TF_Buffer* buf, TF_Status* status) {
363 TF_SetStatus(status, TF_OK, "");
364 const tensorflow::OpDef* op_def_ptr = nullptr;
365 tensorflow::Status s =
366 reinterpret_cast<tensorflow::FunctionLibraryDefinition*>(fn_lib)
367 ->LookUpOpDef(name, &op_def_ptr);
368 if (!s.ok()) {
369 ::tensorflow::Set_TF_Status_from_Status(status, s);
370 return;
371 }
372
373 s = tensorflow::MessageToBuffer(*op_def_ptr, buf);
374 if (!s.ok()) {
375 ::tensorflow::Set_TF_Status_from_Status(status, s);
376 return;
377 }
378 }
379