1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
17
18 #include "tensorflow/compiler/xla/service/hlo.pb.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20
21 namespace xla {
22
OutputHasAlias(const ShapeIndex & output_index) const23 bool HloInputOutputAliasConfig::OutputHasAlias(
24 const ShapeIndex& output_index) const {
25 return alias_.element(output_index).has_value();
26 }
27
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,HloInputOutputAliasConfig::AliasKind must_alias)28 Status HloInputOutputAliasConfig::SetUpAlias(
29 const ShapeIndex& output_index, int64 param_number,
30 const ShapeIndex& param_index,
31 HloInputOutputAliasConfig::AliasKind must_alias) {
32 TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
33 << "Trying to set up alias at " << output_index.ToString()
34 << " which is an invalid index for shape "
35 << ShapeUtil::HumanString(alias_.shape());
36 TF_RET_CHECK(param_number >= 0) << param_number;
37 TF_RET_CHECK(!OutputHasAlias(output_index))
38 << "Output index " << output_index << " already has an alias setup";
39 // Output can't be aliased with multiple parameters.
40 TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
41 "Trying to set up output alias for param %lld at %s but failed: output "
42 "index %s is already aliased with param %lld at %s",
43 param_number, param_index.ToString(), output_index.ToString(),
44 alias_.element(output_index)->parameter_number,
45 alias_.element(output_index)->parameter_index.ToString());
46 (*alias_.mutable_element(output_index)) =
47 Alias(param_number, param_index, must_alias);
48 VLOG(4) << "Set up alias between output index " << output_index.ToString()
49 << " and parameter " << param_index << " at index "
50 << param_index.ToString();
51 return Status::OK();
52 }
53
ToProto() const54 HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
55 HloInputOutputAliasProto result;
56 alias_.ForEachElement(
57 [&](const ShapeIndex& index, const absl::optional<Alias>& data) {
58 if (data) {
59 HloInputOutputAliasProto::AliasEntryProto entry;
60 for (int64 i : index) {
61 entry.add_output_shape_index(i);
62 }
63 entry.set_parameter_number(data->parameter_number);
64 for (int64 i : data->parameter_index) {
65 entry.add_parameter_shape_index(i);
66 }
67 if (data->must_alias()) {
68 entry.set_kind(Kind::MUST_ALIAS);
69 } else {
70 entry.set_kind(Kind::MAY_ALIAS);
71 }
72 result.add_entries()->Swap(&entry);
73 }
74 });
75 return result;
76 }
77
CreateFromProto(Shape output_shape,const HloInputOutputAliasProto & proto)78 StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
79 Shape output_shape, const HloInputOutputAliasProto& proto) {
80 HloInputOutputAliasConfig result(std::move(output_shape));
81 for (const HloInputOutputAliasProto::AliasEntryProto& entry :
82 proto.entries()) {
83 ShapeIndex output_index(entry.output_shape_index().begin(),
84 entry.output_shape_index().end());
85 int64 param_number = entry.parameter_number();
86 ShapeIndex param_index(entry.parameter_shape_index().begin(),
87 entry.parameter_shape_index().end());
88 AliasKind kind = entry.kind() == Kind::MAY_ALIAS ? kMayAlias : kMustAlias;
89 TF_RETURN_IF_ERROR(
90 result.SetUpAlias(output_index, param_number, param_index, kind));
91 }
92 return result;
93 }
94
shape() const95 const Shape& HloInputOutputAliasConfig::shape() const { return alias_.shape(); }
96
ToString() const97 string HloInputOutputAliasConfig::ToString() const {
98 std::vector<string> pieces;
99 pieces.push_back("HloInputOutputAliasConfig");
100 pieces.push_back(
101 absl::StrFormat(" Output shape: %s", alias_.shape().ToString()));
102
103 ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
104 pieces.push_back(absl::StrFormat(
105 " OutputIndex %s is %saliased with parameter %lld at %s:",
106 output_index.ToString(), alias.kind == kMustAlias ? "must-" : "may-",
107 alias.parameter_number, alias.parameter_index.ToString()));
108 });
109 return absl::StrJoin(pieces, "\n");
110 }
111
ToShortString() const112 string HloInputOutputAliasConfig::ToShortString() const {
113 std::vector<string> pieces;
114 for (const auto& p : alias_) {
115 const ShapeIndex& index = p.first;
116 if (absl::optional<Alias> alias = p.second) {
117 pieces.push_back(
118 absl::StrFormat("%s: %s", index.ToString(), alias->ToString()));
119 }
120 }
121 return absl::StrJoin(pieces, ", ");
122 }
123
ParameterMustAlias(int64 param_number,const ShapeIndex & param_index) const124 bool HloInputOutputAliasConfig::ParameterMustAlias(
125 int64 param_number, const ShapeIndex& param_index) const {
126 bool result = false;
127 alias_.ForEachElement(
128 [&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
129 if (alias && alias->parameter_number == param_number &&
130 alias->parameter_index == param_index && alias->must_alias()) {
131 result = true;
132 }
133 });
134 return result;
135 }
136
GetAliasedOutput(int64 param_number,const ShapeIndex & param_index) const137 absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
138 int64 param_number, const ShapeIndex& param_index) const {
139 absl::optional<ShapeIndex> output;
140 alias_.ForEachElement(
141 [&](const xla::ShapeIndex& output_index, absl::optional<Alias> alias) {
142 if (alias && alias->parameter_number == param_number &&
143 alias->parameter_index == param_index) {
144 output = output_index;
145 }
146 });
147 return output;
148 }
149
150 absl::optional<HloInputOutputAliasConfig::Alias>
GetAliasedParameter(const ShapeIndex & output_index) const151 HloInputOutputAliasConfig::GetAliasedParameter(
152 const ShapeIndex& output_index) const {
153 CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
154 << ToString() << " " << alias_.shape().ToString() << " " << output_index;
155 return alias_.element(output_index);
156 }
157
ForEachAlias(AliasFn fn) const158 void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
159 alias_.ForEachElement(
160 [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
161 if (aliased) {
162 fn(output_index, *aliased);
163 }
164 });
165 }
166
ForEachAliasWithStatus(AliasFnWithStatus fn) const167 Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
168 AliasFnWithStatus fn) const {
169 return alias_.ForEachElementWithStatus(
170 [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
171 if (aliased) {
172 TF_RETURN_IF_ERROR(fn(output_index, *aliased));
173 }
174 return Status::OK();
175 });
176 }
177
Verify(const HloModule & module,std::function<int64 (const Shape &)> size_func) const178 Status HloInputOutputAliasConfig::Verify(
179 const HloModule& module,
180 std::function<int64(const Shape&)> size_func) const {
181 std::vector<ShapeTree<bool>> param_has_seen;
182 const HloComputation* entry = module.entry_computation();
183 for (int64 i = 0; i < entry->num_parameters(); ++i) {
184 HloInstruction* param = entry->parameter_instruction(i);
185 param_has_seen.emplace_back(param->shape());
186 }
187 return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
188 const Alias& alias) -> Status {
189 const HloInstruction* root = entry->root_instruction();
190
191 TF_RET_CHECK(0 <= alias.parameter_number);
192 TF_RET_CHECK(entry->num_parameters() > alias.parameter_number);
193 const Shape& param_shape =
194 entry->parameter_instruction(alias.parameter_number)->shape();
195 const Shape& output_shape = root->shape();
196 TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index));
197 TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
198
199 const Shape& param_subshape =
200 ShapeUtil::GetSubshape(param_shape, alias.parameter_index);
201 const Shape& output_subshape =
202 ShapeUtil::GetSubshape(output_shape, output_index);
203 TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape));
204 TF_RET_CHECK(LayoutUtil::IsDenseArray(output_subshape));
205
206 if (size_func(param_subshape) != size_func(output_subshape)) {
207 return InternalError(
208 "Expected aliased input %lld at index %s and output at index %s to "
209 "have the same size. Input sub-shape is %s with size %lld, output "
210 "sub-shape is %s with size %lld",
211 alias.parameter_number, alias.parameter_index.ToString(),
212 output_index.ToString(),
213 ShapeUtil::HumanStringWithLayout(param_subshape),
214 size_func(param_subshape),
215 ShapeUtil::HumanStringWithLayout(output_subshape),
216 size_func(output_subshape));
217 }
218
219 // Check each alias.parameter_number and alias.parameter_index pair only
220 // show up once. No input can be aliased with output buffers.
221 TF_RET_CHECK(param_has_seen[alias.parameter_number].element(
222 alias.parameter_index) == false);
223 *(param_has_seen[alias.parameter_number].mutable_element(
224 alias.parameter_index)) = true;
225 return Status::OK();
226 });
227 }
228
operator <<(std::ostream & out,const HloInputOutputAliasConfig & config)229 std::ostream& operator<<(std::ostream& out,
230 const HloInputOutputAliasConfig& config) {
231 out << config.ToString();
232 return out;
233 }
234 } // namespace xla
235