1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5
6 http://www.apache.org/licenses/LICENSE-2.0
7
8 Unless required by applicable law or agreed to in writing, software
9 distributed under the License is distributed on an "AS IS" BASIS,
10 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 See the License for the specific language governing permissions and
12 limitations under the License.
13 ==============================================================================*/
14
15 #include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h"
16
17 #include "absl/strings/ascii.h"
18 #include "absl/strings/escaping.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
22 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
23 #include "tensorflow/core/grappler/clusters/cluster.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
26 #include "tensorflow/core/grappler/utils/functions.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/casts.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/stacktrace.h"
33
34 #if GOOGLE_CUDA && GOOGLE_TENSORRT
35 namespace tensorflow {
36 namespace tensorrt {
37 namespace convert {
38 // TODO(sami): Remove VLOG messages once the code matures
39 using absl::AsciiStrToUpper;
40 using absl::StrAppend;
41 using absl::StrCat;
42
43 namespace {
44
ValidateValueCase(const AttrValue * attr_value,const AttrValue::ValueCase & value_case)45 Status ValidateValueCase(const AttrValue* attr_value,
46 const AttrValue::ValueCase& value_case) {
47 if (attr_value->value_case() != value_case) {
48 return errors::InvalidArgument("AttrValue had value with type '",
49 attr_value->value_case(), "' when '",
50 value_case, "' was expected.");
51 }
52 return Status::OK();
53 }
54
55 template <typename T>
GetAttrBoolValue(T * value,const AttrValue * attr_value)56 Status GetAttrBoolValue(T* value, const AttrValue* attr_value) {
57 *value = static_cast<T>(attr_value->b());
58 return Status::OK();
59 }
60
61 template <typename T>
GetAttrIntValue(T * value,const AttrValue * attr_value)62 Status GetAttrIntValue(T* value, const AttrValue* attr_value) {
63 *value = static_cast<T>(attr_value->i());
64 return Status::OK();
65 }
66
67 template <typename T>
GetAttrStringValue(T * value,const AttrValue * attr_value)68 Status GetAttrStringValue(T* value, const AttrValue* attr_value) {
69 *value = attr_value->s();
70 return Status::OK();
71 }
72
73 template <typename T>
GetAttrTrtPrecisionModeValue(T * value,const AttrValue * attr_value)74 Status GetAttrTrtPrecisionModeValue(T* value, const AttrValue* attr_value) {
75 return TrtPrecisionModeFromName(attr_value->s(), value);
76 }
77
78 template <typename T>
GetAttrProfileStrategyValue(T * value,const AttrValue * attr_value)79 Status GetAttrProfileStrategyValue(T* value, const AttrValue* attr_value) {
80 return ProfileStrategyFromName(attr_value->s(), value);
81 }
82
83 template <AttrValue::ValueCase value_case, typename T, typename F>
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,T * value,F value_getter)84 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
85 const std::string& attr_name, T* value, F value_getter) {
86 const AttrValue* attr_value = attr_slice.FindByString(attr_name);
87 if (attr_value != nullptr) {
88 TF_RETURN_IF_ERROR(ValidateValueCase(attr_value, value_case));
89 TF_RETURN_IF_ERROR(value_getter(value, attr_value));
90 VLOG(1) << "Updated cp." << attr_name.substr(7) << ".";
91 }
92 return Status::OK();
93 }
94
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,std::string * value)95 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
96 const std::string& attr_name, std::string* value) {
97 return GetAttrValue<AttrValue::ValueCase::kS, std::string,
98 Status(std::string*, const AttrValue*)>(
99 attr_slice, attr_name, value, GetAttrStringValue);
100 }
101
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,size_t * value)102 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
103 const std::string& attr_name, size_t* value) {
104 return GetAttrValue<AttrValue::ValueCase::kI, size_t,
105 Status(size_t*, const AttrValue*)>(
106 attr_slice, attr_name, value, GetAttrIntValue);
107 }
108
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,int * value)109 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
110 const std::string& attr_name, int* value) {
111 return GetAttrValue<AttrValue::ValueCase::kI, int,
112 Status(int*, const AttrValue*)>(attr_slice, attr_name,
113 value, GetAttrIntValue);
114 }
115
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,TrtPrecisionMode * value)116 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
117 const std::string& attr_name, TrtPrecisionMode* value) {
118 return GetAttrValue<AttrValue::ValueCase::kS, TrtPrecisionMode,
119 Status(TrtPrecisionMode*, const AttrValue*)>(
120 attr_slice, attr_name, value, GetAttrTrtPrecisionModeValue);
121 }
122
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,int64_t * value)123 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
124 const std::string& attr_name, int64_t* value) {
125 return GetAttrValue<AttrValue::ValueCase::kI, int64_t,
126 Status(int64_t*, const AttrValue*)>(
127 attr_slice, attr_name, value, GetAttrIntValue);
128 }
129
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,bool * value)130 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
131 const std::string& attr_name, bool* value) {
132 return GetAttrValue<AttrValue::ValueCase::kB, bool,
133 Status(bool*, const AttrValue*)>(attr_slice, attr_name,
134 value, GetAttrBoolValue);
135 }
136
GetAttrValue(const tensorflow::AttrSlice & attr_slice,const std::string & attr_name,ProfileStrategy * value)137 Status GetAttrValue(const tensorflow::AttrSlice& attr_slice,
138 const std::string& attr_name, ProfileStrategy* value) {
139 return GetAttrValue<AttrValue::ValueCase::kS, ProfileStrategy,
140 Status(ProfileStrategy*, const AttrValue*)>(
141 attr_slice, attr_name, value, GetAttrProfileStrategyValue);
142 }
143
ShouldConvertFunction(const grappler::GrapplerItem & item)144 StatusOr<bool> ShouldConvertFunction(const grappler::GrapplerItem& item) {
145 if (item.id == "tf_graph") {
146 return false;
147 }
148 const grappler::GrapplerFunctionItem& func_item =
149 tensorflow::down_cast<const grappler::GrapplerFunctionItem&>(item);
150 const tensorflow::AttrSlice& attr = func_item.func_attr();
151 const AttrValue* attr_value = attr.FindByString("_tftrt_convert_function");
152 if (attr_value != nullptr) {
153 TF_RETURN_IF_ERROR(ValidateValueCase(attr_value, AttrValue::ValueCase::kB));
154 return attr_value->b();
155 }
156 VLOG(1) << "Attribute _tftrt_convert_function was not found.";
157 return false;
158 }
159
CheckForFunctionConversionAttribute(const tensorflow::AttrSlice & attr)160 StatusOr<bool> CheckForFunctionConversionAttribute(
161 const tensorflow::AttrSlice& attr) {
162 const AttrValue* attr_value = attr.FindByString("_tftrt_convert_function");
163 if (attr_value != nullptr) {
164 TF_RETURN_IF_ERROR(ValidateValueCase(attr_value, AttrValue::ValueCase::kB));
165 return attr_value->b();
166 } else {
167 VLOG(1) << "Attribute _tftrt_convert_function was not found.";
168 }
169 return false;
170 }
171
172 // Converts function conversion attributes to conversion parameters.
UpdateFunctionSpecificConversionParams(ConversionParams & cp,const tensorflow::AttrSlice & attr)173 Status UpdateFunctionSpecificConversionParams(
174 ConversionParams& cp, const tensorflow::AttrSlice& attr) {
175 TF_RETURN_IF_ERROR(
176 GetAttrValue(attr, "_tftrt_trt_logger_name", &cp.trt_logger_name));
177 TF_RETURN_IF_ERROR(
178 GetAttrValue(attr, "_tftrt_max_batch_size", &cp.max_batch_size));
179 TF_RETURN_IF_ERROR(GetAttrValue(attr, "_tftrt_max_workspace_size_bytes",
180 &cp.max_workspace_size_bytes));
181 TF_RETURN_IF_ERROR(
182 GetAttrValue(attr, "_tftrt_precision_mode", &cp.precision_mode));
183 TF_RETURN_IF_ERROR(GetAttrValue(attr, "_tftrt_minimum_segment_size",
184 &cp.minimum_segment_size));
185 TF_RETURN_IF_ERROR(GetAttrValue(attr, "_tftrt_is_dyn_op", &cp.is_dyn_op));
186 TF_RETURN_IF_ERROR(
187 GetAttrValue(attr, "_tftrt_max_cached_engines", &cp.max_cached_engines));
188 TF_RETURN_IF_ERROR(
189 GetAttrValue(attr, "_tftrt_use_calibration", &cp.use_calibration));
190 TF_RETURN_IF_ERROR(
191 GetAttrValue(attr, "_tftrt_use_implicit_batch", &cp.use_implicit_batch));
192 TF_RETURN_IF_ERROR(
193 GetAttrValue(attr, "_tftrt_profile_strategy", &cp.profile_strategy));
194 TF_RETURN_IF_ERROR(GetAttrValue(attr, "_tftrt_allow_build_at_runtime",
195 &cp.allow_build_at_runtime));
196 return Status::OK();
197 }
198
199 } // namespace
200
Init(const RewriterConfig_CustomGraphOptimizer * config)201 Status TRTOptimizationPass::Init(
202 const RewriterConfig_CustomGraphOptimizer* config) {
203 VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
204 if (config == nullptr) {
205 return Status::OK();
206 }
207 VLOG(1) << "config = " << config->DebugString();
208 const auto params = config->parameter_map();
209 if (params.count("minimum_segment_size")) {
210 minimum_segment_size_ = params.at("minimum_segment_size").i();
211 }
212 if (params.count("max_batch_size")) {
213 maximum_batch_size_ = params.at("max_batch_size").i();
214 }
215 if (params.count("is_dynamic_op")) {
216 is_dynamic_op_ = params.at("is_dynamic_op").b();
217 }
218 if (params.count("maximum_cached_engines")) {
219 max_cached_batches_ = params.at("maximum_cached_engines").i();
220 }
221 if (params.count("max_workspace_size_bytes")) {
222 max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i();
223 }
224 if (params.count("precision_mode")) {
225 TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
226 AsciiStrToUpper(params.at("precision_mode").s()), &precision_mode_));
227 }
228 if (params.count("use_calibration")) {
229 use_calibration_ = params.at("use_calibration").b();
230 }
231 if (params.count("trt_logger")) {
232 trt_logger_name_ = params.at("trt_logger").s();
233 }
234 if (params.count("allow_build_at_runtime")) {
235 allow_build_at_runtime_ = params.at("allow_build_at_runtime").b();
236 }
237 if (params.count("use_implicit_batch")) {
238 use_implicit_batch_ = params.at("use_implicit_batch").b();
239 }
240 if (params.count("profile_strategy")) {
241 TF_RETURN_IF_ERROR(ProfileStrategyFromName(
242 params.at("profile_strategy").s(), &profile_strategy_));
243 }
244 return Status::OK();
245 }
246
PrintDebugInfo(grappler::Cluster * cluster,const grappler::GrapplerItem & item)247 void TRTOptimizationPass::PrintDebugInfo(grappler::Cluster* cluster,
248 const grappler::GrapplerItem& item) {
249 LOG(INFO) << "Cluster = " << cluster;
250 string offset(" ");
251 string offset2 = StrCat(offset, offset);
252 string offset3 = StrCat(offset2, offset);
253 string offset4 = StrCat(offset2, offset2);
254
255 if (cluster) {
256 LOG(INFO) << offset << "type = " << cluster->type();
257 LOG(INFO) << offset << "num warmup steps = " << cluster->NumWarmupSteps();
258 const auto dev_names = cluster->GetDeviceNames();
259 if (!dev_names.empty()) {
260 LOG(INFO) << offset << " Device names:";
261 for (const auto& s : dev_names) {
262 LOG(INFO) << offset2 << s;
263 }
264 }
265 std::unordered_map<string, uint64> peak_mem;
266 auto status = cluster->GetPeakMemoryUsage(&peak_mem);
267 if (status == Status::OK()) {
268 LOG(INFO) << offset << "Peak Memory Usage :";
269 for (const auto& s : peak_mem) {
270 LOG(INFO) << offset2 << s.first << " = " << s.second;
271 }
272 }
273
274 const auto dev_props = cluster->GetDevices();
275 if (!dev_props.empty()) {
276 LOG(INFO) << offset << "Device properties:";
277 for (const auto& k : dev_props) {
278 LOG(INFO) << offset2 << k.first;
279 const auto& dt = k.second;
280 LOG(INFO) << offset3 << "type = " << dt.type();
281 LOG(INFO) << offset3 << "vendor = " << dt.vendor();
282 LOG(INFO) << offset3 << "model = " << dt.model();
283 LOG(INFO) << offset3 << "frequency = " << dt.frequency();
284 LOG(INFO) << offset3 << "num cores = " << dt.num_cores();
285 LOG(INFO) << offset3 << "num registers = " << dt.num_registers();
286 LOG(INFO) << offset3 << "L1 cache size = " << dt.l1_cache_size();
287 LOG(INFO) << offset3 << "L2 cache size = " << dt.l2_cache_size();
288 LOG(INFO) << offset3 << "L3 cache size = " << dt.l3_cache_size();
289 LOG(INFO) << offset3 << "SHMem per SMP = "
290 << dt.shared_memory_size_per_multiprocessor();
291 LOG(INFO) << offset3 << "memory size = " << dt.memory_size();
292 LOG(INFO) << offset3 << "bandwidth = " << dt.bandwidth();
293 if (dt.environment_size()) {
294 LOG(INFO) << offset3 << "environment :";
295 for (const auto& e : dt.environment()) {
296 LOG(INFO) << offset4 << e.first << " = " << e.second;
297 }
298 }
299 }
300 }
301
302 if (cluster->GetDeviceSet()) {
303 for (const auto dev : cluster->GetDeviceSet()->devices()) {
304 LOG(INFO) << "Device name= " << dev->name() << "Pased name= "
305 << DeviceNameUtils::ParsedNameToString(dev->parsed_name());
306 }
307 }
308 }
309
310 LOG(INFO) << "item: " << item.id;
311 if (!item.feed.empty()) {
312 LOG(INFO) << offset << "Feeds :";
313 for (const auto& f : item.feed) {
314 const auto& shape = f.second.shape();
315 LOG(INFO) << offset2 << f.first << " = shaped " << shape.DebugString();
316 }
317 } else {
318 LOG(INFO) << offset << "No Feeds";
319 }
320 if (!item.fetch.empty()) {
321 LOG(INFO) << offset << "Fetches :";
322 for (const auto& f : item.fetch) {
323 LOG(INFO) << offset2 << f;
324 }
325 } else {
326 LOG(INFO) << offset << "No Fetches";
327 }
328
329 if (!item.init_ops.empty()) {
330 LOG(INFO) << offset << "init ops :";
331 for (const auto& f : item.init_ops) {
332 LOG(INFO) << offset2 << f;
333 }
334 } else {
335 LOG(INFO) << offset << "No init ops";
336 }
337 LOG(INFO) << "Save Op = " << item.save_op;
338 LOG(INFO) << "Restore Op = " << item.restore_op;
339 LOG(INFO) << "save_restore_loc_tensor = " << item.save_restore_loc_tensor;
340 if (!item.keep_ops.empty()) {
341 LOG(INFO) << offset << "keep ops :";
342 for (const auto& f : item.keep_ops) {
343 LOG(INFO) << offset2 << f;
344 }
345 } else {
346 LOG(INFO) << offset << "No keep ops";
347 }
348 }
349
Optimize(grappler::Cluster * cluster,const grappler::GrapplerItem & item,GraphDef * optimized_graph)350 Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
351 const grappler::GrapplerItem& item,
352 GraphDef* optimized_graph) {
353 VLOG(1) << "Called TRTOptimization Pass " << name_
354 << " on a grappler item with id=" << item.id;
355 TF_ASSIGN_OR_RETURN(bool do_function_conversion, ShouldConvertFunction(item));
356 if (minimum_segment_size_ == -1 ||
357 (item.id != "tf_graph" && !do_function_conversion)) {
358 VLOG(1) << "Not optimizing this grappler item: " << item.id;
359 *optimized_graph = item.graph;
360 return Status::OK();
361 }
362 if (VLOG_IS_ON(3)) {
363 LOG(INFO) << CurrentStackTrace();
364 PrintDebugInfo(cluster, item);
365 }
366
367 if (use_calibration_ && precision_mode_ != TrtPrecisionMode::INT8) {
368 VLOG(1) << "Calibration with FP32 or FP16 is not implemented. "
369 << "Falling back to use_calibration = False."
370 << "Note that the default value of use_calibration is True.";
371 use_calibration_ = false;
372 }
373
374 std::vector<string> nodes_to_preserve;
375 for (const auto& n : item.NodesToPreserve()) {
376 auto tokens = str_util::Split(n, ":");
377 string s = tokens.at(0);
378 for (int i = 1; i < tokens.size() - 1; ++i) {
379 StrAppend(&s, ":", tokens.at(i));
380 }
381 int dumm_port = -1;
382 // If the last token is not an integer, it must be part of the name.
383 // Otherwise it is port number.
384 if (tokens.size() > 1 &&
385 !strings::safe_strto32(tokens.back(), &dumm_port)) { // non-absl ok
386 StrAppend(&s, ":", tokens.back());
387 }
388 nodes_to_preserve.push_back(s);
389 }
390
391 ConversionParams cp;
392 cp.grappler_item = &item;
393 cp.output_names = &nodes_to_preserve;
394 cp.trt_logger_name = trt_logger_name_;
395 cp.max_batch_size = maximum_batch_size_;
396 cp.max_workspace_size_bytes = max_workspace_size_bytes_;
397 cp.output_graph_def = optimized_graph;
398 cp.precision_mode = precision_mode_;
399 cp.minimum_segment_size = minimum_segment_size_;
400 cp.cluster = cluster;
401 cp.is_dyn_op = is_dynamic_op_;
402 cp.max_cached_engines = max_cached_batches_;
403 cp.use_calibration = use_calibration_;
404 cp.use_implicit_batch = use_implicit_batch_;
405 cp.profile_strategy = profile_strategy_;
406 cp.allow_build_at_runtime = allow_build_at_runtime_;
407
408 if (item.id != "tf_graph" && do_function_conversion) {
409 const grappler::GrapplerFunctionItem& func_item =
410 tensorflow::down_cast<const grappler::GrapplerFunctionItem&>(item);
411 TF_RETURN_IF_ERROR(
412 UpdateFunctionSpecificConversionParams(cp, func_item.func_attr()));
413 }
414
415 auto status = ConvertAfterShapes(cp);
416 VLOG(1) << "Returning from " << name_;
417 return status;
418 }
419
420 class VerboseCustomGraphOptimizerRegistrar
421 : public grappler::CustomGraphOptimizerRegistrar {
422 public:
VerboseCustomGraphOptimizerRegistrar(const grappler::CustomGraphOptimizerRegistry::Creator & cr,const string & name)423 VerboseCustomGraphOptimizerRegistrar(
424 const grappler::CustomGraphOptimizerRegistry::Creator& cr,
425 const string& name)
426 : grappler::CustomGraphOptimizerRegistrar(cr, name) {
427 VLOG(1) << "Constructing a CustomOptimizationPass registration object for "
428 << name;
429 }
430 };
431
432 static VerboseCustomGraphOptimizerRegistrar TRTOptimizationPass_Registrar(
__anon9beddcc10202() 433 []() {
434 VLOG(1)
435 << "Instantiating CustomOptimizationPass object TensorRTOptimizer";
436 return new TRTOptimizationPass("TensorRTOptimizer");
437 },
438 ("TensorRTOptimizer"));
439
440 } // namespace convert
441 } // namespace tensorrt
442 } // namespace tensorflow
443
444 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
445