1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_module.h"
19
20 namespace xla {
21
OutputHasAlias(const ShapeIndex & output_index) const22 bool HloInputOutputAliasConfig::OutputHasAlias(
23 const ShapeIndex& output_index) const {
24 return alias_.element(output_index).has_value();
25 }
26
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,AliasKind kind)27 Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
28 int64 param_number,
29 const ShapeIndex& param_index,
30 AliasKind kind) {
31 TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias)
32 << kind;
33 TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
34 << "Trying to set up alias at " << output_index.ToString()
35 << " which is an invalid index for shape "
36 << ShapeUtil::HumanString(alias_.shape());
37 TF_RET_CHECK(param_number >= 0) << param_number;
38 TF_RET_CHECK(!OutputHasAlias(output_index))
39 << "Output index " << output_index << " already has an alias setup";
40 // Output can't be aliased with multiple parameters.
41 TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
42 "Trying to set up output alias for param %lld at %s but failed: output "
43 "index %s is already aliased with param %lld at %s",
44 param_number, param_index.ToString(), output_index.ToString(),
45 alias_.element(output_index)->parameter_number,
46 alias_.element(output_index)->parameter_index.ToString());
47 (*alias_.mutable_element(output_index)) =
48 Alias(kind, param_number, param_index);
49 VLOG(4) << "Set up alias between output index " << output_index.ToString()
50 << " and parameter " << param_index << " at index "
51 << param_index.ToString();
52 return Status::OK();
53 }
54
ToProto() const55 HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
56 HloInputOutputAliasProto result;
57 alias_.ForEachElement(
58 [&](const ShapeIndex& index, const absl::optional<Alias>& data) {
59 if (data) {
60 HloInputOutputAliasProto::AliasEntryProto entry;
61 switch (data->kind) {
62 case AliasKind::kUserAlias:
63 entry.set_kind(HloInputOutputAliasProto::USER_ALIAS);
64 break;
65 case AliasKind::kSystemAlias:
66 entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS);
67 break;
68 default:
69 LOG(FATAL) << "Unknown alias kind " << data->kind;
70 }
71 for (int64 i : index) {
72 entry.add_output_shape_index(i);
73 }
74 entry.set_parameter_number(data->parameter_number);
75 for (int64 i : data->parameter_index) {
76 entry.add_parameter_shape_index(i);
77 }
78 result.add_entries()->Swap(&entry);
79 }
80 });
81 return result;
82 }
83
CreateFromProto(Shape output_shape,const HloInputOutputAliasProto & proto)84 StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
85 Shape output_shape, const HloInputOutputAliasProto& proto) {
86 HloInputOutputAliasConfig result(std::move(output_shape));
87 for (const HloInputOutputAliasProto::AliasEntryProto& entry :
88 proto.entries()) {
89 ShapeIndex output_index(entry.output_shape_index().begin(),
90 entry.output_shape_index().end());
91 int64 param_number = entry.parameter_number();
92 ShapeIndex param_index(entry.parameter_shape_index().begin(),
93 entry.parameter_shape_index().end());
94 // Handle backward compatibility with existing protos, which only knew of
95 // system aliases.
96 AliasKind kind = AliasKind::kSystemAlias;
97 if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) {
98 kind = AliasKind::kUserAlias;
99 }
100 TF_RETURN_IF_ERROR(
101 result.SetUpAlias(output_index, param_number, param_index, kind));
102 }
103 return result;
104 }
105
shape() const106 const Shape& HloInputOutputAliasConfig::shape() const { return alias_.shape(); }
107
ToString() const108 string HloInputOutputAliasConfig::ToString() const {
109 std::vector<string> pieces;
110 pieces.push_back("HloInputOutputAliasConfig");
111 pieces.push_back(
112 absl::StrFormat(" Output shape: %s", alias_.shape().ToString()));
113
114 ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
115 const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM";
116 pieces.push_back(absl::StrFormat(
117 " OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:",
118 output_index.ToString(), kind, alias.parameter_number,
119 alias.parameter_index.ToString()));
120 });
121 return absl::StrJoin(pieces, "\n");
122 }
123
124 HloInputOutputAliasConfig::AliasKind
ParameterAliasKind(int64 param_number,const ShapeIndex & param_index) const125 HloInputOutputAliasConfig::ParameterAliasKind(
126 int64 param_number, const ShapeIndex& param_index) const {
127 AliasKind kind = AliasKind::kNoAlias;
128 alias_.ForEachElement(
129 [&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
130 if (alias && alias->parameter_number == param_number &&
131 alias->parameter_index == param_index) {
132 kind = alias->kind;
133 }
134 });
135 return kind;
136 }
137
GetAliasedOutput(int64 param_number,const ShapeIndex & param_index) const138 absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
139 int64 param_number, const ShapeIndex& param_index) const {
140 absl::optional<ShapeIndex> output;
141 alias_.ForEachElement(
142 [&](const xla::ShapeIndex& output_index, absl::optional<Alias> alias) {
143 if (alias && alias->parameter_number == param_number &&
144 alias->parameter_index == param_index) {
145 output = output_index;
146 }
147 });
148 return output;
149 }
150
151 absl::optional<HloInputOutputAliasConfig::Alias>
GetAliasedParameter(const ShapeIndex & output_index) const152 HloInputOutputAliasConfig::GetAliasedParameter(
153 const ShapeIndex& output_index) const {
154 CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
155 << ToString() << " " << alias_.shape().ToString() << " " << output_index;
156 return alias_.element(output_index);
157 }
158
ForEachAlias(AliasFn fn) const159 void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
160 alias_.ForEachElement(
161 [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
162 if (aliased) {
163 fn(output_index, *aliased);
164 }
165 });
166 }
167
ForEachAliasWithStatus(AliasFnWithStatus fn) const168 Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
169 AliasFnWithStatus fn) const {
170 return alias_.ForEachElementWithStatus(
171 [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
172 if (aliased) {
173 TF_RETURN_IF_ERROR(fn(output_index, *aliased));
174 }
175 return Status::OK();
176 });
177 }
178
Verify(const HloModule & module,std::function<int64 (const Shape &)> size_func) const179 Status HloInputOutputAliasConfig::Verify(
180 const HloModule& module,
181 std::function<int64(const Shape&)> size_func) const {
182 std::vector<ShapeTree<bool>> param_has_seen;
183 const HloComputation* entry = module.entry_computation();
184 for (int64 i = 0; i < entry->num_parameters(); ++i) {
185 HloInstruction* param = entry->parameter_instruction(i);
186 param_has_seen.emplace_back(param->shape());
187 }
188 return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
189 const Alias& alias) -> Status {
190 const HloInstruction* root = entry->root_instruction();
191
192 TF_RET_CHECK(0 <= alias.parameter_number);
193 TF_RET_CHECK(entry->num_parameters() > alias.parameter_number);
194 const Shape& param_shape =
195 entry->parameter_instruction(alias.parameter_number)->shape();
196 const Shape& output_shape = root->shape();
197 TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index));
198 TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
199
200 const Shape& param_subshape =
201 ShapeUtil::GetSubshape(param_shape, alias.parameter_index);
202 const Shape& output_subshape =
203 ShapeUtil::GetSubshape(output_shape, output_index);
204 TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape));
205 TF_RET_CHECK(LayoutUtil::IsDenseArray(output_subshape));
206
207 if (size_func(param_subshape) != size_func(output_subshape)) {
208 return InternalError(
209 "Expected aliased input %lld at index %s and output at index %s to "
210 "have the same size. Input sub-shape is %s with size %lld, output "
211 "sub-shape is %s with size %lld",
212 alias.parameter_number, alias.parameter_index.ToString(),
213 output_index.ToString(),
214 ShapeUtil::HumanStringWithLayout(param_subshape),
215 size_func(param_subshape),
216 ShapeUtil::HumanStringWithLayout(output_subshape),
217 size_func(output_subshape));
218 }
219
220 // Check each alias.parameter_number and alias.parameter_index pair only
221 // show up once. No input can be aliased with output buffers.
222 TF_RET_CHECK(param_has_seen[alias.parameter_number].element(
223 alias.parameter_index) == false);
224 *(param_has_seen[alias.parameter_number].mutable_element(
225 alias.parameter_index)) = true;
226 return Status::OK();
227 });
228 }
229
operator <<(std::ostream & out,const HloInputOutputAliasConfig & config)230 std::ostream& operator<<(std::ostream& out,
231 const HloInputOutputAliasConfig& config) {
232 out << config.ToString();
233 return out;
234 }
235 } // namespace xla
236