1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/tensor_array.h"
18
19 #include <vector>
20
21 #include "mindapi/base/shared_ptr.h"
22 #include "mindapi/ir/value.h"
23 #include "mindapi/src/helper.h"
24 #include "ops/primitive_c.h"
25 #include "utils/log_adapter.h"
26
27 namespace mindspore {
28 namespace ops {
29 MIND_API_OPERATOR_IMPL(TensorArray, BaseOperator);
30 constexpr auto kTensorArrayDynamicSize = "dynamic_size";
31 constexpr auto kTensorArrayIdenticalElementShapes = "identical_element_shapes";
32 constexpr auto kTensorArrayElementShape = "element_shape";
33 constexpr auto kTensorArrayDataType = "data_type";
34
Init(bool dynamic_size,bool identical_element_shapes,const std::vector<int> & element_shape,int data_type)35 void TensorArray::Init(bool dynamic_size, bool identical_element_shapes, const std::vector<int> &element_shape,
36 int data_type) {
37 this->set_dynamic_size(dynamic_size);
38 this->set_identical_element_shapes(identical_element_shapes);
39 this->set_element_shape(element_shape);
40 this->set_data_type(data_type);
41 }
42
set_dynamic_size(bool dynamic_size)43 void TensorArray::set_dynamic_size(bool dynamic_size) {
44 (void)this->AddAttr(kTensorArrayDynamicSize, api::MakeValue(dynamic_size));
45 }
46
set_identical_element_shapes(bool identical_element_shapes)47 void TensorArray::set_identical_element_shapes(bool identical_element_shapes) {
48 (void)this->AddAttr(kTensorArrayIdenticalElementShapes, api::MakeValue(identical_element_shapes));
49 }
50
set_element_shape(const std::vector<int> & element_shape)51 void TensorArray::set_element_shape(const std::vector<int> &element_shape) {
52 (void)this->AddAttr(kTensorArrayElementShape, api::MakeValue(element_shape));
53 }
54
set_data_type(int data_type)55 void TensorArray::set_data_type(int data_type) { (void)this->AddAttr(kTensorArrayDataType, api::MakeValue(data_type)); }
56
get_dynamic_size() const57 bool TensorArray::get_dynamic_size() const {
58 auto value_ptr = GetAttr(kTensorArrayDynamicSize);
59 return GetValue<bool>(value_ptr);
60 }
61
get_identical_element_shapes() const62 bool TensorArray::get_identical_element_shapes() const {
63 auto value_ptr = GetAttr(kTensorArrayIdenticalElementShapes);
64 return GetValue<bool>(value_ptr);
65 }
66
get_element_shape() const67 const std::vector<int> TensorArray::get_element_shape() const {
68 auto value_ptr = GetAttr(kTensorArrayElementShape);
69 auto tmp = GetValue<std::vector<int64_t>>(value_ptr);
70 std::vector<int> res(tmp.begin(), tmp.end());
71 return res;
72 }
73
get_data_type() const74 int TensorArray::get_data_type() const {
75 auto value_ptr = GetAttr(kTensorArrayDataType);
76 auto tmp = GetValue<int64_t>(value_ptr);
77 return static_cast<int>(tmp);
78 }
79
80 REGISTER_PRIMITIVE_C(kNameTensorArray, TensorArray);
81 } // namespace ops
82 } // namespace mindspore
83