• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1From c947cdfd1d1e0558ff07ea245c51447b05c57eb1 Mon Sep 17 00:00:00 2001
2From: Zhu Guodong <zhuguodong0001@163.com>
3Date: Mon, 29 May 2023 11:54:03 +0800
4Subject: [PATCH] mindir MaxPoolFusion add roundMode
5
6---
7 mindspore/lite/mindir/include/mindir.h        |  2 +
8 mindspore/lite/mindir/src/mindir.cc           | 45 +++++++++++++++++++
9 .../mindir_nnrt_lite_graph_to_model_v2_0.cc   |  1 +
10 3 files changed, 48 insertions(+)
11
12diff --git a/mindspore/lite/mindir/include/mindir.h b/mindspore/lite/mindir/include/mindir.h
13index b73920ff..ca811dce 100644
14--- a/mindspore/lite/mindir/include/mindir.h
15+++ b/mindspore/lite/mindir/include/mindir.h
16@@ -259,6 +259,8 @@ Format MindIR_MaxPoolFusion_GetFormat(ConstPrimitivePtr primitive);
17 void MindIR_MaxPoolFusion_SetFormat(PrimitivePtr *primitive, Format format);
18 bool MindIR_MaxPoolFusion_GetGlobal(ConstPrimitivePtr primitive);
19 void MindIR_MaxPoolFusion_SetGlobal(PrimitivePtr *primitive, bool global);
20+RoundMode MindIR_MaxPoolFusion_GetRoundMode(ConstPrimitivePtr primitive);
21+void MindIR_MaxPoolFusion_SetRoundMode(PrimitivePtr *primitive, RoundMode round_mode);
22 ActivationType MindIR_MaxPoolFusion_GetActivationType(ConstPrimitivePtr primitive);
23 void MindIR_MaxPoolFusion_SetActivationType(PrimitivePtr *primitive, ActivationType activation_type);
24
25diff --git a/mindspore/lite/mindir/src/mindir.cc b/mindspore/lite/mindir/src/mindir.cc
26index dd249738..7fc9c00e 100644
27--- a/mindspore/lite/mindir/src/mindir.cc
28+++ b/mindspore/lite/mindir/src/mindir.cc
29@@ -2452,6 +2452,50 @@ void MindIR_MaxPoolFusion_SetGlobal(PrimitivePtr *primitive, bool global) {
30     }
31   }
32 }
33+
34+RoundMode MindIR_MaxPoolFusion_GetRoundMode(ConstPrimitivePtr primitive) {
35+  RoundMode round_mode = static_cast<RoundMode>(0); // set default value: RoundMode_FLOOR
36+
37+  if (primitive == nullptr) {
38+    return round_mode;
39+  }
40+  auto prim = static_cast<const schema::Primitive *>(primitive);
41+
42+  auto value = prim->value_as_MaxPoolFusion();
43+  if (value == nullptr) {
44+    return round_mode;
45+  }
46+  round_mode = static_cast<RoundMode>(value->round_mode());
47+  return round_mode;
48+}
49+
50+void MindIR_MaxPoolFusion_SetRoundMode(PrimitivePtr *primitive, RoundMode round_mode) {
51+  if (primitive == nullptr) {
52+    return;
53+  }
54+  auto prim = static_cast<schema::Primitive *>(*primitive);
55+  if (prim == nullptr) {
56+    return;
57+  }
58+  auto value = prim->value_as_MaxPoolFusion();
59+  if (value == nullptr) {
60+    return;
61+  }
62+  flatbuffers::FlatBufferBuilder fbb;
63+  auto ops_offset = schema::CreateMaxPoolFusion(
64+      fbb, fbb.CreateVector(value->kernel_size()->data(), value->kernel_size()->size()),
65+      fbb.CreateVector(value->strides()->data(), value->strides()->size()),
66+      fbb.CreateVector(value->pad()->data(), value->pad()->size()), static_cast<schema::PadMode>(value->pad_mode()),
67+      static_cast<schema::RoundMode>(round_mode), static_cast<schema::Format>(value->format()), value->global(),
68+      static_cast<schema::ActivationType>(value->activation_type()));
69+  auto prim_offset =
70+      schema::CreatePrimitive(fbb, static_cast<schema::PrimitiveType>(NODE_TYPE_MAX_POOL_FUSION), ops_offset.o);
71+  fbb.Finish(prim_offset);
72+  auto new_addr = MindIRMemoryManager::GetInstance()->CreatePrimitiveFromBuilder(fbb, prim);
73+  auto new_prim = flatbuffers::GetMutableRoot<schema::Primitive>(new_addr);
74+  *primitive = new_prim;
75+}
76+
77 ActivationType MindIR_MaxPoolFusion_GetActivationType(ConstPrimitivePtr primitive) {
78   if (primitive != nullptr) {
79     auto prim = static_cast<const schema::Primitive *>(primitive);
80@@ -2501,6 +2545,7 @@ PrimitivePtr MindIR_MulFusion_CreatePrimitive(ActivationType activation_type) {
81   auto ret_value = flatbuffers::GetMutableRoot<schema::Primitive>(new_addr);
82   return ret_value;
83 }
84+
85 ActivationType MindIR_MulFusion_GetActivationType(ConstPrimitivePtr primitive) {
86   if (primitive != nullptr) {
87     auto prim = static_cast<const schema::Primitive *>(primitive);
88diff --git a/mindspore/lite/mindir/src/mindir_nnrt_lite_graph_to_model_v2_0.cc b/mindspore/lite/mindir/src/mindir_nnrt_lite_graph_to_model_v2_0.cc
89index 11012492..4e109eff 100644
90--- a/mindspore/lite/mindir/src/mindir_nnrt_lite_graph_to_model_v2_0.cc
91+++ b/mindspore/lite/mindir/src/mindir_nnrt_lite_graph_to_model_v2_0.cc
92@@ -642,6 +642,7 @@ std::vector<int8_t> ConvertMaxPoolFusion_V2_0(PrimitivePtr primitive) {
93       max_pool_fusion.format = static_cast<HDI::Nnrt::V2_0::Format>(value->format());
94       max_pool_fusion.global = value->global();
95       max_pool_fusion.activationType = static_cast<HDI::Nnrt::V2_0::ActivationType>(value->activation_type());
96+      max_pool_fusion.roundMode = static_cast<HDI::Nnrt::V2_0::RoundMode>(value->round_mode());
97       OHOS::MessageParcel data;
98       (void)MaxPoolFusionBlockMarshalling(data, max_pool_fusion);
99       std::vector<int8_t> ret(reinterpret_cast<const int8_t *>(data.GetData()),
100--
1012.34.1
102
103