1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/fp32/broadcast_to_fp32.h"
17 #include <vector>
18 #include "schema/model_generated.h"
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21
22 using mindspore::lite::KernelRegistrar;
23 using mindspore::lite::RET_ERROR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_BroadcastTo;
26
27 namespace mindspore::kernel {
ReSize()28 int BroadcastToCPUKernel::ReSize() {
29 auto input_shape = in_tensors_.at(0)->shape();
30 for (size_t i = 0; i < input_shape.size(); ++i) {
31 shape_info_.input_shape_[i] = input_shape[i];
32 }
33 auto output_shape = out_tensors_.at(0)->shape();
34 for (size_t i = 0; i < output_shape.size(); ++i) {
35 shape_info_.output_shape_[i] = output_shape[i];
36 }
37 shape_info_.input_shape_size_ = static_cast<int>(input_shape.size());
38 shape_info_.output_shape_size_ = static_cast<int>(output_shape.size());
39
40 data_type_ = in_tensors_.at(0)->data_type();
41 MS_ASSERT(data_type_ == out_tensors_.at(0)->data_type());
42 return RET_OK;
43 }
44
Init()45 int BroadcastToCPUKernel::Init() {
46 CHECK_LESS_RETURN(in_tensors_.size(), 1);
47 CHECK_LESS_RETURN(out_tensors_.size(), 1);
48 if (!InferShapeDone()) {
49 return RET_OK;
50 }
51 return ReSize();
52 }
53
Run()54 int BroadcastToCPUKernel::Run() {
55 const auto input_data = in_tensors_.at(0)->data();
56 auto output_data = out_tensors_.at(0)->data();
57 CHECK_NULL_RETURN(input_data);
58 CHECK_NULL_RETURN(output_data);
59
60 switch (data_type_) {
61 case kNumberTypeFloat32:
62 return BROADCAST_TO(float, reinterpret_cast<const float *>(input_data), &shape_info_,
63 reinterpret_cast<float *>(output_data));
64 #ifdef ENABLE_FP16
65 case kNumberTypeFloat16:
66 return BROADCAST_TO(float16_t, reinterpret_cast<const float16_t *>(input_data), &shape_info_,
67 reinterpret_cast<float16_t *>(output_data));
68 #endif
69 case kNumberTypeInt32:
70 case kNumberTypeInt:
71 return BROADCAST_TO(int, reinterpret_cast<const int *>(input_data), &shape_info_,
72 reinterpret_cast<int *>(output_data));
73 default:
74 MS_LOG(ERROR) << "UnSupported data type: " << data_type_;
75 return RET_ERROR;
76 }
77 }
78
79 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>)
80 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>)
81 #ifdef ENABLE_FP16
82 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>)
83 #endif
84 } // namespace mindspore::kernel
85