Searched refs:trt_type (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ |
D | utils.cc | 206 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { in TfTypeToTrtType() argument 209 *trt_type = nvinfer1::DataType::kFLOAT; in TfTypeToTrtType() 212 *trt_type = nvinfer1::DataType::kHALF; in TfTypeToTrtType() 215 *trt_type = nvinfer1::DataType::kINT32; in TfTypeToTrtType() 224 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { in TrtTypeToTfType() argument 225 switch (trt_type) { in TrtTypeToTfType()
|
D | utils.h | 171 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); 172 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
|
D | convert_nodes_test.cc | 1053 nvinfer1::DataType trt_type; in TestGetWeightRange() local 1054 TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type)); in TestGetWeightRange() 1056 weight_store->GetTempWeights(trt_type, GetTestDims({2, 3})); in TestGetWeightRange() 1468 nvinfer1::DataType trt_type; in BuildAndRun() local 1469 TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type)); in BuildAndRun() 1470 output_info.push_back({data.name, data.name, trt_type}); in BuildAndRun() 1533 nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT, in AddTestTensorWithTFDims() argument 1536 TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type)); in AddTestTensorWithTFDims() 1549 converter_->AddInputTensor(name, trt_type, trt_dims, batch_size); in AddTestTensorWithTFDims() 1843 nvinfer1::DataType trt_type; local [all …]
|
D | convert_nodes.h | 244 TRT_ShapedWeights GetTempWeights(nvinfer1::DataType trt_type,
|
D | convert_nodes.cc | 493 nvinfer1::DataType trt_type = nvinfer1::DataType::kINT32, in CreateScalarConstant() argument 496 params->weight_store->GetTempWeights(trt_type, dims); 510 nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT; // Default to FP32. in CreateBroadcastableScalarConstant() local 514 TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type)); in CreateBroadcastableScalarConstant() 522 return CreateScalarConstant(params, value, tensor, trt_type, in CreateBroadcastableScalarConstant() 881 nvinfer1::DataType trt_type = tensor()->getType(); in GetTfType() local 882 return TrtTypeToTfType(trt_type, tf_type); in GetTfType()
|