• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/fp32/broadcast_to_fp32.h"
17 #include <vector>
18 #include "schema/model_generated.h"
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21 
22 using mindspore::lite::KernelRegistrar;
23 using mindspore::lite::RET_ERROR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_BroadcastTo;
26 
27 namespace mindspore::kernel {
ReSize()28 int BroadcastToCPUKernel::ReSize() {
29   auto input_shape = in_tensors_.at(0)->shape();
30   for (size_t i = 0; i < input_shape.size(); ++i) {
31     shape_info_.input_shape_[i] = input_shape[i];
32   }
33   auto output_shape = out_tensors_.at(0)->shape();
34   for (size_t i = 0; i < output_shape.size(); ++i) {
35     shape_info_.output_shape_[i] = output_shape[i];
36   }
37   shape_info_.input_shape_size_ = static_cast<int>(input_shape.size());
38   shape_info_.output_shape_size_ = static_cast<int>(output_shape.size());
39 
40   data_type_ = in_tensors_.at(0)->data_type();
41   MS_ASSERT(data_type_ == out_tensors_.at(0)->data_type());
42   return RET_OK;
43 }
44 
Init()45 int BroadcastToCPUKernel::Init() {
46   CHECK_LESS_RETURN(in_tensors_.size(), 1);
47   CHECK_LESS_RETURN(out_tensors_.size(), 1);
48   if (!InferShapeDone()) {
49     return RET_OK;
50   }
51   return ReSize();
52 }
53 
Run()54 int BroadcastToCPUKernel::Run() {
55   const auto input_data = in_tensors_.at(0)->data();
56   auto output_data = out_tensors_.at(0)->data();
57   CHECK_NULL_RETURN(input_data);
58   CHECK_NULL_RETURN(output_data);
59 
60   switch (data_type_) {
61     case kNumberTypeFloat32:
62       return BROADCAST_TO(float, reinterpret_cast<const float *>(input_data), &shape_info_,
63                           reinterpret_cast<float *>(output_data));
64 #ifdef ENABLE_FP16
65     case kNumberTypeFloat16:
66       return BROADCAST_TO(float16_t, reinterpret_cast<const float16_t *>(input_data), &shape_info_,
67                           reinterpret_cast<float16_t *>(output_data));
68 #endif
69     case kNumberTypeInt32:
70     case kNumberTypeInt:
71       return BROADCAST_TO(int, reinterpret_cast<const int *>(input_data), &shape_info_,
72                           reinterpret_cast<int *>(output_data));
73     default:
74       MS_LOG(ERROR) << "UnSupported data type: " << data_type_;
75       return RET_ERROR;
76   }
77 }
78 
79 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>)
80 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>)
81 #ifdef ENABLE_FP16
82 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BroadcastTo, LiteKernelCreator<BroadcastToCPUKernel>)
83 #endif
84 }  // namespace mindspore::kernel
85