• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
18 
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
30 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
31 #include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/graph/graph.h"
34 #include "tensorflow/core/grappler/costs/graph_properties.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/stream_executor/lib/statusor.h"
37 
38 #if GOOGLE_CUDA && GOOGLE_TENSORRT
39 #include "third_party/tensorrt/NvInfer.h"
40 
41 namespace tensorflow {
42 namespace tensorrt {
43 
44 namespace convert {
45 using ::stream_executor::port::StatusOr;
46 
47 struct EngineConnection {
48   // Constructs a non-control edge.
EngineConnectionEngineConnection49   EngineConnection(const string& outside, int out_id, int out_port,
50                    const string& inside, int in_id, int in_port,
51                    bool input_edge, int port)
52       : outside_node_name(outside),
53         outside_id(out_id),
54         outside_port(out_port),
55         inside_node_name(inside),
56         inside_id(in_id),
57         inside_port(in_port),
58         is_input_edge(input_edge),
59         port_number(port) {}
60 
61   // Constructs a control edge.
EngineConnectionEngineConnection62   EngineConnection(const string& outside, int out_id, const string& inside,
63                    int in_id, bool input_edge)
64       : outside_node_name(outside),
65         outside_id(out_id),
66         outside_port(Graph::kControlSlot),
67         inside_node_name(inside),
68         inside_id(in_id),
69         inside_port(Graph::kControlSlot),
70         is_input_edge(input_edge),
71         port_number(Graph::kControlSlot) {}
72 
is_control_edgeEngineConnection73   bool is_control_edge() const { return port_number == Graph::kControlSlot; }
74 
75   const string outside_node_name;
76   const int outside_id;
77   const int outside_port;
78   PartialTensorShape outside_shape;  // Only set for input edge.
79 
80   const string inside_node_name;
81   const int inside_id;
82   const int inside_port;
83   PartialTensorShape inside_shape;  // Only set for output edge.
84 
85   DataType connection_type;
86   const bool is_input_edge;
87 
88   // The port number of the TRT node connected with this edge.
89   const int port_number;
90 };
91 
92 struct EngineInfo {
EngineInfoEngineInfo93   EngineInfo()
94       : engine_type(EngineType::TRTStatic),
95         max_workspace_size_bytes(0),
96         max_batch_size(absl::nullopt),
97         maximum_cached_engines(0),
98         precision_mode(TrtPrecisionMode::FP32),
99         use_calibration(true),
100         allow_build_at_runtime(true),
101         has_int32_input(false) {}
102 
103   string engine_name;
104   string device;
105   GraphDef segment_graph_def;
106 
107   // Non-control input connections inside this vector are sorted in a way such
108   // that, the segment nodes connecting to them are topological sorted.
109   // In addition, for non-control connections, there must be no duplicates.
110   std::vector<EngineConnection> connections;
111 
112   enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
113   EngineType engine_type;
114   int64 max_workspace_size_bytes;
115   absl::optional<int> max_batch_size;
116   int maximum_cached_engines;
117   TrtPrecisionMode precision_mode;
118   bool use_calibration;
119   bool allow_build_at_runtime;
120   bool has_int32_input;
121 };
122 
123 // Constructs a graphdef from the segment in the given graph and stores it to
124 // the engine_info. Adds _Arg nodes for input edges (InputPH_*) and _Retval
125 // nodes for output edges (OutputPH_*). Maintains the topological order of the
126 // non-input/output nodes in the graphdef. This function needs to be called
127 // before TensorRT layers are created because it prepares the original graph
128 // for TensorRT conversion.
129 //
130 // - subgraph_node_names: the node names of the subgraph.
131 // - subgraph_node_ids: the node ids of the subgraph, must be sorted in
132 //   topological order.
133 // - engine_info: a data structure that records the information about the
134 //   engine containing the subgraph.
135 //
136 // TODO(aaroey): add tests to validate these properties.
137 Status ConvertSegmentToGraphDef(
138     const Graph* graph, const grappler::GraphProperties& graph_properties,
139     const std::vector<const Node*>& subgraph_nodes, EngineInfo* engine_info);
140 
141 // Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
142 // 'builder' successfully build the engine. If the result is not ok, 'engine'
143 // will be set to nullptr
144 // Once returned, 'builder' is not needed any more and can be safely destroyed.
145 //
146 // - convert_successfully: indicates whether the conversion to TensorRT network
147 //   is successful. This is different than successfully building the engine:
148 //   building can still fail afterwards.
149 Status ConvertGraphDefToEngine(
150     const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
151     size_t max_workspace_size_bytes,
152     const std::vector<PartialTensorShape>& input_shapes,
153     nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator,
154     TRTInt8Calibrator* calibrator,
155     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
156     const bool use_implicit_batch, bool* convert_successfully,
157     TrtShapeOptimizationProfile* profiles, absl::string_view engine_name);
158 
159 // Helper class for the segmenter to determine whether an output edge from the
160 // TRT segment is valid.
161 class OutputEdgeValidator {
162  public:
163   // Return true if the specified edge is eligible to be an output edge of the
164   // TRT segment.
165   bool operator()(const Edge* out_edge) const;
166 };
167 
168 int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
169 
170 // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight.
171 class TRT_ShapedWeights {
172  public:
173   explicit TRT_ShapedWeights(
174       nvinfer1::DataType type = nvinfer1::DataType::kFLOAT);
175 
176   // Copy from another weights.
177   //
178   // NOTE: this does not copy the underlying buffer but only increase its
179   // reference count.
180   TRT_ShapedWeights(const TRT_ShapedWeights& rhs);
181 
182   nvinfer1::Weights GetTrtWeights() const;
183 
GetTensor()184   const Tensor& GetTensor() const { return tensor_; }
185 
186   // Returns the raw pointer to the underlying buffer which holds the weights
187   // value.
GetValues()188   void* GetValues() const {
189     return const_cast<char*>(tensor_.tensor_data().data());
190   }
191 
192   // Fills all the weight values with value.
193   template <typename T>
194   Status SetValues(T value);
195 
196   Status SetShape(nvinfer1::Dims dims);
197 
198   // Returns total number of elements. Returning 0 means either some dim is 0
199   // or the number of dims is 0. Note that a TF scalar constant is marked as
200   // Dims{0, {1}}, and has a count() == 1.
count()201   int64_t count() const { return count(shape_); }
202 
203   // Returns the total number of elements in a weight with shape dims.
204   static int64_t count(nvinfer1::Dims dims);
205 
206   size_t size_bytes() const;
207 
208   string DebugString() const;
209 
210   template <typename T>
GetSpan()211   absl::Span<const T> GetSpan() const {
212     return absl::Span<const T>(tensor_.flat<T>().data(), count());
213   }
214 
215   template <typename T>
ToVector()216   std::vector<T> ToVector() const {
217     auto span = GetSpan<T>();
218     return std::vector<T>(span.data(), span.data() + span.size());
219   }
220 
TrtDType()221   nvinfer1::DataType TrtDType() const { return type_; }
222 
223   // TODO(aaroey): make these private.
224   // Scalar weights are supported, a scalar constant tensor is represented via
225   // TRT_ShapedWeights::shape_ = {0, {1}}.
226   nvinfer1::Dims shape_;  // Note: shape.type[] is not used.
227 
228  private:
229   // This constructor is only used by TrtWeightStore, which creates the
230   // underlying buffer.
231   TRT_ShapedWeights(nvinfer1::DataType type, nvinfer1::Dims dims,
232                     Tensor tensor);
233 
234   nvinfer1::DataType type_;
235 
236   // All weights should be stored inside TrtWeightStore to make sure lifetime of
237   // all the underlying tensors are available until the engine is built. For
238   // this reason, tensor_ should never be reassigned to a different value that
239   // is not already present in the TrtWeightStore.
240   Tensor tensor_;
241 
242   friend class TrtWeightStore;
243 };
244 
245 // Container for TRT_ShapedWeights. We need this container because, TRT doesn't
246 // manage the lifetime of the weights buffer, it only keeps a pointer to it and
247 // requires that the data referenced by the pointer be available until the
248 // building of engine is complete. For more information see
249 // https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_weights.html
250 //
251 // TODO(laigd): consider adding garbage collection to the unused weights.
252 class TrtWeightStore {
253  public:
254   // Get a TRT_ShapedWeights with 'type' and 'dims'.
255   TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type,
256                                    const nvinfer1::Dims& dims);
257 
258   // Get a TRT_ShapedWeights with the same data type and dimensions as
259   // 'weights'.
GetTempWeights(const TRT_ShapedWeights & weights)260   TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) {
261     return GetTempWeights(weights.TrtDType(), weights.shape_);
262   }
263 
264  private:
265   // The backend storage of the TRT_ShapedWeights.
266   std::vector<Tensor> store_;
267 };
268 
269 // Represents a TRT-style input to a TF node, it can be either a
270 // ITensorProxyPtr (representing nvinfer1::ITensor* or SimpleITensor),
271 // or TRT_ShapedWeights which is compile-time constant.
272 //
273 // TODO(laigd): maybe rename it to TrtArgument, or mimic XlaCompiler::Argument.
274 class TRT_TensorOrWeights {
275  public:
TRT_TensorOrWeights()276   TRT_TensorOrWeights() {}
277   TRT_TensorOrWeights(ITensorProxyPtr);
278   TRT_TensorOrWeights(ITensorProxyPtr tensor, int batch_size);
279 
280   // Constructor that makes it an ITensor, doesn't take ownership of 'tensor'.
281   // This is used by Converter when building the TRT network, where the ITensor
282   // is owned by the TRT network being built. See comment for 'trt_tensor_'
283   // in trt_proxy_tensor.h.
284   explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1);
285 
286   // Constructor that makes it an ITensor by creating one using provided data
287   // type and shape, and takes ownership of the created ITensor. This is used by
288   // TrtNodeValidator to encapsulate the type and shape information for
289   // validation of graph nodes, and the created ITensor is fake and temporary,
290   // and should not be used to build any TRT network. See comment for
291   // 'simple_tensor_' in trt_proxy_tensor.h.
292   explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
293                                const nvinfer1::Dims& trt_dims, int batch_size);
294 
295   // Constructor that makes it a TRT_TensorOrWeights.
296   explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights);
297 
298   TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs);
299 
300   void operator=(const TRT_TensorOrWeights& rhs);
301 
is_tensor()302   bool is_tensor() const { return initialized_ && is_tensor_; }
is_weights()303   bool is_weights() const { return initialized_ && !is_tensor_; }
304 
305   ITensorProxyPtr tensor() const;
306 
weights()307   TRT_ShapedWeights& weights() {
308     CHECK(is_weights());
309     return weights_;
310   }
311 
weights()312   const TRT_ShapedWeights& weights() const {
313     CHECK(is_weights());
314     return weights_;
315   }
316 
317   nvinfer1::Dims GetTrtDims() const;
318 
319   Status GetTfType(DataType* tf_type) const;
320 
batch_size()321   int batch_size() const { return batch_size_; }
322 
323   string DebugString() const;
324 
325  private:
set_batch_size(int batch_size)326   void set_batch_size(int batch_size) { batch_size_ = batch_size; }
327 
328   // First dimension of the TF tensor (NOT tensor_) that is represented by
329   // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s
330   // dimensions (obtained via tensor_->getDimensions()) do not contain the batch
331   // dimension. For example, when a TF tensor with shape (A,B,C) is represented
332   // in TRT, tensor_->getDimensions() will be (B,C) and batch_size_ will be A.
333   //
334   // This requires that all tensors in the subgraph that is converted to a TRT
335   // engine have the same batch size are represented by the first dimension of
336   // their shape, and Converter will verify this during conversion. The drawback
337   // is that currently it cannot convert a graph that doesn't have the batch
338   // size represented in the shapes or the batch sizes are different. See
339   // b/118387490 for more details.
340   //
341   // If use_implicit_batch is false, batch_size_ is unused and
342   // tensor_->getDimensions() will contain the entire shape (A,B,C).
343   ITensorProxyPtr tensor_proxy_ptr_ = nullptr;
344   int batch_size_ = -1;
345 
346   TRT_ShapedWeights weights_;
347   bool initialized_ = false;
348   bool is_tensor_ = false;
349 
350   friend class Converter;
351 };
352 
353 class Converter;
354 
355 // Parameters for each op converter.
356 struct OpConverterParams {
357   // Constructor used for validation only.
358   OpConverterParams(const NodeDef& node_def,
359                     const std::vector<TRT_TensorOrWeights>& inputs,
360                     std::vector<TRT_TensorOrWeights>* outputs,
361                     TrtWeightStore* weight_store,
362                     TrtPrecisionMode precision_mode, bool use_calibration,
363                     bool use_implicit_batch);
364 
365   // Constructor used for conversion.
366   OpConverterParams(Converter* converter, const NodeDef& node_def,
367                     const std::vector<TRT_TensorOrWeights>& inputs,
368                     std::vector<TRT_TensorOrWeights>* outputs,
369                     TrtWeightStore* weight_store);
370 
371   Converter* converter = nullptr;
372   const NodeDef& node_def;
373   const std::vector<TRT_TensorOrWeights>& inputs;
374   std::vector<TRT_TensorOrWeights>* outputs;
375   const bool validation_only;
376   TrtWeightStore* weight_store;
377   const TrtPrecisionMode precision_mode;
378   const bool use_calibration;
379   const bool use_implicit_batch;
380 };
381 
382 using OpConverter = std::function<Status(OpConverterParams*)>;
383 
384 // Class to verify if specific TF node is supported by TRT.
385 class TrtNodeValidator {
386  public:
387   // 'graph_properties' is the GraphProperties of the graph whose nodes will be
388   // checked by IsTensorRTCandidate() later. It is used to get the shape and
389   // data type information of a tensor for validation purpose.
390   TrtNodeValidator(const grappler::GraphProperties& graph_properties,
391                    TrtPrecisionMode precision_mode, bool use_calibration,
392                    bool use_implicit_batch);
393 
394   // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
395   // to TRT subgraph and later converted into TRT engine.
396   Status IsTensorRTCandidate(const Node* node);
397 
398   static const std::set<string>* quantize_ops;
399 
400  private:
401   void RegisterOpValidators();
402 
403   // Convert a Const node to a TRT_TensorOrWeights.
404   Status ConvertConstToWeights(const NodeDef& const_node_def,
405                                const std::vector<TRT_TensorOrWeights>& inputs,
406                                TRT_TensorOrWeights* output);
407 
408   // Convert the output tensor at 'output_port' of 'node_def' to a
409   // TRT_TensorOrWeights which will be later used as an input to other nodes and
410   // passed to ValidateNode() below.
411   Status ConvertToTensorOrWeights(const NodeDef& node_def, int output_port,
412                                   TRT_TensorOrWeights* tensor_or_weights);
413 
414   // Stores all the validators by op type. If no validator is registered for
415   // specific op, it means no validation is needed and ValidateNode() will
416   // return OK.
417   std::unordered_map<string, OpConverter> op_validators_;
418 
419   // Store the weights added during validation. Some validations (e.g.
420   // validation for Const node) may produce weights.
421   TrtWeightStore weight_store_;
422 
423   // GraphProperties of the graph whose nodes are to be validated by
424   // IsTensorRTCandidate().
425   const grappler::GraphProperties& graph_properties_;
426 
427   // Quantization ops are only converted when using quantized precisions.
428   const TrtPrecisionMode precision_mode_;
429 
430   const bool use_calibration_;
431 
432   const bool use_implicit_batch_;
433 
434   friend class ValidatorTest;
435   friend class OpConverterTest;
436 };
437 
438 // Class to convert TF nodes to TRT network.
439 class Converter {
440  public:
441   // Used for Converter::RenameAndMarkOutputTensors()
442   struct EngineOutputInfo {
443     // The TRT tensor name which produces the output.
444     string source_tensor_name;
445     // The TensorFlow node name which is receiving the output from the TRT
446     // engine. This should always be the Identity node created in
447     // ConvertSegmentToGraphDef.
448     string dest_node_name;
449     // Output type. TensorRT requires this to be explicitly set for engine
450     // outputs.
451     nvinfer1::DataType trt_dtype;
452   };
453 
454   static StatusOr<std::unique_ptr<Converter>> Create(
455       TrtPrecisionMode precision_mode, bool use_calibration,
456       nvinfer1::ILogger* trt_logger, const bool use_implicit_batch,
457       absl::string_view engine_name);
458 
459   //////////////////////////////////////////////////////////////////////////////
460   // Methods used by the TRT engine builder to build a TRT network from a TF
461   // function/subgraph.
462 
463   // Convert the node to TRT network.
464   Status ConvertNode(const NodeDef& node_def);
465 
466   // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and
467   // 'batch_size'.
468   Status AddInputTensor(const string& name, nvinfer1::DataType dtype,
469                         const nvinfer1::Dims& dims, int batch_size);
470 
471   // Mark the tensors with names specified by source_tensor_name as output of
472   // the TRT network, and set their names in the TRT network as dest_node_name.
473   Status RenameAndMarkOutputTensors(
474       const std::vector<EngineOutputInfo>& output_tensors);
475 
476   // Build a TRT engine using the created network.
477   Status BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
478                          int max_batch_size, size_t max_workspace_size_bytes,
479                          nvinfer1::IGpuAllocator* allocator,
480                          TRTInt8Calibrator* calibrator,
481                          TrtShapeOptimizationProfile* profiles);
482 
483   //////////////////////////////////////////////////////////////////////////////
484   // Methods used by op converters to convert individual TF node and add layers
485   // to the TRT network.
486 
487   // Op converters (e.g. ConvertReshape) need to access the TRT network in order
488   // to add TRT layers.
network()489   nvinfer1::INetworkDefinition* network() { return trt_network_.get(); }
490 
491   // What precision are we targeting?
precision_mode()492   TrtPrecisionMode precision_mode() const { return precision_mode_; }
493 
494   // Calibration will be or was previously performed on this network?
use_calibration()495   bool use_calibration() const { return use_calibration_; }
496 
497   // Whether implicit batch mode is enabled
use_implicit_batch()498   bool use_implicit_batch() const { return use_implicit_batch_; }
499 
500   // This function should be called when we know the quantization range of a
501   // tensor from a quantize/dequantize node.
502   void ProvideQuantizationRange(ITensorProxyPtr* tensor, float min_range,
503                                 float max_range);
504 
505   // Should be called when full TRT network has been constructed and before
506   // building the engine.
507   void MaybeApplyQuantizationRanges();
508 
509   // Below are helper methods for op converters to add different layers to the
510   // TRT network.
511 
512   // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to
513   // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch
514   // dimension which should always be 0. If this is for adding a transpose layer
515   // to support the conversion of 'node_def', callers need to provide a
516   // non-empty 'sub_op_name' appended to the name of 'node_def' to avoid layer
517   // name conflicts.
518   Status TransposeTensor(ITensorProxyPtr input_tensor,
519                          const std::vector<int>& order_with_batch_dim,
520                          ITensorProxyPtr* output_tensor,
521                          const NodeDef& node_def,
522                          absl::string_view sub_op_name = "");
523 
524   // Reshapes a dynamic shape tensor by removing or adding dimensions of size 1,
525   // and/or permuting the dimensions. The new shape is derived from the shape of
526   // the input tensor according to the slices and size_for_added_dims arguments.
527   //
528   // If there would be at most one unknown dimension, we could set the new shape
529   // using IShuffleLayer::setReshapeDimensions, which treats -1 as a special
530   // value (the same way as TF). In general, we can have more than one unknown
531   // dimensions, and we have to manipulate the shape tensors during runtime to
532   // define the new shape. This helper function defines the necessary shape
533   // inference layers and calls reshape using the calculated new shape.
534   //
535   // Example:
536   //
537   // Assume that we want to reshape a tensor from shape {A,B,C,D} to {C,D,A,B}
538   // (no transpose, just change the shape). In dynamic shape mode, the A,B,C,D
539   // values are not necessarily known at conversion time, they can be all -1. We
540   // can only define the new shape at runtime, when the actual shape is already
541   // known. To define the new shape:
542   // - We use an IShapeLayer to retrieve a shape tensor with the {A,B,C,D}
543   //   values.
544   // - Create two slices {C,D} and {A,B} of the shape tensor.
545   // - Concatenate these slices {C,D,A,B},
546   // - Set the {C,D,A,B} shape tensor as an input shape tensor for
547   // IShuffleLayer.
548   //
549   // This can be achieved by calling DynamicReshape(input, {{2,4},{0,2}},
550   // params).
551   //
552   // Before each slice we can insert new dims if the corresponding
553   // size_for_added_dims element is not negative. The size_for_added_dims array
554   // can have more than slices.size() elements, in order to insert a dimension
555   // after the last slice. For example, to add two leading 1 dimensions, and
556   // three trailing 1 dimensions, call DynamicReshape(input, {{0,nbDims}},
557   // {2, 3}).
558   //
559   // Parameters:
560   // input - input tensor
561   // slices - [start, end) pairs of slices
562   // params - conversion parameters
563   // output - reshaped tensor
564   // size_for_added_dims - size of dimension inserted right before slice[i]. We
565   //   only insert a new dim if size_for_added_dims[i] >= 0.
566   Status DynamicReshape(ITensorProxyPtr input,
567                         std::vector<std::pair<int, int>> slices,
568                         OpConverterParams* params, ITensorProxyPtr* output,
569                         std::vector<int> size_for_added_dims = {},
570                         absl::optional<int> op_instance = absl::nullopt);
571 
572   // Inserts a singleton dimension at axis for a dynamic shape tensor.
573   Status DynamicExpandDims(ITensorProxyPtr input, const nvinfer1::Dims& dims,
574                            int axis, OpConverterParams* params,
575                            ITensorProxyPtr* output,
576                            absl::optional<int> op_instance = absl::nullopt);
577 
578   // Helper function to add a squeeze op to the network.
579   //
580   // The input_dims argument stores the TRT dimensions of the input tensor,
581   // where the dimensions to be squeezed are replaced by 0.
582   Status SqueezeTensor(ITensorProxyPtr input, std::vector<int>* input_dims,
583                        OpConverterParams* params, ITensorProxyPtr* output);
584 
585   // Creates an IConstantLayer using 'weights' whose dimensions are specified by
586   // 'dims', and returns the output ITensor.
587   ITensorProxyPtr CreateConstantLayer(const TRT_ShapedWeights& weights,
588                                       const nvinfer1::Dims& dims);
589 
590   // Gets the min and max value in a TRT_ShapedWeights
591   Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
592                         float* out_max) const;
593 
594   // Constructs a name and passed it to the TensorRT layer to support xprof.
595   void SetLayerName(
596       nvinfer1::ILayer* layer, const NodeDef& node_def,
597       absl::string_view sub_op_name = "",
598       absl::optional<int> sub_op_instance = absl::nullopt,
599       absl::optional<std::string> origin_node_name = absl::nullopt);
600 
601   void SetLayerName(nvinfer1::ILayer* layer, absl::string_view main_op_name,
602                     absl::string_view sub_op_name,
603                     absl::optional<int> sub_op_instance = absl::nullopt);
604 
605  private:
606   Converter(TrtPrecisionMode precision_mode, bool use_calibration,
607             nvinfer1::ILogger* trt_logger, const bool use_implicit_batch,
608             absl::string_view engine_name);
609 
610   Status Init(nvinfer1::ILogger* trt_logger);
611 
612   // Verify the provided batch_size is consistent with batch_size_ and update it
613   // if necessary.
614   Status MaybeUpdateBatchSize(int batch_size);
615 
616   // Add the provided tensor/weights to the map trt_tensors_.
617   Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input);
618 
619   // Get the tensor/weights from trt_tensors_ by 'name'.
620   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output);
621 
622   // Get the inputs of 'node_def' from trt_tensors_.
623   Status GetInputs(const NodeDef& node_def,
624                    std::vector<TRT_TensorOrWeights>* inputs) const;
625 
626   void RegisterOpConverters();
627 
628   // Registered op converters by op type.
629   std::unordered_map<string, OpConverter> op_registry_;
630 
631   // Tensors/weights added during construction of trt_network_.
632   std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
633 
634   // The TRT builder used to create the network and build the engine. Not owned.
635   TrtUniquePtrType<nvinfer1::IBuilder> trt_builder_;
636 
637   // The TRT network being built.
638   TrtUniquePtrType<nvinfer1::INetworkDefinition> trt_network_;
639 
640   // Store the weights added during construction of trt_network_.
641   TrtWeightStore weight_store_;
642 
643   // During conversion, this table is populated with quantization ranges per
644   // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT
645   // quantization ranges. Since TRT only supports symmetric ranges, we will
646   // store the range as a single float = max(abs(min_range), abs(max_range)).
647   // Range refers to the floating point values, e.g. min_range = 0.0f, max_range
648   // = 6.0f for Relu6.
649   std::unordered_map<ITensorProxyPtr*, float> quantization_ranges_proxy_;
650   std::unordered_map<nvinfer1::ITensor*, float> quantization_ranges_;
651 
652   const TrtPrecisionMode precision_mode_;
653 
654   const bool use_calibration_;
655 
656   // If this is false, all dimensions including the batch dimension are
657   // set explicitely.
658   const bool use_implicit_batch_;
659 
660   // Batch size of inputs to trt_network_ added by AddInputTensor(). During
661   // network construction it will update this, use it to verify the batch
662   // size of all inputs are compatible, and make sure individual TF node is
663   // acceptable by TRT.
664   int batch_size_ = -1;
665 
666   // Assign a ID to each constant layer we create, so that we can assign a
667   // unique name to the layer.
668   int next_constant_layer_id_ = 0;
669 
670   // The name of the TRTEngineOp node.
671   absl::string_view engine_name_;
672 
673   friend class ConverterTest;
674   friend class OpConverterTest;
675 };
676 
677 // Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims'
678 // (which doesn't contain the batch dimension).
679 //
680 // If validation_only is true, it doesn't do the conversion but only do some
681 // minimum validation for the eligibility of the conversion, and *tensor will
682 // be set to nullptr.
683 // If validation_only is false converter must not be nullptr.
684 Status PrepareTensorForShape(
685     Converter* converter, const TRT_TensorOrWeights& input,
686     const nvinfer1::Dims& dims, const bool validation_only,
687     ITensorProxyPtr* tensor, const NodeDef& node_def,
688     absl::optional<int> op_instance = absl::nullopt,
689     absl::optional<std::string> origin_node_name = absl::nullopt);
690 
691 // Return OK if the broadcast scheme is supported and compute the shapes after
692 // broadcasting. check_feasibility can be set to false in cases where dimensions
693 // do not need to match exactly (as in the case of BatchMatMulV2).
694 Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
695                             const TRT_TensorOrWeights& operand_r,
696                             const bool check_feasibility,
697                             const bool use_implicit_batch,
698                             nvinfer1::Dims* operand_l_new_dims,
699                             nvinfer1::Dims* operand_r_new_dims);
700 
701 // Map of all supported UnaryOperations
702 const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
703 // Map of all supported ActivationTypes
704 const std::unordered_map<string, nvinfer1::ActivationType>* ActivationTypeMap();
705 // Map of all supported BinaryOperations
706 const std::unordered_map<string, nvinfer1::ElementWiseOperation>*
707 BinaryOperationMap();
708 
709 // Returns true if the node is a quantize and dequantize Op.
710 bool IsQuantizeAndDequantizeOp(const Node*);
711 
712 }  // namespace convert
713 }  // namespace tensorrt
714 }  // namespace tensorflow
715 
716 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
717 
718 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
719