1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5
6 http://www.apache.org/licenses/LICENSE-2.0
7
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14
15 #include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
16
17 #include <memory>
18
19 #include "absl/strings/ascii.h"
20 #include "absl/strings/escaping.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
24 #include "tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h"
26 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/grappler/op_types.h"
30 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
31 #include "tensorflow/core/grappler/utils/functions.h"
32 #include "tensorflow/core/grappler/utils/topological_sort.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/casts.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 #if GOOGLE_CUDA && GOOGLE_TENSORRT
40 namespace tensorflow {
41 namespace tensorrt {
42 namespace convert {
43 using absl::AsciiStrToUpper;
44 using absl::StrAppend;
45 using absl::StrCat;
46
47 namespace {
48
ShouldUseExplicitPrecision(const GraphDef & gdef)49 bool ShouldUseExplicitPrecision(const GraphDef& gdef) {
50 if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) {
51 return false;
52 }
53 return absl::c_any_of(gdef.node(), [](const auto& node) {
54 return (absl::c_find(kExplicitQuantizationOpNames, node.op()) !=
55 kExplicitQuantizationOpNames.end());
56 });
57 }
58
ShouldConvertFunction(const grappler::GrapplerItem & item)59 StatusOr<bool> ShouldConvertFunction(const grappler::GrapplerItem& item) {
60 if (item.id == "tf_graph") {
61 return false;
62 }
63 const auto& func_item =
64 tensorflow::down_cast<const grappler::GrapplerFunctionItem&>(item);
65 const AttrSlice& attr = func_item.func_attr();
66 const AttrValue* attr_value = attr.FindByString("_tftrt_convert_function");
67 if (attr_value != nullptr) {
68 bool result = false;
69 TF_RETURN_IF_ERROR(GetNodeAttr(attr, "_tftrt_convert_function", &result));
70 return result;
71 }
72 VLOG(1) << "Attribute _tftrt_convert_function was not found.";
73 return false;
74 }
75
76 // Converts function conversion attributes to conversion parameters.
UpdateFunctionSpecificConversionParams(TRTOptimizationPass::ConversionParams & cp,const tensorflow::AttrSlice & attr)77 Status UpdateFunctionSpecificConversionParams(
78 TRTOptimizationPass::ConversionParams& cp,
79 const tensorflow::AttrSlice& attr) {
80 auto get_size_attr = [](const AttrSlice& attr, absl::string_view name,
81 size_t* dst) -> Status {
82 int tmp = 0;
83 TF_RETURN_IF_ERROR(GetNodeAttr(attr, name, &tmp));
84 *dst = static_cast<size_t>(tmp);
85 return Status::OK();
86 };
87
88 TF_RETURN_IF_ERROR(
89 GetNodeAttr(attr, "_tftrt_trt_logger_name", &cp.trt_logger_name));
90 TF_RETURN_IF_ERROR(
91 get_size_attr(attr, "_tftrt_max_batch_size", &cp.max_batch_size));
92 TF_RETURN_IF_ERROR(get_size_attr(attr, "_tftrt_max_workspace_size_bytes",
93 &cp.max_workspace_size_bytes));
94 std::string precision_mode;
95 TF_RETURN_IF_ERROR(
96 GetNodeAttr(attr, "_tftrt_precision_mode", &precision_mode));
97 TF_RETURN_IF_ERROR(
98 TrtPrecisionModeFromName(precision_mode, &cp.precision_mode));
99 TF_RETURN_IF_ERROR(GetNodeAttr(attr, "_tftrt_minimum_segment_size",
100 &cp.minimum_segment_size));
101 TF_RETURN_IF_ERROR(GetNodeAttr(attr, "_tftrt_is_dyn_op", &cp.is_dynamic_op));
102 TF_RETURN_IF_ERROR(
103 GetNodeAttr(attr, "_tftrt_max_cached_engines", &cp.max_cached_engines));
104 TF_RETURN_IF_ERROR(
105 GetNodeAttr(attr, "_tftrt_use_calibration", &cp.use_calibration));
106 TF_RETURN_IF_ERROR(
107 GetNodeAttr(attr, "_tftrt_use_implicit_batch", &cp.use_implicit_batch));
108 std::string profile_strategy;
109 TF_RETURN_IF_ERROR(
110 GetNodeAttr(attr, "_tftrt_profile_strategy", &profile_strategy));
111 TF_RETURN_IF_ERROR(
112 ProfileStrategyFromName(profile_strategy, &cp.profile_strategy));
113 TF_RETURN_IF_ERROR(GetNodeAttr(attr, "_tftrt_allow_build_at_runtime",
114 &cp.allow_build_at_runtime));
115 return Status::OK();
116 }
117 } // namespace
118
Init(const RewriterConfig_CustomGraphOptimizer * config)119 Status TRTOptimizationPass::Init(
120 const RewriterConfig_CustomGraphOptimizer* config) {
121 if (config == nullptr) {
122 return Status::OK();
123 }
124 const auto params = config->parameter_map();
125 if (params.count("minimum_segment_size")) {
126 params_.minimum_segment_size = params.at("minimum_segment_size").i();
127 }
128 if (params.count("max_batch_size")) {
129 params_.max_batch_size = params.at("max_batch_size").i();
130 }
131 if (params.count("is_dynamic_op")) {
132 params_.is_dynamic_op = params.at("is_dynamic_op").b();
133 }
134 if (params.count("maximum_cached_engines")) {
135 params_.max_cached_engines = params.at("maximum_cached_engines").i();
136 }
137 if (params.count("max_workspace_size_bytes")) {
138 params_.max_workspace_size_bytes =
139 params.at("max_workspace_size_bytes").i();
140 }
141 if (params.count("precision_mode")) {
142 TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
143 AsciiStrToUpper(params.at("precision_mode").s()),
144 ¶ms_.precision_mode));
145 }
146 if (params.count("use_calibration")) {
147 params_.use_calibration = params.at("use_calibration").b();
148 }
149 if (params.count("trt_logger")) {
150 params_.trt_logger_name = params.at("trt_logger").s();
151 }
152 if (params.count("allow_build_at_runtime")) {
153 params_.allow_build_at_runtime = params.at("allow_build_at_runtime").b();
154 }
155 if (params.count("use_implicit_batch")) {
156 params_.use_implicit_batch = params.at("use_implicit_batch").b();
157 }
158 if (params.count("profile_strategy")) {
159 TF_RETURN_IF_ERROR(ProfileStrategyFromName(
160 params.at("profile_strategy").s(), ¶ms_.profile_strategy));
161 }
162 return Status::OK();
163 }
164
ExplicitPrecisionModePolicy()165 static bool ExplicitPrecisionModePolicy() {
166 return IS_TRT_VERSION_GE(8, 0, 0, 0);
167 }
168
Optimize(grappler::Cluster * cluster,const grappler::GrapplerItem & item,GraphDef * optimized_graph)169 Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
170 const grappler::GrapplerItem& item,
171 GraphDef* optimized_graph) {
172 VLOG(1) << "Called TRTOptimization Pass " << name_
173 << " on a grappler item with id=" << item.id;
174 TF_ASSIGN_OR_RETURN(bool do_function_conversion, ShouldConvertFunction(item));
175 // Optimizing the main graph(identified with `item.id == "tf_graph"`) with
176 // `minimim_segment_size == -1` indicates skipping main graph conversion.
177 if ((params_.minimum_segment_size == -1 && item.id == "tf_graph") ||
178 (item.id != "tf_graph" && !do_function_conversion)) {
179 VLOG(1) << "Not optimizing this grappler item: " << item.id;
180 *optimized_graph = item.graph;
181 return Status::OK();
182 }
183
184 if (params_.use_calibration &&
185 params_.precision_mode != TrtPrecisionMode::INT8) {
186 LOG(WARNING) << "Calibration with FP32 or FP16 is not implemented. "
187 << "Falling back to use_calibration = False."
188 << "Note that the default value of use_calibration is True.";
189 params_.use_calibration = false;
190 }
191
192 params_.use_explicit_precision = ShouldUseExplicitPrecision(item.graph);
193 if (params_.use_explicit_precision) {
194 LOG(INFO) << "[TF-TRT] Using explicit QDQ mode";
195 if (params_.precision_mode != TrtPrecisionMode::INT8 ||
196 params_.use_calibration) {
197 LOG(WARNING)
198 << "Explicit precision mode with calibration or FP32/FP16 mode is "
199 "not supported."
200 << " Setting precision mode to INT8 and calibration to false.";
201 params_.precision_mode = TrtPrecisionMode::INT8;
202 params_.use_calibration = false;
203 }
204 }
205
206 // Create a copy of the graph to optimize.
207 grappler::GrapplerItem optimized_item(item);
208
209 std::vector<string> nodes_to_preserve;
210 const auto& old_nodes_to_preserve = item.NodesToPreserve();
211 nodes_to_preserve.reserve(old_nodes_to_preserve.size());
212 for (const auto& n : old_nodes_to_preserve) {
213 auto tokens = str_util::Split(n, ":");
214 string s = tokens.at(0);
215 for (int i = 1; i < tokens.size() - 1; ++i) {
216 StrAppend(&s, ":", tokens.at(i));
217 }
218 int dumm_port = -1;
219 // If the last token is not an integer, it must be part of the name.
220 // Otherwise it is port number.
221 if (tokens.size() > 1 &&
222 !strings::safe_strto32(tokens.back(), &dumm_port)) { // non-absl ok
223 StrAppend(&s, ":", tokens.back());
224 }
225 nodes_to_preserve.push_back(s);
226 }
227
228 if (item.id != "tf_graph" && do_function_conversion) {
229 const grappler::GrapplerFunctionItem& func_item =
230 tensorflow::down_cast<const grappler::GrapplerFunctionItem&>(item);
231 TF_RETURN_IF_ERROR(
232 UpdateFunctionSpecificConversionParams(params_, func_item.func_attr()));
233 }
234
235 return ConvertGraph(params_, optimized_item, nodes_to_preserve, cluster,
236 optimized_graph);
237 }
238
239 static grappler::CustomGraphOptimizerRegistrar TRTOptimizationPass_Registrar(
__anon539e0a290402() 240 []() {
241 VLOG(1)
242 << "Instantiating CustomOptimizationPass object TensorRTOptimizer";
243 return new TRTOptimizationPass("TensorRTOptimizer");
244 },
245 ("TensorRTOptimizer"));
246
247 } // namespace convert
248 } // namespace tensorrt
249 } // namespace tensorflow
250
251 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
252