• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
18 
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/grappler/costs/graph_properties.h"
33 #include "tensorflow/core/lib/core/status.h"
34 
35 #if GOOGLE_CUDA
36 #if GOOGLE_TENSORRT
37 #include "tensorrt/include/NvInfer.h"
38 
39 namespace tensorflow {
40 namespace tensorrt {
41 extern const char* const kInputPHName;
42 extern const char* const kOutputPHName;
43 
44 namespace convert {
45 
46 #define IS_TRT_VERSION_GE(major, minor, patch)                  \
47   ((NV_TENSORRT_MAJOR > major) ||                               \
48    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
49    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
50     NV_TENSORRT_PATCH >= patch))
51 
52 struct EngineConnection {
53   // Constructs a non-control edge.
EngineConnectionEngineConnection54   EngineConnection(const string& outside, int out_id, int out_port,
55                    const string& inside, int in_id, int in_port,
56                    bool input_edge, int port)
57       : outside_node_name(outside),
58         outside_id(out_id),
59         outside_port(out_port),
60         inside_node_name(inside),
61         inside_id(in_id),
62         inside_port(in_port),
63         is_input_edge(input_edge),
64         port_number(port) {}
65 
66   // Constructs a control edge.
EngineConnectionEngineConnection67   EngineConnection(const string& outside, int out_id, const string& inside,
68                    int in_id, bool input_edge)
69       : outside_node_name(outside),
70         outside_id(out_id),
71         outside_port(Graph::kControlSlot),
72         inside_node_name(inside),
73         inside_id(in_id),
74         inside_port(Graph::kControlSlot),
75         is_input_edge(input_edge),
76         port_number(Graph::kControlSlot) {}
77 
is_control_edgeEngineConnection78   bool is_control_edge() const { return port_number == Graph::kControlSlot; }
79 
80   const string outside_node_name;
81   const int outside_id;
82   const int outside_port;
83   PartialTensorShape outside_shape;  // Only set for input edge.
84 
85   const string inside_node_name;
86   const int inside_id;
87   const int inside_port;
88   PartialTensorShape inside_shape;  // Only set for output edge.
89 
90   DataType connection_type;
91   const bool is_input_edge;
92 
93   // The port number of the TRT node connected with this edge.
94   const int port_number;
95 };
96 
97 struct EngineInfo {
EngineInfoEngineInfo98   EngineInfo()
99       : engine_type(EngineType::TRTStatic),
100         max_workspace_size_bytes(0),
101         precision_mode(TrtPrecisionMode::FP32),
102         use_calibration(true) {}
103 
104   string engine_name;
105   string device;
106   GraphDef segment_graph_def;
107 
108   // Non-control input connections inside this vector are sorted in a way such
109   // that, the segment nodes connecting to them are topological sorted.
110   // In addition, for non-control connections, there must be no duplicates.
111   std::vector<EngineConnection> connections;
112 
113   enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
114   EngineType engine_type;
115   int64 max_workspace_size_bytes;
116   int maximum_cached_engines;
117   std::vector<int> cached_engine_batches;
118   TrtPrecisionMode precision_mode;
119   bool use_calibration;
120 };
121 
122 // Constructs a graphdef from the segment in the given graph. Adds placeholder
123 // nodes for input edges (InputPH_*) and identity nodes for output edges
124 // (OutputPH_*). This function needs to be called before TensorRT nodes
125 // inserted in order to correctly get sizes from the original graph.
126 //
127 // - subgraph_node_names: the node names of the subgraph.
128 // - subgraph_node_ids: the node ids of the subgraph, must be sorted in
129 //   topological order.
130 // - segment_def: the output GraphDef, whose non-input/output nodedefs will be
131 //   sorted in topological order.
132 // - scope_name: the name of the scope where the TRTEngineOp will be placed.
133 //
134 // TODO(aaroey): add tests to validate these properties.
135 Status ConvertSegmentToGraphDef(
136     const Graph* graph, const grappler::GraphProperties& graph_properties,
137     const std::vector<const Node*>& subgraph_nodes,
138     std::vector<EngineConnection>* connections, GraphDef* segment_def,
139     string* scope_name);
140 
141 // Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
142 // 'builder' successfully build the engine. If the result is not ok, 'engine'
143 // will be set to nullptr
144 // Once returned, 'builder' is not needed any more and can be safely detroyed.
145 //
146 // - convert_successfully: indicates whether the converson to TensorRT network
147 //   is successful. This is different than successfully building the engine:
148 //   building can still fail afterwards.
149 Status ConvertGraphDefToEngine(
150     const GraphDef& gdef, TrtPrecisionMode precision_mode, int max_batch_size,
151     size_t max_workspace_size_bytes,
152     const std::vector<PartialTensorShape>& input_shapes, Logger* logger,
153     nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator,
154     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
155     bool* convert_successfully);
156 
157 // Helper class for the segmenter to determine whether an output edge from the
158 // TRT segment is valid.
159 class OutputEdgeValidator {
160  public:
161   // Return true if the specified edge is eligible to be an output edge of the
162   // TRT segment.
163   bool operator()(const Edge* out_edge) const;
164 };
165 
166 string DebugString(const nvinfer1::DimensionType type);
167 string DebugString(const nvinfer1::DataType trt_dtype);
168 string DebugString(const nvinfer1::Dims& dims);
169 string DebugString(const nvinfer1::Permutation& permutation, int len);
170 string DebugString(const nvinfer1::ITensor& tensor);
171 int64_t TrtDimsNumElements(const nvinfer1::Dims& dims);
172 
173 // Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight.
174 class TRT_ShapedWeights {
175  public:
176   explicit TRT_ShapedWeights(DataType type = DT_FLOAT);
177 
178   // Copy from another weights.
179   //
180   // NOTE: this does not copy the underlying buffer but only increase its
181   // reference count.
182   TRT_ShapedWeights(const TRT_ShapedWeights& rhs);
183 
184   nvinfer1::Weights GetTrtWeights() const;
185 
186   // Returns the raw pointer to the underlying buffer which holds the weights
187   // value.
GetValues()188   void* GetValues() const {
189     return const_cast<char*>(tensor_.tensor_data().data());
190   }
191 
192   int64_t count() const;
193 
194   size_t size_bytes() const;
195 
196   string DebugString() const;
197 
198   template <typename T>
GetSpan()199   absl::Span<const T> GetSpan() const {
200     return absl::Span<const T>(tensor_.flat<T>().data(), count());
201   }
202 
203   template <typename T>
ToVector()204   std::vector<T> ToVector() const {
205     auto span = GetSpan<T>();
206     return std::vector<T>(span.data(), span.data() + span.size());
207   }
208 
209   // TODO(aaroey): make these private.
210   nvinfer1::Dims shape_;  // Note: shape.type[] is not used.
211   DataType type_;
212 
213  private:
214   // This constructor is only used by TrtWeightStore, which creates the
215   // underlying buffer.
216   TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor);
217 
218   // All weights should be stored inside TrtWeightStore to make sure lifetime of
219   // all the underlying tensors are available until the engine is built. For
220   // this reason, tensor_ should never be reassigned to a different value that
221   // is not already present in the TrtWeightStore.
222   Tensor tensor_;
223 
224   friend class TrtWeightStore;
225 };
226 
227 // Container for TRT_ShapedWeights. We need this container because, TRT doesn't
228 // manage the lifetime of the weights buffer, it only keeps a pointer to it and
229 // requires that the data referenced by the pointer be available until the
230 // building of engine is complete. For more information see
231 // https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_weights.html
232 //
233 // TODO(laigd): consider adding garbage collection to the unused weights.
234 class TrtWeightStore {
235  public:
236   // Get a TRT_ShapedWeights with 'type' and 'dims'.
237   TRT_ShapedWeights GetTempWeights(DataType type, const nvinfer1::Dims& dims);
238 
239   // Get a TRT_ShapedWeights with the same data type and dimensions as
240   // 'weights'.
GetTempWeights(const TRT_ShapedWeights & weights)241   TRT_ShapedWeights GetTempWeights(const TRT_ShapedWeights& weights) {
242     return GetTempWeights(weights.type_, weights.shape_);
243   }
244 
245  private:
246   // The backend storage of the TRT_ShapedWeights.
247   std::vector<Tensor> store_;
248 };
249 
250 // Represents a TRT-style input to a TF node, it can be either a
251 // nvinfer1::ITensor, or TRT_ShapedWeights which is compile-time constant.
252 //
253 // TODO(laigd): maybe rename it to TrtArgument, or mimic XlaCompiler::Argument.
254 class TRT_TensorOrWeights {
255  public:
TRT_TensorOrWeights()256   TRT_TensorOrWeights() {}
257 
258   // Constructor that makes it an ITensor, doesn't take ownership of 'tensor'.
259   // This is used by Converter when building the TRT network, where the ITensor
260   // is owned by the TRT network being built. See comment for 'tensor_' below.
261   explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1);
262 
263   // Constructor that makes it an ITensor by creating one using provided data
264   // type and shape, and takes ownership of the created ITensor. This is used by
265   // TrtNodeValidator to encapsulate the type and shape information for
266   // validation of graph nodes, and the created ITensor is fake and temporary,
267   // and should not be used to build any TRT network. See comment for
268   // 'simple_itensor_' below.
269   explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype,
270                                const nvinfer1::Dims& trt_dims, int batch_size);
271 
272   // Constructor that makes it a TRT_TensorOrWeights.
273   explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights);
274 
275   TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs);
276 
277   void operator=(const TRT_TensorOrWeights& rhs);
278 
is_tensor()279   bool is_tensor() const { return initialized_ && is_tensor_; }
is_weights()280   bool is_weights() const { return initialized_ && !is_tensor_; }
281 
282   nvinfer1::ITensor* tensor();
283 
284   const nvinfer1::ITensor* tensor() const;
285 
weights()286   TRT_ShapedWeights& weights() {
287     CHECK(is_weights());
288     return weights_;
289   }
290 
weights()291   const TRT_ShapedWeights& weights() const {
292     CHECK(is_weights());
293     return weights_;
294   }
295 
296   nvinfer1::Dims GetTrtDims() const;
297 
batch_size()298   int batch_size() const { return batch_size_; }
299 
300   string DebugString() const;
301 
302  private:
303   class SimpleITensor;
304 
set_batch_size(int batch_size)305   void set_batch_size(int batch_size) { batch_size_ = batch_size; }
306 
307   // When it represents an ITensor, the ITensor can be either passed by the
308   // caller via the constructor that takes an ITensor* as parameter, or be
309   // created as a SimpleITensor.
310   //
311   // In the first case, the ITensor pointer is stored in 'tensor_' below, and
312   // the ITensor itself is not owned by this class. This method is used by
313   // Converter (e.g. AddInputTensor) and op converters during TRT network
314   // construction, where the TRT network owns the ITensor.
315   //
316   // In the second case, the created SimpleITensor is stored in
317   // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake
318   // implementation of ITensor and is used only by TrtNodeValidator to validate
319   // the graph nodes.
320   nvinfer1::ITensor* tensor_ = nullptr;  // Not owned.
321   std::shared_ptr<SimpleITensor> simple_itensor_ = nullptr;
322 
323   // First dimension of the TF tensor (NOT tensor_) that is represented by
324   // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s
325   // dimensions (obtained via tensor_->getDimensions()) do not contain the batch
326   // dimension. For example, when a TF tensor with shape (A,B,C) is represented
327   // in TRT, tensor_->getDimensions() will be (B,C) and batch_size_ will be A.
328   //
329   // This requires that all tensors in the subgraph that is converted to a TRT
330   // engine have the same batch size are represented by the first dimension of
331   // their shape, and Converter will verify this during conversion. The drawback
332   // is that currently it cannot convert a graph that doesn't have the batch
333   // size represented in the shapes or the batch sizes are different. See
334   // b/118387490 for more details.
335   int batch_size_ = -1;
336 
337   TRT_ShapedWeights weights_;
338   bool initialized_ = false;
339   bool is_tensor_ = false;
340 
341   friend class Converter;
342 };
343 
344 class Converter;
345 
346 // Parameters for each op converter.
347 struct OpConverterParams {
OpConverterParamsOpConverterParams348   OpConverterParams(Converter* arg_converter, const NodeDef& arg_node_def,
349                     const std::vector<TRT_TensorOrWeights>& arg_inputs,
350                     std::vector<TRT_TensorOrWeights>* arg_outputs,
351                     bool arg_validation_only, TrtWeightStore* arg_weight_store)
352       : converter(arg_converter),
353         node_def(arg_node_def),
354         inputs(arg_inputs),
355         outputs(arg_outputs),
356         validation_only(arg_validation_only),
357         weight_store(arg_weight_store) {}
358 
359   Converter* converter;
360   const NodeDef& node_def;
361   const std::vector<TRT_TensorOrWeights>& inputs;
362   std::vector<TRT_TensorOrWeights>* outputs;
363   const bool validation_only;
364   TrtWeightStore* weight_store;
365 };
366 
367 using OpConverter = std::function<Status(OpConverterParams*)>;
368 
369 // Class to verify if specific TF node is supported by TRT.
370 class TrtNodeValidator {
371  public:
372   TrtNodeValidator();
373 
374   // Validate the node, and return ok if it's supported by TRT.
375   //
376   // - 'node_def' is the node to validate.
377   // - 'input_node_and_ports' are the input NodeDefs and their output ports that
378   //   are connected to 'node_def' in the TF graph.
379   // - 'graph_properties' is the GraphProperties of the graph where 'node_def'
380   //   belongs. It is used to get the shape and data type information of a
381   //   tensor for validation purpose.
382   Status ValidateNode(
383       const NodeDef& node_def,
384       const std::vector<std::pair<const NodeDef*, int>>& input_node_and_ports,
385       const TrtPrecisionMode precision_mode,
386       const grappler::GraphProperties& graph_properties);
387 
388  private:
389   static const std::set<string>* quantize_ops;
390 
391   void RegisterOpValidators();
392 
393   // Convert a Const node to a TRT_TensorOrWeights.
394   Status ConvertConstToWeights(const NodeDef& const_node_def,
395                                const std::vector<TRT_TensorOrWeights>& inputs,
396                                TRT_TensorOrWeights* output);
397 
398   // Convert the output tensor at 'output_port' of 'node_def' to a
399   // TRT_TensorOrWeights which will be later used as an input to other nodes and
400   // passed to ValidateNode() below.
401   Status ConvertToTensorOrWeights(
402       const NodeDef& node_def, int output_port,
403       const grappler::GraphProperties& graph_properties,
404       TRT_TensorOrWeights* tensor_or_weights);
405 
406   // Stores all the validators by op type. If no validator is registered for
407   // specific op, it means no validation is needed and ValidateNode() will
408   // return OK.
409   std::unordered_map<string, OpConverter> op_validators_;
410 
411   // Store the weights added during validation. Some validations (e.g.
412   // validation for Const node) may produce weights.
413   TrtWeightStore weight_store_;
414 
415   friend class ValidatorTest;
416   friend class OpConverterTest;
417 };
418 
419 // Class to convert TF nodes to TRT network.
420 class Converter {
421  public:
422   // Used for Converter::RenameAndMarkOutputTensors()
423   struct EngineOutputInfo {
424     // The TRT tensor name which produces the output.
425     string source_tensor_name;
426     // The TensorFlow node name which is receiving the output from the TRT
427     // engine. This should always be the Identity node created in
428     // ConvertSegmentToGraphDef.
429     string dest_node_name;
430     // Output type. TensorRT requires this to be explicitly set for engine
431     // outputs.
432     nvinfer1::DataType trt_dtype;
433   };
434 
435   Converter(nvinfer1::INetworkDefinition* trt_network,
436             TrtPrecisionMode precision_mode, bool use_calibration);
437 
438   //////////////////////////////////////////////////////////////////////////////
439   // Methods used by the TRT engine builder to build a TRT network from a TF
440   // function/subgraph.
441 
442   // Convert the node to TRT network.
443   Status ConvertNode(const NodeDef& node_def);
444 
445   // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and
446   // 'batch_size'.
447   Status AddInputTensor(const string& name, nvinfer1::DataType dtype,
448                         const nvinfer1::Dims& dims, int batch_size);
449 
450   // Mark the tensors with names specified by source_tensor_name as output of
451   // the TRT network, and set their names in the TRT network as dest_node_name.
452   Status RenameAndMarkOutputTensors(
453       const std::vector<EngineOutputInfo>& output_tensors);
454 
455   //////////////////////////////////////////////////////////////////////////////
456   // Methods used by op converters to convert individual TF node and add layers
457   // to the TRT network.
458 
459   // Op converters (e.g. ConvertReshape) need to access the TRT network in order
460   // to add TRT layers.
network()461   nvinfer1::INetworkDefinition* network() { return trt_network_; }
462 
463   // What precision are we targeting?
precision_mode()464   TrtPrecisionMode precision_mode() const { return precision_mode_; }
465 
466   // Calibration will be or was previously performed on this network?
use_calibration()467   bool use_calibration() const { return use_calibration_; }
468 
469   // This should be called on the inputs and outputs of any layer we create
470   // where we know that the quantization range does not change during that
471   // operation. (e.g. Reshape, Transpose, Identity, MaxPool).
472   void MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input,
473                                           nvinfer1::ITensor* output);
474 
475   // This function should be called when we know the quantization range of a
476   // tensor, either from a quantize/dequantize node or when the output is a
477   // fixed range (e.g. SoftMax, Relu6, Sigmoid).
478   void ProvideQuantizationRange(nvinfer1::ITensor* tensor, float min_range,
479                                 float max_range);
480 
481   // Should be called when full TRT network has been constructed and before
482   // building the engine.
483   void MaybeApplyQuantizationRanges();
484 
485   // Below are helper methods for op converters to add different layers to the
486   // TRT network.
487 
488   // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to
489   // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch
490   // dimension which should always be 0.
491   Status TransposeTensor(nvinfer1::ITensor* input_tensor,
492                          const std::vector<int>& order_with_batch_dim,
493                          const nvinfer1::ITensor** output_tensor);
494 
495   // Converts 'input' into 'tensor' with shape specified by 'dims'.
496   //
497   // If validation_only is true, it doesn't do the conversion but only do some
498   // minimum validation for the eligibility of the conversion, and *tensor will
499   // be set to nullptr.
500   Status PrepareTensorForShape(const TRT_TensorOrWeights& input,
501                                const nvinfer1::Dims& dims,
502                                const bool validation_only,
503                                const nvinfer1::ITensor** tensor);
504 
505   // Return OK if the broadcast scheme is supported and compute the shapes after
506   // broadcasting.
507   Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
508                               const TRT_TensorOrWeights& operand_r,
509                               nvinfer1::Dims* operand_l_new_dims,
510                               nvinfer1::Dims* operand_r_new_dims) const;
511 
512   // Creates an IConstantLayer using 'weights' whose dimensions are specified by
513   // 'dims', and returns the output ITensor.
514   nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights,
515                                          const nvinfer1::Dims& dims);
516 
517  private:
518   // Verify the provided batch_size is consistent with batch_size_ and update it
519   // if necessary.
520   Status MaybeUpdateBatchSize(int batch_size);
521 
522   // Add the provided tensor/weights to the map trt_tensors_.
523   Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input);
524 
525   // Get the tensor/weights from trt_tensors_ by 'name'.
526   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output);
527 
528   // Get the inputs of 'node_def' from trt_tensors_.
529   Status GetInputs(const NodeDef& node_def,
530                    std::vector<TRT_TensorOrWeights>* inputs) const;
531 
532   void RegisterOpConverters();
533 
534   void PropagateQuantizationRanges();
535 
536   // Gets the min and max value in a TRT_ShapedWeights
537   Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
538                         float* out_max) const;
539 
540   // Registered op converters by op type.
541   std::unordered_map<string, OpConverter> op_registry_;
542 
543   // Tensors/weights added during construction of trt_network_.
544   std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
545 
546   // Special op converter for custom plugins.
547   OpConverter plugin_converter_;
548 
549   // The TRT networking being built.
550   nvinfer1::INetworkDefinition* trt_network_;
551 
552   // Store the weights added during construction of trt_network_.
553   TrtWeightStore weight_store_;
554 
555   // During conversion, this table is populated with quantization ranges per
556   // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT
557   // quantization ranges. Since TRT only supports symmetric ranges, we will
558   // store the range as a single float = max(abs(min_range), abs(max_range)).
559   // Range refers to the floating point values, e.g. min_range = 0.0f, max_range
560   // = 6.0f for Relu6.
561   std::unordered_map<nvinfer1::ITensor*, float> quantization_ranges_;
562 
563   // Edges where quantization ranges can be inferred (copied) across ops - from
564   // first tensor to second tensor. PropagateQuantizationRanges() will propagate
565   // known ranges from quantization_ranges_ across these edges, adding the new
566   // ranges to quantization_ranges_ so that they can be applied in
567   // MaybeApplyQuantizationRanges().
568   std::vector<std::pair<nvinfer1::ITensor*, nvinfer1::ITensor*>>
569       quantization_infer_;
570 
571   const TrtPrecisionMode precision_mode_;
572 
573   const bool use_calibration_;
574 
575   // Batch size of inputs to trt_network_ added by AddInputTensor(). During
576   // network construction it will update this, use it to verify the batch
577   // size of all inputs are compatible, and make sure individual TF node is
578   // acceptable by TRT.
579   int batch_size_ = -1;
580 
581   friend class ConverterTest;
582   friend class OpConverterTest;
583 };
584 
585 // Map of all supported UnaryOperations
586 const std::unordered_map<string, nvinfer1::UnaryOperation>* UnaryOperationMap();
587 
588 }  // namespace convert
589 }  // namespace tensorrt
590 }  // namespace tensorflow
591 
592 #endif  // GOOGLE_TENSORRT
593 #endif  // GOOGLE_CUDA
594 
595 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
596