• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
18 
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/grappler/costs/graph_properties.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34 
35 #if GOOGLE_CUDA
36 #if GOOGLE_TENSORRT
37 #include "third_party/tensorrt/NvInfer.h"
38 
39 namespace tensorflow {
40 namespace tensorrt {
41 
42 namespace convert {
43 using ::stream_executor::port::StatusOr;
44 
45 struct EngineConnection {
46   // Constructs a non-control edge.
EngineConnectionEngineConnection47   EngineConnection(const string& outside, int out_id, int out_port,
48                    const string& inside, int in_id, int in_port,
49                    bool input_edge, int port)
50       : outside_node_name(outside),
51         outside_id(out_id),
52         outside_port(out_port),
53         inside_node_name(inside),
54         inside_id(in_id),
55         inside_port(in_port),
56         is_input_edge(input_edge),
57         port_number(port) {}
58 
59   // Constructs a control edge.
EngineConnectionEngineConnection60   EngineConnection(const string& outside, int out_id, const string& inside,
61                    int in_id, bool input_edge)
62       : outside_node_name(outside),
63         outside_id(out_id),
64         outside_port(Graph::kControlSlot),
65         inside_node_name(inside),
66         inside_id(in_id),
67         inside_port(Graph::kControlSlot),
68         is_input_edge(input_edge),
69         port_number(Graph::kControlSlot) {}
70 
is_control_edgeEngineConnection71   bool is_control_edge() const { return port_number == Graph::kControlSlot; }
72 
73   const string outside_node_name;
74   const int outside_id;
75   const int outside_port;
76   PartialTensorShape outside_shape;  // Only set for input edge.
77 
78   const string inside_node_name;
79   const int inside_id;
80   const int inside_port;
81   PartialTensorShape inside_shape;  // Only set for output edge.
82 
83   DataType connection_type;
84   const bool is_input_edge;
85 
86   // The port number of the TRT node connected with this edge.
87   const int port_number;
88 };
89 
90 struct EngineInfo {
EngineInfoEngineInfo91   EngineInfo()
92       : engine_type(EngineType::TRTStatic),
93         max_workspace_size_bytes(0),
94         precision_mode(TrtPrecisionMode::FP32),
95         use_calibration(true) {}
96 
97   string engine_name;
98   string device;
99   GraphDef segment_graph_def;
100 
101   // Non-control input connections inside this vector are sorted in a way such
102   // that, the segment nodes connecting to them are topological sorted.
103   // In addition, for non-control connections, there must be no duplicates.
104   std::vector<EngineConnection> connections;
105 
106   enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
107   EngineType engine_type;
108   int64 max_workspace_size_bytes;
109   int maximum_cached_engines;
110   TrtPrecisionMode precision_mode;
111   bool use_calibration;
112 };
113 
114 // Constructs a graphdef from the segment in the given graph. Adds _Arg
115 // nodes for input edges (InputPH_*) and _Retval nodes for output edges
116 // (OutputPH_*). This function needs to be called before TensorRT nodes
117 // inserted in order to correctly get sizes from the original graph.
118 //
119 // - subgraph_node_names: the node names of the subgraph.
120 // - subgraph_node_ids: the node ids of the subgraph, must be sorted in
121 //   topological order.
122 // - segment_def: the output GraphDef, whose non-input/output nodedefs will be
123 //   sorted in topological order.
124 // - scope_name: the name of the scope where the TRTEngineOp will be placed.
125 //
126 // TODO(aaroey): add tests to validate these properties.
127 Status ConvertSegmentToGraphDef(
128     const Graph* graph, const grappler::GraphProperties& graph_properties,
129     const std::vector<const Node*>& subgraph_nodes,
130     std::vector<EngineConnection>* connections, GraphDef* segment_def,
131     string* scope_name);
132 
133 // Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
134 // 'builder' successfully build the engine. If the result is not ok, 'engine'
135 // will be set to nullptr
136 // Once returned, 'builder' is not needed any more and can be safely destroyed.
137 //
138 // - convert_successfully: indicates whether the conversion to TensorRT network
139 //   is successful. This is different than successfully building the engine:
140 //   building can still fail afterwards.
141 Status ConvertGraphDefToEngine(
142     const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
143     size_t max_workspace_size_bytes,
144     const std::vector<PartialTensorShape>& input_shapes,
145     nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator,
146     TRTInt8Calibrator* calibrator,
147     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
148     const bool use_implicit_batch, bool* convert_successfully);
149 
150 // Helper class for the segmenter to determine whether an output edge from the
151 // TRT segment is valid.
152 class OutputEdgeValidator {
153  public:
154   // Return true if the specified edge is eligible to be an output edge of the
155   // TRT segment.
156   bool operator()(const Edge* out_edge) const;
157 };
158 
159 int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims);
160 int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
161 
162 // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight.
163 class TRT_ShapedWeights {
164  public:
165   explicit TRT_ShapedWeights(
166       nvinfer1::DataType type = nvinfer1::DataType::kFLOAT);
167 
168   // Copy from another weights.
169   //
170   // NOTE: this does not copy the underlying buffer but only increase its
171   // reference count.
172   TRT_ShapedWeights(const TRT_ShapedWeights& rhs);
173 
174   nvinfer1::Weights GetTrtWeights() const;
175 
GetTensor()176   const Tensor& GetTensor() const { return tensor_; }
177 
178   // Returns the raw pointer to the underlying buffer which holds the weights
179   // value.
GetValues()180   void* GetValues() const {
181     return const_cast<char*>(tensor_.tensor_data().data());
182   }
183 
184   int64_t count() const;
185 
186   size_t size_bytes() const;
187 
188   string DebugString() const;
189 
190   template <typename T>
GetSpan()191   absl::Span<const T> GetSpan() const {
192     return absl::Span<const T>(tensor_.flat<T>().data(), count());
193   }
194 
195   template <typename T>
ToVector()196   std::vector<T> ToVector() const {
197     auto span = GetSpan<T>();
198     return std::vector<T>(span.data(), span.data() + span.size());
199   }
200 
TrtDType()201   nvinfer1::DataType TrtDType() const { return type_; }
202 
203   // TODO(aaroey): make these private.
204   nvinfer1::Dims shape_;  // Note: shape.type[] is not used.
205 
206  private:
207   // This constructor is only used by TrtWeightStore, which creates the
208   // underlying buffer.
209   TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims,
210                     Tensor tensor);
211 
212   nvinfer1::DataType type_;
213 
214   // All weights should be stored inside TrtWeightStore to make sure lifetime of
215   // all the underlying tensors are available until the engine is built. For
216   // this reason, tensor_ should never be reassigned to a different value that
217   // is not already present in the TrtWeightStore.
218   Tensor tensor_;
219 
220   friend class TrtWeightStore;
221 };
222 
223 // Container for TRT_ShapedWeights. We need this container because, TRT doesn't
224 // manage the lifetime of the weights buffer, it only keeps a pointer to it and
225 // requires that the data referenced by the pointer be available until the
226 // building of engine is complete. For more information see
227 // https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_weights.html
228 //
229 // TODO(laigd): consider adding garbage collection to the unused weights.
230 class TrtWeightStore {
231  public:
232   // Get a TRT_ShapedWeights with 'type' and 'dims'.
233   TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type,
234                                    const nvinfer1::Dims& dims);
235 
236   // Get a TRT_ShapedWeights with the same data type and dimensions as
237   // 'weights'.
GetTempWeights(const TRT_ShapedWeights & weights)238   TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) {
239     return GetTempWeights(weights.TrtDType(), weights.shape_);
240   }
241 
242  private:
243   // The backend storage of the TRT_ShapedWeights.
244   std::vector<Tensor> store_;
245 };
246 
247 // Represents a TRT-style input to a TF node, it can be either a
248 // nvinfer1::ITensor, or TRT_ShapedWeights which is compile-time constant.
249 //
250 // TODO(laigd): maybe rename it to TrtArgument, or mimic XlaCompiler::Argument.
251 class TRT_TensorOrWeights {
252  public:
TRT_TensorOrWeights()253   TRT_TensorOrWeights() {}
254 
255   // Constructor that makes it an ITensor, doesn't take ownership of 'tensor'.
256   // This is used by Converter when building the TRT network, where the ITensor
257   // is owned by the TRT network being built. See comment for 'tensor_' below.
258   explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1);
259 
260   // Constructor that makes it an ITensor by creating one using provided data
261   // type and shape, and takes ownership of the created ITensor. This is used by
262   // TrtNodeValidator to encapsulate the type and shape information for
263   // validation of graph nodes, and the created ITensor is fake and temporary,
264   // and should not be used to build any TRT network. See comment for
265   // 'simple_itensor_' below.
266   explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
267                                const nvinfer1::Dims& trt_dims, int batch_size);
268 
269   // Constructor that makes it a TRT_TensorOrWeights.
270   explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights);
271 
272   TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs);
273 
274   void operator=(const TRT_TensorOrWeights& rhs);
275 
is_tensor()276   bool is_tensor() const { return initialized_ && is_tensor_; }
is_weights()277   bool is_weights() const { return initialized_ && !is_tensor_; }
278 
279   nvinfer1::ITensor* tensor() const;
280 
weights()281   TRT_ShapedWeights& weights() {
282     CHECK(is_weights());
283     return weights_;
284   }
285 
weights()286   const TRT_ShapedWeights& weights() const {
287     CHECK(is_weights());
288     return weights_;
289   }
290 
291   nvinfer1::Dims GetTrtDims() const;
292 
batch_size()293   int batch_size() const { return batch_size_; }
294 
295   string DebugString() const;
296 
297  private:
298   class SimpleITensor;
299 
set_batch_size(int batch_size)300   void set_batch_size(int batch_size) { batch_size_ = batch_size; }
301 
302   // When it represents an ITensor, the ITensor can be either passed by the
303   // caller via the constructor that takes an ITensor* as parameter, or be
304   // created as a SimpleITensor.
305   //
306   // In the first case, the ITensor pointer is stored in 'tensor_' below, and
307   // the ITensor itself is not owned by this class. This method is used by
308   // Converter (e.g. AddInputTensor) and op converters during TRT network
309   // construction, where the TRT network owns the ITensor.
310   //
311   // In the second case, the created SimpleITensor is stored in
312   // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake
313   // implementation of ITensor and is used only by TrtNodeValidator to validate
314   // the graph nodes.
315   nvinfer1::ITensor* tensor_ = nullptr;  // Not owned.
316   std::shared_ptr<SimpleITensor> simple_itensor_ = nullptr;
317 
318   // First dimension of the TF tensor (NOT tensor_) that is represented by
319   // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s
320   // dimensions (obtained via tensor_->getDimensions()) do not contain the batch
321   // dimension. For example, when a TF tensor with shape (A,B,C) is represented
322   // in TRT, tensor_->getDimensions() will be (B,C) and batch_size_ will be A.
323   //
324   // This requires that all tensors in the subgraph that is converted to a TRT
325   // engine have the same batch size are represented by the first dimension of
326   // their shape, and Converter will verify this during conversion. The drawback
327   // is that currently it cannot convert a graph that doesn't have the batch
328   // size represented in the shapes or the batch sizes are different. See
329   // b/118387490 for more details.
330   //
331   // If use_implicit_batch is false, batch_size_ is unused and
332   // tensor_->getDimensions() will contain the entire shape (A,B,C).
333   int batch_size_ = -1;
334 
335   TRT_ShapedWeights weights_;
336   bool initialized_ = false;
337   bool is_tensor_ = false;
338 
339   friend class Converter;
340 };
341 
342 class Converter;
343 
344 // Parameters for each op converter.
345 struct OpConverterParams {
346   // Constructor used for validation only.
347   OpConverterParams(const NodeDef& node_def,
348                     const std::vector<TRT_TensorOrWeights>& inputs,
349                     std::vector<TRT_TensorOrWeights>* outputs,
350                     TrtWeightStore* weight_store,
351                     TrtPrecisionMode precision_mode, bool use_calibration,
352                     bool use_implicit_batch);
353 
354   // Constructor used for conversion.
355   OpConverterParams(Converter* converter, const NodeDef& node_def,
356                     const std::vector<TRT_TensorOrWeights>& inputs,
357                     std::vector<TRT_TensorOrWeights>* outputs,
358                     TrtWeightStore* weight_store);
359 
360   Converter* converter = nullptr;
361   const NodeDef& node_def;
362   const std::vector<TRT_TensorOrWeights>& inputs;
363   std::vector<TRT_TensorOrWeights>* outputs;
364   const bool validation_only;
365   TrtWeightStore* weight_store;
366   const TrtPrecisionMode precision_mode;
367   const bool use_calibration;
368   const bool use_implicit_batch;
369 };
370 
371 using OpConverter = std::function<Status(OpConverterParams*)>;
372 
373 // Class to verify if specific TF node is supported by TRT.
374 class TrtNodeValidator {
375  public:
376   // 'graph_properties' is the GraphProperties of the graph whose nodes will be
377   // checked by IsTensorRTCandidate() later. It is used to get the shape and
378   // data type information of a tensor for validation purpose.
379   TrtNodeValidator(const grappler::GraphProperties& graph_properties,
380                    TrtPrecisionMode precision_mode, bool use_calibration,
381                    bool use_implicit_batch);
382 
383   // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
384   // to TRT subgraph and later converted into TRT engine.
385   Status IsTensorRTCandidate(const Node* node);
386 
387  private:
388   static const std::set<string>* quantize_ops;
389 
390   void RegisterOpValidators();
391 
392   // Convert a Const node to a TRT_TensorOrWeights.
393   Status ConvertConstToWeights(const NodeDef& const_node_def,
394                                const std::vector<TRT_TensorOrWeights>& inputs,
395                                TRT_TensorOrWeights* output);
396 
397   // Convert the output tensor at 'output_port' of 'node_def' to a
398   // TRT_TensorOrWeights which will be later used as an input to other nodes and
399   // passed to ValidateNode() below.
400   Status ConvertToTensorOrWeights(const NodeDef& node_def, int output_port,
401                                   TRT_TensorOrWeights* tensor_or_weights);
402 
403   // Stores all the validators by op type. If no validator is registered for
404   // specific op, it means no validation is needed and ValidateNode() will
405   // return OK.
406   std::unordered_map<string, OpConverter> op_validators_;
407 
408   // Store the weights added during validation. Some validations (e.g.
409   // validation for Const node) may produce weights.
410   TrtWeightStore weight_store_;
411 
412   // GraphProperties of the graph whose nodes are to be validated by
413   // IsTensorRTCandidate().
414   const grappler::GraphProperties& graph_properties_;
415 
416   // Quantization ops are only converted when using quantized precisions.
417   const TrtPrecisionMode precision_mode_;
418 
419   const bool use_calibration_;
420 
421   const bool use_implicit_batch_;
422 
423   friend class ValidatorTest;
424   friend class OpConverterTest;
425 };
426 
427 // Class to convert TF nodes to TRT network.
428 class Converter {
429  public:
430   // Used for Converter::RenameAndMarkOutputTensors()
431   struct EngineOutputInfo {
432     // The TRT tensor name which produces the output.
433     string source_tensor_name;
434     // The TensorFlow node name which is receiving the output from the TRT
435     // engine. This should always be the Identity node created in
436     // ConvertSegmentToGraphDef.
437     string dest_node_name;
438     // Output type. TensorRT requires this to be explicitly set for engine
439     // outputs.
440     nvinfer1::DataType trt_dtype;
441   };
442 
443   static StatusOr<std::unique_ptr<Converter>> Create(
444       TrtPrecisionMode precision_mode, bool use_calibration,
445       nvinfer1::ILogger* trt_logger, const bool use_implicit_batch);
446 
447   //////////////////////////////////////////////////////////////////////////////
448   // Methods used by the TRT engine builder to build a TRT network from a TF
449   // function/subgraph.
450 
451   // Convert the node to TRT network.
452   Status ConvertNode(const NodeDef& node_def);
453 
454   // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and
455   // 'batch_size'.
456   Status AddInputTensor(const string& name, nvinfer1::DataType dtype,
457                         const nvinfer1::Dims& dims, int batch_size);
458 
459   // Mark the tensors with names specified by source_tensor_name as output of
460   // the TRT network, and set their names in the TRT network as dest_node_name.
461   Status RenameAndMarkOutputTensors(
462       const std::vector<EngineOutputInfo>& output_tensors);
463 
464   // Build a TRT engine using the created network.
465   Status BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
466                          int max_batch_size, size_t max_workspace_size_bytes,
467                          nvinfer1::IGpuAllocator* allocator,
468                          TRTInt8Calibrator* calibrator);
469 
470   //////////////////////////////////////////////////////////////////////////////
471   // Methods used by op converters to convert individual TF node and add layers
472   // to the TRT network.
473 
474   // Op converters (e.g. ConvertReshape) need to access the TRT network in order
475   // to add TRT layers.
network()476   nvinfer1::INetworkDefinition* network() { return trt_network_.get(); }
477 
478   // What precision are we targeting?
precision_mode()479   TrtPrecisionMode precision_mode() const { return precision_mode_; }
480 
481   // Calibration will be or was previously performed on this network?
use_calibration()482   bool use_calibration() const { return use_calibration_; }
483 
484   // Whether implicit batch mode is enabled
use_implicit_batch()485   bool use_implicit_batch() const { return use_implicit_batch_; }
486 
487   // This should be called on the inputs and outputs of any layer we create
488   // where we know that the quantization range does not change during that
489   // operation. (e.g. Reshape, Transpose, Identity, MaxPool).
490   void MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input,
491                                           nvinfer1::ITensor* output);
492 
493   // This function should be called when we know the quantization range of a
494   // tensor, either from a quantize/dequantize node or when the output is a
495   // fixed range (e.g. SoftMax, Relu6, Sigmoid).
496   void ProvideQuantizationRange(nvinfer1::ITensor* tensor, float min_range,
497                                 float max_range);
498 
499   // Should be called when full TRT network has been constructed and before
500   // building the engine.
501   void MaybeApplyQuantizationRanges();
502 
503   // Below are helper methods for op converters to add different layers to the
504   // TRT network.
505 
506   // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to
507   // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch
508   // dimension which should always be 0.
509   Status TransposeTensor(nvinfer1::ITensor* input_tensor,
510                          const std::vector<int>& order_with_batch_dim,
511                          absl::string_view name,
512                          nvinfer1::ITensor** output_tensor);
513 
514   // Converts 'input' into 'tensor' with shape specified by 'dims' (which
515   // doesn't contain the batch dimension).
516   //
517   // If validation_only is true, it doesn't do the conversion but only do some
518   // minimum validation for the eligibility of the conversion, and *tensor will
519   // be set to nullptr.
520   Status PrepareTensorForShape(const TRT_TensorOrWeights& input,
521                                const nvinfer1::Dims& dims,
522                                const bool validation_only,
523                                nvinfer1::ITensor** tensor);
524 
525   // Creates an IConstantLayer using 'weights' whose dimensions are specified by
526   // 'dims', and returns the output ITensor.
527   nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights,
528                                          const nvinfer1::Dims& dims);
529 
530  private:
531   Converter(TrtPrecisionMode precision_mode, bool use_calibration,
532             nvinfer1::ILogger* trt_logger, const bool use_implicit_batch);
533 
534   Status Init(nvinfer1::ILogger* trt_logger);
535 
536   // Verify the provided batch_size is consistent with batch_size_ and update it
537   // if necessary.
538   Status MaybeUpdateBatchSize(int batch_size);
539 
540   // Add the provided tensor/weights to the map trt_tensors_.
541   Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input);
542 
543   // Get the tensor/weights from trt_tensors_ by 'name'.
544   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output);
545 
546   // Get the inputs of 'node_def' from trt_tensors_.
547   Status GetInputs(const NodeDef& node_def,
548                    std::vector<TRT_TensorOrWeights>* inputs) const;
549 
550   void RegisterOpConverters();
551 
552   void PropagateQuantizationRanges();
553 
554   // Gets the min and max value in a TRT_ShapedWeights
555   Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
556                         float* out_max) const;
557 
558   // Registered op converters by op type.
559   std::unordered_map<string, OpConverter> op_registry_;
560 
561   // Tensors/weights added during construction of trt_network_.
562   std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
563 
564   // The TRT builder used to create the network and build the engine. Not owned.
565   TrtUniquePtrType<nvinfer1::IBuilder> trt_builder_;
566 
567   // The TRT network being built.
568   TrtUniquePtrType<nvinfer1::INetworkDefinition> trt_network_;
569 
570   // Store the weights added during construction of trt_network_.
571   TrtWeightStore weight_store_;
572 
573   // During conversion, this table is populated with quantization ranges per
574   // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT
575   // quantization ranges. Since TRT only supports symmetric ranges, we will
576   // store the range as a single float = max(abs(min_range), abs(max_range)).
577   // Range refers to the floating point values, e.g. min_range = 0.0f, max_range
578   // = 6.0f for Relu6.
579   std::unordered_map<nvinfer1::ITensor*, float> quantization_ranges_;
580 
581   // Edges where quantization ranges can be inferred (copied) across ops - from
582   // first tensor to second tensor. PropagateQuantizationRanges() will propagate
583   // known ranges from quantization_ranges_ across these edges, adding the new
584   // ranges to quantization_ranges_ so that they can be applied in
585   // MaybeApplyQuantizationRanges().
586   std::vector<std::pair<nvinfer1::ITensor*, nvinfer1::ITensor*>>
587       quantization_infer_;
588 
589   const TrtPrecisionMode precision_mode_;
590 
591   const bool use_calibration_;
592 
593   // If this is false, all dimensions including the batch dimension are
594   // set explicitely.
595   const bool use_implicit_batch_;
596 
597   // Batch size of inputs to trt_network_ added by AddInputTensor(). During
598   // network construction it will update this, use it to verify the batch
599   // size of all inputs are compatible, and make sure individual TF node is
600   // acceptable by TRT.
601   int batch_size_ = -1;
602 
603   friend class ConverterTest;
604   friend class OpConverterTest;
605 };
606 
607 // Return OK if the broadcast scheme is supported and compute the shapes after
608 // broadcasting. check_feasibility can be set to false in cases where dimensions
609 // do not need to match exactly (as in the case of BatchMatMulV2).
610 Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
611                             const TRT_TensorOrWeights& operand_r,
612                             const bool check_feasibility,
613                             const bool use_implicit_batch,
614                             nvinfer1::Dims* operand_l_new_dims,
615                             nvinfer1::Dims* operand_r_new_dims);
616 
617 // Map of all supported UnaryOperations
618 const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
619 // Map of all supported ActivationTypes
620 const std::unordered_map<string, nvinfer1::ActivationType>* ActivationTypeMap();
621 
622 }  // namespace convert
623 }  // namespace tensorrt
624 }  // namespace tensorflow
625 
626 #endif  // GOOGLE_TENSORRT
627 #endif  // GOOGLE_CUDA
628 
629 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
630