• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/infer/broadcast_to_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 #include "nnacl/tensor_c_utils.h"
20 
GetShapeByType(const TensorC * shape_tensor,int shape_size,int * dst_shape)21 int GetShapeByType(const TensorC *shape_tensor, int shape_size, int *dst_shape) {
22   if (shape_tensor == NULL || dst_shape == NULL) {
23     return NNACL_ERR;
24   }
25   if (shape_size == 0) {
26     return NNACL_INFER_INVALID;
27   }
28   NNACL_CHECK_NULL_RETURN_ERR(shape_tensor->data_);
29   switch (shape_tensor->data_type_) {
30     case kNumberTypeInt8: {
31       int8_t *data = (int8_t *)(shape_tensor->data_);
32       for (int i = 0; i < shape_size; i++) {
33         dst_shape[i] = data[i];
34       }
35     } break;
36     case kNumberTypeInt32: {
37       int32_t *data = (int32_t *)(shape_tensor->data_);
38       for (int i = 0; i < shape_size; i++) {
39         dst_shape[i] = data[i];
40       }
41     } break;
42     case kNumberTypeInt64: {
43       int64_t *data = (int64_t *)(shape_tensor->data_);
44       for (int i = 0; i < shape_size; i++) {
45         dst_shape[i] = (int)data[i];
46       }
47     } break;
48     case kNumberTypeFloat: {
49       float *data = (float *)(shape_tensor->data_);
50       for (int i = 0; i < shape_size; i++) {
51         dst_shape[i] = data[i];
52       }
53     } break;
54     case kNumberTypeUInt32: {
55       uint32_t *data = (uint32_t *)(shape_tensor->data_);
56       for (int i = 0; i < shape_size; i++) {
57         dst_shape[i] = (int)data[i];
58       }
59     } break;
60     default: {
61       return NNACL_ERR;
62     }
63   }
64   return NNACL_OK;
65 }
66 
MakeUpInputShapes(const int input_shape0_size,const int input_shape1_size,const int * input_shape0,const int * input_shape1,int * ndim,int * in_shape0,int * in_shape1)67 void MakeUpInputShapes(const int input_shape0_size, const int input_shape1_size, const int *input_shape0,
68                        const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1) {
69   if (input_shape0_size < input_shape1_size) {
70     *ndim = input_shape1_size;
71     int fill_dim_num = input_shape1_size - input_shape0_size;
72     int j = 0;
73     for (int i = 0; i < input_shape1_size; i++) {
74       if (i < fill_dim_num) {
75         in_shape0[i] = 1;
76       } else {
77         in_shape0[i] = input_shape0[j++];
78       }
79       in_shape1[i] = input_shape1[i];
80     }
81   } else if (input_shape0_size > input_shape1_size) {
82     *ndim = input_shape0_size;
83     int fill_dim_num = input_shape0_size - input_shape1_size;
84     int j = 0;
85     for (int i = 0; i < input_shape0_size; i++) {
86       if (i < fill_dim_num) {
87         in_shape1[i] = 1;
88       } else {
89         in_shape1[i] = input_shape1[j++];
90       }
91       in_shape0[i] = input_shape0[i];
92     }
93   } else {
94     for (int i = 0; i < input_shape0_size; i++) {
95       in_shape1[i] = input_shape1[i];
96       in_shape0[i] = input_shape0[i];
97     }
98   }
99 }
100 
BroadCastOutputShape(const int * in_shape0,const int * in_shape1,const int ndim,int * out_shape,bool * has_broad_cast)101 int BroadCastOutputShape(const int *in_shape0, const int *in_shape1, const int ndim, int *out_shape,
102                          bool *has_broad_cast) {
103   for (int i = 0; i < ndim; i++) {
104     if (in_shape0[i] != in_shape1[i]) {
105       if (in_shape0[i] == 1) {
106         out_shape[i] = in_shape1[i];
107       } else if (in_shape1[i] == 1) {
108         out_shape[i] = in_shape0[i];
109       } else {
110         return NNACL_ERR;
111       }
112       *has_broad_cast = true;
113     } else {
114       out_shape[i] = in_shape0[i];
115     }
116   }
117   return NNACL_OK;
118 }
119 
BroadCastToShape(const int input_shape0_size,const int input_shape1_size,const int * input_shape0,const int * input_shape1,int * ndim,int * out_shape,bool * has_broad_cast)120 int BroadCastToShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0,
121                      const int *input_shape1, int *ndim, int *out_shape, bool *has_broad_cast) {
122   if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) {
123     return NNACL_ERR;
124   }
125 
126   int in_shape0[MAX_SHAPE_SIZE] = {0};
127   int in_shape1[MAX_SHAPE_SIZE] = {0};
128 
129   MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1);
130   if (*ndim >= MAX_SHAPE_SIZE) {
131     return NNACL_INFER_INVALID;
132   }
133 
134   return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast);
135 }
136 
BroadcastToInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)137 int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
138                           OpParameter *parameter) {
139   int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
140   if (ret != NNACL_OK) {
141     return ret;
142   }
143   if (inputs_size != 1 && inputs_size != 2) {
144     return NNACL_ERR;
145   }
146   if (outputs_size != 1) {
147     return NNACL_ERR;
148   }
149 
150   const TensorC *input = inputs[0];
151   SetDataTypeFormat(outputs[0], input);
152   if (!InferFlag(inputs, inputs_size)) {
153     return NNACL_INFER_INVALID;
154   }
155   int dst_shape[MAX_SHAPE_SIZE] = {0};
156   int dst_shape_size;
157   const int *input_shape = input->shape_;
158   int input_shape_size = input->shape_size_;
159   int output_shape[MAX_SHAPE_SIZE] = {0};
160   int ndim = input_shape_size;
161   bool has_broad_cast = false;
162   if (inputs_size == 1) {
163     BroadcastToParameter *param = (BroadcastToParameter *)parameter;
164     dst_shape_size = (int)param->shape_size_;
165     if (dst_shape_size > MAX_SHAPE_SIZE) {
166       return NNACL_PARAM_INVALID;
167     }
168     for (int i = 0; i < dst_shape_size; i++) {
169       dst_shape[i] = param->shape_[i];
170     }
171   } else {
172     const TensorC *shape_tensor = inputs[1];
173     if (shape_tensor->data_ == NULL) {
174       return NNACL_INFER_INVALID;
175     }
176     dst_shape_size = GetElementNum(shape_tensor);
177     if (dst_shape_size > MAX_SHAPE_SIZE) {
178       return NNACL_INPUT_TENSOR_ERROR;
179     }
180     ret = GetShapeByType(shape_tensor, dst_shape_size, dst_shape);
181     if (ret != NNACL_OK) {
182       return ret;
183     }
184     for (int i = 0; i < dst_shape_size; ++i) {
185       if (dst_shape[i] == -1) {
186         dst_shape[i] = inputs[0]->shape_[i];
187       }
188     }
189   }
190 
191   if (BroadCastToShape(input_shape_size, dst_shape_size, input_shape, dst_shape, &ndim, output_shape,
192                        &has_broad_cast) != NNACL_OK) {
193     return NNACL_ERR;
194   }
195 
196   SetShapeArray(outputs[0], output_shape, (size_t)ndim);
197   return NNACL_OK;
198 }
199 
200 REG_INFER(BroadcastTo, PrimType_BroadcastTo, BroadcastToInferShape)
201