• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file extends/implements core graph optimizer base classes in terms of
16 // the C API defined in grappler.h. A class "CSomething" represents a
17 // "Something" that can be manipulated via calls in the C interface and a C
18 // struct called "TP_Something".
19 
20 #include "tensorflow/c/experimental/grappler/grappler.h"
21 
22 #include <memory>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/experimental/grappler/grappler_internal.h"
29 #include "tensorflow/c/tf_status_helper.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/grappler/costs/graph_properties.h"
32 #include "tensorflow/core/grappler/grappler_item.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/status.h"
36 
37 namespace {
38 
39 #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME)    \
40   do {                                                                    \
41     if (STRUCT_OBJ.struct_size == 0) {                                    \
42       return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION,   \
43                                 "struct_size field in " #STRUCT_NAME      \
44                                 " must be set to " #SIZE_VALUE_NAME "."); \
45     }                                                                     \
46   } while (0)
47 
48 #define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME)                  \
49   do {                                                                  \
50     if (STRUCT_OBJ.NAME == 0) {                                         \
51       return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
52                                 "'" #NAME "' field in " #STRUCT_NAME    \
53                                 " must be set.");                       \
54     }                                                                   \
55   } while (0)
56 
ValidateTPOptimizerRegistrationParams(const TP_OptimizerRegistrationParams & params)57 tensorflow::Status ValidateTPOptimizerRegistrationParams(
58     const TP_OptimizerRegistrationParams& params) {
59   VALIDATE_STRUCT_SIZE(TP_OptimizerRegistrationParams, params,
60                        TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE);
61   VALIDATE_MEMBER(TP_OptimizerRegistrationParams, params, device_type);
62   return tensorflow::Status::OK();
63 }
64 
ValidateTPOptimizer(const TP_Optimizer & optimizer)65 tensorflow::Status ValidateTPOptimizer(const TP_Optimizer& optimizer) {
66   VALIDATE_STRUCT_SIZE(TP_Optimizer, optimizer, TP_OPTIMIZER_STRUCT_SIZE);
67   VALIDATE_MEMBER(TP_Optimizer, optimizer, optimize_func);
68   return tensorflow::Status::OK();
69 }
70 
ValidateTPOptimizerConfigs(const TP_OptimizerConfigs & configs)71 tensorflow::Status ValidateTPOptimizerConfigs(
72     const TP_OptimizerConfigs& configs) {
73   VALIDATE_STRUCT_SIZE(TP_OptimizerConfigs, configs,
74                        TP_OPTIMIZER_CONFIGS_STRUCT_SIZE);
75   return tensorflow::Status::OK();
76 }
77 
78 #undef VALIDATE_MEMBER
79 #undef VALIDATE_STRUCT_SIZE
80 }  // namespace
81 
82 namespace tensorflow {
83 namespace grappler {
84 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph_def)85 Status CGraphOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
86                                  GraphDef* optimized_graph_def) {
87   OwnedTFStatus c_status(TF_NewStatus());
88   OwnedTFBuffer graph_buf(TF_NewBuffer());
89   OwnedTFBuffer optimized_graph_buf(TF_NewBuffer());
90   TF_RETURN_IF_ERROR(MessageToBuffer(item.graph, graph_buf.get()));
91 
92   optimizer_.optimize_func(c_optimizer_, graph_buf.get(),
93                            reinterpret_cast<const TF_GrapplerItem*>(&item),
94                            optimized_graph_buf.get(), c_status.get());
95   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
96   TF_RETURN_IF_ERROR(
97       BufferToMessage(optimized_graph_buf.get(), optimized_graph_def));
98 
99   return Status::OK();
100 }
101 
102 #define CONFIG_TOGGLE(optimizer)                             \
103   if (tp_configs.optimizer == TF_TriState_Off)               \
104     configs.toggle_config[#optimizer] = RewriterConfig::OFF; \
105   else                                                       \
106     configs.toggle_config[#optimizer] = RewriterConfig::ON;
107 
CGraphOptimizerRegister(const PluginGraphOptimizerRegistry::Creator & creator,const TP_OptimizerConfigs tp_configs,const char * device_type)108 void CGraphOptimizerRegister(
109     const PluginGraphOptimizerRegistry::Creator& creator,
110     const TP_OptimizerConfigs tp_configs, const char* device_type) {
111   ConfigList configs;
112   // disable_model_pruning is turned off by default.
113   if (tp_configs.disable_model_pruning == TF_TriState_On)
114     configs.disable_model_pruning = true;
115   else
116     configs.disable_model_pruning = false;
117   // The other configs are turned on by default.
118   CONFIG_TOGGLE(implementation_selector);
119   CONFIG_TOGGLE(function_optimization);
120   CONFIG_TOGGLE(common_subgraph_elimination);
121   CONFIG_TOGGLE(arithmetic_optimization);
122   CONFIG_TOGGLE(debug_stripper);
123   CONFIG_TOGGLE(constant_folding);
124   CONFIG_TOGGLE(shape_optimization);
125   CONFIG_TOGGLE(auto_mixed_precision);
126   CONFIG_TOGGLE(auto_mixed_precision_mkl);
127   CONFIG_TOGGLE(pin_to_host_optimization);
128   CONFIG_TOGGLE(layout_optimizer);
129   CONFIG_TOGGLE(remapping);
130   CONFIG_TOGGLE(loop_optimization);
131   CONFIG_TOGGLE(dependency_optimization);
132   CONFIG_TOGGLE(auto_parallel);
133   CONFIG_TOGGLE(memory_optimization);
134   CONFIG_TOGGLE(scoped_allocator_optimization);
135   PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
136       creator, device_type, configs);
137 }
138 
139 #undef CONFIG_TOGGLE
140 
InitGraphPlugin(void * dso_handle)141 tensorflow::Status InitGraphPlugin(void* dso_handle) {
142   tensorflow::Env* env = tensorflow::Env::Default();
143 
144   // Step 1: Load symbol for `TF_InitPlugin`
145   void* dso_symbol;
146   TF_RETURN_IF_ERROR(
147       env->GetSymbolFromLibrary(dso_handle, "TF_InitGraph", &dso_symbol));
148 
149   // Step 2: Call `TF_InitPlugin`
150   auto init_fn = reinterpret_cast<TFInitGraphPluginFn>(dso_symbol);
151   return InitGraphPlugin(init_fn);
152 }
153 
InitGraphPlugin(TFInitGraphPluginFn init_fn)154 tensorflow::Status InitGraphPlugin(TFInitGraphPluginFn init_fn) {
155   TP_OptimizerRegistrationParams params{
156       TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE};
157   TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE};
158   TP_OptimizerConfigs optimizer_configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE};
159   params.major_version = GO_MAJOR;
160   params.minor_version = GO_MINOR;
161   params.patch_version = GO_PATCH;
162   params.optimizer = &optimizer;
163   params.optimizer_configs = &optimizer_configs;
164 
165   OwnedTFStatus c_status(TF_NewStatus());
166   init_fn(&params, c_status.get());
167   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
168   TF_RETURN_IF_ERROR(ValidateTPOptimizerRegistrationParams(params));
169   TF_RETURN_IF_ERROR(ValidateTPOptimizer(optimizer));
170   TF_RETURN_IF_ERROR(ValidateTPOptimizerConfigs(optimizer_configs));
171 
172   CGraphOptimizerRegister(
173       [=]() { return new CGraphOptimizer(optimizer, params.device_type); },
174       optimizer_configs, params.device_type);
175 
176   return Status::OK();
177 }
178 
179 }  // namespace grappler
180 }  // namespace tensorflow
181 
TF_GetNodesToPreserveListSize(const TF_GrapplerItem * item,int * num_values,size_t * storage_size,TF_Status * status)182 void TF_GetNodesToPreserveListSize(const TF_GrapplerItem* item, int* num_values,
183                                    size_t* storage_size, TF_Status* status) {
184   TF_SetStatus(status, TF_OK, "");
185   const std::unordered_set<std::string>& nodes =
186       reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)
187           ->NodesToPreserve();
188   *num_values = nodes.size();
189   *storage_size = 0;
190   for (const std::string& str : nodes) {
191     *storage_size += str.size();
192   }
193 }
194 
TF_GetNodesToPreserveList(const TF_GrapplerItem * item,char ** values,size_t * lengths,int num_values,void * storage,size_t storage_size,TF_Status * status)195 void TF_GetNodesToPreserveList(const TF_GrapplerItem* item, char** values,
196                                size_t* lengths, int num_values, void* storage,
197                                size_t storage_size, TF_Status* status) {
198   TF_SetStatus(status, TF_OK, "");
199   const std::unordered_set<std::string>& nodes =
200       reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)
201           ->NodesToPreserve();
202   char* p = static_cast<char*>(storage);
203 
204   int index = 0;
205   for (const std::string& s : nodes) {
206     if (index >= num_values) break;
207     values[index] = p;
208     lengths[index] = s.size();
209     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
210       status->status = tensorflow::errors::InvalidArgument(
211           "Not enough storage to hold the requested list of nodes");
212       return;
213     }
214     memcpy(values[index], s.data(), s.size());
215     p += s.size();
216     index++;
217   }
218 }
219 
TF_GetFetchNodesListSize(const TF_GrapplerItem * item,int * num_values,size_t * storage_size,TF_Status * status)220 void TF_GetFetchNodesListSize(const TF_GrapplerItem* item, int* num_values,
221                               size_t* storage_size, TF_Status* status) {
222   TF_SetStatus(status, TF_OK, "");
223   const std::vector<std::string>& nodes =
224       reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)->fetch;
225   *num_values = nodes.size();
226   *storage_size = 0;
227   for (const std::string& str : nodes) {
228     *storage_size += str.size();
229   }
230 }
231 
TF_GetFetchNodesList(const TF_GrapplerItem * item,char ** values,size_t * lengths,int num_values,void * storage,size_t storage_size,TF_Status * status)232 void TF_GetFetchNodesList(const TF_GrapplerItem* item, char** values,
233                           size_t* lengths, int num_values, void* storage,
234                           size_t storage_size, TF_Status* status) {
235   TF_SetStatus(status, TF_OK, "");
236   const std::vector<std::string>& nodes =
237       reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)->fetch;
238 
239   const int len = std::min(num_values, static_cast<int>(nodes.size()));
240   char* p = static_cast<char*>(storage);
241   for (int index = 0; index < len; ++index) {
242     const std::string& s = nodes[index];
243     values[index] = p;
244     lengths[index] = s.size();
245     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
246       status->status = tensorflow::errors::InvalidArgument(
247           "Not enough storage to hold the requested list of nodes");
248       return;
249     }
250     memcpy(values[index], s.data(), s.size());
251     p += s.size();
252   }
253 }
254 
TF_NewGraphProperties(const TF_GrapplerItem * item)255 TF_GraphProperties* TF_NewGraphProperties(const TF_GrapplerItem* item) {
256   return reinterpret_cast<TF_GraphProperties*>(
257       new tensorflow::grappler::GraphProperties(
258           *reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)));
259 }
260 
TF_DeleteGraphProperties(TF_GraphProperties * graph_properties)261 void TF_DeleteGraphProperties(TF_GraphProperties* graph_properties) {
262   if (graph_properties == nullptr) return;
263   delete reinterpret_cast<tensorflow::grappler::GraphProperties*>(
264       graph_properties);
265 }
266 
TF_InferStatically(TF_GraphProperties * graph_properties,TF_Bool assume_valid_feeds,TF_Bool aggressive_shape_inference,TF_Bool include_input_tensor_values,TF_Bool include_output_tensor_values,TF_Status * status)267 void TF_InferStatically(TF_GraphProperties* graph_properties,
268                         TF_Bool assume_valid_feeds,
269                         TF_Bool aggressive_shape_inference,
270                         TF_Bool include_input_tensor_values,
271                         TF_Bool include_output_tensor_values,
272                         TF_Status* status) {
273   TF_SetStatus(status, TF_OK, "");
274   tensorflow::Status s =
275       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
276           ->InferStatically(assume_valid_feeds, aggressive_shape_inference,
277                             include_input_tensor_values,
278                             include_output_tensor_values);
279   if (!s.ok()) {
280     ::tensorflow::Set_TF_Status_from_Status(status, s);
281   }
282 }
283 
TF_GetInputPropertiesListSize(TF_GraphProperties * graph_properties,const char * name,int * num_values,TF_Status * status)284 void TF_GetInputPropertiesListSize(TF_GraphProperties* graph_properties,
285                                    const char* name, int* num_values,
286                                    TF_Status* status) {
287   TF_SetStatus(status, TF_OK, "");
288   *num_values =
289       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
290           ->GetInputProperties(name)
291           .size();
292 }
293 
TF_GetOutputPropertiesListSize(TF_GraphProperties * graph_properties,const char * name,int * num_values,TF_Status * status)294 void TF_GetOutputPropertiesListSize(TF_GraphProperties* graph_properties,
295                                     const char* name, int* num_values,
296                                     TF_Status* status) {
297   TF_SetStatus(status, TF_OK, "");
298   *num_values =
299       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
300           ->GetOutputProperties(name)
301           .size();
302 }
303 
TF_GetInputPropertiesList(TF_GraphProperties * graph_properties,const char * name,TF_Buffer ** properties,int num_values,TF_Status * status)304 void TF_GetInputPropertiesList(TF_GraphProperties* graph_properties,
305                                const char* name, TF_Buffer** properties,
306                                int num_values, TF_Status* status) {
307   TF_SetStatus(status, TF_OK, "");
308   const std::vector<tensorflow::OpInfo::TensorProperties>& tensor_properties =
309       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
310           ->GetInputProperties(name);
311   const int len =
312       std::min(num_values, static_cast<int>(tensor_properties.size()));
313   for (int i = 0; i < len; ++i) {
314     tensorflow::Status s =
315         tensorflow::MessageToBuffer(tensor_properties[i], properties[i]);
316     if (!s.ok()) {
317       ::tensorflow::Set_TF_Status_from_Status(status, s);
318       return;
319     }
320   }
321 }
322 
TF_GetOutputPropertiesList(TF_GraphProperties * graph_properties,const char * name,TF_Buffer ** properties,int num_values,TF_Status * status)323 void TF_GetOutputPropertiesList(TF_GraphProperties* graph_properties,
324                                 const char* name, TF_Buffer** properties,
325                                 int num_values, TF_Status* status) {
326   TF_SetStatus(status, TF_OK, "");
327   const std::vector<tensorflow::OpInfo::TensorProperties>& tensor_properties =
328       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
329           ->GetOutputProperties(name);
330   const int len =
331       std::min(num_values, static_cast<int>(tensor_properties.size()));
332   for (int i = 0; i < len; ++i) {
333     tensorflow::Status s =
334         tensorflow::MessageToBuffer(tensor_properties[i], properties[i]);
335     if (!s.ok()) {
336       ::tensorflow::Set_TF_Status_from_Status(status, s);
337       return;
338     }
339   }
340 }
341 
TF_NewFunctionLibraryDefinition(TF_Buffer * graph_buf,TF_Status * status)342 TF_FunctionLibraryDefinition* TF_NewFunctionLibraryDefinition(
343     TF_Buffer* graph_buf, TF_Status* status) {
344   TF_SetStatus(status, TF_OK, "");
345   tensorflow::GraphDef graph_def;
346   tensorflow::Status s = tensorflow::BufferToMessage(graph_buf, &graph_def);
347   if (!s.ok()) {
348     ::tensorflow::Set_TF_Status_from_Status(status, s);
349     return nullptr;
350   }
351   return reinterpret_cast<TF_FunctionLibraryDefinition*>(
352       new tensorflow::FunctionLibraryDefinition(
353           tensorflow::OpRegistry::Global(), graph_def.library()));
354 }
355 
TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition * fn_lib)356 void TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition* fn_lib) {
357   if (fn_lib == nullptr) return;
358   delete reinterpret_cast<tensorflow::FunctionLibraryDefinition*>(fn_lib);
359 }
360 
TF_LookUpOpDef(TF_FunctionLibraryDefinition * fn_lib,const char * name,TF_Buffer * buf,TF_Status * status)361 void TF_LookUpOpDef(TF_FunctionLibraryDefinition* fn_lib, const char* name,
362                     TF_Buffer* buf, TF_Status* status) {
363   TF_SetStatus(status, TF_OK, "");
364   const tensorflow::OpDef* op_def_ptr = nullptr;
365   tensorflow::Status s =
366       reinterpret_cast<tensorflow::FunctionLibraryDefinition*>(fn_lib)
367           ->LookUpOpDef(name, &op_def_ptr);
368   if (!s.ok()) {
369     ::tensorflow::Set_TF_Status_from_Status(status, s);
370     return;
371   }
372 
373   s = tensorflow::MessageToBuffer(*op_def_ptr, buf);
374   if (!s.ok()) {
375     ::tensorflow::Set_TF_Status_from_Status(status, s);
376     return;
377   }
378 }
379