• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1From a303e237bf5506d75b98703d442f01e18fb2c820 Mon Sep 17 00:00:00 2001
2From: zhangyanhui <zhangyanhui17@huawei.com>
3Date: Mon, 8 Jul 2024 15:44:46 +0800
4Subject: [PATCH] ConstantOfShape and StridedSlice kernel support bool type
5
6---
7 .../device/cpu/kernel/nnacl/constant_of_shape_parameter.h  | 1 +
8 .../device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h  | 7 +++++++
9 .../plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c  | 1 +
10 .../ops/operator_populate/constant_of_shape_populate.cc    | 3 +++
11 .../src/common/ops/populate/constant_of_shape_populate.cc  | 3 +++
12 .../lite/src/litert/kernel/cpu/base/constant_of_shape.cc   | 5 +++++
13 .../lite/tools/converter/parser/onnx/onnx_node_parser.cc   | 6 ++++++
14 7 files changed, 26 insertions(+)
15
16diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h
17index f108ea98..d75edb6f 100644
18--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h
19+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h
20@@ -23,6 +23,7 @@ typedef struct ConstantOfShapeParameter {
21   union value_ {
22     float f32_value_;
23     int32_t int32_value_;
24+    bool bool_value_;
25   } value_;
26   int data_type_;
27   int element_size_;
28diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h
29index 6c607cf5..c884d031 100644
30--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h
31+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h
32@@ -38,6 +38,13 @@ inline int ConstantOfShapeFp32(float *output, int start, int end, float value) {
33   return NNACL_OK;
34 }
35
36+inline int ConstantOfShapeBool(bool *output, int start, int end, bool value) {
37+  for (int i = start; i < end; i++) {
38+    output[i] = value;
39+  }
40+  return NNACL_OK;
41+}
42+
43 #ifdef __cplusplus
44 }
45 #endif
46diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c
47index 1460c2cc..714bcaef 100644
48--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c
49+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c
50@@ -275,3 +275,4 @@ REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat16, CreateStridedSlice
51 REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt64, CreateStridedSlice)
52 REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt32, CreateStridedSlice)
53 REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt8, CreateStridedSlice)
54+REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeBool, CreateStridedSlice)
55diff --git a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc
56index 3552b5f9..743f42f5 100644
57--- a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc
58+++ b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc
59@@ -42,6 +42,9 @@ OpParameter *PopulateConstantOfShapeOpParameter(const BaseOperatorPtr &base_oper
60     case kNumberTypeInt32:
61       param->value_.int32_value_ = static_cast<int32_t>(value[0]);
62       break;
63+    case kNumberTypeBool:
64+      param->value_.bool_value_ = static_cast<bool>(value[0]);
65+      break;
66     default:
67       MS_LOG(ERROR) << "The value of constant of shape is invalid";
68       free(param);
69diff --git a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc
70index 56263d13..d8fd6473 100644
71--- a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc
72+++ b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc
73@@ -48,6 +48,9 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) {
74     case kNumberTypeInt32:
75       param->value_.int32_value_ = static_cast<int32_t>(val[0]);
76       break;
77+    case kNumberTypeBool:
78+      param->value_.bool_value_ = static_cast<bool>(val[0]);
79+      break;
80     default:
81       MS_LOG(ERROR) << "The value of constant of shape is invalid";
82       free(param);
83diff --git a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc
84index d8d24146..94f4a490 100644
85--- a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc
86+++ b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc
87@@ -53,6 +53,10 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) {
88       ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride,
89                            param_->value_.int32_value_);
90       break;
91+    case kNumberTypeBool:
92+      ConstantOfShapeBool(reinterpret_cast<bool *>(output_ptr_), start, start + current_stride,
93+                           param_->value_.bool_value_);
94+      break;
95 #ifdef ENABLE_FP16
96     case kNumberTypeFloat16:
97       ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride,
98@@ -100,4 +104,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCr
99 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
100 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
101 REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
102+REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
103 }  // namespace mindspore::kernel
104diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
105index 39197be6..4d11561e 100644
106--- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
107+++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
108@@ -223,6 +223,12 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso
109         value->push_back(static_cast<float>(reinterpret_cast<const float16 *>(onnx_tensor.raw_data().data())[i]));
110       }
111       break;
112+    case onnx::TensorProto_DataType_BOOL:
113+      *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_BOOL);
114+      for (size_t i = 0; i < data_count; i++) {
115+        value->push_back(static_cast<float>(reinterpret_cast<const bool *>(onnx_tensor.raw_data().data())[i]));
116+      }
117+      break;
118     default:
119       MS_LOG(ERROR) << "The data type is not supported.";
120       return RET_ERROR;
121--
1222.25.1
123
124