• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/litert/tensor_category.h"
17 #include "src/common/utils.h"
18 #include "schema/model_generated.h"
19 
20 namespace mindspore {
21 namespace lite {
TensorCategory(const int node_type,const size_t shape_num,const TypeId data_type,const size_t data_size)22 Category TensorCategory(const int node_type, const size_t shape_num, const TypeId data_type, const size_t data_size) {
23   return (node_type == NodeType_ValueNode)
24            ? (shape_num == 0 && data_size == DataTypeSize(data_type) ? Category::CONST_SCALAR : Category::CONST_TENSOR)
25            : Category::VAR;
26 }
27 
TensorCategory(const schema::Tensor & tensor)28 Category TensorCategory(const schema::Tensor &tensor) {
29   auto shape_num = tensor.dims() == nullptr ? 0 : tensor.dims()->size();
30   auto data_size = tensor.data() == nullptr ? 0 : tensor.data()->size();
31   return TensorCategory(tensor.nodeType(), shape_num, TypeId(tensor.dataType()), data_size);
32 }
33 
IsConstTensor(const schema::Tensor & tensor)34 bool IsConstTensor(const schema::Tensor &tensor) {
35   return TensorCategory(tensor) != Category::VAR;
36 }
37 }  // namespace lite
38 }  // namespace mindspore
39