1From a303e237bf5506d75b98703d442f01e18fb2c820 Mon Sep 17 00:00:00 2001 2From: zhangyanhui <zhangyanhui17@huawei.com> 3Date: Mon, 8 Jul 2024 15:44:46 +0800 4Subject: [PATCH] ConstantOfShape and StridedSlice kernel support bool type 5 6--- 7 .../device/cpu/kernel/nnacl/constant_of_shape_parameter.h | 1 + 8 .../device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h | 7 +++++++ 9 .../plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c | 1 + 10 .../ops/operator_populate/constant_of_shape_populate.cc | 3 +++ 11 .../src/common/ops/populate/constant_of_shape_populate.cc | 3 +++ 12 .../lite/src/litert/kernel/cpu/base/constant_of_shape.cc | 5 +++++ 13 .../lite/tools/converter/parser/onnx/onnx_node_parser.cc | 6 ++++++ 14 7 files changed, 26 insertions(+) 15 16diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h 17index f108ea98..d75edb6f 100644 18--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h 19+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h 20@@ -23,6 +23,7 @@ typedef struct ConstantOfShapeParameter { 21 union value_ { 22 float f32_value_; 23 int32_t int32_value_; 24+ bool bool_value_; 25 } value_; 26 int data_type_; 27 int element_size_; 28diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h 29index 6c607cf5..c884d031 100644 30--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h 31+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h 32@@ -38,6 +38,13 @@ inline int ConstantOfShapeFp32(float *output, int start, int end, float value) { 33 return NNACL_OK; 34 } 35 36+inline int ConstantOfShapeBool(bool *output, int start, int end, bool value) { 37+ for (int i = start; i < end; i++) { 38+ output[i] = value; 39+ } 40+ return NNACL_OK; 41+} 42+ 43 #ifdef __cplusplus 44 } 45 #endif 46diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c 47index 1460c2cc..714bcaef 100644 48--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c 49+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c 50@@ -275,3 +275,4 @@ REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat16, CreateStridedSlice 51 REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt64, CreateStridedSlice) 52 REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt32, CreateStridedSlice) 53 REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt8, CreateStridedSlice) 54+REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeBool, CreateStridedSlice) 55diff --git a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc 56index 3552b5f9..743f42f5 100644 57--- a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc 58+++ b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc 59@@ -42,6 +42,9 @@ OpParameter *PopulateConstantOfShapeOpParameter(const BaseOperatorPtr &base_oper 60 case kNumberTypeInt32: 61 param->value_.int32_value_ = static_cast<int32_t>(value[0]); 62 break; 63+ case kNumberTypeBool: 64+ param->value_.bool_value_ = static_cast<bool>(value[0]); 65+ break; 66 default: 67 MS_LOG(ERROR) << "The value of constant of shape is invalid"; 68 free(param); 69diff --git a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc 70index 56263d13..d8fd6473 100644 71--- a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc 72+++ b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc 73@@ -48,6 +48,9 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) { 74 case kNumberTypeInt32: 75 param->value_.int32_value_ = static_cast<int32_t>(val[0]); 76 break; 77+ case kNumberTypeBool: 78+ param->value_.bool_value_ = static_cast<bool>(val[0]); 79+ break; 80 default: 81 MS_LOG(ERROR) << "The value of constant of shape is invalid"; 82 free(param); 83diff --git a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc 84index d8d24146..94f4a490 100644 85--- a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc 86+++ b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc 87@@ -53,6 +53,10 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) { 88 ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride, 89 param_->value_.int32_value_); 90 break; 91+ case kNumberTypeBool: 92+ ConstantOfShapeBool(reinterpret_cast<bool *>(output_ptr_), start, start + current_stride, 93+ param_->value_.bool_value_); 94+ break; 95 #ifdef ENABLE_FP16 96 case kNumberTypeFloat16: 97 ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride, 98@@ -100,4 +104,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCr 99 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) 100 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) 101 REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) 102+REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) 103 } // namespace mindspore::kernel 104diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 105index 39197be6..4d11561e 100644 106--- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 107+++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 108@@ -223,6 +223,12 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso 109 value->push_back(static_cast<float>(reinterpret_cast<const float16 *>(onnx_tensor.raw_data().data())[i])); 110 } 111 break; 112+ case onnx::TensorProto_DataType_BOOL: 113+ *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_BOOL); 114+ for (size_t i = 0; i < data_count; i++) { 115+ value->push_back(static_cast<float>(reinterpret_cast<const bool *>(onnx_tensor.raw_data().data())[i])); 116+ } 117+ break; 118 default: 119 MS_LOG(ERROR) << "The data type is not supported."; 120 return RET_ERROR; 121-- 1222.25.1 123 124