1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/parallel/ops_info/fillv2_info.h"
17
18 #include <functional>
19 #include "frontend/parallel/dynamic_creator.h"
20
21 namespace mindspore {
22 namespace parallel {
GetAttrs()23 Status FillV2Info::GetAttrs() {
24 ResetInputsShape();
25 fake_inputs_shape_ = inputs_shape_;
26 MS_LOG(INFO) << name_ << ": The origin shape is " << inputs_shape_;
27 for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
28 if (inputs_shape_[0][i] == -1) { // if dynamic shape, replace -1 to 1, this dimension can not be split
29 fake_inputs_shape_[0][i] = 1;
30 is_dynamic_shape_ = true;
31 }
32 }
33
34 if (is_dynamic_shape_) {
35 MS_LOG(INFO) << name_ << ": the fake shape is " << fake_inputs_shape_;
36 }
37
38 return SUCCESS;
39 }
40
CheckStrategy(const StrategyPtr & strategy)41 Status FillV2Info::CheckStrategy(const StrategyPtr &strategy) {
42 MS_EXCEPTION_IF_NULL(strategy);
43 if (CheckStrategyValue(strategy, fake_inputs_shape_) != SUCCESS) {
44 MS_LOG(ERROR) << name_ << ": Invalid strategy " << strategy->ToString();
45 return FAILED;
46 }
47 return SUCCESS;
48 }
49
InferDevMatrixShape()50 Status FillV2Info::InferDevMatrixShape() {
51 MS_EXCEPTION_IF_NULL(strategy_);
52 auto strategies = strategy_->GetInputDim();
53 if (strategies.empty()) {
54 MS_LOG(ERROR) << name_ << ": Infer device matric failed, inputs_startegy is empty.";
55 return FAILED;
56 }
57 dev_matrix_shape_ = strategies.at(0);
58 return SUCCESS;
59 }
60
InferTensorMap()61 Status FillV2Info::InferTensorMap() {
62 TensorMap tensor_map;
63 std::vector<Dimensions> strategies = strategy_->GetInputDim();
64 auto input_shape_strategy = strategies.at(0);
65 auto size = input_shape_strategy.size();
66 for (size_t i = 0; i < size; ++i) {
67 tensor_map.push_back(SizeToLong(size - i - 1));
68 }
69 inputs_tensor_map_.push_back(tensor_map);
70 (void)inputs_tensor_map_.emplace_back(TensorMap());
71 outputs_tensor_map_.push_back(tensor_map);
72 return SUCCESS;
73 }
74
GenerateOpStrategies(int64_t stage_id)75 std::vector<StrategyPtr> FillV2Info::GenerateOpStrategies(int64_t stage_id) {
76 Shape input0_split(fake_inputs_shape_.at(0).size(), 1);
77 Shape input1_split;
78 Shapes splittable_inputs = {input0_split, input1_split};
79 std::vector<StrategyPtr> sp_vector;
80 if (GenerateStrategiesForIndependentInputs(stage_id, fake_inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
81 MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
82 }
83 if (sp_vector.empty()) {
84 MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
85 }
86 return sp_vector;
87 }
88
ReplaceDynamicInput(const CNodePtr & cnode,const Shape & strategy)89 void FillV2Info::ReplaceDynamicInput(const CNodePtr &cnode, const Shape &strategy) {
90 auto dynamic_node = cnode->input(kIndex1);
91 if (!IsPrimitiveCNode(dynamic_node, prim::kPrimMakeTuple)) {
92 MS_LOG(EXCEPTION) << name_ << "The dynamic input must be MakeTuple cnode, but got "
93 << dynamic_node->fullname_with_scope();
94 return;
95 }
96
97 auto make_tuple_cnode = dynamic_node->cast<CNodePtr>();
98 MS_EXCEPTION_IF_NULL(make_tuple_cnode);
99
100 for (size_t i = 1; i < make_tuple_cnode->inputs().size(); ++i) {
101 if (strategy[i - 1] <= 1) {
102 continue;
103 }
104
105 auto input_node = make_tuple_cnode->input(i);
106 MS_EXCEPTION_IF_NULL(input_node);
107 auto value_node = GetValueNode(input_node);
108 if (value_node != nullptr && value_node->isa<Int64Imm>()) {
109 auto origin_ele = GetValue<int64_t>(value_node);
110 if (origin_ele % strategy[i - 1] != 0) {
111 MS_LOG(EXCEPTION) << name_ << ": the origin shape is " << origin_ele << ", can not be div by shard size "
112 << strategy[i - 1];
113 }
114 int64_t replace_shape = origin_ele / strategy[i - 1];
115 MS_LOG(INFO) << name_ << ": replace shape from " << origin_ele << " to " << replace_shape << ", the index is "
116 << (i - 1);
117 auto replace_value_ptr = MakeValue(replace_shape);
118 auto replace_value_node = std::make_shared<ValueNode>(replace_value_ptr);
119 make_tuple_cnode->set_input(i, replace_value_node);
120 }
121 }
122 }
123
ReplaceNodeInputOrAttrs()124 void FillV2Info::ReplaceNodeInputOrAttrs() {
125 Shape strategy = strategy_->GetInputDim()[0];
126 if (std::accumulate(strategy.cbegin(), strategy.cend(), 1, std::multiplies<int64_t>()) == 1) {
127 return;
128 }
129
130 for (auto &cnode : cnodes_) {
131 MS_EXCEPTION_IF_NULL(cnode);
132 if (!is_dynamic_shape_) { // static shape
133 auto input_shape = inputs_shape_.at(kIndex0);
134 for (size_t i = 0; i < strategy.size(); i++) {
135 input_shape[i] /= strategy[i];
136 }
137 auto func_graph = cnode->func_graph();
138 MS_EXCEPTION_IF_NULL(func_graph);
139 auto manager = func_graph->manager();
140 MS_EXCEPTION_IF_NULL(manager);
141 auto val_tensor_node = NewValueNode(MakeValue(std::make_shared<tensor::Tensor>(input_shape)));
142 MS_LOG(INFO) << name_ << ": the new shape is " << input_shape;
143 cnode->set_input(kIndex1, val_tensor_node);
144 } else { // dynamic shape
145 ReplaceDynamicInput(cnode, strategy);
146 }
147 }
148 }
149
InferMirrorOps()150 Status FillV2Info::InferMirrorOps() {
151 if (OperatorInfo::InferMirrorOps() != SUCCESS) {
152 return FAILED;
153 }
154 // No need to insert mirror ops
155 if (mirror_ops_.empty()) {
156 return SUCCESS;
157 }
158 if (mirror_ops_.size() == kSizeOne) {
159 // Insert empty mirror op for shape
160 (void)mirror_ops_.insert(mirror_ops_.begin(), OperatorVector());
161 }
162 return SUCCESS;
163 }
164
GetShapeFromTensor(const tensor::TensorPtr & shape_tensor)165 Shape FillV2Info::GetShapeFromTensor(const tensor::TensorPtr &shape_tensor) {
166 MS_EXCEPTION_IF_NULL(shape_tensor);
167 auto dim = shape_tensor->DataDim();
168 if (IntToSize(dim) != kDim1) {
169 MS_LOG(EXCEPTION) << name_ << ": The rank of 'input_shape' must be 1, but got rank " << dim;
170 }
171 auto size = shape_tensor->DataSize();
172 if (size <= 0) {
173 MS_LOG(EXCEPTION) << name_ << ": The size of 'input_shape' must be greater than 0, but got size " << size;
174 }
175 auto dtype = shape_tensor->data_type();
176 auto data = shape_tensor->data_c();
177 MS_EXCEPTION_IF_NULL(data);
178 if (dtype == kNumberTypeInt32) {
179 auto shape_data = static_cast<int32_t *>(data);
180 Shape shape(shape_data, shape_data + size);
181 return shape;
182 } else if (dtype == kNumberTypeInt64) {
183 auto shape_data = static_cast<int64_t *>(data);
184 Shape shape(shape_data, shape_data + size);
185 return shape;
186 }
187 MS_LOG(EXCEPTION) << name_ << ": The dtype of 'input_shape' must be int32 or int64, but got type "
188 << TypeIdToString(dtype);
189 }
190
ResetInputsShape()191 void FillV2Info::ResetInputsShape() {
192 auto input_value_shape = input_value_[0];
193 if (input_value_shape == nullptr) {
194 MS_LOG(EXCEPTION) << name_ << ": The value of input 'shape' must be a constant. "
195 << "If you pass this value via construct, try to define its value in __init__";
196 }
197 MS_EXCEPTION_IF_NULL(input_value_shape);
198 if (input_value_shape->isa<tensor::Tensor>()) {
199 auto tensor_shape_ptr = GetValue<tensor::TensorPtr>(input_value_shape);
200 auto shape = GetShapeFromTensor(tensor_shape_ptr);
201 inputs_shape_[0] = shape;
202 is_parameter_[0] = false;
203 return;
204 } else if (input_value_shape->isa<ValueTuple>()) {
205 (void)inputs_shape_.insert(inputs_shape_.begin(), GetValue<Shape>(input_value_shape));
206 (void)is_parameter_.insert(is_parameter_.begin(), false);
207 return;
208 }
209 MS_LOG(EXCEPTION) << name_ << ": The type of input 'shape' must be Tensor or Tuple, but got "
210 << input_value_shape->type()->ToString();
211 }
212
213 REGISTER(FillV2Info);
214 } // namespace parallel
215 } // namespace mindspore
216