• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/reshape_layout_transfer.h"
18 #include "frontend/parallel/status.h"
19 #include "frontend/parallel/tensor_layout/shape_util.h"
20 
21 namespace mindspore {
22 namespace parallel {
CheckValidTransfer()23 Status ReshapeLayoutTransfer::CheckValidTransfer() {
24   if (!IsSameDeviceArrangement()) {
25     return Status::FAILED;
26   }
27   return Status::SUCCESS;
28 }
29 
UnifyDeviceArrangementAndTensorShape() const30 std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const {
31   bool is_unified = IsSameTensorShape();
32   std::shared_ptr<ReshapeLayoutTransfer> out_layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this);
33   if (out_layout_ptr == nullptr) {
34     return nullptr;
35   }
36   while (!is_unified) {
37     std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo();
38     if (temp_layout_ptr == nullptr) {
39       out_layout_ptr->SetExpandAble(false);
40       return out_layout_ptr;
41     }
42     out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom();
43     if (out_layout_ptr == nullptr) {
44       std::shared_ptr<ReshapeLayoutTransfer> layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this);
45       layout_ptr->SetExpandAble(false);
46       return layout_ptr;
47     }
48     is_unified = out_layout_ptr->IsSameTensorShape();
49   }
50   return out_layout_ptr;
51 }
52 
ExtendFromTensorShapeByTo() const53 std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendFromTensorShapeByTo() const {
54   std::shared_ptr<ReshapeLayoutTransfer> out_ptr = std::make_shared<ReshapeLayoutTransfer>(*this);
55   bool is_expanded = FromTensorShapeCanBeExpandByTo();
56   while (!is_expanded) {
57     out_ptr = out_ptr->ExtendFromTensorShapeByExpandedTensorShape();
58     if (out_ptr == nullptr) {
59       return nullptr;
60     }
61     is_expanded = out_ptr->FromTensorShapeCanBeExpandByTo();
62   }
63   return out_ptr;
64 }
65 
ExtendToTensorShapeByFrom() const66 std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendToTensorShapeByFrom() const {
67   std::shared_ptr<ReshapeLayoutTransfer> out_ptr = std::make_shared<ReshapeLayoutTransfer>(*this);
68   bool is_expanded = ToTensorShapeCanBeExpandByFrom();
69   while (!is_expanded) {
70     out_ptr = out_ptr->ExtendToTensorShapeByExpandedTensorShape();
71     if (out_ptr == nullptr) {
72       return nullptr;
73     }
74     is_expanded = out_ptr->ToTensorShapeCanBeExpandByFrom();
75   }
76   return out_ptr;
77 }
78 
FromTensorShapeCanBeExpandByTo() const79 bool ReshapeLayoutTransfer::FromTensorShapeCanBeExpandByTo() const {
80   return from_in_.TensorShapeCanBeExpanded(to_in_.tensor_shape());
81 }
82 
ToTensorShapeCanBeExpandByFrom() const83 bool ReshapeLayoutTransfer::ToTensorShapeCanBeExpandByFrom() const {
84   return to_in_.TensorShapeCanBeExpanded(from_in_.tensor_shape());
85 }
86 
ExtendFromTensorShapeByExpandedTensorShape() const87 std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendFromTensorShapeByExpandedTensorShape() const {
88   std::shared_ptr<Arrangement> expanded_shape_ptr = ComputeExpandedFromTensorShapeByTo();
89   if (expanded_shape_ptr == nullptr) {
90     return nullptr;
91   }
92   return ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr);
93 }
94 
ExtendToTensorShapeByExpandedTensorShape() const95 std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendToTensorShapeByExpandedTensorShape() const {
96   std::shared_ptr<ReshapeLayoutTransfer> exchanged_from_and_to_ptr = ExchangeFromAndTo();
97   if (exchanged_from_and_to_ptr == nullptr) {
98     return nullptr;
99   }
100   std::shared_ptr<Arrangement> expanded_shape_ptr = exchanged_from_and_to_ptr->ComputeExpandedFromTensorShapeByTo();
101   if (expanded_shape_ptr == nullptr) {
102     return nullptr;
103   }
104   std::shared_ptr<ReshapeLayoutTransfer> exchanged_out =
105     exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr);
106   if (exchanged_out == nullptr) {
107     return nullptr;
108   }
109   return exchanged_out->ExchangeFromAndTo();
110 }
111 
ExchangeFromAndTo() const112 std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExchangeFromAndTo() const {
113   ReshapeLayoutTransfer out;
114   Status status = out.Init(to_in_, from_in_);
115   if (status != Status::SUCCESS) {
116     return nullptr;
117   }
118   return std::make_shared<ReshapeLayoutTransfer>(out);
119 }
120 
ExpandFromTensorShapeAndExpandToDeviceArrangement(const Arrangement & expand_shape) const121 std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement(
122   const Arrangement &expand_shape) const {
123   std::shared_ptr<TensorLayout> extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape);
124   if (extend_tensor_shape_from_ptr == nullptr) {
125     return nullptr;
126   }
127   Arrangement unified_device_arrangement = extend_tensor_shape_from_ptr->device_arrangement();
128   std::shared_ptr<TensorLayout> extend_device_arrangement_to_ptr =
129     to_in_.ExpandDeviceArrangement(unified_device_arrangement);
130   if (extend_device_arrangement_to_ptr == nullptr) {
131     return nullptr;
132   }
133   ReshapeLayoutTransfer out;
134   Status status = out.Init(*extend_tensor_shape_from_ptr, *extend_device_arrangement_to_ptr);
135   if (status != Status::SUCCESS) {
136     return nullptr;
137   }
138   return std::make_shared<ReshapeLayoutTransfer>(out);
139 }
140 
ComputeExpandedFromTensorShapeByTo() const141 std::shared_ptr<Arrangement> ReshapeLayoutTransfer::ComputeExpandedFromTensorShapeByTo() const {
142   return from_in_.ComputeExpandedTensorShape(to_in_.tensor_shape());
143 }
144 }  // namespace parallel
145 }  // namespace mindspore
146