1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/delegate/npu/op/strided_slice_npu.h"
18 #include "src/delegate/npu/npu_converter_utils.h"
19 #include "src/delegate/npu/pass/npu_pass_utils.h"
20
21 namespace mindspore {
IsSupport(const schema::Primitive * primitive,const std::vector<mindspore::MSTensor> & in_tensors,const std::vector<mindspore::MSTensor> & out_tensors)22 int StridedSliceNPUOp::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
23 const std::vector<mindspore::MSTensor> &out_tensors) {
24 // Only onnx StridedSlice has 5 in_tensors, of which the 4th input is axes and the 5th input is strides.
25 if (in_tensors.size() == ONNX_INPUT_SIZE) {
26 vector<int> axes;
27 size_t size = in_tensors[STRIDE_INDEX].Shape()[0];
28 axes.resize(size);
29 MS_ASSERT(in_tensors[STRIDE_INDEX].Data());
30 memcpy(axes.data(), in_tensors[STRIDE_INDEX].Data().get(), sizeof(int) * size);
31 for (int i = 0; i < axes.size(); ++i) {
32 if (i != axes[i]) {
33 MS_LOG(WARNING) << "Does not support setting axis, so the axis must be continuous.";
34 return RET_NOT_SUPPORT;
35 }
36 }
37 }
38 auto input_x = in_tensors.at(0);
39 if (input_x.DataType() != DataType::kNumberTypeFloat32 || input_x.DataType() != DataType::kNumberTypeFloat16) {
40 need_cast_ = true;
41 MS_LOG(INFO) << "StridedSlice does not support input datatype other than FLOAT. Cast op will be inserted.";
42 }
43 return RET_OK;
44 }
45
Init(const schema::Primitive * primitive,const std::vector<mindspore::MSTensor> & in_tensors,const std::vector<mindspore::MSTensor> & out_tensors)46 int StridedSliceNPUOp::Init(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
47 const std::vector<mindspore::MSTensor> &out_tensors) {
48 strided_slice_ = new (std::nothrow) hiai::op::StridedSlice(name_);
49 if (strided_slice_ == nullptr) {
50 MS_LOG(ERROR) << "New stridedSlice npu operator for op " << name_ << " failed.";
51 return RET_ERROR;
52 }
53 auto strided_slice_prim = primitive->value_as_StridedSlice();
54 if (strided_slice_prim == nullptr) {
55 MS_LOG(ERROR) << "Get null primitive value for op ." << name_;
56 return RET_ERROR;
57 }
58 begins_mask_ = strided_slice_prim->begin_mask();
59 ends_mask_ = strided_slice_prim->end_mask();
60 ellipsis_mask_ = strided_slice_prim->ellipsis_mask();
61 new_axis_mask_ = strided_slice_prim->new_axis_mask();
62 shrink_axis_mask_ = strided_slice_prim->shrink_axis_mask();
63 return RET_OK;
64 }
65
SetNPUInputs(const std::vector<mindspore::MSTensor> & in_tensors,const std::vector<mindspore::MSTensor> & out_tensors,const std::vector<ge::Operator * > & npu_inputs)66 int StridedSliceNPUOp::SetNPUInputs(const std::vector<mindspore::MSTensor> &in_tensors,
67 const std::vector<mindspore::MSTensor> &out_tensors,
68 const std::vector<ge::Operator *> &npu_inputs) {
69 strided_slice_->set_attr_begin_mask(begins_mask_);
70 strided_slice_->set_attr_ellipsis_mask(ellipsis_mask_);
71 strided_slice_->set_attr_end_mask(ends_mask_);
72 strided_slice_->set_attr_shrink_axis_mask(shrink_axis_mask_);
73 strided_slice_->set_attr_new_axis_mask(new_axis_mask_);
74 // StridedSliceV2 supports setting axes, but it will cause an endless loop.
75 if (need_cast_) {
76 auto ret = SetCast(npu_inputs[0], strided_slice_, in_tensors[0], out_tensors[0]);
77 if (ret != RET_OK) {
78 MS_LOG(ERROR) << "Insert Cast operator for op " << name_ << " failed.";
79 return ret;
80 }
81 } else {
82 strided_slice_->set_input_x(*npu_inputs[0]);
83 }
84 strided_slice_->set_input_begin(*npu_inputs[BEGIN_INDEX]);
85 strided_slice_->set_input_end(*npu_inputs[END_INDEX]);
86
87 // The strides position of onnx is the 5th, and the others are the 4th.
88 if (npu_inputs.size() == ONNX_INPUT_SIZE) {
89 strided_slice_->set_input_strides(*npu_inputs[ONNX_STRIDE_INDEX]);
90 } else {
91 strided_slice_->set_input_strides(*npu_inputs[STRIDE_INDEX]);
92 }
93 return RET_OK;
94 }
95
GetNPUOp()96 ge::Operator *StridedSliceNPUOp::GetNPUOp() {
97 if (need_cast_) {
98 return this->out_cast_;
99 } else {
100 return this->strided_slice_;
101 }
102 }
103
HandleAxis()104 int StridedSliceNPUOp::HandleAxis() {
105 begins_mask_ = NPUPassUtils::MaskDataNHWC2NCHW(begins_mask_);
106 ends_mask_ = NPUPassUtils::MaskDataNHWC2NCHW(ends_mask_);
107 ellipsis_mask_ = NPUPassUtils::MaskDataNHWC2NCHW(ellipsis_mask_);
108 shrink_axis_mask_ = NPUPassUtils::MaskDataNHWC2NCHW(shrink_axis_mask_);
109 new_axis_mask_ = NPUPassUtils::MaskDataNHWC2NCHW(new_axis_mask_);
110 return RET_OK;
111 }
112
SetCast(const ge::Operator * input,const ge::Operator * cur_op,const mindspore::MSTensor in_tensor,const mindspore::MSTensor out_tensor)113 int StridedSliceNPUOp::SetCast(const ge::Operator *input, const ge::Operator *cur_op,
114 const mindspore::MSTensor in_tensor, const mindspore::MSTensor out_tensor) {
115 in_cast_ = new (std::nothrow) hiai::op::CastT(name_ + "_in_cast");
116 out_cast_ = new (std::nothrow) hiai::op::CastT(name_ + "_out_cast");
117 if (in_cast_ == nullptr || out_cast_ == nullptr) {
118 MS_LOG(ERROR) << "New activation npu operator for op " << name_ << " failed.";
119 return RET_ERROR;
120 }
121 in_cast_->set_input_x(*input);
122 in_cast_->set_attr_src_dtype(ConverterToNPUDataType(static_cast<DataType>(in_tensor.DataType())));
123 in_cast_->set_attr_dst_dtype(ge::DT_FLOAT);
124 strided_slice_->set_input_x(*in_cast_);
125 out_cast_->set_input_x(*cur_op);
126 out_cast_->set_attr_src_dtype(ge::DT_FLOAT);
127 out_cast_->set_attr_dst_dtype(ConverterToNPUDataType(static_cast<DataType>(out_tensor.DataType())));
128 return RET_OK;
129 }
130
~StridedSliceNPUOp()131 StridedSliceNPUOp::~StridedSliceNPUOp() {
132 if (strided_slice_ != nullptr) {
133 delete strided_slice_;
134 strided_slice_ = nullptr;
135 }
136 if (in_cast_ != nullptr) {
137 delete in_cast_;
138 in_cast_ = nullptr;
139 }
140 if (out_cast_ != nullptr) {
141 delete out_cast_;
142 out_cast_ = nullptr;
143 }
144 }
145 } // namespace mindspore
146