• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 
6     http://www.apache.org/licenses/LICENSE-2.0
7 
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 
15 #include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
16 
17 #include "absl/strings/ascii.h"
18 #include "absl/strings/escaping.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23 #include "tensorflow/core/grappler/clusters/cluster.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
26 #include "tensorflow/core/grappler/utils/functions.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/casts.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/stacktrace.h"
33 
34 #if GOOGLE_CUDA && GOOGLE_TENSORRT
35 namespace tensorflow {
36 namespace tensorrt {
37 namespace convert {
38 // TODO(sami): Remove VLOG messages once the code matures
39 using absl::AsciiStrToUpper;
40 using absl::StrAppend;
41 using absl::StrCat;
42 
43 namespace {
44 
ValidateValueCase(const AttrValue * attr_value,const AttrValue::ValueCase & value_case)45 Status ValidateValueCase(const AttrValue* attr_value,
46                          const AttrValue::ValueCase& value_case) {
47   if (attr_value->value_case() != value_case) {
48     return errors::InvalidArgument("AttrValue had value with type '",
49                                    attr_value->value_case(), "' when '",
50                                    value_case, "' was expected.");
51   }
52   return Status::OK();
53 }
54 
55 template <typename T>
GetAttrBoolValue(T * value,const AttrValue * attr_value)56 Status GetAttrBoolValue(T* value, const AttrValue* attr_value) {
57   *value = static_cast<T>(attr_value->b());
58   return Status::OK();
59 }
60 
61 template <typename T>
GetAttrIntValue(T * value,const AttrValue * attr_value)62 Status GetAttrIntValue(T* value, const AttrValue* attr_value) {
63   *value = static_cast<T>(attr_value->i());
64   return Status::OK();
65 }
66 
67 template <typename T>
GetAttrStringValue(T * value,const AttrValue * attr_value)68 Status GetAttrStringValue(T* value, const AttrValue* attr_value) {
69   *value = attr_value->s();
70   return Status::OK();
71 }
72 
73 template <typename T>
GetAttrTrtPrecisionModeValue(T * value,const AttrValue * attr_value)74 Status GetAttrTrtPrecisionModeValue(T* value, const AttrValue* attr_value) {
75   return TrtPrecisionModeFromName(attr_value->s(), value);
76 }
77 
78 template <typename T>
GetAttrProfileStrategyValue(T * value,const AttrValue * attr_value)79 Status GetAttrProfileStrategyValue(T* value, const AttrValue* attr_value) {
80   return ProfileStrategyFromName(attr_value->s(), value);
81 }
82 
83 template <AttrValue::ValueCase value_case, typename T, typename F>
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,T * value,F value_getter)84 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
85                     const std::string& attr_name, T* value, F value_getter) {
86   const AttrValue* attr_value = attr_slice.FindByString(attr_name);
87   if (attr_value != nullptr) {
88     TF_RETURN_IF_ERROR(ValidateValueCase(attr_value, value_case));
89     TF_RETURN_IF_ERROR(value_getter(value, attr_value));
90     VLOG(1) << "Updated cp." << attr_name.substr(7) << ".";
91   }
92   return Status::OK();
93 }
94 
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,std::string * value)95 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
96                     const std::string& attr_name, std::string* value) {
97   return GetAttrValue<AttrValue::ValueCase::kS, std::string,
98                       Status(std::string*, const AttrValue*)>(
99       attr_slice, attr_name, value, GetAttrStringValue);
100 }
101 
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,size_t * value)102 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
103                     const std::string& attr_name, size_t* value) {
104   return GetAttrValue<AttrValue::ValueCase::kI, size_t,
105                       Status(size_t*, const AttrValue*)>(
106       attr_slice, attr_name, value, GetAttrIntValue);
107 }
108 
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,int * value)109 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
110                     const std::string& attr_name, int* value) {
111   return GetAttrValue<AttrValue::ValueCase::kI, int,
112                       Status(int*, const AttrValue*)>(attr_slice, attr_name,
113                                                       value, GetAttrIntValue);
114 }
115 
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,TrtPrecisionMode * value)116 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
117                     const std::string& attr_name, TrtPrecisionMode* value) {
118   return GetAttrValue<AttrValue::ValueCase::kS, TrtPrecisionMode,
119                       Status(TrtPrecisionMode*, const AttrValue*)>(
120       attr_slice, attr_name, value, GetAttrTrtPrecisionModeValue);
121 }
122 
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,int64_t * value)123 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
124                     const std::string& attr_name, int64_t* value) {
125   return GetAttrValue<AttrValue::ValueCase::kI, int64_t,
126                       Status(int64_t*, const AttrValue*)>(
127       attr_slice, attr_name, value, GetAttrIntValue);
128 }
129 
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,bool * value)130 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
131                     const std::string& attr_name, bool* value) {
132   return GetAttrValue<AttrValue::ValueCase::kB, bool,
133                       Status(bool*, const AttrValue*)>(attr_slice, attr_name,
134                                                        value, GetAttrBoolValue);
135 }
136 
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,ProfileStrategy * value)137 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
138                     const std::string& attr_name, ProfileStrategy* value) {
139   return GetAttrValue<AttrValue::ValueCase::kS, ProfileStrategy,
140                       Status(ProfileStrategy*, const AttrValue*)>(
141       attr_slice, attr_name, value, GetAttrProfileStrategyValue);
142 }
143 
ShouldConvertFunction(const grappler::GrapplerItem & item)144 StatusOr<bool> ShouldConvertFunction(const grappler::GrapplerItem& item) {
145   if (item.id == "tf_graph") {
146     return false;
147   }
148   const grappler::GrapplerFunctionItem& func_item =
149       tensorflow::down_cast<const grappler::GrapplerFunctionItem&>(item);
150   const tensorflow::AttrSlice& attr = func_item.func_attr();
151   const AttrValue* attr_value = attr.FindByString("_tftrt_convert_function");
152   if (attr_value != nullptr) {
153     TF_RETURN_IF_ERROR(ValidateValueCase(attr_value, AttrValue::ValueCase::kB));
154     return attr_value->b();
155   }
156   VLOG(1) << "Attribute _tftrt_convert_function was not found.";
157   return false;
158 }
159 
CheckForFunctionConversionAttribute(const tensorflow::AttrSlice & attr)160 StatusOr<bool> CheckForFunctionConversionAttribute(
161     const tensorflow::AttrSlice& attr) {
162   const AttrValue* attr_value = attr.FindByString("_tftrt_convert_function");
163   if (attr_value != nullptr) {
164     TF_RETURN_IF_ERROR(ValidateValueCase(attr_value, AttrValue::ValueCase::kB));
165     return attr_value->b();
166   } else {
167     VLOG(1) << "Attribute _tftrt_convert_function was not found.";
168   }
169   return false;
170 }
171 
172 // Converts function conversion attributes to conversion parameters.
UpdateFunctionSpecificConversionParams(ConversionParams & cp,const tensorflow::AttrSlice & attr)173 Status UpdateFunctionSpecificConversionParams(
174     ConversionParams& cp, const tensorflow::AttrSlice& attr) {
175   TF_RETURN_IF_ERROR(
176       GetAttrValue(attr, "_tftrt_trt_logger_name", &cp.trt_logger_name));
177   TF_RETURN_IF_ERROR(
178       GetAttrValue(attr, "_tftrt_max_batch_size", &cp.max_batch_size));
179   TF_RETURN_IF_ERROR(GetAttrValue(attr, "_tftrt_max_workspace_size_bytes",
180                                   &cp.max_workspace_size_bytes));
181   TF_RETURN_IF_ERROR(
182       GetAttrValue(attr, "_tftrt_precision_mode", &cp.precision_mode));
183   TF_RETURN_IF_ERROR(GetAttrValue(attr, "_tftrt_minimum_segment_size",
184                                   &cp.minimum_segment_size));
185   TF_RETURN_IF_ERROR(GetAttrValue(attr, "_tftrt_is_dyn_op", &cp.is_dyn_op));
186   TF_RETURN_IF_ERROR(
187       GetAttrValue(attr, "_tftrt_max_cached_engines", &cp.max_cached_engines));
188   TF_RETURN_IF_ERROR(
189       GetAttrValue(attr, "_tftrt_use_calibration", &cp.use_calibration));
190   TF_RETURN_IF_ERROR(
191       GetAttrValue(attr, "_tftrt_use_implicit_batch", &cp.use_implicit_batch));
192   TF_RETURN_IF_ERROR(
193       GetAttrValue(attr, "_tftrt_profile_strategy", &cp.profile_strategy));
194   TF_RETURN_IF_ERROR(GetAttrValue(attr, "_tftrt_allow_build_at_runtime",
195                                   &cp.allow_build_at_runtime));
196   return Status::OK();
197 }
198 
199 }  // namespace
200 
Init(const RewriterConfig_CustomGraphOptimizer * config)201 Status TRTOptimizationPass::Init(
202     const RewriterConfig_CustomGraphOptimizer* config) {
203   VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
204   if (config == nullptr) {
205     return Status::OK();
206   }
207   VLOG(1) << "config = " << config->DebugString();
208   const auto params = config->parameter_map();
209   if (params.count("minimum_segment_size")) {
210     minimum_segment_size_ = params.at("minimum_segment_size").i();
211   }
212   if (params.count("max_batch_size")) {
213     maximum_batch_size_ = params.at("max_batch_size").i();
214   }
215   if (params.count("is_dynamic_op")) {
216     is_dynamic_op_ = params.at("is_dynamic_op").b();
217   }
218   if (params.count("maximum_cached_engines")) {
219     max_cached_batches_ = params.at("maximum_cached_engines").i();
220   }
221   if (params.count("max_workspace_size_bytes")) {
222     max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i();
223   }
224   if (params.count("precision_mode")) {
225     TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
226         AsciiStrToUpper(params.at("precision_mode").s()), &precision_mode_));
227   }
228   if (params.count("use_calibration")) {
229     use_calibration_ = params.at("use_calibration").b();
230   }
231   if (params.count("trt_logger")) {
232     trt_logger_name_ = params.at("trt_logger").s();
233   }
234   if (params.count("allow_build_at_runtime")) {
235     allow_build_at_runtime_ = params.at("allow_build_at_runtime").b();
236   }
237   if (params.count("use_implicit_batch")) {
238     use_implicit_batch_ = params.at("use_implicit_batch").b();
239   }
240   if (params.count("profile_strategy")) {
241     TF_RETURN_IF_ERROR(ProfileStrategyFromName(
242         params.at("profile_strategy").s(), &profile_strategy_));
243   }
244   return Status::OK();
245 }
246 
PrintDebugInfo(grappler::Cluster * cluster,const grappler::GrapplerItem & item)247 void TRTOptimizationPass::PrintDebugInfo(grappler::Cluster* cluster,
248                                          const grappler::GrapplerItem& item) {
249   LOG(INFO) << "Cluster = " << cluster;
250   string offset("  ");
251   string offset2 = StrCat(offset, offset);
252   string offset3 = StrCat(offset2, offset);
253   string offset4 = StrCat(offset2, offset2);
254 
255   if (cluster) {
256     LOG(INFO) << offset << "type             = " << cluster->type();
257     LOG(INFO) << offset << "num warmup steps = " << cluster->NumWarmupSteps();
258     const auto dev_names = cluster->GetDeviceNames();
259     if (!dev_names.empty()) {
260       LOG(INFO) << offset << " Device names:";
261       for (const auto& s : dev_names) {
262         LOG(INFO) << offset2 << s;
263       }
264     }
265     std::unordered_map<string, uint64> peak_mem;
266     auto status = cluster->GetPeakMemoryUsage(&peak_mem);
267     if (status == Status::OK()) {
268       LOG(INFO) << offset << "Peak Memory Usage :";
269       for (const auto& s : peak_mem) {
270         LOG(INFO) << offset2 << s.first << " = " << s.second;
271       }
272     }
273 
274     const auto dev_props = cluster->GetDevices();
275     if (!dev_props.empty()) {
276       LOG(INFO) << offset << "Device properties:";
277       for (const auto& k : dev_props) {
278         LOG(INFO) << offset2 << k.first;
279         const auto& dt = k.second;
280         LOG(INFO) << offset3 << "type          = " << dt.type();
281         LOG(INFO) << offset3 << "vendor        = " << dt.vendor();
282         LOG(INFO) << offset3 << "model         = " << dt.model();
283         LOG(INFO) << offset3 << "frequency     = " << dt.frequency();
284         LOG(INFO) << offset3 << "num cores     = " << dt.num_cores();
285         LOG(INFO) << offset3 << "num registers = " << dt.num_registers();
286         LOG(INFO) << offset3 << "L1 cache size = " << dt.l1_cache_size();
287         LOG(INFO) << offset3 << "L2 cache size = " << dt.l2_cache_size();
288         LOG(INFO) << offset3 << "L3 cache size = " << dt.l3_cache_size();
289         LOG(INFO) << offset3 << "SHMem per SMP = "
290                   << dt.shared_memory_size_per_multiprocessor();
291         LOG(INFO) << offset3 << "memory size   = " << dt.memory_size();
292         LOG(INFO) << offset3 << "bandwidth     = " << dt.bandwidth();
293         if (dt.environment_size()) {
294           LOG(INFO) << offset3 << "environment   :";
295           for (const auto& e : dt.environment()) {
296             LOG(INFO) << offset4 << e.first << " = " << e.second;
297           }
298         }
299       }
300     }
301 
302     if (cluster->GetDeviceSet()) {
303       for (const auto dev : cluster->GetDeviceSet()->devices()) {
304         LOG(INFO) << "Device name= " << dev->name() << "Pased name= "
305                   << DeviceNameUtils::ParsedNameToString(dev->parsed_name());
306       }
307     }
308   }
309 
310   LOG(INFO) << "item: " << item.id;
311   if (!item.feed.empty()) {
312     LOG(INFO) << offset << "Feeds  :";
313     for (const auto& f : item.feed) {
314       const auto& shape = f.second.shape();
315       LOG(INFO) << offset2 << f.first << " = shaped " << shape.DebugString();
316     }
317   } else {
318     LOG(INFO) << offset << "No Feeds";
319   }
320   if (!item.fetch.empty()) {
321     LOG(INFO) << offset << "Fetches  :";
322     for (const auto& f : item.fetch) {
323       LOG(INFO) << offset2 << f;
324     }
325   } else {
326     LOG(INFO) << offset << "No Fetches";
327   }
328 
329   if (!item.init_ops.empty()) {
330     LOG(INFO) << offset << "init ops  :";
331     for (const auto& f : item.init_ops) {
332       LOG(INFO) << offset2 << f;
333     }
334   } else {
335     LOG(INFO) << offset << "No init ops";
336   }
337   LOG(INFO) << "Save Op = " << item.save_op;
338   LOG(INFO) << "Restore Op = " << item.restore_op;
339   LOG(INFO) << "save_restore_loc_tensor = " << item.save_restore_loc_tensor;
340   if (!item.keep_ops.empty()) {
341     LOG(INFO) << offset << "keep ops  :";
342     for (const auto& f : item.keep_ops) {
343       LOG(INFO) << offset2 << f;
344     }
345   } else {
346     LOG(INFO) << offset << "No keep ops";
347   }
348 }
349 
Optimize(grappler::Cluster * cluster,const grappler::GrapplerItem & item,GraphDef * optimized_graph)350 Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
351                                      const grappler::GrapplerItem& item,
352                                      GraphDef* optimized_graph) {
353   VLOG(1) << "Called TRTOptimization Pass " << name_
354           << " on a grappler item with id=" << item.id;
355   TF_ASSIGN_OR_RETURN(bool do_function_conversion, ShouldConvertFunction(item));
356   if (minimum_segment_size_ == -1 ||
357       (item.id != "tf_graph" && !do_function_conversion)) {
358     VLOG(1) << "Not optimizing this grappler item: " << item.id;
359     *optimized_graph = item.graph;
360     return Status::OK();
361   }
362   if (VLOG_IS_ON(3)) {
363     LOG(INFO) << CurrentStackTrace();
364     PrintDebugInfo(cluster, item);
365   }
366 
367   if (use_calibration_ && precision_mode_ != TrtPrecisionMode::INT8) {
368     VLOG(1) << "Calibration with FP32 or FP16 is not implemented. "
369             << "Falling back to use_calibration = False."
370             << "Note that the default value of use_calibration is True.";
371     use_calibration_ = false;
372   }
373 
374   std::vector<string> nodes_to_preserve;
375   for (const auto& n : item.NodesToPreserve()) {
376     auto tokens = str_util::Split(n, ":");
377     string s = tokens.at(0);
378     for (int i = 1; i < tokens.size() - 1; ++i) {
379       StrAppend(&s, ":", tokens.at(i));
380     }
381     int dumm_port = -1;
382     // If the last token is not an integer, it must be part of the name.
383     // Otherwise it is port number.
384     if (tokens.size() > 1 &&
385         !strings::safe_strto32(tokens.back(), &dumm_port)) {  // non-absl ok
386       StrAppend(&s, ":", tokens.back());
387     }
388     nodes_to_preserve.push_back(s);
389   }
390 
391   ConversionParams cp;
392   cp.grappler_item = &item;
393   cp.output_names = &nodes_to_preserve;
394   cp.trt_logger_name = trt_logger_name_;
395   cp.max_batch_size = maximum_batch_size_;
396   cp.max_workspace_size_bytes = max_workspace_size_bytes_;
397   cp.output_graph_def = optimized_graph;
398   cp.precision_mode = precision_mode_;
399   cp.minimum_segment_size = minimum_segment_size_;
400   cp.cluster = cluster;
401   cp.is_dyn_op = is_dynamic_op_;
402   cp.max_cached_engines = max_cached_batches_;
403   cp.use_calibration = use_calibration_;
404   cp.use_implicit_batch = use_implicit_batch_;
405   cp.profile_strategy = profile_strategy_;
406   cp.allow_build_at_runtime = allow_build_at_runtime_;
407 
408   if (item.id != "tf_graph" && do_function_conversion) {
409     const grappler::GrapplerFunctionItem& func_item =
410         tensorflow::down_cast<const grappler::GrapplerFunctionItem&>(item);
411     TF_RETURN_IF_ERROR(
412         UpdateFunctionSpecificConversionParams(cp, func_item.func_attr()));
413   }
414 
415   auto status = ConvertAfterShapes(cp);
416   VLOG(1) << "Returning from " << name_;
417   return status;
418 }
419 
420 class VerboseCustomGraphOptimizerRegistrar
421     : public grappler::CustomGraphOptimizerRegistrar {
422  public:
VerboseCustomGraphOptimizerRegistrar(const grappler::CustomGraphOptimizerRegistry::Creator & cr,const string & name)423   VerboseCustomGraphOptimizerRegistrar(
424       const grappler::CustomGraphOptimizerRegistry::Creator& cr,
425       const string& name)
426       : grappler::CustomGraphOptimizerRegistrar(cr, name) {
427     VLOG(1) << "Constructing a CustomOptimizationPass registration object for "
428             << name;
429   }
430 };
431 
432 static VerboseCustomGraphOptimizerRegistrar TRTOptimizationPass_Registrar(
__anon9beddcc10202() 433     []() {
434       VLOG(1)
435           << "Instantiating CustomOptimizationPass object TensorRTOptimizer";
436       return new TRTOptimizationPass("TensorRTOptimizer");
437     },
438     ("TensorRTOptimizer"));
439 
440 }  // namespace convert
441 }  // namespace tensorrt
442 }  // namespace tensorflow
443 
444 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
445