Home
last modified time | relevance | path

Searched refs:trt_type (Results 1 – 6 of 6) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/
Dutils.cc153 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { in TfTypeToTrtType() argument
156 *trt_type = nvinfer1::DataType::kFLOAT; in TfTypeToTrtType()
159 *trt_type = nvinfer1::DataType::kHALF; in TfTypeToTrtType()
162 *trt_type = nvinfer1::DataType::kINT32; in TfTypeToTrtType()
166 *trt_type = nvinfer1::DataType::kBOOL; in TfTypeToTrtType()
176 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { in TrtTypeToTfType() argument
177 switch (trt_type) { in TrtTypeToTfType()
Dweights.cc164 nvinfer1::DataType trt_type = tensor()->getType(); in GetTfType() local
165 return TrtTypeToTfType(trt_type, tf_type); in GetTfType()
Dutils.h363 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
364 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
Dweights.h152 StatusOr<TRT_ShapedWeights> GetTempWeights(nvinfer1::DataType trt_type,
Dconvert_nodes_test.cc852 nvinfer1::DataType trt_type; in TestGetWeightRange() local
853 TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type)); in TestGetWeightRange()
855 weight_store->GetTempWeights(trt_type, CreateDims({2, 3})).ValueOrDie(); in TestGetWeightRange()
1272 nvinfer1::DataType trt_type; in BuildAndRun() local
1273 TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type)); in BuildAndRun()
1274 output_info.push_back({data.name, data.name, trt_type}); in BuildAndRun()
1345 nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT, in AddTestTensorWithTFDims() argument
1348 TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type)); in AddTestTensorWithTFDims()
1368 name, trt_type, dims_adap->AsTrtDims(), batch_size); in AddTestTensorWithTFDims()
1851 nvinfer1::DataType trt_type; local
[all …]
Dconvert_nodes.cc493 nvinfer1::DataType trt_type = nvinfer1::DataType::kINT32, in CreateScalarConstant() argument
496 params->weight_store->GetTempWeights(trt_type, dims);
510 nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT; // Default to FP32. in CreateBroadcastableScalarConstant() local
515 TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type)); in CreateBroadcastableScalarConstant()
523 return CreateScalarConstant(params, value, tensor, trt_type, in CreateBroadcastableScalarConstant()