From a303e237bf5506d75b98703d442f01e18fb2c820 Mon Sep 17 00:00:00 2001 From: zhangyanhui Date: Mon, 8 Jul 2024 15:44:46 +0800 Subject: [PATCH] ConstantOfShape and StridedSlice kernel support bool type --- .../device/cpu/kernel/nnacl/constant_of_shape_parameter.h | 1 + .../device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h | 7 +++++++ .../plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c | 1 + .../ops/operator_populate/constant_of_shape_populate.cc | 3 +++ .../src/common/ops/populate/constant_of_shape_populate.cc | 3 +++ .../lite/src/litert/kernel/cpu/base/constant_of_shape.cc | 5 +++++ .../lite/tools/converter/parser/onnx/onnx_node_parser.cc | 6 ++++++ 7 files changed, 26 insertions(+) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h index f108ea98..d75edb6f 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h @@ -23,6 +23,7 @@ typedef struct ConstantOfShapeParameter { union value_ { float f32_value_; int32_t int32_value_; + bool bool_value_; } value_; int data_type_; int element_size_; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h index 6c607cf5..c884d031 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h @@ -38,6 +38,13 @@ inline int ConstantOfShapeFp32(float *output, int start, int end, float value) { return NNACL_OK; } +inline int ConstantOfShapeBool(bool *output, int start, int end, bool value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + #ifdef __cplusplus } #endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c index 1460c2cc..714bcaef 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c @@ -275,3 +275,4 @@ REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat16, CreateStridedSlice REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt64, CreateStridedSlice) REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt32, CreateStridedSlice) REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt8, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeBool, CreateStridedSlice) diff --git a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc index 3552b5f9..743f42f5 100644 --- a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc +++ b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc @@ -42,6 +42,9 @@ OpParameter *PopulateConstantOfShapeOpParameter(const BaseOperatorPtr &base_oper case kNumberTypeInt32: param->value_.int32_value_ = static_cast(value[0]); break; + case kNumberTypeBool: + param->value_.bool_value_ = static_cast(value[0]); + break; default: MS_LOG(ERROR) << "The value of constant of shape is invalid"; free(param); diff --git a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc index 56263d13..d8fd6473 100644 --- a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc +++ b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc @@ -48,6 +48,9 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) { case kNumberTypeInt32: param->value_.int32_value_ = static_cast(val[0]); break; + case kNumberTypeBool: + param->value_.bool_value_ = static_cast(val[0]); + break; default: MS_LOG(ERROR) << "The value of constant of shape is invalid"; free(param); diff --git a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc index d8d24146..94f4a490 100644 --- a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc +++ b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc @@ -53,6 +53,10 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) { ConstantOfShapeInt32(reinterpret_cast(output_ptr_), start, start + current_stride, param_->value_.int32_value_); break; + case kNumberTypeBool: + ConstantOfShapeBool(reinterpret_cast(output_ptr_), start, start + current_stride, + param_->value_.bool_value_); + break; #ifdef ENABLE_FP16 case kNumberTypeFloat16: ConstantOfShapeFp16(reinterpret_cast(output_ptr_), start, start + current_stride, @@ -100,4 +104,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCr REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ConstantOfShape, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_ConstantOfShape, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc index 39197be6..4d11561e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc @@ -223,6 +223,12 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso value->push_back(static_cast(reinterpret_cast(onnx_tensor.raw_data().data())[i])); } break; + case onnx::TensorProto_DataType_BOOL: + *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_BOOL); + for (size_t i = 0; i < data_count; i++) { + value->push_back(static_cast(reinterpret_cast(onnx_tensor.raw_data().data())[i])); + } + break; default: MS_LOG(ERROR) << "The data type is not supported."; return RET_ERROR; -- 2.25.1