1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ 18 19 #include <set> 20 #include <string> 21 #include <unordered_map> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" 26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" 27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" 28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" 29 #include "tensorflow/core/framework/graph.pb.h" 30 #include "tensorflow/core/graph/graph.h" 31 #include "tensorflow/core/grappler/costs/graph_properties.h" 32 #include "tensorflow/core/lib/core/status.h" 33 #include "tensorflow/stream_executor/lib/statusor.h" 34 35 #if GOOGLE_CUDA 36 #if GOOGLE_TENSORRT 37 #include "third_party/tensorrt/NvInfer.h" 38 39 namespace tensorflow { 40 namespace tensorrt { 41 42 namespace convert { 43 using ::stream_executor::port::StatusOr; 44 45 struct EngineConnection { 46 // Constructs a non-control edge. EngineConnectionEngineConnection47 EngineConnection(const string& outside, int out_id, int out_port, 48 const string& inside, int in_id, int in_port, 49 bool input_edge, int port) 50 : outside_node_name(outside), 51 outside_id(out_id), 52 outside_port(out_port), 53 inside_node_name(inside), 54 inside_id(in_id), 55 inside_port(in_port), 56 is_input_edge(input_edge), 57 port_number(port) {} 58 59 // Constructs a control edge. EngineConnectionEngineConnection60 EngineConnection(const string& outside, int out_id, const string& inside, 61 int in_id, bool input_edge) 62 : outside_node_name(outside), 63 outside_id(out_id), 64 outside_port(Graph::kControlSlot), 65 inside_node_name(inside), 66 inside_id(in_id), 67 inside_port(Graph::kControlSlot), 68 is_input_edge(input_edge), 69 port_number(Graph::kControlSlot) {} 70 is_control_edgeEngineConnection71 bool is_control_edge() const { return port_number == Graph::kControlSlot; } 72 73 const string outside_node_name; 74 const int outside_id; 75 const int outside_port; 76 PartialTensorShape outside_shape; // Only set for input edge. 77 78 const string inside_node_name; 79 const int inside_id; 80 const int inside_port; 81 PartialTensorShape inside_shape; // Only set for output edge. 82 83 DataType connection_type; 84 const bool is_input_edge; 85 86 // The port number of the TRT node connected with this edge. 87 const int port_number; 88 }; 89 90 struct EngineInfo { EngineInfoEngineInfo91 EngineInfo() 92 : engine_type(EngineType::TRTStatic), 93 max_workspace_size_bytes(0), 94 precision_mode(TrtPrecisionMode::FP32), 95 use_calibration(true) {} 96 97 string engine_name; 98 string device; 99 GraphDef segment_graph_def; 100 101 // Non-control input connections inside this vector are sorted in a way such 102 // that, the segment nodes connecting to them are topological sorted. 103 // In addition, for non-control connections, there must be no duplicates. 104 std::vector<EngineConnection> connections; 105 106 enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; 107 EngineType engine_type; 108 int64 max_workspace_size_bytes; 109 int maximum_cached_engines; 110 TrtPrecisionMode precision_mode; 111 bool use_calibration; 112 }; 113 114 // Constructs a graphdef from the segment in the given graph. Adds _Arg 115 // nodes for input edges (InputPH_*) and _Retval nodes for output edges 116 // (OutputPH_*). This function needs to be called before TensorRT nodes 117 // inserted in order to correctly get sizes from the original graph. 118 // 119 // - subgraph_node_names: the node names of the subgraph. 120 // - subgraph_node_ids: the node ids of the subgraph, must be sorted in 121 // topological order. 122 // - segment_def: the output GraphDef, whose non-input/output nodedefs will be 123 // sorted in topological order. 124 // - scope_name: the name of the scope where the TRTEngineOp will be placed. 125 // 126 // TODO(aaroey): add tests to validate these properties. 127 Status ConvertSegmentToGraphDef( 128 const Graph* graph, const grappler::GraphProperties& graph_properties, 129 const std::vector<const Node*>& subgraph_nodes, 130 std::vector<EngineConnection>* connections, GraphDef* segment_def, 131 string* scope_name); 132 133 // Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff 134 // 'builder' successfully build the engine. If the result is not ok, 'engine' 135 // will be set to nullptr 136 // Once returned, 'builder' is not needed any more and can be safely destroyed. 137 // 138 // - convert_successfully: indicates whether the conversion to TensorRT network 139 // is successful. This is different than successfully building the engine: 140 // building can still fail afterwards. 141 Status ConvertGraphDefToEngine( 142 const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size, 143 size_t max_workspace_size_bytes, 144 const std::vector<PartialTensorShape>& input_shapes, 145 nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator, 146 TRTInt8Calibrator* calibrator, 147 TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration, 148 const bool use_implicit_batch, bool* convert_successfully); 149 150 // Helper class for the segmenter to determine whether an output edge from the 151 // TRT segment is valid. 152 class OutputEdgeValidator { 153 public: 154 // Return true if the specified edge is eligible to be an output edge of the 155 // TRT segment. 156 bool operator()(const Edge* out_edge) const; 157 }; 158 159 int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims); 160 int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims); 161 162 // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight. 163 class TRT_ShapedWeights { 164 public: 165 explicit TRT_ShapedWeights( 166 nvinfer1::DataType type = nvinfer1::DataType::kFLOAT); 167 168 // Copy from another weights. 169 // 170 // NOTE: this does not copy the underlying buffer but only increase its 171 // reference count. 172 TRT_ShapedWeights(const TRT_ShapedWeights& rhs); 173 174 nvinfer1::Weights GetTrtWeights() const; 175 GetTensor()176 const Tensor& GetTensor() const { return tensor_; } 177 178 // Returns the raw pointer to the underlying buffer which holds the weights 179 // value. GetValues()180 void* GetValues() const { 181 return const_cast<char*>(tensor_.tensor_data().data()); 182 } 183 184 int64_t count() const; 185 186 size_t size_bytes() const; 187 188 string DebugString() const; 189 190 template <typename T> GetSpan()191 absl::Span<const T> GetSpan() const { 192 return absl::Span<const T>(tensor_.flat<T>().data(), count()); 193 } 194 195 template <typename T> ToVector()196 std::vector<T> ToVector() const { 197 auto span = GetSpan<T>(); 198 return std::vector<T>(span.data(), span.data() + span.size()); 199 } 200 TrtDType()201 nvinfer1::DataType TrtDType() const { return type_; } 202 203 // TODO(aaroey): make these private. 204 nvinfer1::Dims shape_; // Note: shape.type[] is not used. 205 206 private: 207 // This constructor is only used by TrtWeightStore, which creates the 208 // underlying buffer. 209 TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims, 210 Tensor tensor); 211 212 nvinfer1::DataType type_; 213 214 // All weights should be stored inside TrtWeightStore to make sure lifetime of 215 // all the underlying tensors are available until the engine is built. For 216 // this reason, tensor_ should never be reassigned to a different value that 217 // is not already present in the TrtWeightStore. 218 Tensor tensor_; 219 220 friend class TrtWeightStore; 221 }; 222 223 // Container for TRT_ShapedWeights. We need this container because, TRT doesn't 224 // manage the lifetime of the weights buffer, it only keeps a pointer to it and 225 // requires that the data referenced by the pointer be available until the 226 // building of engine is complete. For more information see 227 // https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_weights.html 228 // 229 // TODO(laigd): consider adding garbage collection to the unused weights. 230 class TrtWeightStore { 231 public: 232 // Get a TRT_ShapedWeights with 'type' and 'dims'. 233 TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type, 234 const nvinfer1::Dims& dims); 235 236 // Get a TRT_ShapedWeights with the same data type and dimensions as 237 // 'weights'. GetTempWeights(const TRT_ShapedWeights & weights)238 TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) { 239 return GetTempWeights(weights.TrtDType(), weights.shape_); 240 } 241 242 private: 243 // The backend storage of the TRT_ShapedWeights. 244 std::vector<Tensor> store_; 245 }; 246 247 // Represents a TRT-style input to a TF node, it can be either a 248 // nvinfer1::ITensor, or TRT_ShapedWeights which is compile-time constant. 249 // 250 // TODO(laigd): maybe rename it to TrtArgument, or mimic XlaCompiler::Argument. 251 class TRT_TensorOrWeights { 252 public: TRT_TensorOrWeights()253 TRT_TensorOrWeights() {} 254 255 // Constructor that makes it an ITensor, doesn't take ownership of 'tensor'. 256 // This is used by Converter when building the TRT network, where the ITensor 257 // is owned by the TRT network being built. See comment for 'tensor_' below. 258 explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1); 259 260 // Constructor that makes it an ITensor by creating one using provided data 261 // type and shape, and takes ownership of the created ITensor. This is used by 262 // TrtNodeValidator to encapsulate the type and shape information for 263 // validation of graph nodes, and the created ITensor is fake and temporary, 264 // and should not be used to build any TRT network. See comment for 265 // 'simple_itensor_' below. 266 explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, 267 const nvinfer1::Dims& trt_dims, int batch_size); 268 269 // Constructor that makes it a TRT_TensorOrWeights. 270 explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights); 271 272 TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs); 273 274 void operator=(const TRT_TensorOrWeights& rhs); 275 is_tensor()276 bool is_tensor() const { return initialized_ && is_tensor_; } is_weights()277 bool is_weights() const { return initialized_ && !is_tensor_; } 278 279 nvinfer1::ITensor* tensor() const; 280 weights()281 TRT_ShapedWeights& weights() { 282 CHECK(is_weights()); 283 return weights_; 284 } 285 weights()286 const TRT_ShapedWeights& weights() const { 287 CHECK(is_weights()); 288 return weights_; 289 } 290 291 nvinfer1::Dims GetTrtDims() const; 292 batch_size()293 int batch_size() const { return batch_size_; } 294 295 string DebugString() const; 296 297 private: 298 class SimpleITensor; 299 set_batch_size(int batch_size)300 void set_batch_size(int batch_size) { batch_size_ = batch_size; } 301 302 // When it represents an ITensor, the ITensor can be either passed by the 303 // caller via the constructor that takes an ITensor* as parameter, or be 304 // created as a SimpleITensor. 305 // 306 // In the first case, the ITensor pointer is stored in 'tensor_' below, and 307 // the ITensor itself is not owned by this class. This method is used by 308 // Converter (e.g. AddInputTensor) and op converters during TRT network 309 // construction, where the TRT network owns the ITensor. 310 // 311 // In the second case, the created SimpleITensor is stored in 312 // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake 313 // implementation of ITensor and is used only by TrtNodeValidator to validate 314 // the graph nodes. 315 nvinfer1::ITensor* tensor_ = nullptr; // Not owned. 316 std::shared_ptr<SimpleITensor> simple_itensor_ = nullptr; 317 318 // First dimension of the TF tensor (NOT tensor_) that is represented by 319 // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s 320 // dimensions (obtained via tensor_->getDimensions()) do not contain the batch 321 // dimension. For example, when a TF tensor with shape (A,B,C) is represented 322 // in TRT, tensor_->getDimensions() will be (B,C) and batch_size_ will be A. 323 // 324 // This requires that all tensors in the subgraph that is converted to a TRT 325 // engine have the same batch size are represented by the first dimension of 326 // their shape, and Converter will verify this during conversion. The drawback 327 // is that currently it cannot convert a graph that doesn't have the batch 328 // size represented in the shapes or the batch sizes are different. See 329 // b/118387490 for more details. 330 // 331 // If use_implicit_batch is false, batch_size_ is unused and 332 // tensor_->getDimensions() will contain the entire shape (A,B,C). 333 int batch_size_ = -1; 334 335 TRT_ShapedWeights weights_; 336 bool initialized_ = false; 337 bool is_tensor_ = false; 338 339 friend class Converter; 340 }; 341 342 class Converter; 343 344 // Parameters for each op converter. 345 struct OpConverterParams { 346 // Constructor used for validation only. 347 OpConverterParams(const NodeDef& node_def, 348 const std::vector<TRT_TensorOrWeights>& inputs, 349 std::vector<TRT_TensorOrWeights>* outputs, 350 TrtWeightStore* weight_store, 351 TrtPrecisionMode precision_mode, bool use_calibration, 352 bool use_implicit_batch); 353 354 // Constructor used for conversion. 355 OpConverterParams(Converter* converter, const NodeDef& node_def, 356 const std::vector<TRT_TensorOrWeights>& inputs, 357 std::vector<TRT_TensorOrWeights>* outputs, 358 TrtWeightStore* weight_store); 359 360 Converter* converter = nullptr; 361 const NodeDef& node_def; 362 const std::vector<TRT_TensorOrWeights>& inputs; 363 std::vector<TRT_TensorOrWeights>* outputs; 364 const bool validation_only; 365 TrtWeightStore* weight_store; 366 const TrtPrecisionMode precision_mode; 367 const bool use_calibration; 368 const bool use_implicit_batch; 369 }; 370 371 using OpConverter = std::function<Status(OpConverterParams*)>; 372 373 // Class to verify if specific TF node is supported by TRT. 374 class TrtNodeValidator { 375 public: 376 // 'graph_properties' is the GraphProperties of the graph whose nodes will be 377 // checked by IsTensorRTCandidate() later. It is used to get the shape and 378 // data type information of a tensor for validation purpose. 379 TrtNodeValidator(const grappler::GraphProperties& graph_properties, 380 TrtPrecisionMode precision_mode, bool use_calibration, 381 bool use_implicit_batch); 382 383 // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added 384 // to TRT subgraph and later converted into TRT engine. 385 Status IsTensorRTCandidate(const Node* node); 386 387 private: 388 static const std::set<string>* quantize_ops; 389 390 void RegisterOpValidators(); 391 392 // Convert a Const node to a TRT_TensorOrWeights. 393 Status ConvertConstToWeights(const NodeDef& const_node_def, 394 const std::vector<TRT_TensorOrWeights>& inputs, 395 TRT_TensorOrWeights* output); 396 397 // Convert the output tensor at 'output_port' of 'node_def' to a 398 // TRT_TensorOrWeights which will be later used as an input to other nodes and 399 // passed to ValidateNode() below. 400 Status ConvertToTensorOrWeights(const NodeDef& node_def, int output_port, 401 TRT_TensorOrWeights* tensor_or_weights); 402 403 // Stores all the validators by op type. If no validator is registered for 404 // specific op, it means no validation is needed and ValidateNode() will 405 // return OK. 406 std::unordered_map<string, OpConverter> op_validators_; 407 408 // Store the weights added during validation. Some validations (e.g. 409 // validation for Const node) may produce weights. 410 TrtWeightStore weight_store_; 411 412 // GraphProperties of the graph whose nodes are to be validated by 413 // IsTensorRTCandidate(). 414 const grappler::GraphProperties& graph_properties_; 415 416 // Quantization ops are only converted when using quantized precisions. 417 const TrtPrecisionMode precision_mode_; 418 419 const bool use_calibration_; 420 421 const bool use_implicit_batch_; 422 423 friend class ValidatorTest; 424 friend class OpConverterTest; 425 }; 426 427 // Class to convert TF nodes to TRT network. 428 class Converter { 429 public: 430 // Used for Converter::RenameAndMarkOutputTensors() 431 struct EngineOutputInfo { 432 // The TRT tensor name which produces the output. 433 string source_tensor_name; 434 // The TensorFlow node name which is receiving the output from the TRT 435 // engine. This should always be the Identity node created in 436 // ConvertSegmentToGraphDef. 437 string dest_node_name; 438 // Output type. TensorRT requires this to be explicitly set for engine 439 // outputs. 440 nvinfer1::DataType trt_dtype; 441 }; 442 443 static StatusOr<std::unique_ptr<Converter>> Create( 444 TrtPrecisionMode precision_mode, bool use_calibration, 445 nvinfer1::ILogger* trt_logger, const bool use_implicit_batch); 446 447 ////////////////////////////////////////////////////////////////////////////// 448 // Methods used by the TRT engine builder to build a TRT network from a TF 449 // function/subgraph. 450 451 // Convert the node to TRT network. 452 Status ConvertNode(const NodeDef& node_def); 453 454 // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and 455 // 'batch_size'. 456 Status AddInputTensor(const string& name, nvinfer1::DataType dtype, 457 const nvinfer1::Dims& dims, int batch_size); 458 459 // Mark the tensors with names specified by source_tensor_name as output of 460 // the TRT network, and set their names in the TRT network as dest_node_name. 461 Status RenameAndMarkOutputTensors( 462 const std::vector<EngineOutputInfo>& output_tensors); 463 464 // Build a TRT engine using the created network. 465 Status BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, 466 int max_batch_size, size_t max_workspace_size_bytes, 467 nvinfer1::IGpuAllocator* allocator, 468 TRTInt8Calibrator* calibrator); 469 470 ////////////////////////////////////////////////////////////////////////////// 471 // Methods used by op converters to convert individual TF node and add layers 472 // to the TRT network. 473 474 // Op converters (e.g. ConvertReshape) need to access the TRT network in order 475 // to add TRT layers. network()476 nvinfer1::INetworkDefinition* network() { return trt_network_.get(); } 477 478 // What precision are we targeting? precision_mode()479 TrtPrecisionMode precision_mode() const { return precision_mode_; } 480 481 // Calibration will be or was previously performed on this network? use_calibration()482 bool use_calibration() const { return use_calibration_; } 483 484 // Whether implicit batch mode is enabled use_implicit_batch()485 bool use_implicit_batch() const { return use_implicit_batch_; } 486 487 // This should be called on the inputs and outputs of any layer we create 488 // where we know that the quantization range does not change during that 489 // operation. (e.g. Reshape, Transpose, Identity, MaxPool). 490 void MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input, 491 nvinfer1::ITensor* output); 492 493 // This function should be called when we know the quantization range of a 494 // tensor, either from a quantize/dequantize node or when the output is a 495 // fixed range (e.g. SoftMax, Relu6, Sigmoid). 496 void ProvideQuantizationRange(nvinfer1::ITensor* tensor, float min_range, 497 float max_range); 498 499 // Should be called when full TRT network has been constructed and before 500 // building the engine. 501 void MaybeApplyQuantizationRanges(); 502 503 // Below are helper methods for op converters to add different layers to the 504 // TRT network. 505 506 // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to 507 // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch 508 // dimension which should always be 0. 509 Status TransposeTensor(nvinfer1::ITensor* input_tensor, 510 const std::vector<int>& order_with_batch_dim, 511 absl::string_view name, 512 nvinfer1::ITensor** output_tensor); 513 514 // Converts 'input' into 'tensor' with shape specified by 'dims' (which 515 // doesn't contain the batch dimension). 516 // 517 // If validation_only is true, it doesn't do the conversion but only do some 518 // minimum validation for the eligibility of the conversion, and *tensor will 519 // be set to nullptr. 520 Status PrepareTensorForShape(const TRT_TensorOrWeights& input, 521 const nvinfer1::Dims& dims, 522 const bool validation_only, 523 nvinfer1::ITensor** tensor); 524 525 // Creates an IConstantLayer using 'weights' whose dimensions are specified by 526 // 'dims', and returns the output ITensor. 527 nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights, 528 const nvinfer1::Dims& dims); 529 530 private: 531 Converter(TrtPrecisionMode precision_mode, bool use_calibration, 532 nvinfer1::ILogger* trt_logger, const bool use_implicit_batch); 533 534 Status Init(nvinfer1::ILogger* trt_logger); 535 536 // Verify the provided batch_size is consistent with batch_size_ and update it 537 // if necessary. 538 Status MaybeUpdateBatchSize(int batch_size); 539 540 // Add the provided tensor/weights to the map trt_tensors_. 541 Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input); 542 543 // Get the tensor/weights from trt_tensors_ by 'name'. 544 Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output); 545 546 // Get the inputs of 'node_def' from trt_tensors_. 547 Status GetInputs(const NodeDef& node_def, 548 std::vector<TRT_TensorOrWeights>* inputs) const; 549 550 void RegisterOpConverters(); 551 552 void PropagateQuantizationRanges(); 553 554 // Gets the min and max value in a TRT_ShapedWeights 555 Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, 556 float* out_max) const; 557 558 // Registered op converters by op type. 559 std::unordered_map<string, OpConverter> op_registry_; 560 561 // Tensors/weights added during construction of trt_network_. 562 std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_; 563 564 // The TRT builder used to create the network and build the engine. Not owned. 565 TrtUniquePtrType<nvinfer1::IBuilder> trt_builder_; 566 567 // The TRT network being built. 568 TrtUniquePtrType<nvinfer1::INetworkDefinition> trt_network_; 569 570 // Store the weights added during construction of trt_network_. 571 TrtWeightStore weight_store_; 572 573 // During conversion, this table is populated with quantization ranges per 574 // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT 575 // quantization ranges. Since TRT only supports symmetric ranges, we will 576 // store the range as a single float = max(abs(min_range), abs(max_range)). 577 // Range refers to the floating point values, e.g. min_range = 0.0f, max_range 578 // = 6.0f for Relu6. 579 std::unordered_map<nvinfer1::ITensor*, float> quantization_ranges_; 580 581 // Edges where quantization ranges can be inferred (copied) across ops - from 582 // first tensor to second tensor. PropagateQuantizationRanges() will propagate 583 // known ranges from quantization_ranges_ across these edges, adding the new 584 // ranges to quantization_ranges_ so that they can be applied in 585 // MaybeApplyQuantizationRanges(). 586 std::vector<std::pair<nvinfer1::ITensor*, nvinfer1::ITensor*>> 587 quantization_infer_; 588 589 const TrtPrecisionMode precision_mode_; 590 591 const bool use_calibration_; 592 593 // If this is false, all dimensions including the batch dimension are 594 // set explicitely. 595 const bool use_implicit_batch_; 596 597 // Batch size of inputs to trt_network_ added by AddInputTensor(). During 598 // network construction it will update this, use it to verify the batch 599 // size of all inputs are compatible, and make sure individual TF node is 600 // acceptable by TRT. 601 int batch_size_ = -1; 602 603 friend class ConverterTest; 604 friend class OpConverterTest; 605 }; 606 607 // Return OK if the broadcast scheme is supported and compute the shapes after 608 // broadcasting. check_feasibility can be set to false in cases where dimensions 609 // do not need to match exactly (as in the case of BatchMatMulV2). 610 Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, 611 const TRT_TensorOrWeights& operand_r, 612 const bool check_feasibility, 613 const bool use_implicit_batch, 614 nvinfer1::Dims* operand_l_new_dims, 615 nvinfer1::Dims* operand_r_new_dims); 616 617 // Map of all supported UnaryOperations 618 const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap(); 619 // Map of all supported ActivationTypes 620 const std::unordered_map<string, nvinfer1::ActivationType>* ActivationTypeMap(); 621 622 } // namespace convert 623 } // namespace tensorrt 624 } // namespace tensorflow 625 626 #endif // GOOGLE_TENSORRT 627 #endif // GOOGLE_CUDA 628 629 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ 630