• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 
6     http://www.apache.org/licenses/LICENSE-2.0
7 
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14 
15 #include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
16 
17 #include <memory>
18 
19 #include "absl/strings/ascii.h"
20 #include "absl/strings/escaping.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h"
26 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/op_types.h"
30 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
31 #include "tensorflow/core/grappler/utils/functions.h"
32 #include "tensorflow/core/grappler/utils/topological_sort.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/casts.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 #if GOOGLE_CUDA && GOOGLE_TENSORRT
40 namespace tensorflow {
41 namespace tensorrt {
42 namespace convert {
43 using absl::AsciiStrToUpper;
44 using absl::StrAppend;
45 using absl::StrCat;
46 
47 namespace {
48 
ShouldUseExplicitPrecision(const GraphDef & gdef)49 bool ShouldUseExplicitPrecision(const GraphDef& gdef) {
50   if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) {
51     return false;
52   }
53   return absl::c_any_of(gdef.node(), [](const auto& node) {
54     return (absl::c_find(kExplicitQuantizationOpNames, node.op()) !=
55             kExplicitQuantizationOpNames.end());
56   });
57 }
58 
ShouldConvertFunction(const grappler::GrapplerItem & item)59 StatusOr<bool> ShouldConvertFunction(const grappler::GrapplerItem& item) {
60   if (item.id == "tf_graph") {
61     return false;
62   }
63   const auto& func_item =
64       tensorflow::down_cast<const grappler::GrapplerFunctionItem&>(item);
65   const AttrSlice& attr = func_item.func_attr();
66   const AttrValue* attr_value = attr.FindByString("_tftrt_convert_function");
67   if (attr_value != nullptr) {
68     bool result = false;
69     TF_RETURN_IF_ERROR(GetNodeAttr(attr, "_tftrt_convert_function", &result));
70     return result;
71   }
72   VLOG(1) << "Attribute _tftrt_convert_function was not found.";
73   return false;
74 }
75 
76 // Converts function conversion attributes to conversion parameters.
UpdateFunctionSpecificConversionParams(TRTOptimizationPass::ConversionParams & cp,const tensorflow::AttrSlice & attr)77 Status UpdateFunctionSpecificConversionParams(
78     TRTOptimizationPass::ConversionParams& cp,
79     const tensorflow::AttrSlice& attr) {
80   auto get_size_attr = [](const AttrSlice& attr, absl::string_view name,
81                           size_t* dst) -> Status {
82     int tmp = 0;
83     TF_RETURN_IF_ERROR(GetNodeAttr(attr, name, &tmp));
84     *dst = static_cast<size_t>(tmp);
85     return Status::OK();
86   };
87 
88   TF_RETURN_IF_ERROR(
89       GetNodeAttr(attr, "_tftrt_trt_logger_name", &cp.trt_logger_name));
90   TF_RETURN_IF_ERROR(
91       get_size_attr(attr, "_tftrt_max_batch_size", &cp.max_batch_size));
92   TF_RETURN_IF_ERROR(get_size_attr(attr, "_tftrt_max_workspace_size_bytes",
93                                    &cp.max_workspace_size_bytes));
94   std::string precision_mode;
95   TF_RETURN_IF_ERROR(
96       GetNodeAttr(attr, "_tftrt_precision_mode", &precision_mode));
97   TF_RETURN_IF_ERROR(
98       TrtPrecisionModeFromName(precision_mode, &cp.precision_mode));
99   TF_RETURN_IF_ERROR(GetNodeAttr(attr, "_tftrt_minimum_segment_size",
100                                  &cp.minimum_segment_size));
101   TF_RETURN_IF_ERROR(GetNodeAttr(attr, "_tftrt_is_dyn_op", &cp.is_dynamic_op));
102   TF_RETURN_IF_ERROR(
103       GetNodeAttr(attr, "_tftrt_max_cached_engines", &cp.max_cached_engines));
104   TF_RETURN_IF_ERROR(
105       GetNodeAttr(attr, "_tftrt_use_calibration", &cp.use_calibration));
106   TF_RETURN_IF_ERROR(
107       GetNodeAttr(attr, "_tftrt_use_implicit_batch", &cp.use_implicit_batch));
108   std::string profile_strategy;
109   TF_RETURN_IF_ERROR(
110       GetNodeAttr(attr, "_tftrt_profile_strategy", &profile_strategy));
111   TF_RETURN_IF_ERROR(
112       ProfileStrategyFromName(profile_strategy, &cp.profile_strategy));
113   TF_RETURN_IF_ERROR(GetNodeAttr(attr, "_tftrt_allow_build_at_runtime",
114                                  &cp.allow_build_at_runtime));
115   return Status::OK();
116 }
117 }  // namespace
118 
Init(const RewriterConfig_CustomGraphOptimizer * config)119 Status TRTOptimizationPass::Init(
120     const RewriterConfig_CustomGraphOptimizer* config) {
121   if (config == nullptr) {
122     return Status::OK();
123   }
124   const auto params = config->parameter_map();
125   if (params.count("minimum_segment_size")) {
126     params_.minimum_segment_size = params.at("minimum_segment_size").i();
127   }
128   if (params.count("max_batch_size")) {
129     params_.max_batch_size = params.at("max_batch_size").i();
130   }
131   if (params.count("is_dynamic_op")) {
132     params_.is_dynamic_op = params.at("is_dynamic_op").b();
133   }
134   if (params.count("maximum_cached_engines")) {
135     params_.max_cached_engines = params.at("maximum_cached_engines").i();
136   }
137   if (params.count("max_workspace_size_bytes")) {
138     params_.max_workspace_size_bytes =
139         params.at("max_workspace_size_bytes").i();
140   }
141   if (params.count("precision_mode")) {
142     TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
143         AsciiStrToUpper(params.at("precision_mode").s()),
144         &params_.precision_mode));
145   }
146   if (params.count("use_calibration")) {
147     params_.use_calibration = params.at("use_calibration").b();
148   }
149   if (params.count("trt_logger")) {
150     params_.trt_logger_name = params.at("trt_logger").s();
151   }
152   if (params.count("allow_build_at_runtime")) {
153     params_.allow_build_at_runtime = params.at("allow_build_at_runtime").b();
154   }
155   if (params.count("use_implicit_batch")) {
156     params_.use_implicit_batch = params.at("use_implicit_batch").b();
157   }
158   if (params.count("profile_strategy")) {
159     TF_RETURN_IF_ERROR(ProfileStrategyFromName(
160         params.at("profile_strategy").s(), &params_.profile_strategy));
161   }
162   return Status::OK();
163 }
164 
ExplicitPrecisionModePolicy()165 static bool ExplicitPrecisionModePolicy() {
166   return IS_TRT_VERSION_GE(8, 0, 0, 0);
167 }
168 
Optimize(grappler::Cluster * cluster,const grappler::GrapplerItem & item,GraphDef * optimized_graph)169 Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
170                                      const grappler::GrapplerItem& item,
171                                      GraphDef* optimized_graph) {
172   VLOG(1) << "Called TRTOptimization Pass " << name_
173           << " on a grappler item with id=" << item.id;
174   TF_ASSIGN_OR_RETURN(bool do_function_conversion, ShouldConvertFunction(item));
175   // Optimizing the main graph(identified with `item.id == "tf_graph"`) with
176   // `minimim_segment_size == -1` indicates skipping main graph conversion.
177   if ((params_.minimum_segment_size == -1 && item.id == "tf_graph") ||
178       (item.id != "tf_graph" && !do_function_conversion)) {
179     VLOG(1) << "Not optimizing this grappler item: " << item.id;
180     *optimized_graph = item.graph;
181     return Status::OK();
182   }
183 
184   if (params_.use_calibration &&
185       params_.precision_mode != TrtPrecisionMode::INT8) {
186     LOG(WARNING) << "Calibration with FP32 or FP16 is not implemented. "
187                  << "Falling back to use_calibration = False."
188                  << "Note that the default value of use_calibration is True.";
189     params_.use_calibration = false;
190   }
191 
192   params_.use_explicit_precision = ShouldUseExplicitPrecision(item.graph);
193   if (params_.use_explicit_precision) {
194     LOG(INFO) << "[TF-TRT] Using explicit QDQ mode";
195     if (params_.precision_mode != TrtPrecisionMode::INT8 ||
196         params_.use_calibration) {
197       LOG(WARNING)
198           << "Explicit precision mode with calibration or FP32/FP16 mode is "
199              "not supported."
200           << " Setting precision mode to INT8 and calibration to false.";
201       params_.precision_mode = TrtPrecisionMode::INT8;
202       params_.use_calibration = false;
203     }
204   }
205 
206   // Create a copy of the graph to optimize.
207   grappler::GrapplerItem optimized_item(item);
208 
209   std::vector<string> nodes_to_preserve;
210   const auto& old_nodes_to_preserve = item.NodesToPreserve();
211   nodes_to_preserve.reserve(old_nodes_to_preserve.size());
212   for (const auto& n : old_nodes_to_preserve) {
213     auto tokens = str_util::Split(n, ":");
214     string s = tokens.at(0);
215     for (int i = 1; i < tokens.size() - 1; ++i) {
216       StrAppend(&s, ":", tokens.at(i));
217     }
218     int dumm_port = -1;
219     // If the last token is not an integer, it must be part of the name.
220     // Otherwise it is port number.
221     if (tokens.size() > 1 &&
222         !strings::safe_strto32(tokens.back(), &dumm_port)) {  // non-absl ok
223       StrAppend(&s, ":", tokens.back());
224     }
225     nodes_to_preserve.push_back(s);
226   }
227 
228   if (item.id != "tf_graph" && do_function_conversion) {
229     const grappler::GrapplerFunctionItem& func_item =
230         tensorflow::down_cast<const grappler::GrapplerFunctionItem&>(item);
231     TF_RETURN_IF_ERROR(
232         UpdateFunctionSpecificConversionParams(params_, func_item.func_attr()));
233   }
234 
235   return ConvertGraph(params_, optimized_item, nodes_to_preserve, cluster,
236                       optimized_graph);
237 }
238 
239 static grappler::CustomGraphOptimizerRegistrar TRTOptimizationPass_Registrar(
__anon539e0a290402() 240     []() {
241       VLOG(1)
242           << "Instantiating CustomOptimizationPass object TensorRTOptimizer";
243       return new TRTOptimizationPass("TensorRTOptimizer");
244     },
245     ("TensorRTOptimizer"));
246 
247 }  // namespace convert
248 }  // namespace tensorrt
249 }  // namespace tensorflow
250 
251 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
252