1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/cpu/kernel/concat_offset_cpu_kernel.h"
18 #include <algorithm>
19 #include <utility>
20 #include <vector>
21 #include "mindspore/core/ops/concat_offset.h"
22 #include "plugin/device/cpu/hal/device/cpu_device_address.h"
23
24 namespace mindspore {
25 namespace kernel {
26 namespace {
27 constexpr size_t kConcatOffsetOutputNum = 1;
28 constexpr size_t kConcatOffsetOutputShapeSize = 2;
29 } // namespace
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)30 bool ConcatOffsetCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
31 const std::vector<KernelTensor *> &outputs) {
32 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOffsetOutputNum, kernel_name_);
33 if (inputs.empty()) {
34 MS_LOG(ERROR) << "For '" << kernel_name_ << ", input tensors can not be empty";
35 return false;
36 }
37 if (primitive_->HasAttr(kAttrAxis)) {
38 axis_ = GetValue<int64_t>(primitive_->GetAttr(kAttrAxis));
39 }
40 auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
41 auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
42 if (!is_match) {
43 MS_LOG(ERROR) << "Concat offset does not support this kernel data type: " << kernel_attr;
44 return false;
45 }
46
47 kernel_func_ = func_list_[index].second;
48 return true;
49 }
50
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)51 int ConcatOffsetCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
52 const std::vector<KernelTensor *> &outputs) {
53 if (auto ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) {
54 return ret;
55 }
56 output_shape_ = outputs[kIndex0]->GetShapeVector();
57 if (output_shape_.size() != kConcatOffsetOutputShapeSize) {
58 MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of output must be " << kConcatOffsetOutputShapeSize
59 << ", but got:" << output_shape_.size();
60 return KRET_RESIZE_FAILED;
61 }
62 if (LongToSize(output_shape_[kIndex0]) != inputs.size()) {
63 MS_LOG(ERROR) << "For '" << kernel_name_
64 << "', the first dimension value of output must be equal to "
65 "the number of input, but got the first dimension value of output: "
66 << output_shape_[kIndex0] << ", and the number of input: " << inputs.size();
67 return KRET_RESIZE_FAILED;
68 }
69 input_shapes_.clear();
70 for (size_t i = 0; i < inputs.size(); i++) {
71 ShapeVector shape_i = inputs[i]->GetShapeVector();
72 input_shapes_.push_back(shape_i);
73 if (shape_i.size() != input_shapes_[kIndex0].size()) {
74 MS_LOG(ERROR) << "For '" << kernel_name_
75 << "', input tensors shape's rank must be equal, but got input[0] shape's rank = "
76 << input_shapes_[kIndex0].size() << ", input[" << i << "] shape's rank = " << shape_i.size();
77 return KRET_RESIZE_FAILED;
78 }
79 }
80 return KRET_OK;
81 }
82
83 template <typename T>
LaunchKernel(const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > & outputs)84 bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::KernelTensor *> &inputs,
85 const std::vector<kernel::KernelTensor *> &outputs) {
86 auto output_addr = reinterpret_cast<int64_t *>(outputs[kIndex0]->device_ptr());
87
88 auto x_rank = SizeToLong(input_shapes_[kIndex0].size());
89 if (axis_ < -x_rank || axis_ >= x_rank) {
90 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", 'axis' must be in range [-" << x_rank << ", " << x_rank
91 << "), but got " << axis_;
92 }
93 if (axis_ < 0) {
94 axis_ += x_rank;
95 }
96 auto axis = LongToSize(axis_);
97
98 ShapeVector offset{0};
99 auto all_shape = input_shapes_[0][axis];
100
101 // cal offset
102 for (size_t i = 1; i < inputs.size(); i++) {
103 offset.emplace_back(all_shape);
104 all_shape += input_shapes_[i][axis];
105 }
106 size_t rank = LongToSize(output_shape_[kIndex1]);
107 size_t idx = 0;
108 for (size_t i = 0; i < inputs.size(); ++i) {
109 for (size_t j = 0; j < rank; ++j) {
110 if (j == axis) {
111 output_addr[idx] = offset[i];
112 } else {
113 output_addr[idx] = 0;
114 }
115 idx++;
116 }
117 }
118 return true;
119 }
120
121 std::vector<std::pair<KernelAttr, ConcatOffsetCpuKernelMod::ConcatOffsetFunc>> ConcatOffsetCpuKernelMod::func_list_ = {
122 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
123 &ConcatOffsetCpuKernelMod::LaunchKernel<double>},
124 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
125 &ConcatOffsetCpuKernelMod::LaunchKernel<float>},
126 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
127 &ConcatOffsetCpuKernelMod::LaunchKernel<int8_t>},
128 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
129 &ConcatOffsetCpuKernelMod::LaunchKernel<int16_t>},
130 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
131 &ConcatOffsetCpuKernelMod::LaunchKernel<int32_t>},
132 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
133 &ConcatOffsetCpuKernelMod::LaunchKernel<int64_t>},
134 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
135 &ConcatOffsetCpuKernelMod::LaunchKernel<uint8_t>},
136 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64),
137 &ConcatOffsetCpuKernelMod::LaunchKernel<uint16_t>},
138 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
139 &ConcatOffsetCpuKernelMod::LaunchKernel<uint32_t>},
140 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64),
141 &ConcatOffsetCpuKernelMod::LaunchKernel<uint64_t>},
142 {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
143 &ConcatOffsetCpuKernelMod::LaunchKernel<bool>}}; // namespace kernel
144
GetOpSupport()145 std::vector<KernelAttr> ConcatOffsetCpuKernelMod::GetOpSupport() {
146 std::vector<KernelAttr> support_list;
147 (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
148 [](const std::pair<KernelAttr, ConcatOffsetFunc> &pair) { return pair.first; });
149
150 return support_list;
151 }
152
153 MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ConcatOffset, ConcatOffsetCpuKernelMod);
154 } // namespace kernel
155 } // namespace mindspore
156