1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
17 #include "tensorflow/compiler/xla/service/hlo_module.h"
18
19 namespace xla {
20
OutputHasAlias(const ShapeIndex & output_index) const21 bool HloInputOutputAliasConfig::OutputHasAlias(
22 const ShapeIndex& output_index) const {
23 return alias_.element(output_index).has_value();
24 }
25
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,AliasKind kind)26 Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
27 int64 param_number,
28 const ShapeIndex& param_index,
29 AliasKind kind) {
30 TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias)
31 << kind;
32 TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
33 << absl::StrCat("Tring to set up alias at ", output_index.ToString(),
34 " which is an invalid index for shape ",
35 ShapeUtil::HumanString(alias_.shape()));
36 TF_RET_CHECK(param_number >= 0) << param_number;
37 TF_RET_CHECK(!OutputHasAlias(output_index))
38 << "Output index " << output_index << " already has an alias setup";
39 // Output can't be aliased with multiple parameters.
40 TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
41 "Trying to set up output alias for param %lld at %s but failed: output "
42 "index %s is already aliased with param %lld at %s",
43 param_number, param_index.ToString(), output_index.ToString(),
44 alias_.element(output_index)->parameter_number,
45 alias_.element(output_index)->parameter_index.ToString());
46 (*alias_.mutable_element(output_index)) =
47 Alias(kind, param_number, param_index);
48 VLOG(4) << "Set up alias between output index " << output_index.ToString()
49 << " and parameter " << param_index << " at index "
50 << param_index.ToString();
51 return Status::OK();
52 }
53
ToProto() const54 HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
55 HloInputOutputAliasProto result;
56 alias_.ForEachElement(
57 [&](const ShapeIndex& index, const absl::optional<Alias>& data) {
58 if (data) {
59 HloInputOutputAliasProto::AliasEntryProto entry;
60 switch (data->kind) {
61 case AliasKind::kUserAlias:
62 entry.set_kind(HloInputOutputAliasProto::USER_ALIAS);
63 break;
64 case AliasKind::kSystemAlias:
65 entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS);
66 break;
67 default:
68 LOG(FATAL) << "Unknown alias kind " << data->kind;
69 }
70 for (int64 i : index) {
71 entry.add_output_shape_index(i);
72 }
73 entry.set_parameter_number(data->parameter_number);
74 for (int64 i : data->parameter_index) {
75 entry.add_parameter_shape_index(i);
76 }
77 result.add_entries()->Swap(&entry);
78 }
79 });
80 return result;
81 }
82
CreateFromProto(const Shape & output_shape,const HloInputOutputAliasProto & proto)83 StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
84 const Shape& output_shape, const HloInputOutputAliasProto& proto) {
85 HloInputOutputAliasConfig result(output_shape);
86 for (const HloInputOutputAliasProto::AliasEntryProto& entry :
87 proto.entries()) {
88 ShapeIndex output_index(entry.output_shape_index().begin(),
89 entry.output_shape_index().end());
90 int64 param_number = entry.parameter_number();
91 ShapeIndex param_index(entry.parameter_shape_index().begin(),
92 entry.parameter_shape_index().end());
93 // Handle backward compatibility with existing protos, which only knew of
94 // system aliases.
95 AliasKind kind = AliasKind::kSystemAlias;
96 if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) {
97 kind = AliasKind::kUserAlias;
98 }
99 TF_RETURN_IF_ERROR(
100 result.SetUpAlias(output_index, param_number, param_index, kind));
101 }
102 return result;
103 }
104
ToString() const105 string HloInputOutputAliasConfig::ToString() const {
106 std::vector<string> pieces;
107 pieces.push_back("HloInputOutputAliasConfig");
108
109 ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
110 const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM";
111 pieces.push_back(absl::StrFormat(
112 " OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:",
113 output_index.ToString(), kind, alias.parameter_number,
114 alias.parameter_index.ToString()));
115 });
116 return absl::StrJoin(pieces, "\n");
117 }
118
119 HloInputOutputAliasConfig::AliasKind
ParameterAliasKind(int64 param_number,const ShapeIndex & param_index) const120 HloInputOutputAliasConfig::ParameterAliasKind(
121 int64 param_number, const ShapeIndex& param_index) const {
122 AliasKind kind = AliasKind::kNoAlias;
123 alias_.ForEachElement(
124 [&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
125 if (alias && alias->parameter_number == param_number &&
126 alias->parameter_index == param_index) {
127 kind = alias->kind;
128 }
129 });
130 return kind;
131 }
132
GetAliasedOutput(int64 param_number,const ShapeIndex & param_index) const133 absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
134 int64 param_number, const ShapeIndex& param_index) const {
135 absl::optional<ShapeIndex> output;
136 alias_.ForEachElement(
137 [&](const xla::ShapeIndex& output_index, absl::optional<Alias> alias) {
138 if (alias && alias->parameter_number == param_number &&
139 alias->parameter_index == param_index) {
140 output = output_index;
141 }
142 });
143 return output;
144 }
145
146 absl::optional<HloInputOutputAliasConfig::Alias>
GetAliasedParameter(const ShapeIndex & output_index) const147 HloInputOutputAliasConfig::GetAliasedParameter(
148 const ShapeIndex& output_index) const {
149 CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
150 return alias_.element(output_index);
151 }
152
ForEachAlias(AliasFn fn) const153 void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
154 alias_.ForEachElement(
155 [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
156 if (aliased) {
157 fn(output_index, *aliased);
158 }
159 });
160 }
161
ForEachAliasWithStatus(AliasFnWithStatus fn) const162 Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
163 AliasFnWithStatus fn) const {
164 return alias_.ForEachElementWithStatus(
165 [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
166 if (aliased) {
167 TF_RETURN_IF_ERROR(fn(output_index, *aliased));
168 }
169 return Status::OK();
170 });
171 }
172
Verify(const HloModule & module,std::function<int64 (const Shape &)> size_func) const173 Status HloInputOutputAliasConfig::Verify(
174 const HloModule& module,
175 std::function<int64(const Shape&)> size_func) const {
176 std::vector<ShapeTree<bool>> param_has_seen;
177 const HloComputation* entry = module.entry_computation();
178 for (int64 i = 0; i < entry->num_parameters(); ++i) {
179 HloInstruction* param = entry->parameter_instruction(i);
180 param_has_seen.emplace_back(param->shape());
181 }
182 return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
183 const Alias& alias) -> Status {
184 const HloInstruction* root = entry->root_instruction();
185
186 TF_RET_CHECK(0 <= alias.parameter_number);
187 TF_RET_CHECK(entry->num_parameters() > alias.parameter_number);
188 const Shape& param_shape =
189 entry->parameter_instruction(alias.parameter_number)->shape();
190 const Shape& output_shape = root->shape();
191 TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index));
192 TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
193
194 const Shape& param_subshape =
195 ShapeUtil::GetSubshape(param_shape, alias.parameter_index);
196 const Shape& output_subshape =
197 ShapeUtil::GetSubshape(output_shape, output_index);
198 TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape));
199 TF_RET_CHECK(LayoutUtil::IsDenseArray(output_subshape));
200
201 if (size_func(param_subshape) != size_func(output_subshape)) {
202 return InternalError(
203 "Expected aliased input %lld at index %s and output at index %s to "
204 "have the same size. Input sub-shape is %s with size %lld, output "
205 "sub-shape is %s with size %lld",
206 alias.parameter_number, alias.parameter_index.ToString(),
207 output_index.ToString(),
208 ShapeUtil::HumanStringWithLayout(param_subshape),
209 size_func(param_subshape),
210 ShapeUtil::HumanStringWithLayout(output_subshape),
211 size_func(output_subshape));
212 }
213
214 // Check each alias.parameter_number and alias.parameter_index pair only
215 // show up once. No input can be aliased with output buffers.
216 TF_RET_CHECK(param_has_seen[alias.parameter_number].element(
217 alias.parameter_index) == false);
218 *(param_has_seen[alias.parameter_number].mutable_element(
219 alias.parameter_index)) = true;
220 return Status::OK();
221 });
222 }
223
operator <<(std::ostream & out,const HloInputOutputAliasConfig & config)224 std::ostream& operator<<(std::ostream& out,
225 const HloInputOutputAliasConfig& config) {
226 out << config.ToString();
227 return out;
228 }
229 } // namespace xla
230