1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FILLV2_INFO_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FILLV2_INFO_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <string> 23 24 #include "frontend/parallel/ops_info/operator_info.h" 25 #include "frontend/parallel/strategy.h" 26 #include "frontend/parallel/tensor_layout/tensor_redistribution.h" 27 28 namespace mindspore { 29 namespace parallel { 30 class FillV2Info : public OperatorInfo { 31 public: FillV2Info(const std::string & name,const Shapes & input_shape,const Shapes & output_shape,const PrimitiveAttrs & attrs)32 FillV2Info(const std::string &name, const Shapes &input_shape, const Shapes &output_shape, 33 const PrimitiveAttrs &attrs) 34 : OperatorInfo(name, input_shape, output_shape, attrs, std::make_shared<FillV2Cost>()) {} 35 ~FillV2Info() = default; 36 37 std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; SetCostUnderStrategy(const StrategyPtr & strategy)38 Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); } 39 void ReplaceNodeInputOrAttrs() override; 40 41 protected: 42 Status GetAttrs() override; 43 Status CheckStrategy(const StrategyPtr &strategy) override; 44 Status InferDevMatrixShape() override; 45 Status InferTensorMap() override; 46 Status InferMirrorOps() override; InferForwardCommunication()47 Status InferForwardCommunication() override { return SUCCESS; }; 48 49 private: 50 void ResetInputsShape(); 51 void ReplaceDynamicInput(const CNodePtr &cnode, const Shape &strategy); 52 Shape GetShapeFromTensor(const tensor::TensorPtr &shape_tensor); 53 Shapes fake_inputs_shape_; // if dynamic shape, replace -1 to 1 54 bool is_dynamic_shape_ = false; 55 }; 56 } // namespace parallel 57 } // namespace mindspore 58 59 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_FILLV2_INFO_H_ 60