Searched refs:trt_type (Results 1 – 6 of 6) sorted by relevance
/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/ |
D | utils.cc | 153 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { in TfTypeToTrtType() argument 156 *trt_type = nvinfer1::DataType::kFLOAT; in TfTypeToTrtType() 159 *trt_type = nvinfer1::DataType::kHALF; in TfTypeToTrtType() 162 *trt_type = nvinfer1::DataType::kINT32; in TfTypeToTrtType() 166 *trt_type = nvinfer1::DataType::kBOOL; in TfTypeToTrtType() 176 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { in TrtTypeToTfType() argument 177 switch (trt_type) { in TrtTypeToTfType()
|
D | weights.cc | 164 nvinfer1::DataType trt_type = tensor()->getType(); in GetTfType() local 165 return TrtTypeToTfType(trt_type, tf_type); in GetTfType()
|
D | utils.h | 363 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); 364 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
|
D | weights.h | 152 StatusOr<TRT_ShapedWeights> GetTempWeights(nvinfer1::DataType trt_type,
|
D | convert_nodes_test.cc | 852 nvinfer1::DataType trt_type; in TestGetWeightRange() local 853 TF_ASSERT_OK(TfTypeToTrtType(DataTypeToEnum<T>::v(), &trt_type)); in TestGetWeightRange() 855 weight_store->GetTempWeights(trt_type, CreateDims({2, 3})).ValueOrDie(); in TestGetWeightRange() 1272 nvinfer1::DataType trt_type; in BuildAndRun() local 1273 TF_RETURN_IF_ERROR(TfTypeToTrtType(data.tensor.dtype(), &trt_type)); in BuildAndRun() 1274 output_info.push_back({data.name, data.name, trt_type}); in BuildAndRun() 1345 nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT, in AddTestTensorWithTFDims() argument 1348 TF_ASSERT_OK(TrtTypeToTfType(trt_type, &tf_type)); in AddTestTensorWithTFDims() 1368 name, trt_type, dims_adap->AsTrtDims(), batch_size); in AddTestTensorWithTFDims() 1851 nvinfer1::DataType trt_type; local [all …]
|
D | convert_nodes.cc | 493 nvinfer1::DataType trt_type = nvinfer1::DataType::kINT32, in CreateScalarConstant() argument 496 params->weight_store->GetTempWeights(trt_type, dims); 510 nvinfer1::DataType trt_type = nvinfer1::DataType::kFLOAT; // Default to FP32. in CreateBroadcastableScalarConstant() local 515 TF_RETURN_IF_ERROR(TfTypeToTrtType(dtype, &trt_type)); in CreateBroadcastableScalarConstant() 523 return CreateScalarConstant(params, value, tensor, trt_type, in CreateBroadcastableScalarConstant()
|