• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/tensor_array.h"
18 
19 #include <vector>
20 
21 #include "mindapi/base/shared_ptr.h"
22 #include "mindapi/ir/value.h"
23 #include "mindapi/src/helper.h"
24 #include "ops/primitive_c.h"
25 #include "utils/log_adapter.h"
26 
27 namespace mindspore {
28 namespace ops {
29 MIND_API_OPERATOR_IMPL(TensorArray, BaseOperator);
30 constexpr auto kTensorArrayDynamicSize = "dynamic_size";
31 constexpr auto kTensorArrayIdenticalElementShapes = "identical_element_shapes";
32 constexpr auto kTensorArrayElementShape = "element_shape";
33 constexpr auto kTensorArrayDataType = "data_type";
34 
Init(bool dynamic_size,bool identical_element_shapes,const std::vector<int> & element_shape,int data_type)35 void TensorArray::Init(bool dynamic_size, bool identical_element_shapes, const std::vector<int> &element_shape,
36                        int data_type) {
37   this->set_dynamic_size(dynamic_size);
38   this->set_identical_element_shapes(identical_element_shapes);
39   this->set_element_shape(element_shape);
40   this->set_data_type(data_type);
41 }
42 
set_dynamic_size(bool dynamic_size)43 void TensorArray::set_dynamic_size(bool dynamic_size) {
44   (void)this->AddAttr(kTensorArrayDynamicSize, api::MakeValue(dynamic_size));
45 }
46 
set_identical_element_shapes(bool identical_element_shapes)47 void TensorArray::set_identical_element_shapes(bool identical_element_shapes) {
48   (void)this->AddAttr(kTensorArrayIdenticalElementShapes, api::MakeValue(identical_element_shapes));
49 }
50 
set_element_shape(const std::vector<int> & element_shape)51 void TensorArray::set_element_shape(const std::vector<int> &element_shape) {
52   (void)this->AddAttr(kTensorArrayElementShape, api::MakeValue(element_shape));
53 }
54 
set_data_type(int data_type)55 void TensorArray::set_data_type(int data_type) { (void)this->AddAttr(kTensorArrayDataType, api::MakeValue(data_type)); }
56 
get_dynamic_size() const57 bool TensorArray::get_dynamic_size() const {
58   auto value_ptr = GetAttr(kTensorArrayDynamicSize);
59   return GetValue<bool>(value_ptr);
60 }
61 
get_identical_element_shapes() const62 bool TensorArray::get_identical_element_shapes() const {
63   auto value_ptr = GetAttr(kTensorArrayIdenticalElementShapes);
64   return GetValue<bool>(value_ptr);
65 }
66 
get_element_shape() const67 const std::vector<int> TensorArray::get_element_shape() const {
68   auto value_ptr = GetAttr(kTensorArrayElementShape);
69   auto tmp = GetValue<std::vector<int64_t>>(value_ptr);
70   std::vector<int> res(tmp.begin(), tmp.end());
71   return res;
72 }
73 
get_data_type() const74 int TensorArray::get_data_type() const {
75   auto value_ptr = GetAttr(kTensorArrayDataType);
76   auto tmp = GetValue<int64_t>(value_ptr);
77   return static_cast<int>(tmp);
78 }
79 
80 REGISTER_PRIMITIVE_C(kNameTensorArray, TensorArray);
81 }  // namespace ops
82 }  // namespace mindspore
83