Home
last modified time | relevance | path

Searched refs:trt_type (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/
Dutils.cc206 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { in TfTypeToTrtType() argument
209 *trt_type = nvinfer1::DataType::kFLOAT; in TfTypeToTrtType()
212 *trt_type = nvinfer1::DataType::kHALF; in TfTypeToTrtType()
215 *trt_type = nvinfer1::DataType::kINT32; in TfTypeToTrtType()
224 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { in TrtTypeToTfType() argument
225 switch (trt_type) { in TrtTypeToTfType()
Dutils.h171 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
172 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
Dconvert_nodes_test.cc1053 nvinfer1::DataType trt_type; in TestGetWeightRange() local
1054 TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type)); in TestGetWeightRange()
1056 weight_store->GetTempWeights(trt_type, GetTestDims({2, 3})); in TestGetWeightRange()
1468 nvinfer1::DataType trt_type; in BuildAndRun() local
1469 TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type)); in BuildAndRun()
1470 output_info.push_back({data.name, data.name, trt_type}); in BuildAndRun()
1533 nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT, in AddTestTensorWithTFDims() argument
1536 TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type)); in AddTestTensorWithTFDims()
1549 converter_->AddInputTensor(name, trt_type, trt_dims, batch_size); in AddTestTensorWithTFDims()
1843 nvinfer1::DataType trt_type; local
[all …]
Dconvert_nodes.h244 TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type,
Dconvert_nodes.cc493 nvinfer1::DataType trt_type = nvinfer1::DataType::kINT32, in CreateScalarConstant() argument
496 params->weight_store->GetTempWeights(trt_type, dims);
510 nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT; // Default to FP32. in CreateBroadcastableScalarConstant() local
514 TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type)); in CreateBroadcastableScalarConstant()
522 return CreateScalarConstant(params, value, tensor, trt_type, in CreateBroadcastableScalarConstant()
881 nvinfer1::DataType trt_type = tensor()->getType(); in GetTfType() local
882 return TrtTypeToTfType(trt_type, tf_type); in GetTfType()