1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/splice.h"
18
19 #include <vector>
20
21 #include "mindapi/base/shared_ptr.h"
22 #include "mindapi/ir/value.h"
23 #include "mindapi/src/helper.h"
24 #include "ops/op_name.h"
25 #include "ops/primitive_c.h"
26 #include "utils/log_adapter.h"
27
28 namespace mindspore {
29 namespace ops {
30 MIND_API_OPERATOR_IMPL(Splice, BaseOperator);
Init(const std::vector<int64_t> & contexts,const std::vector<int64_t> & forward_indexes,int64_t output_dims)31 void Splice::Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes,
32 int64_t output_dims) {
33 this->set_context(contexts);
34 this->set_forward_indexes(forward_indexes);
35 this->set_output_dim(output_dims);
36 }
37
set_context(const std::vector<int64_t> & contexts)38 void Splice::set_context(const std::vector<int64_t> &contexts) {
39 (void)this->AddAttr(kSpliceContext, api::MakeValue(contexts));
40 }
41
set_forward_indexes(const std::vector<int64_t> & forward_indexes)42 void Splice::set_forward_indexes(const std::vector<int64_t> &forward_indexes) {
43 (void)this->AddAttr(kSpliceForwardIndexes, api::MakeValue(forward_indexes));
44 }
45
set_output_dim(int64_t output_dim)46 void Splice::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kSpliceOutputDims, api::MakeValue(output_dim)); }
47
get_context() const48 std::vector<int64_t> Splice::get_context() const {
49 auto value_ptr = GetAttr(kSpliceContext);
50 return GetValue<std::vector<int64_t>>(value_ptr);
51 }
52
get_forward_indexes() const53 std::vector<int64_t> Splice::get_forward_indexes() const {
54 auto value_ptr = GetAttr(kSpliceForwardIndexes);
55 return GetValue<std::vector<int64_t>>(value_ptr);
56 }
57
get_output_dim() const58 int64_t Splice::get_output_dim() const {
59 auto value_ptr = GetAttr(kSpliceOutputDims);
60 return GetValue<int64_t>(value_ptr);
61 }
62
63 REGISTER_PRIMITIVE_C(kNameSplice, Splice);
64 } // namespace ops
65 } // namespace mindspore
66