• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/transpose_info.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "frontend/parallel/device_manager.h"
23 #include "frontend/parallel/device_matrix.h"
24 #include "frontend/parallel/step_parallel.h"
25 #include "utils/convert_utils.h"
26 #include "utils/log_adapter.h"
27 
28 namespace mindspore {
29 namespace parallel {
CheckStrategy(const StrategyPtr & strategy)30 Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
31 
InferDevMatrixShape()32 Status TransposeInfo::InferDevMatrixShape() {
33   Strategys stra = strategy_->GetInputDim();
34   input_strategy_ = stra.at(0);
35   for (auto &iter : input_strategy_) {
36     dev_matrix_shape_.push_back(iter);
37   }
38   return SUCCESS;
39 }
40 
41 // there is no Parameter for Transpose Primitive, so no need to do all reduce
InferMirrorOps()42 Status TransposeInfo::InferMirrorOps() { return SUCCESS; }
43 
44 // there is no reduction dimension for forward computation of Transpose Primitive, so no need to do all reduce
InferForwardCommunication()45 Status TransposeInfo::InferForwardCommunication() { return SUCCESS; }
46 
47 /*
48  * get perm input of Transpose Primitive
49  * perm is a permutation of the dimensions of input
50  * the result is saved in axis_v_
51  */
ComputeAxis()52 Status TransposeInfo::ComputeAxis() {
53   if (input_value_[1] == nullptr) {
54     MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr.";
55     return FAILED;
56   }
57   std::vector<ValuePtr> elements;
58   ValueTuplePtr dim_tuple = input_value_[1]->cast<ValueTuplePtr>();
59   if (dim_tuple == nullptr) {
60     MS_LOG(ERROR) << name_ << ": input_value_[1] must be ValueTuplePtr.";
61     return FAILED;
62   }
63   elements = dim_tuple->value();
64   if (elements.size() != inputs_shape_[0].size()) {
65     MS_LOG(ERROR) << name_ << ": elements size must equal to inputs shape 0 size.";
66     return FAILED;
67   }
68   axis_v_.clear();
69   for (auto &element : elements) {
70     MS_EXCEPTION_IF_NULL(element);
71     if (element->isa<Int64Imm>()) {
72       int64_t axis = element->cast<Int64ImmPtr>()->value();
73       axis_v_.push_back(axis);
74     } else {
75       MS_LOG(ERROR) << name_ << ": The value of axis must be int32.";
76       return FAILED;
77     }
78   }
79 
80   for (int64_t i = 0; i < SizeToLong(axis_v_.size()); i++) {
81     auto iter = std::find(axis_v_.begin(), axis_v_.end(), i);
82     if (iter == axis_v_.end()) {
83       MS_LOG(ERROR) << name_ << ": axis_v_ must be a permutation.";
84     }
85   }
86   return SUCCESS;
87 }
88 
89 // the output tensor map is the permutation of input tensor map, the permutation is axis_v
InferTensorMap()90 Status TransposeInfo::InferTensorMap() {
91   if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
92     MS_LOG(ERROR) << name_ << ": inputs_shape_ and outputs_shape_ size must be 1, inputs shape and outputs shape is "
93                   << inputs_shape_.size() << ", " << outputs_shape_.size();
94     return FAILED;
95   }
96 
97   Shape tensor_map_index_input;
98   for (size_t j = 0; j < inputs_shape_[0].size(); ++j) {
99     tensor_map_index_input.push_back(SizeToLong(inputs_shape_[0].size() - j - 1));
100   }
101   inputs_tensor_map_.push_back(tensor_map_index_input);
102 
103   Shape tensor_map_index_output = tensor_map_index_input;
104   for (uint64_t i = 0; i < tensor_map_index_output.size(); i++) {
105     tensor_map_index_output[i] = tensor_map_index_input[LongToUlong(axis_v_[i])];
106   }
107   outputs_tensor_map_.push_back(tensor_map_index_output);
108   return SUCCESS;
109 }
110 
111 // compute axis_v_ during this method
GetAttrs()112 Status TransposeInfo::GetAttrs() { return ComputeAxis(); }
113 
Init(const StrategyPtr & strategy)114 Status TransposeInfo::Init(const StrategyPtr &strategy) {
115   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
116     MS_LOG(ERROR) << name_ << ": Init failed.";
117     return FAILED;
118   }
119   MS_LOG(INFO) << name_ << ": Init success.";
120   return SUCCESS;
121 }
122 
InitForCostModel(const StrategyPtr & strategy)123 Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) {
124   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
125     MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
126     return FAILED;
127   }
128 
129   MS_LOG(INFO) << name_ << ": Init for cost model success.";
130   return SUCCESS;
131 }
132 
SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & strategy)133 Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
134   return SetCostUnderStrategyBase(strategy);
135 }
136 
GenerateOpStrategies(int64_t stage_id)137 std::vector<StrategyPtr> TransposeInfo::GenerateOpStrategies(int64_t stage_id) {
138   Shape input0_split(inputs_shape_[0].size(), 1);
139   Shapes splittable_inputs = {input0_split};
140   std::vector<StrategyPtr> sp_vector;
141   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
142     MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed";
143   }
144 
145   return sp_vector;
146 }
147 }  // namespace parallel
148 }  // namespace mindspore
149