1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h"
18 #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
19 #include "runtime/device/cpu/cpu_device_address.h"
20 #include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h"
21 #include "backend/kernel_compiler/cpu/nnacl/errorcode.h"
22 #include "utils/ms_utils.h"
23 #include "common/thread_pool.h"
24
25 namespace mindspore {
26 namespace kernel {
27 namespace {
28 constexpr size_t kAddNInputsMinNum = 2;
29 constexpr size_t kAddNOutputsNum = 1;
30
AddInt(const int * in_0,const int * in_1,int * out,int start,int end)31 void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) {
32 int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start);
33 if (ret != NNACL_OK) {
34 MS_LOG(EXCEPTION) << "Add failed.";
35 }
36 }
37 } // namespace
38
InitKernel(const CNodePtr & kernel_node)39 void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
40 MS_EXCEPTION_IF_NULL(kernel_node);
41 kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
42 input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
43 if (input_num_ < kAddNInputsMinNum) {
44 MS_LOG(EXCEPTION) << "Input numbers should not less " << kAddNInputsMinNum << ", but got " << input_num_;
45 }
46 CheckParam(kernel_node);
47 dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
48 std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
49 std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
50 std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
51 dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape);
52 dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape);
53 dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape);
54 dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
55 auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
56 primitive_ = std::make_shared<dnnl::binary>(prim_desc);
57 AddArgument(DNNL_ARG_SRC_0, src0_mem_desc);
58 AddArgument(DNNL_ARG_SRC_1, src1_mem_desc);
59 AddArgument(DNNL_ARG_DST, dst_mem_desc);
60 }
61
Launch(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> & outputs)62 bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
63 const std::vector<kernel::AddressPtr> &outputs) {
64 CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num_, kernel_name_);
65 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kAddNOutputsNum, kernel_name_);
66 if (dtype_ == kNumberTypeFloat32) {
67 SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
68 SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
69 SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
70 ExecutePrimitive();
71 for (size_t index = 2; index < input_num_; ++index) {
72 SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr);
73 SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr);
74 SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
75 ExecutePrimitive();
76 }
77 } else if (dtype_ == kNumberTypeInt32) {
78 size_t elements_num = outputs[0]->size / sizeof(int);
79 const auto input_0 = reinterpret_cast<int *>(inputs[0]->addr);
80 const auto input_1 = reinterpret_cast<int *>(inputs[1]->addr);
81 auto output = reinterpret_cast<int *>(outputs[0]->addr);
82 auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
83 CPUKernelUtils::ParallelFor(task_0, elements_num);
84 for (size_t index = 2; index < input_num_; ++index) {
85 const auto input = reinterpret_cast<int *>(inputs[index]->addr);
86 auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2);
87 CPUKernelUtils::ParallelFor(task, elements_num);
88 }
89 } else {
90 MS_LOG(EXCEPTION) << "AddN only support float32 and int32, but got " << TypeIdToType(dtype_)->ToString();
91 }
92 return true;
93 }
94
CheckParam(const CNodePtr & kernel_node)95 void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) {
96 auto src0_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
97 auto dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
98 if (src0_shape != dst_shape) {
99 MS_LOG(EXCEPTION) << "AddN output shape must be equal to input shape.";
100 }
101 for (size_t index = 1; index < input_num_; ++index) {
102 auto src_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
103 if (src0_shape != src_shape) {
104 MS_LOG(EXCEPTION) << "AddN input shapes must be equal.";
105 }
106 }
107 }
108 } // namespace kernel
109 } // namespace mindspore
110