• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/infer/addn_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 #include "nnacl/tensor_c_utils.h"
20 
AddnInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)21 int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
22                    OpParameter *parameter) {
23   int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
24   if (check_ret != NNACL_OK) {
25     return check_ret;
26   }
27 
28   const TensorC *input = inputs[0];
29   TensorC *output = outputs[0];
30   if (inputs_size < 2) {
31     return NNACL_ERR;
32   }
33   SetDataTypeFormat(output, input);
34   if (!InferFlag(inputs, inputs_size)) {
35     return NNACL_INFER_INVALID;
36   }
37 
38   size_t max_dims = input->shape_size_;
39   size_t max_dims_idx = 0;
40 
41   // check zerp dimension
42   for (size_t i = 0; i < max_dims; i++) {
43     NNACL_CHECK_FALSE(input->shape_[i] == 0, NNACL_ERR);
44   }
45 
46   // determine max_dims
47   for (size_t i = 1; i < inputs_size; ++i) {
48     if (inputs[i]->shape_size_ > max_dims) {
49       max_dims = inputs[i]->shape_size_;
50       max_dims_idx = i;
51     }
52   }
53   ShapeSet(output->shape_, &output->shape_size_, inputs[max_dims_idx]->shape_, inputs[max_dims_idx]->shape_size_);
54 
55   // make sure all elements have the same size or 1 (broadcasting) in all dimensions
56   for (size_t i = 1; i < inputs_size; ++i) {
57     if ((inputs[i]->shape_size_ != max_dims) && (GetElementNum(inputs[i]) != GetElementNum(inputs[max_dims_idx]))) {
58       return NNACL_ERR;
59     }
60     if (inputs[i]->shape_size_ == max_dims) {
61       for (size_t j = 0; j < max_dims; j++) {
62         if (inputs[i]->shape_[j] != inputs[max_dims_idx]->shape_[j] && inputs[i]->shape_[j] != 1 &&
63             inputs[max_dims_idx]->shape_[j] != 1) {
64           return NNACL_ERR;
65         }
66       }
67     }
68   }
69 
70   for (size_t d = 0; d < inputs[max_dims_idx]->shape_size_; ++d) {
71     size_t max_dim = 0;
72     for (size_t i = 0; i < inputs_size; ++i) {
73       size_t shift = max_dims - (size_t)(inputs[i]->shape_size_);
74       size_t dim = (i < shift) ? 1 : (size_t)(inputs[i]->shape_[d]);
75       if (dim > max_dim) {
76         max_dim = dim;
77       }
78     }
79     output->shape_[d] = (int)(max_dim);  // set the biggest dimension in the output tensor
80   }
81 
82   return NNACL_OK;
83 }
84 
85 REG_INFER(AddN, PrimType_AddN, AddnInferShape)
86