From c947cdfd1d1e0558ff07ea245c51447b05c57eb1 Mon Sep 17 00:00:00 2001 From: Zhu Guodong Date: Mon, 29 May 2023 11:54:03 +0800 Subject: [PATCH] mindir MaxPoolFusion add roundMode --- mindspore/lite/mindir/include/mindir.h | 2 + mindspore/lite/mindir/src/mindir.cc | 45 +++++++++++++++++++ .../mindir_nnrt_lite_graph_to_model_v2_0.cc | 1 + 3 files changed, 48 insertions(+) diff --git a/mindspore/lite/mindir/include/mindir.h b/mindspore/lite/mindir/include/mindir.h index b73920ff..ca811dce 100644 --- a/mindspore/lite/mindir/include/mindir.h +++ b/mindspore/lite/mindir/include/mindir.h @@ -259,6 +259,8 @@ Format MindIR_MaxPoolFusion_GetFormat(ConstPrimitivePtr primitive); void MindIR_MaxPoolFusion_SetFormat(PrimitivePtr *primitive, Format format); bool MindIR_MaxPoolFusion_GetGlobal(ConstPrimitivePtr primitive); void MindIR_MaxPoolFusion_SetGlobal(PrimitivePtr *primitive, bool global); +RoundMode MindIR_MaxPoolFusion_GetRoundMode(ConstPrimitivePtr primitive); +void MindIR_MaxPoolFusion_SetRoundMode(PrimitivePtr *primitive, RoundMode round_mode); ActivationType MindIR_MaxPoolFusion_GetActivationType(ConstPrimitivePtr primitive); void MindIR_MaxPoolFusion_SetActivationType(PrimitivePtr *primitive, ActivationType activation_type); diff --git a/mindspore/lite/mindir/src/mindir.cc b/mindspore/lite/mindir/src/mindir.cc index dd249738..7fc9c00e 100644 --- a/mindspore/lite/mindir/src/mindir.cc +++ b/mindspore/lite/mindir/src/mindir.cc @@ -2452,6 +2452,50 @@ void MindIR_MaxPoolFusion_SetGlobal(PrimitivePtr *primitive, bool global) { } } } + +RoundMode MindIR_MaxPoolFusion_GetRoundMode(ConstPrimitivePtr primitive) { + RoundMode round_mode = static_cast(0); // set default value: RoundMode_FLOOR + + if (primitive == nullptr) { + return round_mode; + } + auto prim = static_cast(primitive); + + auto value = prim->value_as_MaxPoolFusion(); + if (value == nullptr) { + return round_mode; + } + round_mode = static_cast(value->round_mode()); + return round_mode; +} + +void MindIR_MaxPoolFusion_SetRoundMode(PrimitivePtr *primitive, RoundMode round_mode) { + if (primitive == nullptr) { + return; + } + auto prim = static_cast(*primitive); + if (prim == nullptr) { + return; + } + auto value = prim->value_as_MaxPoolFusion(); + if (value == nullptr) { + return; + } + flatbuffers::FlatBufferBuilder fbb; + auto ops_offset = schema::CreateMaxPoolFusion( + fbb, fbb.CreateVector(value->kernel_size()->data(), value->kernel_size()->size()), + fbb.CreateVector(value->strides()->data(), value->strides()->size()), + fbb.CreateVector(value->pad()->data(), value->pad()->size()), static_cast(value->pad_mode()), + static_cast(round_mode), static_cast(value->format()), value->global(), + static_cast(value->activation_type())); + auto prim_offset = + schema::CreatePrimitive(fbb, static_cast(NODE_TYPE_MAX_POOL_FUSION), ops_offset.o); + fbb.Finish(prim_offset); + auto new_addr = MindIRMemoryManager::GetInstance()->CreatePrimitiveFromBuilder(fbb, prim); + auto new_prim = flatbuffers::GetMutableRoot(new_addr); + *primitive = new_prim; +} + ActivationType MindIR_MaxPoolFusion_GetActivationType(ConstPrimitivePtr primitive) { if (primitive != nullptr) { auto prim = static_cast(primitive); @@ -2501,6 +2545,7 @@ PrimitivePtr MindIR_MulFusion_CreatePrimitive(ActivationType activation_type) { auto ret_value = flatbuffers::GetMutableRoot(new_addr); return ret_value; } + ActivationType MindIR_MulFusion_GetActivationType(ConstPrimitivePtr primitive) { if (primitive != nullptr) { auto prim = static_cast(primitive); diff --git a/mindspore/lite/mindir/src/mindir_nnrt_lite_graph_to_model_v2_0.cc b/mindspore/lite/mindir/src/mindir_nnrt_lite_graph_to_model_v2_0.cc index 11012492..4e109eff 100644 --- a/mindspore/lite/mindir/src/mindir_nnrt_lite_graph_to_model_v2_0.cc +++ b/mindspore/lite/mindir/src/mindir_nnrt_lite_graph_to_model_v2_0.cc @@ -642,6 +642,7 @@ std::vector ConvertMaxPoolFusion_V2_0(PrimitivePtr primitive) { max_pool_fusion.format = static_cast(value->format()); max_pool_fusion.global = value->global(); max_pool_fusion.activationType = static_cast(value->activation_type()); + max_pool_fusion.roundMode = static_cast(value->round_mode()); OHOS::MessageParcel data; (void)MaxPoolFusionBlockMarshalling(data, max_pool_fusion); std::vector ret(reinterpret_cast(data.GetData()), -- 2.34.1