• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
18 
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
30 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/graph/graph.h"
33 #include "tensorflow/core/grappler/costs/graph_properties.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/stream_executor/lib/statusor.h"
36 
37 #if GOOGLE_CUDA && GOOGLE_TENSORRT
38 #include "third_party/tensorrt/NvInfer.h"
39 
40 namespace tensorflow {
41 namespace tensorrt {
42 
43 namespace convert {
44 using ::stream_executor::port::StatusOr;
45 
46 struct EngineConnection {
47   // Constructs a non-control edge.
EngineConnectionEngineConnection48   EngineConnection(const string& outside, int out_id, int out_port,
49                    const string& inside, int in_id, int in_port,
50                    bool input_edge, int port)
51       : outside_node_name(outside),
52         outside_id(out_id),
53         outside_port(out_port),
54         inside_node_name(inside),
55         inside_id(in_id),
56         inside_port(in_port),
57         is_input_edge(input_edge),
58         port_number(port) {}
59 
60   // Constructs a control edge.
EngineConnectionEngineConnection61   EngineConnection(const string& outside, int out_id, const string& inside,
62                    int in_id, bool input_edge)
63       : outside_node_name(outside),
64         outside_id(out_id),
65         outside_port(Graph::kControlSlot),
66         inside_node_name(inside),
67         inside_id(in_id),
68         inside_port(Graph::kControlSlot),
69         is_input_edge(input_edge),
70         port_number(Graph::kControlSlot) {}
71 
is_control_edgeEngineConnection72   bool is_control_edge() const { return port_number == Graph::kControlSlot; }
73 
74   const string outside_node_name;
75   const int outside_id;
76   const int outside_port;
77   PartialTensorShape outside_shape;  // Only set for input edge.
78 
79   const string inside_node_name;
80   const int inside_id;
81   const int inside_port;
82   PartialTensorShape inside_shape;  // Only set for output edge.
83 
84   DataType connection_type;
85   const bool is_input_edge;
86 
87   // The port number of the TRT node connected with this edge.
88   const int port_number;
89 };
90 
91 struct EngineInfo {
EngineInfoEngineInfo92   EngineInfo()
93       : engine_type(EngineType::TRTStatic),
94         max_workspace_size_bytes(0),
95         max_batch_size(absl::nullopt),
96         maximum_cached_engines(0),
97         precision_mode(TrtPrecisionMode::FP32),
98         use_calibration(true),
99         allow_build_at_runtime(true) {}
100 
101   string engine_name;
102   string device;
103   GraphDef segment_graph_def;
104 
105   // Non-control input connections inside this vector are sorted in a way such
106   // that, the segment nodes connecting to them are topological sorted.
107   // In addition, for non-control connections, there must be no duplicates.
108   std::vector<EngineConnection> connections;
109 
110   enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
111   EngineType engine_type;
112   int64 max_workspace_size_bytes;
113   absl::optional<int> max_batch_size;
114   int maximum_cached_engines;
115   TrtPrecisionMode precision_mode;
116   bool use_calibration;
117   bool allow_build_at_runtime;
118 };
119 
120 // Constructs a graphdef from the segment in the given graph. Adds _Arg
121 // nodes for input edges (InputPH_*) and _Retval nodes for output edges
122 // (OutputPH_*). This function needs to be called before TensorRT nodes
123 // inserted in order to correctly get sizes from the original graph.
124 //
125 // - subgraph_node_names: the node names of the subgraph.
126 // - subgraph_node_ids: the node ids of the subgraph, must be sorted in
127 //   topological order.
128 // - segment_def: the output GraphDef, whose non-input/output nodedefs will be
129 //   sorted in topological order.
130 // - scope_name: the name of the scope where the TRTEngineOp will be placed.
131 //
132 // TODO(aaroey): add tests to validate these properties.
133 Status ConvertSegmentToGraphDef(
134     const Graph* graph, const grappler::GraphProperties& graph_properties,
135     const std::vector<const Node*>& subgraph_nodes,
136     std::vector<EngineConnection>* connections, GraphDef* segment_def,
137     string* scope_name);
138 
139 // Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
140 // 'builder' successfully build the engine. If the result is not ok, 'engine'
141 // will be set to nullptr
142 // Once returned, 'builder' is not needed any more and can be safely destroyed.
143 //
144 // - convert_successfully: indicates whether the conversion to TensorRT network
145 //   is successful. This is different than successfully building the engine:
146 //   building can still fail afterwards.
147 Status ConvertGraphDefToEngine(
148     const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
149     size_t max_workspace_size_bytes,
150     const std::vector<PartialTensorShape>& input_shapes,
151     nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator,
152     TRTInt8Calibrator* calibrator,
153     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
154     const bool use_implicit_batch, bool* convert_successfully,
155     TrtShapeOptimizationProfile* profiles, absl::string_view engine_name);
156 
157 // Helper class for the segmenter to determine whether an output edge from the
158 // TRT segment is valid.
159 class OutputEdgeValidator {
160  public:
161   // Return true if the specified edge is eligible to be an output edge of the
162   // TRT segment.
163   bool operator()(const Edge* out_edge) const;
164 };
165 
166 int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims);
167 int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
168 
169 // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight.
170 class TRT_ShapedWeights {
171  public:
172   explicit TRT_ShapedWeights(
173       nvinfer1::DataType type = nvinfer1::DataType::kFLOAT);
174 
175   // Copy from another weights.
176   //
177   // NOTE: this does not copy the underlying buffer but only increase its
178   // reference count.
179   TRT_ShapedWeights(const TRT_ShapedWeights& rhs);
180 
181   nvinfer1::Weights GetTrtWeights() const;
182 
GetTensor()183   const Tensor& GetTensor() const { return tensor_; }
184 
185   // Returns the raw pointer to the underlying buffer which holds the weights
186   // value.
GetValues()187   void* GetValues() const {
188     return const_cast<char*>(tensor_.tensor_data().data());
189   }
190 
191   // Fills all the weight values with value.
192   template <typename T>
193   Status SetValues(T value);
194 
195   int64_t count() const;
196 
197   size_t size_bytes() const;
198 
199   string DebugString() const;
200 
201   template <typename T>
GetSpan()202   absl::Span<const T> GetSpan() const {
203     return absl::Span<const T>(tensor_.flat<T>().data(), count());
204   }
205 
206   template <typename T>
ToVector()207   std::vector<T> ToVector() const {
208     auto span = GetSpan<T>();
209     return std::vector<T>(span.data(), span.data() + span.size());
210   }
211 
TrtDType()212   nvinfer1::DataType TrtDType() const { return type_; }
213 
214   // TODO(aaroey): make these private.
215   nvinfer1::Dims shape_;  // Note: shape.type[] is not used.
216 
217  private:
218   // This constructor is only used by TrtWeightStore, which creates the
219   // underlying buffer.
220   TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims,
221                     Tensor tensor);
222 
223   nvinfer1::DataType type_;
224 
225   // All weights should be stored inside TrtWeightStore to make sure lifetime of
226   // all the underlying tensors are available until the engine is built. For
227   // this reason, tensor_ should never be reassigned to a different value that
228   // is not already present in the TrtWeightStore.
229   Tensor tensor_;
230 
231   friend class TrtWeightStore;
232 };
233 
234 // Container for TRT_ShapedWeights. We need this container because, TRT doesn't
235 // manage the lifetime of the weights buffer, it only keeps a pointer to it and
236 // requires that the data referenced by the pointer be available until the
237 // building of engine is complete. For more information see
238 // https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_weights.html
239 //
240 // TODO(laigd): consider adding garbage collection to the unused weights.
241 class TrtWeightStore {
242  public:
243   // Get a TRT_ShapedWeights with 'type' and 'dims'.
244   TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type,
245                                    const nvinfer1::Dims& dims);
246 
247   // Get a TRT_ShapedWeights with the same data type and dimensions as
248   // 'weights'.
GetTempWeights(const TRT_ShapedWeights & weights)249   TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) {
250     return GetTempWeights(weights.TrtDType(), weights.shape_);
251   }
252 
253  private:
254   // The backend storage of the TRT_ShapedWeights.
255   std::vector<Tensor> store_;
256 };
257 
258 // Represents a TRT-style input to a TF node, it can be either a
259 // nvinfer1::ITensor, or TRT_ShapedWeights which is compile-time constant.
260 //
261 // TODO(laigd): maybe rename it to TrtArgument, or mimic XlaCompiler::Argument.
262 class TRT_TensorOrWeights {
263  public:
TRT_TensorOrWeights()264   TRT_TensorOrWeights() {}
265 
266   // Constructor that makes it an ITensor, doesn't take ownership of 'tensor'.
267   // This is used by Converter when building the TRT network, where the ITensor
268   // is owned by the TRT network being built. See comment for 'tensor_' below.
269   explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1);
270 
271   // Constructor that makes it an ITensor by creating one using provided data
272   // type and shape, and takes ownership of the created ITensor. This is used by
273   // TrtNodeValidator to encapsulate the type and shape information for
274   // validation of graph nodes, and the created ITensor is fake and temporary,
275   // and should not be used to build any TRT network. See comment for
276   // 'simple_itensor_' below.
277   explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
278                                const nvinfer1::Dims& trt_dims, int batch_size);
279 
280   // Constructor that makes it a TRT_TensorOrWeights.
281   explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights);
282 
283   TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs);
284 
285   void operator=(const TRT_TensorOrWeights& rhs);
286 
is_tensor()287   bool is_tensor() const { return initialized_ && is_tensor_; }
is_weights()288   bool is_weights() const { return initialized_ && !is_tensor_; }
289 
290   nvinfer1::ITensor* tensor() const;
291 
weights()292   TRT_ShapedWeights& weights() {
293     CHECK(is_weights());
294     return weights_;
295   }
296 
weights()297   const TRT_ShapedWeights& weights() const {
298     CHECK(is_weights());
299     return weights_;
300   }
301 
302   nvinfer1::Dims GetTrtDims() const;
303 
304   Status GetTfType(DataType* tf_type) const;
305 
batch_size()306   int batch_size() const { return batch_size_; }
307 
308   string DebugString() const;
309 
310  private:
311   class SimpleITensor;
312 
set_batch_size(int batch_size)313   void set_batch_size(int batch_size) { batch_size_ = batch_size; }
314 
315   // When it represents an ITensor, the ITensor can be either passed by the
316   // caller via the constructor that takes an ITensor* as parameter, or be
317   // created as a SimpleITensor.
318   //
319   // In the first case, the ITensor pointer is stored in 'tensor_' below, and
320   // the ITensor itself is not owned by this class. This method is used by
321   // Converter (e.g. AddInputTensor) and op converters during TRT network
322   // construction, where the TRT network owns the ITensor.
323   //
324   // In the second case, the created SimpleITensor is stored in
325   // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake
326   // implementation of ITensor and is used only by TrtNodeValidator to validate
327   // the graph nodes.
328   nvinfer1::ITensor* tensor_ = nullptr;  // Not owned.
329   std::shared_ptr<SimpleITensor> simple_itensor_ = nullptr;
330 
331   // First dimension of the TF tensor (NOT tensor_) that is represented by
332   // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s
333   // dimensions (obtained via tensor_->getDimensions()) do not contain the batch
334   // dimension. For example, when a TF tensor with shape (A,B,C) is represented
335   // in TRT, tensor_->getDimensions() will be (B,C) and batch_size_ will be A.
336   //
337   // This requires that all tensors in the subgraph that is converted to a TRT
338   // engine have the same batch size are represented by the first dimension of
339   // their shape, and Converter will verify this during conversion. The drawback
340   // is that currently it cannot convert a graph that doesn't have the batch
341   // size represented in the shapes or the batch sizes are different. See
342   // b/118387490 for more details.
343   //
344   // If use_implicit_batch is false, batch_size_ is unused and
345   // tensor_->getDimensions() will contain the entire shape (A,B,C).
346   int batch_size_ = -1;
347 
348   TRT_ShapedWeights weights_;
349   bool initialized_ = false;
350   bool is_tensor_ = false;
351 
352   friend class Converter;
353 };
354 
355 class Converter;
356 
357 // Parameters for each op converter.
358 struct OpConverterParams {
359   // Constructor used for validation only.
360   OpConverterParams(const NodeDef& node_def,
361                     const std::vector<TRT_TensorOrWeights>& inputs,
362                     std::vector<TRT_TensorOrWeights>* outputs,
363                     TrtWeightStore* weight_store,
364                     TrtPrecisionMode precision_mode, bool use_calibration,
365                     bool use_implicit_batch);
366 
367   // Constructor used for conversion.
368   OpConverterParams(Converter* converter, const NodeDef& node_def,
369                     const std::vector<TRT_TensorOrWeights>& inputs,
370                     std::vector<TRT_TensorOrWeights>* outputs,
371                     TrtWeightStore* weight_store);
372 
373   Converter* converter = nullptr;
374   const NodeDef& node_def;
375   const std::vector<TRT_TensorOrWeights>& inputs;
376   std::vector<TRT_TensorOrWeights>* outputs;
377   const bool validation_only;
378   TrtWeightStore* weight_store;
379   const TrtPrecisionMode precision_mode;
380   const bool use_calibration;
381   const bool use_implicit_batch;
382 };
383 
384 using OpConverter = std::function<Status(OpConverterParams*)>;
385 
386 // Class to verify if specific TF node is supported by TRT.
387 class TrtNodeValidator {
388  public:
389   // 'graph_properties' is the GraphProperties of the graph whose nodes will be
390   // checked by IsTensorRTCandidate() later. It is used to get the shape and
391   // data type information of a tensor for validation purpose.
392   TrtNodeValidator(const grappler::GraphProperties& graph_properties,
393                    TrtPrecisionMode precision_mode, bool use_calibration,
394                    bool use_implicit_batch);
395 
396   // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
397   // to TRT subgraph and later converted into TRT engine.
398   Status IsTensorRTCandidate(const Node* node);
399 
400  private:
401   static const std::set<string>* quantize_ops;
402 
403   void RegisterOpValidators();
404 
405   // Convert a Const node to a TRT_TensorOrWeights.
406   Status ConvertConstToWeights(const NodeDef& const_node_def,
407                                const std::vector<TRT_TensorOrWeights>& inputs,
408                                TRT_TensorOrWeights* output);
409 
410   // Convert the output tensor at 'output_port' of 'node_def' to a
411   // TRT_TensorOrWeights which will be later used as an input to other nodes and
412   // passed to ValidateNode() below.
413   Status ConvertToTensorOrWeights(const NodeDef& node_def, int output_port,
414                                   TRT_TensorOrWeights* tensor_or_weights);
415 
416   // Stores all the validators by op type. If no validator is registered for
417   // specific op, it means no validation is needed and ValidateNode() will
418   // return OK.
419   std::unordered_map<string, OpConverter> op_validators_;
420 
421   // Store the weights added during validation. Some validations (e.g.
422   // validation for Const node) may produce weights.
423   TrtWeightStore weight_store_;
424 
425   // GraphProperties of the graph whose nodes are to be validated by
426   // IsTensorRTCandidate().
427   const grappler::GraphProperties& graph_properties_;
428 
429   // Quantization ops are only converted when using quantized precisions.
430   const TrtPrecisionMode precision_mode_;
431 
432   const bool use_calibration_;
433 
434   const bool use_implicit_batch_;
435 
436   friend class ValidatorTest;
437   friend class OpConverterTest;
438 };
439 
440 // Class to convert TF nodes to TRT network.
441 class Converter {
442  public:
443   // Used for Converter::RenameAndMarkOutputTensors()
444   struct EngineOutputInfo {
445     // The TRT tensor name which produces the output.
446     string source_tensor_name;
447     // The TensorFlow node name which is receiving the output from the TRT
448     // engine. This should always be the Identity node created in
449     // ConvertSegmentToGraphDef.
450     string dest_node_name;
451     // Output type. TensorRT requires this to be explicitly set for engine
452     // outputs.
453     nvinfer1::DataType trt_dtype;
454   };
455 
456   static StatusOr<std::unique_ptr<Converter>> Create(
457       TrtPrecisionMode precision_mode, bool use_calibration,
458       nvinfer1::ILogger* trt_logger, const bool use_implicit_batch,
459       absl::string_view engine_name);
460 
461   //////////////////////////////////////////////////////////////////////////////
462   // Methods used by the TRT engine builder to build a TRT network from a TF
463   // function/subgraph.
464 
465   // Convert the node to TRT network.
466   Status ConvertNode(const NodeDef& node_def);
467 
468   // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and
469   // 'batch_size'.
470   Status AddInputTensor(const string& name, nvinfer1::DataType dtype,
471                         const nvinfer1::Dims& dims, int batch_size);
472 
473   // Mark the tensors with names specified by source_tensor_name as output of
474   // the TRT network, and set their names in the TRT network as dest_node_name.
475   Status RenameAndMarkOutputTensors(
476       const std::vector<EngineOutputInfo>& output_tensors);
477 
478   // Build a TRT engine using the created network.
479   Status BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
480                          int max_batch_size, size_t max_workspace_size_bytes,
481                          nvinfer1::IGpuAllocator* allocator,
482                          TRTInt8Calibrator* calibrator,
483                          TrtShapeOptimizationProfile* profiles);
484 
485   //////////////////////////////////////////////////////////////////////////////
486   // Methods used by op converters to convert individual TF node and add layers
487   // to the TRT network.
488 
489   // Op converters (e.g. ConvertReshape) need to access the TRT network in order
490   // to add TRT layers.
network()491   nvinfer1::INetworkDefinition* network() { return trt_network_.get(); }
492 
493   // What precision are we targeting?
precision_mode()494   TrtPrecisionMode precision_mode() const { return precision_mode_; }
495 
496   // Calibration will be or was previously performed on this network?
use_calibration()497   bool use_calibration() const { return use_calibration_; }
498 
499   // Whether implicit batch mode is enabled
use_implicit_batch()500   bool use_implicit_batch() const { return use_implicit_batch_; }
501 
502   // This should be called on the inputs and outputs of any layer we create
503   // where we know that the quantization range does not change during that
504   // operation. (e.g. Reshape, Transpose, Identity, MaxPool).
505   void MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input,
506                                           nvinfer1::ITensor* output);
507 
508   // This function should be called when we know the quantization range of a
509   // tensor, either from a quantize/dequantize node or when the output is a
510   // fixed range (e.g. SoftMax, Relu6, Sigmoid).
511   void ProvideQuantizationRange(nvinfer1::ITensor* tensor, float min_range,
512                                 float max_range);
513 
514   // Should be called when full TRT network has been constructed and before
515   // building the engine.
516   void MaybeApplyQuantizationRanges();
517 
518   // Below are helper methods for op converters to add different layers to the
519   // TRT network.
520 
521   // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to
522   // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch
523   // dimension which should always be 0. If this is for adding a transpose layer
524   // to support the conversion of 'node_def', callers need to provide a
525   // non-empty 'sub_op_name' appended to the name of 'node_def' to avoid layer
526   // name conflicts.
527   Status TransposeTensor(nvinfer1::ITensor* input_tensor,
528                          const std::vector<int>& order_with_batch_dim,
529                          nvinfer1::ITensor** output_tensor,
530                          const NodeDef& node_def,
531                          absl::string_view sub_op_name = "");
532 
533   // Reshapes a dynamic shape tensor by removing or adding dimensions of size 1,
534   // and/or permuting the dimensions. The new shape is derived from the shape of
535   // the input tensor according to the slices and size_for_added_dims arguments.
536   //
537   // If there would be at most one unknown dimension, we could set the new shape
538   // using IShuffleLayer::setReshapeDimensions, which treats -1 as a special
539   // value (the same way as TF). In general, we can have more than one unknown
540   // dimensions, and we have to manipulate the shape tensors during runtime to
541   // define the new shape. This helper function defines the necessary shape
542   // inference layers and calls reshape using the calculated new shape.
543   //
544   // Example:
545   //
546   // Assume that we want to reshape a tensor from shape {A,B,C,D} to {C,D,A,B}
547   // (no transpose, just change the shape). In dynamic shape mode, the A,B,C,D
548   // values are not necessarily known at conversion time, they can be all -1. We
549   // can only define the new shape at runtime, when the actual shape is already
550   // known. To define the new shape:
551   // - We use an IShapeLayer to retrieve a shape tensor with the {A,B,C,D}
552   //   values.
553   // - Create two slices {C,D} and {A,B} of the shape tensor.
554   // - Concatenate these slices {C,D,A,B},
555   // - Set the {C,D,A,B} shape tensor as an input shape tensor for
556   // IShuffleLayer.
557   //
558   // This can be achieved by calling DynamicReshape(input, {{2,4},{0,2}},
559   // params).
560   //
561   // Before each slice we can insert a new dim if the corresponding
562   // size_for_added_dims element is not negative. The size_for_added_dims array
563   // can have more than slices.size() elements, in order to insert a dimension
564   // ater the last slice.
565   //
566   // Parameters:
567   // input - input tensor
568   // slices - [start, end) pairs of slices
569   // params - conversion parameters
570   // output - reshaped tensor
571   // size_for_added_dims - size of dimension inserted right before slice[i]. We
572   //   only insert a new dim if size_for_added_dims[i] >= 0.
573   Status DynamicReshape(nvinfer1::ITensor* input,
574                         std::vector<std::pair<int, int>> slices,
575                         OpConverterParams* params, nvinfer1::ITensor** output,
576                         std::vector<int> size_for_added_dims = {},
577                         absl::optional<int> op_instance = absl::nullopt);
578 
579   // Inserts a singleton dimension at axis for a dynamic shape tensor.
580   Status DynamicExpandDims(nvinfer1::ITensor* input, const nvinfer1::Dims& dims,
581                            int axis, OpConverterParams* params,
582                            nvinfer1::ITensor** output,
583                            absl::optional<int> op_instance = absl::nullopt);
584 
585   // Helper function to add a squeeze op to the network.
586   //
587   // The input_dims argument stores the TRT dimensions of the input tensor,
588   // where the dimensions to be squeezed are replaced by 0.
589   Status SqueezeTensor(nvinfer1::ITensor* input, std::vector<int>* input_dims,
590                        OpConverterParams* params, nvinfer1::ITensor** output);
591 
592   // Creates an IConstantLayer using 'weights' whose dimensions are specified by
593   // 'dims', and returns the output ITensor.
594   nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights,
595                                          const nvinfer1::Dims& dims);
596 
597   // Gets the min and max value in a TRT_ShapedWeights
598   Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
599                         float* out_max) const;
600 
601   // Constructs a name and passed it to the TensorRT layer to support xprof.
602   void SetLayerName(nvinfer1::ILayer* layer, const NodeDef& node_def,
603                     absl::string_view sub_op_name = "",
604                     absl::optional<int> sub_op_instance = absl::nullopt);
605   void SetLayerName(nvinfer1::ILayer* layer, absl::string_view main_op_name,
606                     absl::string_view sub_op_name,
607                     absl::optional<int> sub_op_instance = absl::nullopt);
608 
609  private:
610   Converter(TrtPrecisionMode precision_mode, bool use_calibration,
611             nvinfer1::ILogger* trt_logger, const bool use_implicit_batch,
612             absl::string_view engine_name);
613 
614   Status Init(nvinfer1::ILogger* trt_logger);
615 
616   // Verify the provided batch_size is consistent with batch_size_ and update it
617   // if necessary.
618   Status MaybeUpdateBatchSize(int batch_size);
619 
620   // Add the provided tensor/weights to the map trt_tensors_.
621   Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input);
622 
623   // Get the tensor/weights from trt_tensors_ by 'name'.
624   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output);
625 
626   // Get the inputs of 'node_def' from trt_tensors_.
627   Status GetInputs(const NodeDef& node_def,
628                    std::vector<TRT_TensorOrWeights>* inputs) const;
629 
630   void RegisterOpConverters();
631 
632   void PropagateQuantizationRanges();
633 
634   // Registered op converters by op type.
635   std::unordered_map<string, OpConverter> op_registry_;
636 
637   // Tensors/weights added during construction of trt_network_.
638   std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
639 
640   // The TRT builder used to create the network and build the engine. Not owned.
641   TrtUniquePtrType<nvinfer1::IBuilder> trt_builder_;
642 
643   // The TRT network being built.
644   TrtUniquePtrType<nvinfer1::INetworkDefinition> trt_network_;
645 
646   // Store the weights added during construction of trt_network_.
647   TrtWeightStore weight_store_;
648 
649   // During conversion, this table is populated with quantization ranges per
650   // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT
651   // quantization ranges. Since TRT only supports symmetric ranges, we will
652   // store the range as a single float = max(abs(min_range), abs(max_range)).
653   // Range refers to the floating point values, e.g. min_range = 0.0f, max_range
654   // = 6.0f for Relu6.
655   std::unordered_map<nvinfer1::ITensor*, float> quantization_ranges_;
656 
657   // Edges where quantization ranges can be inferred (copied) across ops - from
658   // first tensor to second tensor. PropagateQuantizationRanges() will propagate
659   // known ranges from quantization_ranges_ across these edges, adding the new
660   // ranges to quantization_ranges_ so that they can be applied in
661   // MaybeApplyQuantizationRanges().
662   std::vector<std::pair<nvinfer1::ITensor*, nvinfer1::ITensor*>>
663       quantization_infer_;
664 
665   const TrtPrecisionMode precision_mode_;
666 
667   const bool use_calibration_;
668 
669   // If this is false, all dimensions including the batch dimension are
670   // set explicitely.
671   const bool use_implicit_batch_;
672 
673   // Batch size of inputs to trt_network_ added by AddInputTensor(). During
674   // network construction it will update this, use it to verify the batch
675   // size of all inputs are compatible, and make sure individual TF node is
676   // acceptable by TRT.
677   int batch_size_ = -1;
678 
679   // Assign a ID to each constant layer we create, so that we can assign a
680   // unique name to the layer.
681   int next_constant_layer_id_ = 0;
682 
683   // The name of the TRTEngineOp node.
684   absl::string_view engine_name_;
685 
686   friend class ConverterTest;
687   friend class OpConverterTest;
688 };
689 
690 // Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims'
691 // (which doesn't contain the batch dimension).
692 //
693 // If validation_only is true, it doesn't do the conversion but only do some
694 // minimum validation for the eligibility of the conversion, and *tensor will
695 // be set to nullptr.
696 // If validation_only is false converter must not be nullptr.
697 Status PrepareTensorForShape(Converter* converter,
698                              const TRT_TensorOrWeights& input,
699                              const nvinfer1::Dims& dims,
700                              const bool validation_only,
701                              nvinfer1::ITensor** tensor,
702                              const NodeDef& node_def,
703                              absl::optional<int> op_instance = absl::nullopt);
704 
705 // Return OK if the broadcast scheme is supported and compute the shapes after
706 // broadcasting. check_feasibility can be set to false in cases where dimensions
707 // do not need to match exactly (as in the case of BatchMatMulV2).
708 Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
709                             const TRT_TensorOrWeights& operand_r,
710                             const bool check_feasibility,
711                             const bool use_implicit_batch,
712                             nvinfer1::Dims* operand_l_new_dims,
713                             nvinfer1::Dims* operand_r_new_dims);
714 
715 // Map of all supported UnaryOperations
716 const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
717 // Map of all supported ActivationTypes
718 const std::unordered_map<string, nvinfer1::ActivationType>* ActivationTypeMap();
719 // Map of all supported BinaryOperations
720 const std::unordered_map<string, nvinfer1::ElementWiseOperation>*
721 BinaryOperationMap();
722 
723 }  // namespace convert
724 }  // namespace tensorrt
725 }  // namespace tensorflow
726 
727 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
728 
729 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
730