1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/infer/broadcast_to_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 #include "nnacl/tensor_c_utils.h"
20
GetShapeByType(const TensorC * shape_tensor,int shape_size,int * dst_shape)21 int GetShapeByType(const TensorC *shape_tensor, int shape_size, int *dst_shape) {
22 if (shape_tensor == NULL || dst_shape == NULL) {
23 return NNACL_ERR;
24 }
25 if (shape_size == 0) {
26 return NNACL_INFER_INVALID;
27 }
28 NNACL_CHECK_NULL_RETURN_ERR(shape_tensor->data_);
29 switch (shape_tensor->data_type_) {
30 case kNumberTypeInt8: {
31 int8_t *data = (int8_t *)(shape_tensor->data_);
32 for (int i = 0; i < shape_size; i++) {
33 dst_shape[i] = data[i];
34 }
35 } break;
36 case kNumberTypeInt32: {
37 int32_t *data = (int32_t *)(shape_tensor->data_);
38 for (int i = 0; i < shape_size; i++) {
39 dst_shape[i] = data[i];
40 }
41 } break;
42 case kNumberTypeInt64: {
43 int64_t *data = (int64_t *)(shape_tensor->data_);
44 for (int i = 0; i < shape_size; i++) {
45 dst_shape[i] = (int)data[i];
46 }
47 } break;
48 case kNumberTypeFloat: {
49 float *data = (float *)(shape_tensor->data_);
50 for (int i = 0; i < shape_size; i++) {
51 dst_shape[i] = data[i];
52 }
53 } break;
54 case kNumberTypeUInt32: {
55 uint32_t *data = (uint32_t *)(shape_tensor->data_);
56 for (int i = 0; i < shape_size; i++) {
57 dst_shape[i] = (int)data[i];
58 }
59 } break;
60 default: {
61 return NNACL_ERR;
62 }
63 }
64 return NNACL_OK;
65 }
66
MakeUpInputShapes(const int input_shape0_size,const int input_shape1_size,const int * input_shape0,const int * input_shape1,int * ndim,int * in_shape0,int * in_shape1)67 void MakeUpInputShapes(const int input_shape0_size, const int input_shape1_size, const int *input_shape0,
68 const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1) {
69 if (input_shape0_size < input_shape1_size) {
70 *ndim = input_shape1_size;
71 int fill_dim_num = input_shape1_size - input_shape0_size;
72 int j = 0;
73 for (int i = 0; i < input_shape1_size; i++) {
74 if (i < fill_dim_num) {
75 in_shape0[i] = 1;
76 } else {
77 in_shape0[i] = input_shape0[j++];
78 }
79 in_shape1[i] = input_shape1[i];
80 }
81 } else if (input_shape0_size > input_shape1_size) {
82 *ndim = input_shape0_size;
83 int fill_dim_num = input_shape0_size - input_shape1_size;
84 int j = 0;
85 for (int i = 0; i < input_shape0_size; i++) {
86 if (i < fill_dim_num) {
87 in_shape1[i] = 1;
88 } else {
89 in_shape1[i] = input_shape1[j++];
90 }
91 in_shape0[i] = input_shape0[i];
92 }
93 } else {
94 for (int i = 0; i < input_shape0_size; i++) {
95 in_shape1[i] = input_shape1[i];
96 in_shape0[i] = input_shape0[i];
97 }
98 }
99 }
100
BroadCastOutputShape(const int * in_shape0,const int * in_shape1,const int ndim,int * out_shape,bool * has_broad_cast)101 int BroadCastOutputShape(const int *in_shape0, const int *in_shape1, const int ndim, int *out_shape,
102 bool *has_broad_cast) {
103 for (int i = 0; i < ndim; i++) {
104 if (in_shape0[i] != in_shape1[i]) {
105 if (in_shape0[i] == 1) {
106 out_shape[i] = in_shape1[i];
107 } else if (in_shape1[i] == 1) {
108 out_shape[i] = in_shape0[i];
109 } else {
110 return NNACL_ERR;
111 }
112 *has_broad_cast = true;
113 } else {
114 out_shape[i] = in_shape0[i];
115 }
116 }
117 return NNACL_OK;
118 }
119
BroadCastToShape(const int input_shape0_size,const int input_shape1_size,const int * input_shape0,const int * input_shape1,int * ndim,int * out_shape,bool * has_broad_cast)120 int BroadCastToShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0,
121 const int *input_shape1, int *ndim, int *out_shape, bool *has_broad_cast) {
122 if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) {
123 return NNACL_ERR;
124 }
125
126 int in_shape0[MAX_SHAPE_SIZE] = {0};
127 int in_shape1[MAX_SHAPE_SIZE] = {0};
128
129 MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1);
130 if (*ndim >= MAX_SHAPE_SIZE) {
131 return NNACL_INFER_INVALID;
132 }
133
134 return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast);
135 }
136
BroadcastToInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)137 int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
138 OpParameter *parameter) {
139 int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
140 if (ret != NNACL_OK) {
141 return ret;
142 }
143 if (inputs_size != 1 && inputs_size != 2) {
144 return NNACL_ERR;
145 }
146 if (outputs_size != 1) {
147 return NNACL_ERR;
148 }
149
150 const TensorC *input = inputs[0];
151 SetDataTypeFormat(outputs[0], input);
152 if (!InferFlag(inputs, inputs_size)) {
153 return NNACL_INFER_INVALID;
154 }
155 int dst_shape[MAX_SHAPE_SIZE] = {0};
156 int dst_shape_size;
157 const int *input_shape = input->shape_;
158 int input_shape_size = input->shape_size_;
159 int output_shape[MAX_SHAPE_SIZE] = {0};
160 int ndim = input_shape_size;
161 bool has_broad_cast = false;
162 if (inputs_size == 1) {
163 BroadcastToParameter *param = (BroadcastToParameter *)parameter;
164 dst_shape_size = (int)param->shape_size_;
165 if (dst_shape_size > MAX_SHAPE_SIZE) {
166 return NNACL_PARAM_INVALID;
167 }
168 for (int i = 0; i < dst_shape_size; i++) {
169 dst_shape[i] = param->shape_[i];
170 }
171 } else {
172 const TensorC *shape_tensor = inputs[1];
173 if (shape_tensor->data_ == NULL) {
174 return NNACL_INFER_INVALID;
175 }
176 dst_shape_size = GetElementNum(shape_tensor);
177 if (dst_shape_size > MAX_SHAPE_SIZE) {
178 return NNACL_INPUT_TENSOR_ERROR;
179 }
180 ret = GetShapeByType(shape_tensor, dst_shape_size, dst_shape);
181 if (ret != NNACL_OK) {
182 return ret;
183 }
184 for (int i = 0; i < dst_shape_size; ++i) {
185 if (dst_shape[i] == -1) {
186 dst_shape[i] = inputs[0]->shape_[i];
187 }
188 }
189 }
190
191 if (BroadCastToShape(input_shape_size, dst_shape_size, input_shape, dst_shape, &ndim, output_shape,
192 &has_broad_cast) != NNACL_OK) {
193 return NNACL_ERR;
194 }
195
196 SetShapeArray(outputs[0], output_shape, (size_t)ndim);
197 return NNACL_OK;
198 }
199
200 REG_INFER(BroadcastTo, PrimType_BroadcastTo, BroadcastToInferShape)
201