• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/parallel/ops_info/fillv2_info.h"
17 
18 #include <functional>
19 #include "frontend/parallel/dynamic_creator.h"
20 
21 namespace mindspore {
22 namespace parallel {
GetAttrs()23 Status FillV2Info::GetAttrs() {
24   ResetInputsShape();
25   fake_inputs_shape_ = inputs_shape_;
26   MS_LOG(INFO) << name_ << ": The origin shape is " << inputs_shape_;
27   for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
28     if (inputs_shape_[0][i] == -1) {  // if dynamic shape, replace -1 to 1, this dimension can not be split
29       fake_inputs_shape_[0][i] = 1;
30       is_dynamic_shape_ = true;
31     }
32   }
33 
34   if (is_dynamic_shape_) {
35     MS_LOG(INFO) << name_ << ": the fake shape is " << fake_inputs_shape_;
36   }
37 
38   return SUCCESS;
39 }
40 
CheckStrategy(const StrategyPtr & strategy)41 Status FillV2Info::CheckStrategy(const StrategyPtr &strategy) {
42   MS_EXCEPTION_IF_NULL(strategy);
43   if (CheckStrategyValue(strategy, fake_inputs_shape_) != SUCCESS) {
44     MS_LOG(ERROR) << name_ << ": Invalid strategy " << strategy->ToString();
45     return FAILED;
46   }
47   return SUCCESS;
48 }
49 
InferDevMatrixShape()50 Status FillV2Info::InferDevMatrixShape() {
51   MS_EXCEPTION_IF_NULL(strategy_);
52   auto strategies = strategy_->GetInputDim();
53   if (strategies.empty()) {
54     MS_LOG(ERROR) << name_ << ": Infer device matric failed, inputs_startegy is empty.";
55     return FAILED;
56   }
57   dev_matrix_shape_ = strategies.at(0);
58   return SUCCESS;
59 }
60 
InferTensorMap()61 Status FillV2Info::InferTensorMap() {
62   TensorMap tensor_map;
63   std::vector<Dimensions> strategies = strategy_->GetInputDim();
64   auto input_shape_strategy = strategies.at(0);
65   auto size = input_shape_strategy.size();
66   for (size_t i = 0; i < size; ++i) {
67     tensor_map.push_back(SizeToLong(size - i - 1));
68   }
69   inputs_tensor_map_.push_back(tensor_map);
70   (void)inputs_tensor_map_.emplace_back(TensorMap());
71   outputs_tensor_map_.push_back(tensor_map);
72   return SUCCESS;
73 }
74 
GenerateOpStrategies(int64_t stage_id)75 std::vector<StrategyPtr> FillV2Info::GenerateOpStrategies(int64_t stage_id) {
76   Shape input0_split(fake_inputs_shape_.at(0).size(), 1);
77   Shape input1_split;
78   Shapes splittable_inputs = {input0_split, input1_split};
79   std::vector<StrategyPtr> sp_vector;
80   if (GenerateStrategiesForIndependentInputs(stage_id, fake_inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
81     MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
82   }
83   if (sp_vector.empty()) {
84     MS_LOG(EXCEPTION) << name_ << ": No available strategy.";
85   }
86   return sp_vector;
87 }
88 
ReplaceDynamicInput(const CNodePtr & cnode,const Shape & strategy)89 void FillV2Info::ReplaceDynamicInput(const CNodePtr &cnode, const Shape &strategy) {
90   auto dynamic_node = cnode->input(kIndex1);
91   if (!IsPrimitiveCNode(dynamic_node, prim::kPrimMakeTuple)) {
92     MS_LOG(EXCEPTION) << name_ << "The dynamic input must be MakeTuple cnode, but got "
93                       << dynamic_node->fullname_with_scope();
94     return;
95   }
96 
97   auto make_tuple_cnode = dynamic_node->cast<CNodePtr>();
98   MS_EXCEPTION_IF_NULL(make_tuple_cnode);
99 
100   for (size_t i = 1; i < make_tuple_cnode->inputs().size(); ++i) {
101     if (strategy[i - 1] <= 1) {
102       continue;
103     }
104 
105     auto input_node = make_tuple_cnode->input(i);
106     MS_EXCEPTION_IF_NULL(input_node);
107     auto value_node = GetValueNode(input_node);
108     if (value_node != nullptr && value_node->isa<Int64Imm>()) {
109       auto origin_ele = GetValue<int64_t>(value_node);
110       if (origin_ele % strategy[i - 1] != 0) {
111         MS_LOG(EXCEPTION) << name_ << ": the origin shape is " << origin_ele << ", can not be div by shard size "
112                           << strategy[i - 1];
113       }
114       int64_t replace_shape = origin_ele / strategy[i - 1];
115       MS_LOG(INFO) << name_ << ": replace shape from " << origin_ele << " to " << replace_shape << ", the index is "
116                    << (i - 1);
117       auto replace_value_ptr = MakeValue(replace_shape);
118       auto replace_value_node = std::make_shared<ValueNode>(replace_value_ptr);
119       make_tuple_cnode->set_input(i, replace_value_node);
120     }
121   }
122 }
123 
ReplaceNodeInputOrAttrs()124 void FillV2Info::ReplaceNodeInputOrAttrs() {
125   Shape strategy = strategy_->GetInputDim()[0];
126   if (std::accumulate(strategy.cbegin(), strategy.cend(), 1, std::multiplies<int64_t>()) == 1) {
127     return;
128   }
129 
130   for (auto &cnode : cnodes_) {
131     MS_EXCEPTION_IF_NULL(cnode);
132     if (!is_dynamic_shape_) {  // static shape
133       auto input_shape = inputs_shape_.at(kIndex0);
134       for (size_t i = 0; i < strategy.size(); i++) {
135         input_shape[i] /= strategy[i];
136       }
137       auto func_graph = cnode->func_graph();
138       MS_EXCEPTION_IF_NULL(func_graph);
139       auto manager = func_graph->manager();
140       MS_EXCEPTION_IF_NULL(manager);
141       auto val_tensor_node = NewValueNode(MakeValue(std::make_shared<tensor::Tensor>(input_shape)));
142       MS_LOG(INFO) << name_ << ": the new shape is " << input_shape;
143       cnode->set_input(kIndex1, val_tensor_node);
144     } else {  // dynamic shape
145       ReplaceDynamicInput(cnode, strategy);
146     }
147   }
148 }
149 
InferMirrorOps()150 Status FillV2Info::InferMirrorOps() {
151   if (OperatorInfo::InferMirrorOps() != SUCCESS) {
152     return FAILED;
153   }
154   // No need to insert mirror ops
155   if (mirror_ops_.empty()) {
156     return SUCCESS;
157   }
158   if (mirror_ops_.size() == kSizeOne) {
159     // Insert empty mirror op for shape
160     (void)mirror_ops_.insert(mirror_ops_.begin(), OperatorVector());
161   }
162   return SUCCESS;
163 }
164 
GetShapeFromTensor(const tensor::TensorPtr & shape_tensor)165 Shape FillV2Info::GetShapeFromTensor(const tensor::TensorPtr &shape_tensor) {
166   MS_EXCEPTION_IF_NULL(shape_tensor);
167   auto dim = shape_tensor->DataDim();
168   if (IntToSize(dim) != kDim1) {
169     MS_LOG(EXCEPTION) << name_ << ": The rank of 'input_shape' must be 1, but got rank " << dim;
170   }
171   auto size = shape_tensor->DataSize();
172   if (size <= 0) {
173     MS_LOG(EXCEPTION) << name_ << ": The size of 'input_shape' must be greater than 0, but got size " << size;
174   }
175   auto dtype = shape_tensor->data_type();
176   auto data = shape_tensor->data_c();
177   MS_EXCEPTION_IF_NULL(data);
178   if (dtype == kNumberTypeInt32) {
179     auto shape_data = static_cast<int32_t *>(data);
180     Shape shape(shape_data, shape_data + size);
181     return shape;
182   } else if (dtype == kNumberTypeInt64) {
183     auto shape_data = static_cast<int64_t *>(data);
184     Shape shape(shape_data, shape_data + size);
185     return shape;
186   }
187   MS_LOG(EXCEPTION) << name_ << ": The dtype of 'input_shape' must be int32 or int64, but got type "
188                     << TypeIdToString(dtype);
189 }
190 
ResetInputsShape()191 void FillV2Info::ResetInputsShape() {
192   auto input_value_shape = input_value_[0];
193   if (input_value_shape == nullptr) {
194     MS_LOG(EXCEPTION) << name_ << ": The value of input 'shape' must be a constant. "
195                       << "If you pass this value via construct, try to define its value in __init__";
196   }
197   MS_EXCEPTION_IF_NULL(input_value_shape);
198   if (input_value_shape->isa<tensor::Tensor>()) {
199     auto tensor_shape_ptr = GetValue<tensor::TensorPtr>(input_value_shape);
200     auto shape = GetShapeFromTensor(tensor_shape_ptr);
201     inputs_shape_[0] = shape;
202     is_parameter_[0] = false;
203     return;
204   } else if (input_value_shape->isa<ValueTuple>()) {
205     (void)inputs_shape_.insert(inputs_shape_.begin(), GetValue<Shape>(input_value_shape));
206     (void)is_parameter_.insert(is_parameter_.begin(), false);
207     return;
208   }
209   MS_LOG(EXCEPTION) << name_ << ": The type of input 'shape' must be Tensor or Tuple, but got "
210                     << input_value_shape->type()->ToString();
211 }
212 
213 REGISTER(FillV2Info);
214 }  // namespace parallel
215 }  // namespace mindspore
216