• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/infer/log_softmax_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 
LogSoftmaxInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)20 int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
21                          OpParameter *parameter) {
22   const int input_size_limit = 1;
23   const int output_size_limit = 1;
24   if (inputs_size != input_size_limit || outputs_size != output_size_limit) {
25     return NNACL_ERR;
26   }
27   int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
28   if (check_ret != NNACL_OK) {
29     return check_ret;
30   }
31 
32   const TensorC *input = inputs[0];
33   TensorC *output = outputs[0];
34   SetDataTypeFormat(output, input);
35 
36   if (!InferFlag(inputs, inputs_size)) {
37     return NNACL_INFER_INVALID;
38   }
39   if (input->shape_size_ > 5) {
40     return NNACL_ERR;
41   }
42   SetShapeTensor(output, input);
43   SoftmaxParameter *param = (SoftmaxParameter *)parameter;
44   NNACL_CHECK_NULL_RETURN_ERR(param);
45   if (param->axis_ < (-1 * (int)(input->shape_size_)) || param->axis_ >= (int)(input->shape_size_)) {
46     return NNACL_PARAM_INVALID;
47   }
48   return NNACL_OK;
49 }
50 
51 REG_INFER(LogSoftmax, PrimType_LogSoftmax, LogSoftmaxInferShape)
52