• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h"
18 #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
19 #include "runtime/device/cpu/cpu_device_address.h"
20 #include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h"
21 #include "backend/kernel_compiler/cpu/nnacl/errorcode.h"
22 #include "utils/ms_utils.h"
23 #include "common/thread_pool.h"
24 
25 namespace mindspore {
26 namespace kernel {
27 namespace {
28 constexpr size_t kAddNInputsMinNum = 2;
29 constexpr size_t kAddNOutputsNum = 1;
30 
AddInt(const int * in_0,const int * in_1,int * out,int start,int end)31 void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) {
32   int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start);
33   if (ret != NNACL_OK) {
34     MS_LOG(EXCEPTION) << "Add failed.";
35   }
36 }
37 }  // namespace
38 
InitKernel(const CNodePtr & kernel_node)39 void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
40   MS_EXCEPTION_IF_NULL(kernel_node);
41   kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
42   input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
43   if (input_num_ < kAddNInputsMinNum) {
44     MS_LOG(EXCEPTION) << "Input numbers should not less " << kAddNInputsMinNum << ", but got " << input_num_;
45   }
46   CheckParam(kernel_node);
47   dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
48   std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
49   std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
50   std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
51   dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape);
52   dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape);
53   dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape);
54   dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
55   auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
56   primitive_ = std::make_shared<dnnl::binary>(prim_desc);
57   AddArgument(DNNL_ARG_SRC_0, src0_mem_desc);
58   AddArgument(DNNL_ARG_SRC_1, src1_mem_desc);
59   AddArgument(DNNL_ARG_DST, dst_mem_desc);
60 }
61 
Launch(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> & outputs)62 bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
63                            const std::vector<kernel::AddressPtr> &outputs) {
64   CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num_, kernel_name_);
65   CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kAddNOutputsNum, kernel_name_);
66   if (dtype_ == kNumberTypeFloat32) {
67     SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
68     SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
69     SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
70     ExecutePrimitive();
71     for (size_t index = 2; index < input_num_; ++index) {
72       SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr);
73       SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr);
74       SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
75       ExecutePrimitive();
76     }
77   } else if (dtype_ == kNumberTypeInt32) {
78     size_t elements_num = outputs[0]->size / sizeof(int);
79     const auto input_0 = reinterpret_cast<int *>(inputs[0]->addr);
80     const auto input_1 = reinterpret_cast<int *>(inputs[1]->addr);
81     auto output = reinterpret_cast<int *>(outputs[0]->addr);
82     auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
83     CPUKernelUtils::ParallelFor(task_0, elements_num);
84     for (size_t index = 2; index < input_num_; ++index) {
85       const auto input = reinterpret_cast<int *>(inputs[index]->addr);
86       auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2);
87       CPUKernelUtils::ParallelFor(task, elements_num);
88     }
89   } else {
90     MS_LOG(EXCEPTION) << "AddN only support float32 and int32, but got " << TypeIdToType(dtype_)->ToString();
91   }
92   return true;
93 }
94 
CheckParam(const CNodePtr & kernel_node)95 void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) {
96   auto src0_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
97   auto dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
98   if (src0_shape != dst_shape) {
99     MS_LOG(EXCEPTION) << "AddN output shape must be equal to input shape.";
100   }
101   for (size_t index = 1; index < input_num_; ++index) {
102     auto src_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
103     if (src0_shape != src_shape) {
104       MS_LOG(EXCEPTION) << "AddN input shapes must be equal.";
105     }
106   }
107 }
108 }  // namespace kernel
109 }  // namespace mindspore
110