• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/splice.h"
18 
19 #include <vector>
20 
21 #include "mindapi/base/shared_ptr.h"
22 #include "mindapi/ir/value.h"
23 #include "mindapi/src/helper.h"
24 #include "ops/op_name.h"
25 #include "ops/primitive_c.h"
26 #include "utils/log_adapter.h"
27 
28 namespace mindspore {
29 namespace ops {
30 MIND_API_OPERATOR_IMPL(Splice, BaseOperator);
Init(const std::vector<int64_t> & contexts,const std::vector<int64_t> & forward_indexes,int64_t output_dims)31 void Splice::Init(const std::vector<int64_t> &contexts, const std::vector<int64_t> &forward_indexes,
32                   int64_t output_dims) {
33   this->set_context(contexts);
34   this->set_forward_indexes(forward_indexes);
35   this->set_output_dim(output_dims);
36 }
37 
set_context(const std::vector<int64_t> & contexts)38 void Splice::set_context(const std::vector<int64_t> &contexts) {
39   (void)this->AddAttr(kSpliceContext, api::MakeValue(contexts));
40 }
41 
set_forward_indexes(const std::vector<int64_t> & forward_indexes)42 void Splice::set_forward_indexes(const std::vector<int64_t> &forward_indexes) {
43   (void)this->AddAttr(kSpliceForwardIndexes, api::MakeValue(forward_indexes));
44 }
45 
set_output_dim(int64_t output_dim)46 void Splice::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kSpliceOutputDims, api::MakeValue(output_dim)); }
47 
get_context() const48 std::vector<int64_t> Splice::get_context() const {
49   auto value_ptr = GetAttr(kSpliceContext);
50   return GetValue<std::vector<int64_t>>(value_ptr);
51 }
52 
get_forward_indexes() const53 std::vector<int64_t> Splice::get_forward_indexes() const {
54   auto value_ptr = GetAttr(kSpliceForwardIndexes);
55   return GetValue<std::vector<int64_t>>(value_ptr);
56 }
57 
get_output_dim() const58 int64_t Splice::get_output_dim() const {
59   auto value_ptr = GetAttr(kSpliceOutputDims);
60   return GetValue<int64_t>(value_ptr);
61 }
62 
63 REGISTER_PRIMITIVE_C(kNameSplice, Splice);
64 }  // namespace ops
65 }  // namespace mindspore
66