1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" 18 #include "frontend/parallel/status.h" 19 #include "frontend/parallel/tensor_layout/shape_util.h" 20 21 namespace mindspore { 22 namespace parallel { CheckValidTransfer()23Status ReshapeLayoutTransfer::CheckValidTransfer() { 24 if (!IsSameDeviceArrangement()) { 25 return Status::FAILED; 26 } 27 return Status::SUCCESS; 28 } 29 UnifyDeviceArrangementAndTensorShape() const30std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { 31 bool is_unified = IsSameTensorShape(); 32 std::shared_ptr<ReshapeLayoutTransfer> out_layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); 33 if (out_layout_ptr == nullptr) { 34 return nullptr; 35 } 36 while (!is_unified) { 37 std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); 38 if (temp_layout_ptr == nullptr) { 39 out_layout_ptr->SetExpandAble(false); 40 return out_layout_ptr; 41 } 42 out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); 43 if (out_layout_ptr == nullptr) { 44 std::shared_ptr<ReshapeLayoutTransfer> layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); 45 layout_ptr->SetExpandAble(false); 46 return layout_ptr; 47 } 48 is_unified = out_layout_ptr->IsSameTensorShape(); 49 } 50 return out_layout_ptr; 51 } 52 ExtendFromTensorShapeByTo() const53std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendFromTensorShapeByTo() const { 54 std::shared_ptr<ReshapeLayoutTransfer> out_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); 55 bool is_expanded = FromTensorShapeCanBeExpandByTo(); 56 while (!is_expanded) { 57 out_ptr = out_ptr->ExtendFromTensorShapeByExpandedTensorShape(); 58 if (out_ptr == nullptr) { 59 return nullptr; 60 } 61 is_expanded = out_ptr->FromTensorShapeCanBeExpandByTo(); 62 } 63 return out_ptr; 64 } 65 ExtendToTensorShapeByFrom() const66std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendToTensorShapeByFrom() const { 67 std::shared_ptr<ReshapeLayoutTransfer> out_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); 68 bool is_expanded = ToTensorShapeCanBeExpandByFrom(); 69 while (!is_expanded) { 70 out_ptr = out_ptr->ExtendToTensorShapeByExpandedTensorShape(); 71 if (out_ptr == nullptr) { 72 return nullptr; 73 } 74 is_expanded = out_ptr->ToTensorShapeCanBeExpandByFrom(); 75 } 76 return out_ptr; 77 } 78 FromTensorShapeCanBeExpandByTo() const79bool ReshapeLayoutTransfer::FromTensorShapeCanBeExpandByTo() const { 80 return from_in_.TensorShapeCanBeExpanded(to_in_.tensor_shape()); 81 } 82 ToTensorShapeCanBeExpandByFrom() const83bool ReshapeLayoutTransfer::ToTensorShapeCanBeExpandByFrom() const { 84 return to_in_.TensorShapeCanBeExpanded(from_in_.tensor_shape()); 85 } 86 ExtendFromTensorShapeByExpandedTensorShape() const87std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendFromTensorShapeByExpandedTensorShape() const { 88 std::shared_ptr<Arrangement> expanded_shape_ptr = ComputeExpandedFromTensorShapeByTo(); 89 if (expanded_shape_ptr == nullptr) { 90 return nullptr; 91 } 92 return ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); 93 } 94 ExtendToTensorShapeByExpandedTensorShape() const95std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendToTensorShapeByExpandedTensorShape() const { 96 std::shared_ptr<ReshapeLayoutTransfer> exchanged_from_and_to_ptr = ExchangeFromAndTo(); 97 if (exchanged_from_and_to_ptr == nullptr) { 98 return nullptr; 99 } 100 std::shared_ptr<Arrangement> expanded_shape_ptr = exchanged_from_and_to_ptr->ComputeExpandedFromTensorShapeByTo(); 101 if (expanded_shape_ptr == nullptr) { 102 return nullptr; 103 } 104 std::shared_ptr<ReshapeLayoutTransfer> exchanged_out = 105 exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); 106 if (exchanged_out == nullptr) { 107 return nullptr; 108 } 109 return exchanged_out->ExchangeFromAndTo(); 110 } 111 ExchangeFromAndTo() const112std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExchangeFromAndTo() const { 113 ReshapeLayoutTransfer out; 114 Status status = out.Init(to_in_, from_in_); 115 if (status != Status::SUCCESS) { 116 return nullptr; 117 } 118 return std::make_shared<ReshapeLayoutTransfer>(out); 119 } 120 ExpandFromTensorShapeAndExpandToDeviceArrangement(const Arrangement & expand_shape) const121std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement( 122 const Arrangement &expand_shape) const { 123 std::shared_ptr<TensorLayout> extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape); 124 if (extend_tensor_shape_from_ptr == nullptr) { 125 return nullptr; 126 } 127 Arrangement unified_device_arrangement = extend_tensor_shape_from_ptr->device_arrangement(); 128 std::shared_ptr<TensorLayout> extend_device_arrangement_to_ptr = 129 to_in_.ExpandDeviceArrangement(unified_device_arrangement); 130 if (extend_device_arrangement_to_ptr == nullptr) { 131 return nullptr; 132 } 133 ReshapeLayoutTransfer out; 134 Status status = out.Init(*extend_tensor_shape_from_ptr, *extend_device_arrangement_to_ptr); 135 if (status != Status::SUCCESS) { 136 return nullptr; 137 } 138 return std::make_shared<ReshapeLayoutTransfer>(out); 139 } 140 ComputeExpandedFromTensorShapeByTo() const141std::shared_ptr<Arrangement> ReshapeLayoutTransfer::ComputeExpandedFromTensorShapeByTo() const { 142 return from_in_.ComputeExpandedTensorShape(to_in_.tensor_shape()); 143 } 144 } // namespace parallel 145 } // namespace mindspore 146