1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17 #include <type_traits>
18
19 #include "absl/algorithm/container.h"
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/types/span.h"
26 #include "absl/types/variant.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 #include "tensorflow/core/platform/protobuf.h"
40
41 namespace xla {
42
43 namespace {
44
45 using absl::nullopt;
46 using absl::optional;
47 using absl::StrAppend;
48 using absl::StrCat;
49 using absl::StrFormat;
50 using absl::StrJoin;
51
52 // Creates and returns a schedule created using the order of the instructions in
53 // the HloComputation::instructions() vectors in the module.
ScheduleFromInstructionOrder(HloModule * module)54 HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
55 HloSchedule schedule(module);
56 for (HloComputation* computation : module->computations()) {
57 if (!computation->IsFusionComputation()) {
58 for (HloInstruction* instruction : computation->instructions()) {
59 schedule.GetOrCreateSequence(computation).push_back(instruction);
60 }
61 }
62 }
63 return schedule;
64 }
65
66 // Some functions accept either a linear index or a multi-dimensional index
67 // (used for indexing into sparse literals).
68 using LinearOrMultiIndex = absl::variant<int64, absl::Span<const int64>>;
69
70 // Parser for the HloModule::ToString() format text.
71 class HloParser {
72 public:
73 using LocTy = HloLexer::LocTy;
74
HloParser(absl::string_view str)75 explicit HloParser(absl::string_view str) : lexer_(str) {}
76
77 // Runs the parser and constructs the resulting HLO in the given (empty)
78 // HloModule. Returns false if an error occurred.
79 Status Run(HloModule* module);
80
81 // Returns the error information.
GetError() const82 string GetError() const { return StrJoin(error_, "\n"); }
83
84 // Stand alone parsing utils for various aggregate data types.
85 StatusOr<Shape> ParseShapeOnly();
86 StatusOr<HloSharding> ParseShardingOnly();
87 StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
88 StatusOr<Window> ParseWindowOnly();
89 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
90 StatusOr<PaddingConfig> ParsePaddingConfigOnly();
91
92 private:
93 using InstrNameTable =
94 std::unordered_map<string, std::pair<HloInstruction*, LocTy>>;
95
96 // Returns the map from the instruction name to the instruction itself and its
97 // location in the current scope.
current_name_table()98 InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
99
100 // Locates an instruction with the given name in the current_name_table() or
101 // returns nullptr.
102 //
103 // When the name is not found or name is empty, if create_missing_instruction_
104 // hook is registered and a "shape" is provided, the hook will be called to
105 // create an instruction. This is useful when we reify parameters as they're
106 // resolved; i.e. for ParseSingleInstruction.
107 std::pair<HloInstruction*, LocTy>* FindInstruction(
108 const string& name, const optional<Shape>& shape = nullopt);
109
110 // Parse a single instruction worth of text.
111 bool ParseSingleInstruction(HloModule* module);
112
113 // Parses a module, returning false if an error occurred.
114 bool ParseHloModule(HloModule* module);
115
116 bool ParseComputations(HloModule* module);
117 bool ParseComputation(HloComputation** entry_computation);
118 bool ParseInstructionList(HloComputation** computation,
119 const string& computation_name);
120 bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
121 bool ParseInstructionRhs(HloComputation::Builder* builder, const string& name,
122 LocTy name_loc);
123 bool ParseControlPredecessors(HloInstruction* instruction);
124 bool ParseLiteral(Literal* literal, const Shape& shape);
125 bool ParseTupleLiteral(Literal* literal, const Shape& shape);
126 bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
127 bool ParseDenseLiteral(Literal* literal, const Shape& shape);
128 bool ParseSparseLiteral(Literal* literal, const Shape& shape);
129
130 // Sets the sub-value of literal at the given linear or sparse index to the
131 // given value. If the literal is dense, it myst have the default layout.
132 //
133 // `loc` should be the source location of the value.
134 bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index,
135 Literal* literal);
136 bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index,
137 Literal* literal);
138 bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index,
139 Literal* literal);
140 bool SetValueInLiteral(LocTy loc, std::complex<double> value,
141 LinearOrMultiIndex index, Literal* literal);
142 // `loc` should be the source location of the value.
143 template <typename LiteralNativeT, typename ParsedElemT>
144 bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
145 LinearOrMultiIndex index, Literal* literal);
146
147 // Checks whether the given value is within the range of LiteralNativeT.
148 // `loc` should be the source location of the value.
149 template <typename LiteralNativeT, typename ParsedElemT>
150 bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
151 template <typename LiteralNativeT>
152 bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
153
154 bool ParseOperands(std::vector<HloInstruction*>* operands);
155 // Fills parsed operands into 'operands' and expects a certain number of
156 // operands.
157 bool ParseOperands(std::vector<HloInstruction*>* operands,
158 const int expected_size);
159
160 // Describes the start, limit, and stride on every dimension of the operand
161 // being sliced.
162 struct SliceRanges {
163 std::vector<int64> starts;
164 std::vector<int64> limits;
165 std::vector<int64> strides;
166 };
167
168 // The data parsed for the kDomain instruction.
169 struct DomainData {
170 std::unique_ptr<DomainMetadata> entry_metadata;
171 std::unique_ptr<DomainMetadata> exit_metadata;
172 };
173
174 // Types of attributes.
175 enum class AttrTy {
176 kBool,
177 kInt64,
178 kInt32,
179 kFloat,
180 kString,
181 kBracedInt64List,
182 kBracedInt64ListList,
183 kHloComputation,
184 kBracedHloComputationList,
185 kFftType,
186 kComparisonDirection,
187 kWindow,
188 kConvolutionDimensionNumbers,
189 kSharding,
190 kParameterReplication,
191 kInstructionList,
192 kSliceRanges,
193 kPaddingConfig,
194 kMetadata,
195 kFusionKind,
196 kDistribution,
197 kDomain,
198 kPrecisionList,
199 kShapeList
200 };
201
202 struct AttrConfig {
203 bool required; // whether it's required or optional
204 AttrTy attr_type; // what type it is
205 void* result; // where to store the parsed result.
206 };
207
208 // attributes ::= (',' attribute)*
209 //
210 // Parses attributes given names and configs of the attributes. Each parsed
211 // result is passed back through the result pointer in corresponding
212 // AttrConfig. Note that the result pointer must point to a optional<T> typed
213 // variable which outlives this function. Returns false on error. You should
214 // not use the any of the results if this function failed.
215 //
216 // Example usage:
217 //
218 // std::unordered_map<string, AttrConfig> attrs;
219 // optional<int64> foo;
220 // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
221 // optional<Window> bar;
222 // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
223 // if (!ParseAttributes(attrs)) {
224 // return false; // Do not use 'foo' 'bar' if failed.
225 // }
226 // // Do something with 'bar'.
227 // if (foo) { // If attr foo is seen, do something with 'foo'. }
228 //
229 bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
230
231 // sub_attributes ::= '{' (','? attribute)* '}'
232 //
233 // Usage is the same as ParseAttributes. See immediately above.
234 bool ParseSubAttributes(const std::unordered_map<string, AttrConfig>& attrs);
235
236 // Parses one attribute. If it has already been seen, return error. Returns
237 // true and adds to seen_attrs on success.
238 //
239 // Do not call this except in ParseAttributes or ParseSubAttributes.
240 bool ParseAttributeHelper(const std::unordered_map<string, AttrConfig>& attrs,
241 std::unordered_set<string>* seen_attrs);
242
243 // Parses an attribute string into a protocol buffer `message`.
244 // Since proto3 has no notion of mandatory fields, `required_attrs` gives the
245 // set of mandatory attributes.
246 bool ParseAttributesAsProtoMessage(
247 const std::unordered_set<string>& required_attrs,
248 tensorflow::protobuf::Message* message);
249
250 // Parses one attribute. If it has already been seen, return error. Returns
251 // true and adds to seen_attrs on success.
252 //
253 // Do not call this except in ParseAttributesAsProtoMessage.
254 bool ParseAttributeAsProtoMessageHelper(
255 tensorflow::protobuf::Message* message,
256 std::unordered_set<string>* seen_attrs);
257
258 // Parses a name and finds the corresponding hlo computation.
259 bool ParseComputationName(HloComputation** value);
260 // Parses a list of names and finds the corresponding hlo instructions.
261 bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
262 // Pass expect_outer_curlies == true when parsing a Window in the context of a
263 // larger computation. Pass false when parsing a stand-alone Window string.
264 bool ParseWindow(Window* window, bool expect_outer_curlies);
265 bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
266 bool ParsePaddingConfig(PaddingConfig* padding);
267 bool ParseMetadata(OpMetadata* metadata);
268 bool ParseSharding(OpSharding* sharding);
269 bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
270 bool ParseParameterReplication(ParameterReplication* parameter_replication);
271
272 // Parses the metadata behind a kDOmain instruction.
273 bool ParseDomain(DomainData* domain);
274
275 // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
276 bool ParseDxD(const string& name, std::vector<int64>* result);
277 // Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
278 bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
279
280 bool ParseSliceRanges(SliceRanges* result);
281 bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
282 bool ParseHloComputation(HloComputation** result);
283 bool ParseHloComputationList(std::vector<HloComputation*>* result);
284 bool ParseShapeList(std::vector<Shape>* result);
285 bool ParseInt64List(const TokKind start, const TokKind end,
286 const TokKind delim, std::vector<int64>* result);
287 // 'parse_and_add_item' is an lambda to parse an element in the list and add
288 // the parsed element to the result. It's supposed to capture the result.
289 bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
290 const std::function<bool()>& parse_and_add_item);
291
292 bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
293 bool ParseParamList();
294 bool ParseName(string* result);
295 bool ParseAttributeName(string* result);
296 bool ParseString(string* result);
297 bool ParseDimensionSizes(std::vector<int64>* dimension_sizes,
298 std::vector<bool>* dynamic_dimensions);
299 bool ParseShape(Shape* result);
300 bool ParseLayout(Layout* layout);
301 bool ParseTiles(std::vector<Tile>* tiles);
302 bool ParseOpcode(HloOpcode* result);
303 bool ParseFftType(FftType* result);
304 bool ParseComparisonDirection(ComparisonDirection* result);
305 bool ParseFusionKind(HloInstruction::FusionKind* result);
306 bool ParseRandomDistribution(RandomDistribution* result);
307 bool ParsePrecision(PrecisionConfig::Precision* result);
308 bool ParseInt64(int64* result);
309 bool ParseDouble(double* result);
310 bool ParseComplex(std::complex<double>* result);
311 bool ParseBool(bool* result);
312 bool ParseToken(TokKind kind, const string& msg);
313
314 // Returns true if the current token is the beginning of a shape.
315 bool CanBeShape();
316 // Returns true if the current token is the beginning of a
317 // param_list_to_shape.
318 bool CanBeParamListToShape();
319
320 // Logs the current parsing line and the given message. Always returns false.
321 bool TokenError(absl::string_view msg);
322 bool Error(LocTy loc, absl::string_view msg);
323
324 // If the current token is 'kind', eats it (i.e. lexes the next token) and
325 // returns true.
326 bool EatIfPresent(TokKind kind);
327
328 // Adds the instruction to the pool. Returns false and emits an error if the
329 // instruction already exists.
330 bool AddInstruction(const string& name, HloInstruction* instruction,
331 LocTy name_loc);
332 // Adds the computation to the pool. Returns false and emits an error if the
333 // computation already exists.
334 bool AddComputation(const string& name, HloComputation* computation,
335 LocTy name_loc);
336
337 HloLexer lexer_;
338
339 // A stack for the instruction names. The top of the stack stores the
340 // instruction name table for the current scope.
341 //
342 // A instruction's name is unique among its scope (i.e. its parent
343 // computation), but it's not necessarily unique among all computations in the
344 // module. When there are multiple levels of nested computations, the same
345 // name could appear in both an outer computation and an inner computation. So
346 // we need a stack to make sure a name is only visible within its scope,
347 std::vector<InstrNameTable> scoped_name_tables_;
348
349 // A helper class which pushes and pops to an InstrNameTable stack via RAII.
350 class Scope {
351 public:
Scope(std::vector<InstrNameTable> * scoped_name_tables)352 explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
353 : scoped_name_tables_(scoped_name_tables) {
354 scoped_name_tables_->emplace_back();
355 }
~Scope()356 ~Scope() { scoped_name_tables_->pop_back(); }
357
358 private:
359 std::vector<InstrNameTable>* scoped_name_tables_;
360 };
361
362 // Map from the computation name to the computation itself and its location.
363 std::unordered_map<string, std::pair<HloComputation*, LocTy>>
364 computation_pool_;
365
366 std::vector<std::unique_ptr<HloComputation>> computations_;
367 std::vector<string> error_;
368
369 // When an operand name cannot be resolved, this function is called to create
370 // a parameter instruction with the given name and shape. It registers the
371 // name, instruction, and a placeholder location in the name table. It returns
372 // the newly-created instruction and the placeholder location. If `name` is
373 // empty, this should create the parameter with a generated name. This is
374 // supposed to be set and used only in ParseSingleInstruction.
375 std::function<std::pair<HloInstruction*, LocTy>*(const string& name,
376 const Shape& shape)>
377 create_missing_instruction_;
378 };
379
SplitToInt64s(absl::string_view s,char delim,std::vector<int64> * out)380 bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
381 for (const auto& split : absl::StrSplit(s, delim)) {
382 int64 val;
383 if (!absl::SimpleAtoi(split, &val)) {
384 return false;
385 }
386 out->push_back(val);
387 }
388 return true;
389 }
390
391 // Creates replica groups from the provided nested array. groups[i] represents
392 // the replica ids for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64>> groups)393 std::vector<ReplicaGroup> CreateReplicaGroups(
394 absl::Span<const std::vector<int64>> groups) {
395 std::vector<ReplicaGroup> replica_groups;
396 absl::c_transform(groups, std::back_inserter(replica_groups),
397 [](const std::vector<int64>& ids) {
398 ReplicaGroup group;
399 *group.mutable_replica_ids() = {ids.begin(), ids.end()};
400 return group;
401 });
402 return replica_groups;
403 }
404
Error(LocTy loc,absl::string_view msg)405 bool HloParser::Error(LocTy loc, absl::string_view msg) {
406 auto line_col = lexer_.GetLineAndColumn(loc);
407 const unsigned line = line_col.first;
408 const unsigned col = line_col.second;
409 std::vector<string> error_lines;
410 error_lines.push_back(
411 StrCat("was parsing ", line, ":", col, ": error: ", msg));
412 error_lines.emplace_back(lexer_.GetLine(loc));
413 error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
414
415 error_.push_back(StrJoin(error_lines, "\n"));
416 VLOG(1) << "Error: " << error_.back();
417 return false;
418 }
419
TokenError(absl::string_view msg)420 bool HloParser::TokenError(absl::string_view msg) {
421 return Error(lexer_.GetLoc(), msg);
422 }
423
Run(HloModule * module)424 Status HloParser::Run(HloModule* module) {
425 lexer_.Lex();
426 if (lexer_.GetKind() == TokKind::kw_HloModule) {
427 // This means that the text contains a full HLO module.
428 if (!ParseHloModule(module)) {
429 return InvalidArgument(
430 "Syntax error when trying to parse the text as a HloModule:\n%s",
431 GetError());
432 }
433 return Status::OK();
434 }
435 // This means that the text is a single HLO instruction.
436 if (!ParseSingleInstruction(module)) {
437 return InvalidArgument(
438 "Syntax error when trying to parse the text as a single "
439 "HloInstruction:\n%s",
440 GetError());
441 }
442 return Status::OK();
443 }
444
FindInstruction(const string & name,const optional<Shape> & shape)445 std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
446 const string& name, const optional<Shape>& shape) {
447 std::pair<HloInstruction*, LocTy>* instr = nullptr;
448 if (!name.empty()) {
449 instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
450 }
451
452 // Potentially call the missing instruction hook.
453 if (instr == nullptr && create_missing_instruction_ != nullptr &&
454 scoped_name_tables_.size() == 1) {
455 if (!shape.has_value()) {
456 Error(lexer_.GetLoc(),
457 "Operand had no shape in HLO text; cannot create parameter for "
458 "single-instruction module.");
459 return nullptr;
460 }
461 return create_missing_instruction_(name, *shape);
462 }
463
464 if (instr != nullptr && shape.has_value() &&
465 !ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
466 Error(
467 lexer_.GetLoc(),
468 StrCat("The declared operand shape ",
469 ShapeUtil::HumanStringWithLayout(shape.value()),
470 " is not compatible with the shape of the operand instruction ",
471 ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
472 return nullptr;
473 }
474
475 return instr;
476 }
477
478 // ::= 'HloModule' name computations
ParseHloModule(HloModule * module)479 bool HloParser::ParseHloModule(HloModule* module) {
480 if (lexer_.GetKind() != TokKind::kw_HloModule) {
481 return TokenError("expects HloModule");
482 }
483 // Eat 'HloModule'
484 lexer_.Lex();
485
486 string name;
487 if (!ParseName(&name)) {
488 return false;
489 }
490
491 absl::optional<bool> is_scheduled;
492 std::unordered_map<string, AttrConfig> attrs;
493 attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
494 if (!ParseAttributes(attrs)) {
495 return false;
496 }
497
498 module->set_name(name);
499 if (!ParseComputations(module)) {
500 return false;
501 }
502
503 if (is_scheduled.has_value() && *is_scheduled) {
504 TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
505 }
506
507 return true;
508 }
509
510 // computations ::= (computation)+
ParseComputations(HloModule * module)511 bool HloParser::ParseComputations(HloModule* module) {
512 HloComputation* entry_computation = nullptr;
513 do {
514 if (!ParseComputation(&entry_computation)) {
515 return false;
516 }
517 } while (lexer_.GetKind() != TokKind::kEof);
518
519 for (int i = 0; i < computations_.size(); i++) {
520 // If entry_computation is not nullptr, it means the computation it pointed
521 // to is marked with "ENTRY"; otherwise, no computation is marked with
522 // "ENTRY", and we use the last computation as the entry computation. We
523 // add the non-entry computations as embedded computations to the module.
524 if ((entry_computation != nullptr &&
525 computations_[i].get() != entry_computation) ||
526 (entry_computation == nullptr && i != computations_.size() - 1)) {
527 module->AddEmbeddedComputation(std::move(computations_[i]));
528 continue;
529 }
530 auto computation = module->AddEntryComputation(std::move(computations_[i]));
531 // The parameters and result layouts were set to default layout. Here we
532 // set the layouts to what the hlo text says.
533 for (int p = 0; p < computation->num_parameters(); p++) {
534 const Shape& param_shape = computation->parameter_instruction(p)->shape();
535 TF_CHECK_OK(module->mutable_entry_computation_layout()
536 ->mutable_parameter_layout(p)
537 ->CopyLayoutFromShape(param_shape));
538 }
539 const Shape& result_shape = computation->root_instruction()->shape();
540 TF_CHECK_OK(module->mutable_entry_computation_layout()
541 ->mutable_result_layout()
542 ->CopyLayoutFromShape(result_shape));
543 }
544 return true;
545 }
546
547 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
ParseComputation(HloComputation ** entry_computation)548 bool HloParser::ParseComputation(HloComputation** entry_computation) {
549 LocTy maybe_entry_loc = lexer_.GetLoc();
550 const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
551
552 string name;
553 LocTy name_loc = lexer_.GetLoc();
554 if (!ParseName(&name)) {
555 return false;
556 }
557
558 LocTy shape_loc = nullptr;
559 Shape shape;
560 if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
561 return false;
562 }
563
564 HloComputation* computation = nullptr;
565 if (!ParseInstructionList(&computation, name)) {
566 return false;
567 }
568
569 // If param_list_to_shape was present, check compatibility.
570 if (shape_loc != nullptr &&
571 !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
572 return Error(
573 shape_loc,
574 StrCat(
575 "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
576 ", is not compatible with that of its root instruction ",
577 computation->root_instruction()->name(), ", ",
578 ShapeUtil::HumanString(computation->root_instruction()->shape())));
579 }
580
581 if (is_entry_computation) {
582 if (*entry_computation != nullptr) {
583 return Error(maybe_entry_loc, "expects only one ENTRY");
584 }
585 *entry_computation = computation;
586 }
587
588 return AddComputation(name, computation, name_loc);
589 }
590
591 // instruction_list ::= '{' instruction_list1 '}'
592 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation ** computation,const string & computation_name)593 bool HloParser::ParseInstructionList(HloComputation** computation,
594 const string& computation_name) {
595 Scope scope(&scoped_name_tables_);
596 HloComputation::Builder builder(computation_name);
597 if (!ParseToken(TokKind::kLbrace,
598 "expects '{' at the beginning of instruction list.")) {
599 return false;
600 }
601 string root_name;
602 do {
603 if (!ParseInstruction(&builder, &root_name)) {
604 return false;
605 }
606 } while (lexer_.GetKind() != TokKind::kRbrace);
607 if (!ParseToken(TokKind::kRbrace,
608 "expects '}' at the end of instruction list.")) {
609 return false;
610 }
611 HloInstruction* root = nullptr;
612 if (!root_name.empty()) {
613 std::pair<HloInstruction*, LocTy>* root_node =
614 tensorflow::gtl::FindOrNull(current_name_table(), root_name);
615
616 // This means some instruction was marked as ROOT but we didn't find it in
617 // the pool, which should not happen.
618 if (root_node == nullptr) {
619 LOG(FATAL) << "instruction " << root_name
620 << " was marked as ROOT but the parser has not seen it before";
621 }
622 root = root_node->first;
623 }
624
625 // Now root can be either an existing instruction or a nullptr. If it's a
626 // nullptr, the implementation of Builder will set the last instruction as
627 // the root instruction.
628 computations_.emplace_back(builder.Build(root));
629 *computation = computations_.back().get();
630 return true;
631 }
632
633 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,string * root_name)634 bool HloParser::ParseInstruction(HloComputation::Builder* builder,
635 string* root_name) {
636 string name;
637 LocTy maybe_root_loc = lexer_.GetLoc();
638 bool is_root = EatIfPresent(TokKind::kw_ROOT);
639
640 const LocTy name_loc = lexer_.GetLoc();
641 if (!ParseName(&name) ||
642 !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
643 return false;
644 }
645
646 if (is_root) {
647 if (!root_name->empty()) {
648 return Error(maybe_root_loc, "one computation should have only one ROOT");
649 }
650 *root_name = name;
651 }
652
653 return ParseInstructionRhs(builder, name, name_loc);
654 }
655
ParseInstructionRhs(HloComputation::Builder * builder,const string & name,LocTy name_loc)656 bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
657 const string& name, LocTy name_loc) {
658 Shape shape;
659 HloOpcode opcode;
660 std::vector<HloInstruction*> operands;
661
662 if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
663 return false;
664 }
665
666 // Add optional attributes.
667 std::unordered_map<string, AttrConfig> attrs;
668 optional<OpSharding> sharding;
669 attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
670 optional<ParameterReplication> parameter_replication;
671 attrs["parameter_replication"] = {/*required=*/false,
672 AttrTy::kParameterReplication,
673 ¶meter_replication};
674 optional<std::vector<HloInstruction*>> predecessors;
675 attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
676 &predecessors};
677 optional<OpMetadata> metadata;
678 attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
679
680 optional<string> backend_config;
681 attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
682 &backend_config};
683
684 HloInstruction* instruction;
685 switch (opcode) {
686 case HloOpcode::kParameter: {
687 int64 parameter_number;
688 if (!ParseToken(TokKind::kLparen,
689 "expects '(' before parameter number") ||
690 !ParseInt64(¶meter_number)) {
691 return false;
692 }
693 if (parameter_number < 0) {
694 Error(lexer_.GetLoc(), "parameter number must be >= 0");
695 return false;
696 }
697 if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
698 !ParseAttributes(attrs)) {
699 return false;
700 }
701 instruction = builder->AddInstruction(
702 HloInstruction::CreateParameter(parameter_number, shape, name));
703 break;
704 }
705 case HloOpcode::kConstant: {
706 Literal literal;
707 if (!ParseToken(TokKind::kLparen,
708 "expects '(' before constant literal") ||
709 !ParseLiteral(&literal, shape) ||
710 !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
711 !ParseAttributes(attrs)) {
712 return false;
713 }
714 instruction = builder->AddInstruction(
715 HloInstruction::CreateConstant(std::move(literal)));
716 break;
717 }
718 case HloOpcode::kIota: {
719 optional<int64> iota_dimension;
720 attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
721 &iota_dimension};
722 if (!ParseOperands(&operands, /*expected_size=*/0) ||
723 !ParseAttributes(attrs)) {
724 return false;
725 }
726 instruction = builder->AddInstruction(
727 HloInstruction::CreateIota(shape, *iota_dimension));
728 break;
729 }
730 // Unary ops.
731 case HloOpcode::kAbs:
732 case HloOpcode::kRoundNearestAfz:
733 case HloOpcode::kBitcast:
734 case HloOpcode::kCeil:
735 case HloOpcode::kClz:
736 case HloOpcode::kCopy:
737 case HloOpcode::kCos:
738 case HloOpcode::kExp:
739 case HloOpcode::kExpm1:
740 case HloOpcode::kImag:
741 case HloOpcode::kIsFinite:
742 case HloOpcode::kFloor:
743 case HloOpcode::kLog:
744 case HloOpcode::kLog1p:
745 case HloOpcode::kNot:
746 case HloOpcode::kNegate:
747 case HloOpcode::kReal:
748 case HloOpcode::kRsqrt:
749 case HloOpcode::kSign:
750 case HloOpcode::kSin:
751 case HloOpcode::kSqrt:
752 case HloOpcode::kTanh: {
753 if (!ParseOperands(&operands, /*expected_size=*/1) ||
754 !ParseAttributes(attrs)) {
755 return false;
756 }
757 instruction = builder->AddInstruction(
758 HloInstruction::CreateUnary(shape, opcode, operands[0]));
759 break;
760 }
761 // Binary ops.
762 case HloOpcode::kAdd:
763 case HloOpcode::kDivide:
764 case HloOpcode::kMultiply:
765 case HloOpcode::kSubtract:
766 case HloOpcode::kAtan2:
767 case HloOpcode::kComplex:
768 case HloOpcode::kMaximum:
769 case HloOpcode::kMinimum:
770 case HloOpcode::kPower:
771 case HloOpcode::kRemainder:
772 case HloOpcode::kAnd:
773 case HloOpcode::kOr:
774 case HloOpcode::kXor:
775 case HloOpcode::kShiftLeft:
776 case HloOpcode::kShiftRightArithmetic:
777 case HloOpcode::kShiftRightLogical: {
778 if (!ParseOperands(&operands, /*expected_size=*/2) ||
779 !ParseAttributes(attrs)) {
780 return false;
781 }
782 instruction = builder->AddInstruction(HloInstruction::CreateBinary(
783 shape, opcode, operands[0], operands[1]));
784 break;
785 }
786 // Ternary ops.
787 case HloOpcode::kClamp:
788 case HloOpcode::kSelect:
789 case HloOpcode::kTupleSelect: {
790 if (!ParseOperands(&operands, /*expected_size=*/3) ||
791 !ParseAttributes(attrs)) {
792 return false;
793 }
794 instruction = builder->AddInstruction(HloInstruction::CreateTernary(
795 shape, opcode, operands[0], operands[1], operands[2]));
796 break;
797 }
798 // Other supported ops.
799 case HloOpcode::kConvert: {
800 if (!ParseOperands(&operands, /*expected_size=*/1) ||
801 !ParseAttributes(attrs)) {
802 return false;
803 }
804 instruction = builder->AddInstruction(
805 HloInstruction::CreateConvert(shape, operands[0]));
806 break;
807 }
808 case HloOpcode::kBitcastConvert: {
809 if (!ParseOperands(&operands, /*expected_size=*/1) ||
810 !ParseAttributes(attrs)) {
811 return false;
812 }
813 instruction = builder->AddInstruction(
814 HloInstruction::CreateBitcastConvert(shape, operands[0]));
815 break;
816 }
817 case HloOpcode::kAllReduce: {
818 optional<std::vector<std::vector<int64>>> tmp_groups;
819 optional<HloComputation*> to_apply;
820 optional<std::vector<int64>> replica_group_ids;
821 optional<string> barrier;
822 optional<int64> all_reduce_id;
823 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
824 &to_apply};
825 attrs["replica_groups"] = {/*required=*/false,
826 AttrTy::kBracedInt64ListList, &tmp_groups};
827 attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
828 attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
829 &all_reduce_id};
830 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
831 return false;
832 }
833 std::vector<ReplicaGroup> replica_groups;
834 if (tmp_groups) {
835 replica_groups = CreateReplicaGroups(*tmp_groups);
836 }
837 instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
838 shape, operands, *to_apply, replica_groups, barrier ? *barrier : "",
839 all_reduce_id));
840 break;
841 }
842 case HloOpcode::kAllToAll: {
843 optional<std::vector<std::vector<int64>>> tmp_groups;
844 optional<string> barrier;
845 attrs["replica_groups"] = {/*required=*/false,
846 AttrTy::kBracedInt64ListList, &tmp_groups};
847 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
848 return false;
849 }
850 std::vector<ReplicaGroup> replica_groups;
851 if (tmp_groups) {
852 replica_groups = CreateReplicaGroups(*tmp_groups);
853 }
854 instruction = builder->AddInstruction(
855 HloInstruction::CreateAllToAll(shape, operands, replica_groups));
856 break;
857 }
858 case HloOpcode::kCollectivePermute: {
859 optional<std::vector<std::vector<int64>>> source_targets;
860 attrs["source_target_pairs"] = {
861 /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
862 if (!ParseOperands(&operands, /*expected_size=*/1) ||
863 !ParseAttributes(attrs)) {
864 return false;
865 }
866 std::vector<std::pair<int64, int64>> pairs(source_targets->size());
867 for (int i = 0; i < pairs.size(); i++) {
868 if ((*source_targets)[i].size() != 2) {
869 return TokenError(
870 "expects 'source_target_pairs=' to be a list of pairs");
871 }
872 pairs[i].first = (*source_targets)[i][0];
873 pairs[i].second = (*source_targets)[i][1];
874 }
875 instruction = builder->AddInstruction(
876 HloInstruction::CreateCollectivePermute(shape, operands[0], pairs));
877 break;
878 }
879 case HloOpcode::kReplicaId: {
880 if (!ParseOperands(&operands, /*expected_size=*/0) ||
881 !ParseAttributes(attrs)) {
882 return false;
883 }
884 instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
885 break;
886 }
887 case HloOpcode::kReshape: {
888 if (!ParseOperands(&operands, /*expected_size=*/1) ||
889 !ParseAttributes(attrs)) {
890 return false;
891 }
892 instruction = builder->AddInstruction(
893 HloInstruction::CreateReshape(shape, operands[0]));
894 break;
895 }
896 case HloOpcode::kAfterAll: {
897 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
898 return false;
899 }
900 if (operands.empty()) {
901 instruction = builder->AddInstruction(HloInstruction::CreateToken());
902 } else {
903 instruction =
904 builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
905 }
906 break;
907 }
908 case HloOpcode::kAddDependency: {
909 if (!ParseOperands(&operands, /*expected_size=*/2) ||
910 !ParseAttributes(attrs)) {
911 return false;
912 }
913 instruction = builder->AddInstruction(
914 HloInstruction::CreateAddDependency(operands[0], operands[1]));
915 break;
916 }
917 case HloOpcode::kSort: {
918 optional<std::vector<int64>> dimensions;
919 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
920 &dimensions};
921 optional<bool> is_stable = false;
922 attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
923 optional<HloComputation*> to_apply;
924 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
925 &to_apply};
926 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
927 dimensions->size() != 1) {
928 return false;
929 }
930 instruction = builder->AddInstruction(
931 HloInstruction::CreateSort(shape, dimensions->at(0), operands,
932 to_apply.value(), is_stable.value()));
933 break;
934 }
935 case HloOpcode::kTuple: {
936 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
937 return false;
938 }
939 instruction =
940 builder->AddInstruction(HloInstruction::CreateTuple(operands));
941 break;
942 }
943 case HloOpcode::kWhile: {
944 optional<HloComputation*> condition;
945 optional<HloComputation*> body;
946 attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
947 &condition};
948 attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
949 if (!ParseOperands(&operands, /*expected_size=*/1) ||
950 !ParseAttributes(attrs)) {
951 return false;
952 }
953 instruction = builder->AddInstruction(HloInstruction::CreateWhile(
954 shape, *condition, *body, /*init=*/operands[0]));
955 break;
956 }
957 case HloOpcode::kRecv: {
958 optional<int64> channel_id;
959 // If the is_host_transfer attribute is not present then default to false.
960 optional<bool> is_host_transfer = false;
961 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
962 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
963 &is_host_transfer};
964 if (!ParseOperands(&operands, /*expected_size=*/1) ||
965 !ParseAttributes(attrs)) {
966 return false;
967 }
968 // If the is_host_transfer attribute is not present then default to false.
969 instruction = builder->AddInstruction(HloInstruction::CreateRecv(
970 shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
971 break;
972 }
973 case HloOpcode::kRecvDone: {
974 optional<int64> channel_id;
975 // If the is_host_transfer attribute is not present then default to false.
976 optional<bool> is_host_transfer = false;
977 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
978 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
979 &is_host_transfer};
980 if (!ParseOperands(&operands, /*expected_size=*/1) ||
981 !ParseAttributes(attrs)) {
982 return false;
983 }
984 if (channel_id != operands[0]->channel_id()) {
985 return false;
986 }
987 instruction = builder->AddInstruction(
988 HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
989 break;
990 }
991 case HloOpcode::kSend: {
992 optional<int64> channel_id;
993 // If the is_host_transfer attribute is not present then default to false.
994 optional<bool> is_host_transfer = false;
995 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
996 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
997 &is_host_transfer};
998 if (!ParseOperands(&operands, /*expected_size=*/2) ||
999 !ParseAttributes(attrs)) {
1000 return false;
1001 }
1002 instruction = builder->AddInstruction(HloInstruction::CreateSend(
1003 operands[0], operands[1], *channel_id, *is_host_transfer));
1004 break;
1005 }
1006 case HloOpcode::kSendDone: {
1007 optional<int64> channel_id;
1008 // If the is_host_transfer attribute is not present then default to false.
1009 optional<bool> is_host_transfer = false;
1010 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1011 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1012 &is_host_transfer};
1013 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1014 !ParseAttributes(attrs)) {
1015 return false;
1016 }
1017 if (channel_id != operands[0]->channel_id()) {
1018 return false;
1019 }
1020 instruction = builder->AddInstruction(
1021 HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
1022 break;
1023 }
1024 case HloOpcode::kGetTupleElement: {
1025 optional<int64> index;
1026 attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
1027 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1028 !ParseAttributes(attrs)) {
1029 return false;
1030 }
1031 instruction = builder->AddInstruction(
1032 HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
1033 break;
1034 }
1035 case HloOpcode::kCall: {
1036 optional<HloComputation*> to_apply;
1037 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1038 &to_apply};
1039 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1040 return false;
1041 }
1042 instruction = builder->AddInstruction(
1043 HloInstruction::CreateCall(shape, operands, *to_apply));
1044 break;
1045 }
1046 case HloOpcode::kReduceWindow: {
1047 optional<HloComputation*> reduce_computation;
1048 optional<Window> window;
1049 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1050 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1051 &reduce_computation};
1052 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1053 !ParseAttributes(attrs)) {
1054 return false;
1055 }
1056 if (!window) {
1057 window.emplace();
1058 }
1059 instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
1060 shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
1061 *reduce_computation));
1062 break;
1063 }
1064 case HloOpcode::kConvolution: {
1065 optional<Window> window;
1066 optional<ConvolutionDimensionNumbers> dnums;
1067 optional<int64> feature_group_count;
1068 optional<int64> batch_group_count;
1069 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1070 attrs["dim_labels"] = {/*required=*/true,
1071 AttrTy::kConvolutionDimensionNumbers, &dnums};
1072 attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1073 &feature_group_count};
1074 attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1075 &batch_group_count};
1076 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1077 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1078 &operand_precision};
1079 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1080 !ParseAttributes(attrs)) {
1081 return false;
1082 }
1083 if (!window) {
1084 window.emplace();
1085 }
1086 if (!feature_group_count) {
1087 feature_group_count = 1;
1088 }
1089 if (!batch_group_count) {
1090 batch_group_count = 1;
1091 }
1092 PrecisionConfig precision_config;
1093 if (operand_precision) {
1094 *precision_config.mutable_operand_precision() = {
1095 operand_precision->begin(), operand_precision->end()};
1096 } else {
1097 precision_config.mutable_operand_precision()->Resize(
1098 operands.size(), PrecisionConfig::DEFAULT);
1099 }
1100 instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
1101 shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
1102 feature_group_count.value(), batch_group_count.value(), *window,
1103 *dnums, precision_config));
1104 break;
1105 }
1106 case HloOpcode::kFft: {
1107 optional<FftType> fft_type;
1108 optional<std::vector<int64>> fft_length;
1109 attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
1110 attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
1111 &fft_length};
1112 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1113 !ParseAttributes(attrs)) {
1114 return false;
1115 }
1116 instruction = builder->AddInstruction(HloInstruction::CreateFft(
1117 shape, operands[0], *fft_type, *fft_length));
1118 break;
1119 }
1120 case HloOpcode::kTriangularSolve: {
1121 TriangularSolveOptions options;
1122 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1123 !ParseAttributesAsProtoMessage(
1124 /*required_attrs=*/std::unordered_set<string>(), &options)) {
1125 return false;
1126 }
1127 instruction =
1128 builder->AddInstruction(HloInstruction::CreateTriangularSolve(
1129 shape, operands[0], operands[1], options));
1130 break;
1131 }
1132 case HloOpcode::kCompare: {
1133 optional<ComparisonDirection> direction;
1134 attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
1135 &direction};
1136 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1137 !ParseAttributes(attrs)) {
1138 return false;
1139 }
1140 instruction = builder->AddInstruction(HloInstruction::CreateCompare(
1141 shape, operands[0], operands[1], *direction));
1142 break;
1143 }
1144 case HloOpcode::kCholesky: {
1145 CholeskyOptions options;
1146 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1147 !ParseAttributesAsProtoMessage(
1148 /*required_attrs=*/std::unordered_set<string>(), &options)) {
1149 return false;
1150 }
1151 instruction = builder->AddInstruction(
1152 HloInstruction::CreateCholesky(shape, operands[0], options));
1153 break;
1154 }
1155 case HloOpcode::kBroadcast: {
1156 optional<std::vector<int64>> broadcast_dimensions;
1157 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1158 &broadcast_dimensions};
1159 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1160 !ParseAttributes(attrs)) {
1161 return false;
1162 }
1163 instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
1164 shape, operands[0], *broadcast_dimensions));
1165 break;
1166 }
1167 case HloOpcode::kConcatenate: {
1168 optional<std::vector<int64>> dimensions;
1169 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1170 &dimensions};
1171 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1172 dimensions->size() != 1) {
1173 return false;
1174 }
1175 instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
1176 shape, operands, dimensions->at(0)));
1177 break;
1178 }
1179 case HloOpcode::kMap: {
1180 optional<HloComputation*> to_apply;
1181 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1182 &to_apply};
1183 optional<std::vector<int64>> dimensions;
1184 attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1185 &dimensions};
1186 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1187 return false;
1188 }
1189 instruction = builder->AddInstruction(
1190 HloInstruction::CreateMap(shape, operands, *to_apply));
1191 break;
1192 }
1193 case HloOpcode::kReduce: {
1194 auto loc = lexer_.GetLoc();
1195
1196 optional<HloComputation*> reduce_computation;
1197 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1198 &reduce_computation};
1199 optional<std::vector<int64>> dimensions_to_reduce;
1200 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1201 &dimensions_to_reduce};
1202 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1203 return false;
1204 }
1205 if (operands.size() % 2) {
1206 return Error(loc, StrCat("expects an even number of operands, but has ",
1207 operands.size(), " operands"));
1208 }
1209 instruction = builder->AddInstruction(HloInstruction::CreateReduce(
1210 shape, /*operands=*/
1211 absl::Span<HloInstruction* const>(operands).subspan(
1212 0, operands.size() / 2),
1213 /*init_values=*/
1214 absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1215 2),
1216 *dimensions_to_reduce, *reduce_computation));
1217 break;
1218 }
1219 case HloOpcode::kReverse: {
1220 optional<std::vector<int64>> dimensions;
1221 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1222 &dimensions};
1223 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1224 !ParseAttributes(attrs)) {
1225 return false;
1226 }
1227 instruction = builder->AddInstruction(
1228 HloInstruction::CreateReverse(shape, operands[0], *dimensions));
1229 break;
1230 }
1231 case HloOpcode::kSelectAndScatter: {
1232 optional<HloComputation*> select;
1233 attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
1234 optional<HloComputation*> scatter;
1235 attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
1236 optional<Window> window;
1237 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1238 if (!ParseOperands(&operands, /*expected_size=*/3) ||
1239 !ParseAttributes(attrs)) {
1240 return false;
1241 }
1242 if (!window) {
1243 window.emplace();
1244 }
1245 instruction =
1246 builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
1247 shape, /*operand=*/operands[0], *select, *window,
1248 /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
1249 break;
1250 }
1251 case HloOpcode::kSlice: {
1252 optional<SliceRanges> slice_ranges;
1253 attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
1254 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1255 !ParseAttributes(attrs)) {
1256 return false;
1257 }
1258 instruction = builder->AddInstruction(HloInstruction::CreateSlice(
1259 shape, operands[0], slice_ranges->starts, slice_ranges->limits,
1260 slice_ranges->strides));
1261 break;
1262 }
1263 case HloOpcode::kDynamicSlice: {
1264 optional<std::vector<int64>> dynamic_slice_sizes;
1265 attrs["dynamic_slice_sizes"] = {
1266 /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
1267 LocTy loc = lexer_.GetLoc();
1268 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1269 return false;
1270 }
1271 if (operands.empty()) {
1272 return Error(loc, "Expected at least one operand.");
1273 }
1274 if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
1275 operands.size() != 1 + operands[0]->shape().rank()) {
1276 return Error(loc, "Wrong number of operands.");
1277 }
1278 instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
1279 shape, /*operand=*/operands[0],
1280 /*start_indices=*/absl::MakeSpan(operands).subspan(1),
1281 *dynamic_slice_sizes));
1282 break;
1283 }
1284 case HloOpcode::kDynamicUpdateSlice: {
1285 LocTy loc = lexer_.GetLoc();
1286 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1287 return false;
1288 }
1289 if (operands.size() < 2) {
1290 return Error(loc, "Expected at least two operands.");
1291 }
1292 if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
1293 operands.size() != 2 + operands[0]->shape().rank()) {
1294 return Error(loc, "Wrong number of operands.");
1295 }
1296 instruction =
1297 builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1298 shape, /*operand=*/operands[0], /*update=*/operands[1],
1299 /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
1300 break;
1301 }
1302 case HloOpcode::kTranspose: {
1303 optional<std::vector<int64>> dimensions;
1304 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1305 &dimensions};
1306 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1307 !ParseAttributes(attrs)) {
1308 return false;
1309 }
1310 instruction = builder->AddInstruction(
1311 HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
1312 break;
1313 }
1314 case HloOpcode::kBatchNormTraining: {
1315 optional<float> epsilon;
1316 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
1317 optional<int64> feature_index;
1318 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
1319 &feature_index};
1320 if (!ParseOperands(&operands, /*expected_size=*/3) ||
1321 !ParseAttributes(attrs)) {
1322 return false;
1323 }
1324 instruction =
1325 builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
1326 shape, /*operand=*/operands[0], /*scale=*/operands[1],
1327 /*offset=*/operands[2], *epsilon, *feature_index));
1328 break;
1329 }
1330 case HloOpcode::kBatchNormInference: {
1331 optional<float> epsilon;
1332 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
1333 optional<int64> feature_index;
1334 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
1335 &feature_index};
1336 if (!ParseOperands(&operands, /*expected_size=*/5) ||
1337 !ParseAttributes(attrs)) {
1338 return false;
1339 }
1340 instruction =
1341 builder->AddInstruction(HloInstruction::CreateBatchNormInference(
1342 shape, /*operand=*/operands[0], /*scale=*/operands[1],
1343 /*offset=*/operands[2], /*mean=*/operands[3],
1344 /*variance=*/operands[4], *epsilon, *feature_index));
1345 break;
1346 }
1347 case HloOpcode::kBatchNormGrad: {
1348 optional<float> epsilon;
1349 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
1350 optional<int64> feature_index;
1351 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
1352 &feature_index};
1353 if (!ParseOperands(&operands, /*expected_size=*/5) ||
1354 !ParseAttributes(attrs)) {
1355 return false;
1356 }
1357 instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
1358 shape, /*operand=*/operands[0], /*scale=*/operands[1],
1359 /*mean=*/operands[2], /*variance=*/operands[3],
1360 /*grad_output=*/operands[4], *epsilon, *feature_index));
1361 break;
1362 }
1363 case HloOpcode::kPad: {
1364 optional<PaddingConfig> padding;
1365 attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
1366 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1367 !ParseAttributes(attrs)) {
1368 return false;
1369 }
1370 instruction = builder->AddInstruction(HloInstruction::CreatePad(
1371 shape, operands[0], /*padding_value=*/operands[1], *padding));
1372 break;
1373 }
1374 case HloOpcode::kFusion: {
1375 optional<HloComputation*> fusion_computation;
1376 attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
1377 &fusion_computation};
1378 optional<HloInstruction::FusionKind> fusion_kind;
1379 attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
1380 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1381 return false;
1382 }
1383 instruction = builder->AddInstruction(HloInstruction::CreateFusion(
1384 shape, *fusion_kind, operands, *fusion_computation));
1385 break;
1386 }
1387 case HloOpcode::kInfeed: {
1388 optional<string> config;
1389 attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
1390 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1391 !ParseAttributes(attrs)) {
1392 return false;
1393 }
1394 // We need to know the infeed data shape to construct the infeed
1395 // instruction. This is the zero-th element of the tuple-shaped output of
1396 // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
1397 // if the shape is not a non-empty tuple, so add guard so an error message
1398 // can be emitted instead of a check fail
1399 if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
1400 return Error(lexer_.GetLoc(),
1401 "infeed must have a non-empty tuple shape");
1402 }
1403 instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
1404 ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
1405 config ? *config : ""));
1406 break;
1407 }
1408 case HloOpcode::kOutfeed: {
1409 optional<string> config;
1410 attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
1411 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1412 !ParseAttributes(attrs)) {
1413 return false;
1414 }
1415 instruction = builder->AddInstruction(
1416 HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
1417 operands[1], config ? *config : ""));
1418 break;
1419 }
1420 case HloOpcode::kRng: {
1421 optional<RandomDistribution> distribution;
1422 attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
1423 &distribution};
1424 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1425 return false;
1426 }
1427 instruction = builder->AddInstruction(
1428 HloInstruction::CreateRng(shape, *distribution, operands));
1429 break;
1430 }
1431 case HloOpcode::kReducePrecision: {
1432 optional<int64> exponent_bits;
1433 optional<int64> mantissa_bits;
1434 attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
1435 &exponent_bits};
1436 attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
1437 &mantissa_bits};
1438 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1439 !ParseAttributes(attrs)) {
1440 return false;
1441 }
1442 instruction =
1443 builder->AddInstruction(HloInstruction::CreateReducePrecision(
1444 shape, operands[0], static_cast<int>(*exponent_bits),
1445 static_cast<int>(*mantissa_bits)));
1446 break;
1447 }
1448 case HloOpcode::kConditional: {
1449 optional<HloComputation*> true_computation;
1450 optional<HloComputation*> false_computation;
1451 optional<std::vector<HloComputation*>> branch_computations;
1452 if (!ParseOperands(&operands)) {
1453 return false;
1454 }
1455 const bool branch_index_is_bool =
1456 operands[0]->shape().element_type() == PRED;
1457 if (branch_index_is_bool) {
1458 attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
1459 &true_computation};
1460 attrs["false_computation"] = {
1461 /*required=*/true, AttrTy::kHloComputation, &false_computation};
1462 } else {
1463 attrs["branch_computations"] = {/*required=*/true,
1464 AttrTy::kBracedHloComputationList,
1465 &branch_computations};
1466 }
1467 if (!ParseAttributes(attrs)) {
1468 return false;
1469 }
1470 if (branch_index_is_bool) {
1471 branch_computations.emplace({*true_computation, *false_computation});
1472 }
1473 if (branch_computations->empty() ||
1474 operands.size() != branch_computations->size() + 1) {
1475 return false;
1476 }
1477 instruction = builder->AddInstruction(HloInstruction::CreateConditional(
1478 shape, /*branch_index=*/operands[0],
1479 absl::MakeSpan(*branch_computations),
1480 absl::MakeSpan(operands).subspan(1)));
1481 break;
1482 }
1483 case HloOpcode::kCustomCall: {
1484 optional<string> custom_call_target;
1485 optional<string> opaque;
1486 optional<Window> window;
1487 optional<ConvolutionDimensionNumbers> dnums;
1488 optional<int64> feature_group_count;
1489 optional<int64> batch_group_count;
1490 optional<std::vector<Shape>> operand_layout_constraints;
1491 attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
1492 &custom_call_target};
1493 attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
1494 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1495 attrs["dim_labels"] = {/*required=*/false,
1496 AttrTy::kConvolutionDimensionNumbers, &dnums};
1497 attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1498 &feature_group_count};
1499 attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1500 &batch_group_count};
1501 attrs["operand_layout_constraints"] = {
1502 /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
1503 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1504 return false;
1505 }
1506 if (operand_layout_constraints.has_value()) {
1507 if (!LayoutUtil::HasLayout(shape)) {
1508 return Error(lexer_.GetLoc(),
1509 "Layout must be set on layout-constrained custom call");
1510 }
1511 if (operands.size() != operand_layout_constraints->size()) {
1512 return Error(lexer_.GetLoc(),
1513 StrCat("Expected ", operands.size(),
1514 " operand layout constraints, ",
1515 operand_layout_constraints->size(), " given"));
1516 }
1517 for (int64 i = 0; i < operands.size(); ++i) {
1518 const Shape& operand_shape_with_layout =
1519 (*operand_layout_constraints)[i];
1520 if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
1521 return Error(lexer_.GetLoc(),
1522 StrCat("Operand layout constraint shape ",
1523 ShapeUtil::HumanStringWithLayout(
1524 operand_shape_with_layout),
1525 " for operand ", i, " does not have a layout"));
1526 }
1527 if (!ShapeUtil::Compatible(operand_shape_with_layout,
1528 operands[i]->shape())) {
1529 return Error(
1530 lexer_.GetLoc(),
1531 StrCat(
1532 "Operand layout constraint shape ",
1533 ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
1534 " for operand ", i,
1535 " is not compatible with operand shape ",
1536 ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
1537 }
1538 }
1539 instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
1540 shape, operands, *custom_call_target, *operand_layout_constraints,
1541 opaque.has_value() ? *opaque : ""));
1542 } else {
1543 instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
1544 shape, operands, *custom_call_target,
1545 opaque.has_value() ? *opaque : ""));
1546 }
1547 if (window.has_value()) {
1548 instruction->set_window(*window);
1549 }
1550 if (dnums.has_value()) {
1551 instruction->set_convolution_dimension_numbers(*dnums);
1552 }
1553 if (feature_group_count.has_value()) {
1554 instruction->set_feature_group_count(*feature_group_count);
1555 }
1556 if (batch_group_count.has_value()) {
1557 instruction->set_batch_group_count(*batch_group_count);
1558 }
1559 break;
1560 }
1561 case HloOpcode::kDot: {
1562 optional<std::vector<int64>> lhs_contracting_dims;
1563 attrs["lhs_contracting_dims"] = {
1564 /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
1565 optional<std::vector<int64>> rhs_contracting_dims;
1566 attrs["rhs_contracting_dims"] = {
1567 /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
1568 optional<std::vector<int64>> lhs_batch_dims;
1569 attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
1570 &lhs_batch_dims};
1571 optional<std::vector<int64>> rhs_batch_dims;
1572 attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
1573 &rhs_batch_dims};
1574 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1575 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1576 &operand_precision};
1577
1578 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1579 !ParseAttributes(attrs)) {
1580 return false;
1581 }
1582
1583 DotDimensionNumbers dnum;
1584 if (lhs_contracting_dims) {
1585 *dnum.mutable_lhs_contracting_dimensions() = {
1586 lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
1587 }
1588 if (rhs_contracting_dims) {
1589 *dnum.mutable_rhs_contracting_dimensions() = {
1590 rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
1591 }
1592 if (lhs_batch_dims) {
1593 *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
1594 lhs_batch_dims->end()};
1595 }
1596 if (rhs_batch_dims) {
1597 *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
1598 rhs_batch_dims->end()};
1599 }
1600
1601 PrecisionConfig precision_config;
1602 if (operand_precision) {
1603 *precision_config.mutable_operand_precision() = {
1604 operand_precision->begin(), operand_precision->end()};
1605 } else {
1606 precision_config.mutable_operand_precision()->Resize(
1607 operands.size(), PrecisionConfig::DEFAULT);
1608 }
1609
1610 instruction = builder->AddInstruction(HloInstruction::CreateDot(
1611 shape, operands[0], operands[1], dnum, precision_config));
1612 break;
1613 }
1614 case HloOpcode::kGather: {
1615 optional<std::vector<int64>> offset_dims;
1616 attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
1617 &offset_dims};
1618 optional<std::vector<int64>> collapsed_slice_dims;
1619 attrs["collapsed_slice_dims"] = {
1620 /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
1621 optional<std::vector<int64>> start_index_map;
1622 attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
1623 &start_index_map};
1624 optional<int64> index_vector_dim;
1625 attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
1626 &index_vector_dim};
1627 optional<std::vector<int64>> slice_sizes;
1628 attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
1629 &slice_sizes};
1630
1631 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1632 !ParseAttributes(attrs)) {
1633 return false;
1634 }
1635
1636 GatherDimensionNumbers dim_numbers =
1637 HloGatherInstruction::MakeGatherDimNumbers(
1638 /*offset_dims=*/*offset_dims,
1639 /*collapsed_slice_dims=*/*collapsed_slice_dims,
1640 /*start_index_map=*/*start_index_map,
1641 /*index_vector_dim=*/*index_vector_dim);
1642
1643 instruction = builder->AddInstruction(HloInstruction::CreateGather(
1644 shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
1645 dim_numbers, *slice_sizes));
1646 break;
1647 }
1648 case HloOpcode::kScatter: {
1649 optional<std::vector<int64>> update_window_dims;
1650 attrs["update_window_dims"] = {
1651 /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
1652 optional<std::vector<int64>> inserted_window_dims;
1653 attrs["inserted_window_dims"] = {
1654 /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
1655 optional<std::vector<int64>> scatter_dims_to_operand_dims;
1656 attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
1657 AttrTy::kBracedInt64List,
1658 &scatter_dims_to_operand_dims};
1659 optional<int64> index_vector_dim;
1660 attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
1661 &index_vector_dim};
1662
1663 optional<HloComputation*> update_computation;
1664 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1665 &update_computation};
1666
1667 if (!ParseOperands(&operands, /*expected_size=*/3) ||
1668 !ParseAttributes(attrs)) {
1669 return false;
1670 }
1671
1672 ScatterDimensionNumbers dim_numbers =
1673 HloScatterInstruction::MakeScatterDimNumbers(
1674 /*update_window_dims=*/*update_window_dims,
1675 /*inserted_window_dims=*/*inserted_window_dims,
1676 /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
1677 /*index_vector_dim=*/*index_vector_dim);
1678
1679 instruction = builder->AddInstruction(HloInstruction::CreateScatter(
1680 shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
1681 /*updates=*/operands[2], *update_computation, dim_numbers));
1682 break;
1683 }
1684 case HloOpcode::kDomain: {
1685 DomainData domain;
1686 attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
1687 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1688 !ParseAttributes(attrs)) {
1689 return false;
1690 }
1691 instruction = builder->AddInstruction(HloInstruction::CreateDomain(
1692 shape, operands[0], std::move(domain.exit_metadata),
1693 std::move(domain.entry_metadata)));
1694 break;
1695 }
1696 case HloOpcode::kTrace:
1697 return TokenError(StrCat("parsing not yet implemented for op: ",
1698 HloOpcodeString(opcode)));
1699 case HloOpcode::kGetDimensionSize:
1700 optional<std::vector<int64>> dimensions;
1701 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1702 &dimensions};
1703 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1704 !ParseAttributes(attrs)) {
1705 return false;
1706 }
1707 instruction =
1708 builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
1709 shape, operands[0], (*dimensions)[0]));
1710 break;
1711 }
1712
1713 instruction->SetAndSanitizeName(name);
1714 if (instruction->name() != name) {
1715 return Error(name_loc,
1716 StrCat("illegal instruction name: ", name,
1717 "; suggest renaming to: ", instruction->name()));
1718 }
1719
1720 // Add shared attributes like metadata to the instruction, if they were seen.
1721 if (sharding) {
1722 instruction->set_sharding(
1723 HloSharding::FromProto(sharding.value()).ValueOrDie());
1724 }
1725 if (parameter_replication) {
1726 int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
1727 const auto& replicated =
1728 parameter_replication->replicated_at_leaf_buffers();
1729 if (leaf_count != replicated.size()) {
1730 return Error(lexer_.GetLoc(),
1731 StrCat("parameter has ", leaf_count,
1732 " leaf buffers, but parameter_replication has ",
1733 replicated.size(), " elements."));
1734 }
1735 instruction->set_parameter_replicated_at_leaf_buffers(replicated);
1736 }
1737 if (predecessors) {
1738 for (auto* pre : *predecessors) {
1739 Status status = pre->AddControlDependencyTo(instruction);
1740 if (!status.ok()) {
1741 return Error(name_loc, StrCat("error adding control dependency for: ",
1742 name, " status: ", status.ToString()));
1743 }
1744 }
1745 }
1746 if (metadata) {
1747 instruction->set_metadata(*metadata);
1748 }
1749 if (backend_config) {
1750 instruction->set_raw_backend_config_string(std::move(*backend_config));
1751 }
1752 return AddInstruction(name, instruction, name_loc);
1753 } // NOLINT(readability/fn_size)
1754
1755 // ::= '{' (single_sharding | tuple_sharding) '}'
1756 //
1757 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)1758 bool HloParser::ParseSharding(OpSharding* sharding) {
1759 // A single sharding starts with '{' and is not followed by '{'.
1760 // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
1761 // an empty tuple.
1762 if (!ParseToken(TokKind::kLbrace,
1763 "expected '{' to start sharding attribute")) {
1764 return false;
1765 }
1766
1767 if (lexer_.GetKind() != TokKind::kLbrace &&
1768 lexer_.GetKind() != TokKind::kRbrace) {
1769 return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
1770 }
1771
1772 // Tuple sharding.
1773 // Allow empty tuple shardings.
1774 if (lexer_.GetKind() != TokKind::kRbrace) {
1775 do {
1776 if (!ParseSingleSharding(sharding->add_tuple_shardings(),
1777 /*lbrace_pre_lexed=*/false)) {
1778 return false;
1779 }
1780 } while (EatIfPresent(TokKind::kComma));
1781 }
1782 sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE);
1783
1784 return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
1785 }
1786
1787 // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
1788 // ('devices=' ('[' dims ']')* device_list)? '}'
1789 // dims ::= int_list device_list ::= int_list
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)1790 bool HloParser::ParseSingleSharding(OpSharding* sharding,
1791 bool lbrace_pre_lexed) {
1792 if (!lbrace_pre_lexed &&
1793 !ParseToken(TokKind::kLbrace,
1794 "expected '{' to start sharding attribute")) {
1795 return false;
1796 }
1797
1798 LocTy loc = lexer_.GetLoc();
1799 bool maximal = false;
1800 bool replicated = false;
1801 std::vector<int64> devices;
1802 std::vector<int64> tile_assignment_dimensions;
1803 while (lexer_.GetKind() != TokKind::kRbrace) {
1804 switch (lexer_.GetKind()) {
1805 case TokKind::kw_maximal:
1806 maximal = true;
1807 lexer_.Lex();
1808 break;
1809 case TokKind::kw_replicated:
1810 replicated = true;
1811 lexer_.Lex();
1812 break;
1813 case TokKind::kAttributeName: {
1814 if (lexer_.GetStrVal() == "device") {
1815 if (lexer_.Lex() != TokKind::kInt) {
1816 return TokenError("device= attribute must be an integer");
1817 }
1818 devices = {lexer_.GetInt64Val()};
1819 lexer_.Lex();
1820 } else if (lexer_.GetStrVal() == "devices") {
1821 lexer_.Lex();
1822 if (!ParseToken(TokKind::kLsquare,
1823 "expected '[' to start sharding devices shape")) {
1824 return false;
1825 }
1826
1827 do {
1828 int64 dim;
1829 if (!ParseInt64(&dim)) {
1830 return false;
1831 }
1832 tile_assignment_dimensions.push_back(dim);
1833 } while (EatIfPresent(TokKind::kComma));
1834
1835 if (!ParseToken(TokKind::kRsquare,
1836 "expected ']' to start sharding devices shape")) {
1837 return false;
1838 }
1839 do {
1840 int64 device;
1841 if (!ParseInt64(&device)) {
1842 return false;
1843 }
1844 devices.push_back(device);
1845 } while (EatIfPresent(TokKind::kComma));
1846 } else {
1847 return TokenError(
1848 "unknown attribute in sharding: expected device= or devices=");
1849 }
1850 break;
1851 }
1852 case TokKind::kRbrace:
1853 break;
1854 default:
1855 return TokenError("unexpected token");
1856 }
1857 }
1858
1859 if (replicated) {
1860 if (!devices.empty()) {
1861 return Error(loc,
1862 "replicated shardings should not have any devices assigned");
1863 }
1864 sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
1865 } else if (maximal) {
1866 if (devices.size() != 1) {
1867 return Error(loc,
1868 "maximal shardings should have exactly one device assigned");
1869 }
1870 sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
1871 sharding->add_tile_assignment_devices(devices[0]);
1872 } else {
1873 if (devices.size() <= 1) {
1874 return Error(
1875 loc, "non-maximal shardings must have more than one device assigned");
1876 }
1877 if (tile_assignment_dimensions.empty()) {
1878 return Error(
1879 loc,
1880 "non-maximal shardings must have a tile assignment list including "
1881 "dimensions");
1882 }
1883 sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
1884 for (int64 dim : tile_assignment_dimensions) {
1885 sharding->add_tile_assignment_dimensions(dim);
1886 }
1887 for (int64 device : devices) {
1888 sharding->add_tile_assignment_devices(device);
1889 }
1890 }
1891
1892 lexer_.Lex();
1893 return true;
1894 }
1895
1896 // parameter_replication ::=
1897 // '{' ('true' | 'false')* (',' ('true' | 'false'))* '}'
ParseParameterReplication(ParameterReplication * parameter_replication)1898 bool HloParser::ParseParameterReplication(
1899 ParameterReplication* parameter_replication) {
1900 if (!ParseToken(TokKind::kLbrace,
1901 "expected '{' to start parameter_replication attribute")) {
1902 return false;
1903 }
1904
1905 if (lexer_.GetKind() != TokKind::kRbrace) {
1906 do {
1907 if (lexer_.GetKind() == TokKind::kw_true) {
1908 parameter_replication->add_replicated_at_leaf_buffers(true);
1909 } else if (lexer_.GetKind() == TokKind::kw_false) {
1910 parameter_replication->add_replicated_at_leaf_buffers(false);
1911 } else {
1912 return false;
1913 }
1914 lexer_.Lex();
1915 } while (EatIfPresent(TokKind::kComma));
1916 }
1917
1918 return ParseToken(TokKind::kRbrace,
1919 "expected '}' to end parameter_replication attribute");
1920 }
1921
1922 // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
1923 // 'exit=' exit_sharding '}'
ParseDomain(DomainData * domain)1924 bool HloParser::ParseDomain(DomainData* domain) {
1925 std::unordered_map<string, AttrConfig> attrs;
1926 optional<string> kind;
1927 optional<OpSharding> entry_sharding;
1928 optional<OpSharding> exit_sharding;
1929 attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
1930 attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
1931 attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
1932 if (!ParseSubAttributes(attrs)) {
1933 return false;
1934 }
1935 if (*kind == ShardingMetadata::KindName()) {
1936 auto entry_sharding_ptr = absl::make_unique<HloSharding>(
1937 HloSharding::FromProto(*entry_sharding).ValueOrDie());
1938 auto exit_sharding_ptr = absl::make_unique<HloSharding>(
1939 HloSharding::FromProto(*exit_sharding).ValueOrDie());
1940 domain->entry_metadata =
1941 absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
1942 domain->exit_metadata =
1943 absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
1944 } else {
1945 return TokenError(StrCat("unsupported domain kind: ", *kind));
1946 }
1947 return true;
1948 }
1949
1950 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)1951 bool HloParser::ParseInstructionNames(
1952 std::vector<HloInstruction*>* instructions) {
1953 if (!ParseToken(TokKind::kLbrace,
1954 "expects '{' at the beginning of instruction name list")) {
1955 return false;
1956 }
1957 LocTy loc = lexer_.GetLoc();
1958 do {
1959 string name;
1960 if (!ParseName(&name)) {
1961 return Error(loc, "expects a instruction name");
1962 }
1963 std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
1964 if (!instr) {
1965 return TokenError(StrFormat("instruction '%s' is not defined", name));
1966 }
1967 instructions->push_back(instr->first);
1968 } while (EatIfPresent(TokKind::kComma));
1969
1970 return ParseToken(TokKind::kRbrace,
1971 "expects '}' at the end of instruction name list");
1972 }
1973
SetValueInLiteral(LocTy loc,int64 value,LinearOrMultiIndex index,Literal * literal)1974 bool HloParser::SetValueInLiteral(LocTy loc, int64 value,
1975 LinearOrMultiIndex index, Literal* literal) {
1976 const Shape& shape = literal->shape();
1977 switch (shape.element_type()) {
1978 case S8:
1979 return SetValueInLiteralHelper<int8>(loc, value, index, literal);
1980 case S16:
1981 return SetValueInLiteralHelper<int16>(loc, value, index, literal);
1982 case S32:
1983 return SetValueInLiteralHelper<int32>(loc, value, index, literal);
1984 case S64:
1985 return SetValueInLiteralHelper<int64>(loc, value, index, literal);
1986 case U8:
1987 return SetValueInLiteralHelper<tensorflow::uint8>(loc, value, index,
1988 literal);
1989 case U16:
1990 return SetValueInLiteralHelper<tensorflow::uint16>(loc, value, index,
1991 literal);
1992 case U32:
1993 return SetValueInLiteralHelper<tensorflow::uint32>(loc, value, index,
1994 literal);
1995 case U64:
1996 return SetValueInLiteralHelper<tensorflow::uint64>(loc, value, index,
1997 literal);
1998 case PRED:
1999 // Bool type literals with rank >= 1 are printed in 0s and 1s.
2000 return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
2001 literal);
2002 default:
2003 LOG(FATAL) << "unknown integral primitive type "
2004 << PrimitiveType_Name(shape.element_type());
2005 }
2006 }
2007
SetValueInLiteral(LocTy loc,double value,LinearOrMultiIndex index,Literal * literal)2008 bool HloParser::SetValueInLiteral(LocTy loc, double value,
2009 LinearOrMultiIndex index, Literal* literal) {
2010 const Shape& shape = literal->shape();
2011 switch (shape.element_type()) {
2012 case F16:
2013 return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
2014 case BF16:
2015 return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
2016 literal);
2017 case F32:
2018 return SetValueInLiteralHelper<float>(loc, value, index, literal);
2019 case F64:
2020 return SetValueInLiteralHelper<double>(loc, value, index, literal);
2021 default:
2022 LOG(FATAL) << "unknown floating point primitive type "
2023 << PrimitiveType_Name(shape.element_type());
2024 }
2025 }
2026
SetValueInLiteral(LocTy loc,bool value,LinearOrMultiIndex index,Literal * literal)2027 bool HloParser::SetValueInLiteral(LocTy loc, bool value,
2028 LinearOrMultiIndex index, Literal* literal) {
2029 const Shape& shape = literal->shape();
2030 switch (shape.element_type()) {
2031 case PRED:
2032 return SetValueInLiteralHelper<bool>(loc, value, index, literal);
2033 default:
2034 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
2035 << " is not PRED type";
2036 }
2037 }
2038
SetValueInLiteral(LocTy loc,std::complex<double> value,LinearOrMultiIndex index,Literal * literal)2039 bool HloParser::SetValueInLiteral(LocTy loc, std::complex<double> value,
2040 LinearOrMultiIndex index, Literal* literal) {
2041 const Shape& shape = literal->shape();
2042 switch (shape.element_type()) {
2043 case C64:
2044 return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
2045 literal);
2046 case C128:
2047 return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
2048 literal);
2049 default:
2050 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
2051 << " is not a complex type type";
2052 }
2053 }
2054
2055 template <typename T>
StringifyValue(T val)2056 string StringifyValue(T val) {
2057 return StrCat(val);
2058 }
2059 template <>
StringifyValue(std::complex<double> val)2060 string StringifyValue(std::complex<double> val) {
2061 return StrFormat("(%f, %f)", std::real(val), std::imag(val));
2062 }
2063
2064 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(LocTy loc,ParsedElemT value,LinearOrMultiIndex index,Literal * literal)2065 bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
2066 LinearOrMultiIndex index,
2067 Literal* literal) {
2068 if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
2069 return false;
2070 }
2071
2072 // Check that the index is in range and assign into the literal
2073 if (auto* linear_index = absl::get_if<int64>(&index)) {
2074 if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
2075 return Error(loc, StrCat("trys to set value ", StringifyValue(value),
2076 " to a literal in shape ",
2077 ShapeUtil::HumanString(literal->shape()),
2078 " at linear index ", *linear_index,
2079 ", but the index is out of range"));
2080 }
2081 literal->data<LiteralNativeT>().at(*linear_index) =
2082 static_cast<LiteralNativeT>(value);
2083 } else {
2084 auto* multi_index = absl::get_if<absl::Span<const int64>>(&index);
2085 CHECK(multi_index != nullptr);
2086
2087 auto invalid_idx = [&](string msg) {
2088 return Error(loc, StrFormat("Invalid sparse index [%s]. %s",
2089 absl::StrJoin(*multi_index, ", "), msg));
2090 };
2091
2092 const auto& shape = literal->shape();
2093 if (shape.rank() != multi_index->size()) {
2094 return invalid_idx(
2095 StrFormat("Has rank %d, but constant has shape %s, which has rank %d",
2096 multi_index->size(), shape.ToString(), shape.rank()));
2097 }
2098 for (int64 i = 0; i < shape.rank(); ++i) {
2099 auto idx = (*multi_index)[i];
2100 if (idx < 0) {
2101 return invalid_idx(StrFormat(
2102 "Sub-index value at %d, namely %d, cannot be negative.", i, idx));
2103 }
2104 if (idx >= shape.dimensions(i)) {
2105 return invalid_idx(
2106 StrFormat("Sub-index at %d, namely %d, doesn't fit within shape "
2107 "dimension %d in %s",
2108 i, idx, shape.dimensions(i), shape.ToString()));
2109 }
2110 }
2111 literal->AppendSparseElement(*multi_index,
2112 static_cast<LiteralNativeT>(value));
2113 }
2114 return true;
2115 }
2116
2117 // literal
2118 // ::= tuple
2119 // ::= non_tuple
ParseLiteral(Literal * literal,const Shape & shape)2120 bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
2121 return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
2122 : ParseNonTupleLiteral(literal, shape);
2123 }
2124
2125 // tuple
2126 // ::= shape '(' literal_list ')'
2127 // literal_list
2128 // ::= /*empty*/
2129 // ::= literal (',' literal)*
ParseTupleLiteral(Literal * literal,const Shape & shape)2130 bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
2131 if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
2132 return false;
2133 }
2134 std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
2135
2136 if (lexer_.GetKind() == TokKind::kRparen) {
2137 // empty
2138 } else {
2139 // literal, (',' literal)*
2140 for (int i = 0; i < elements.size(); i++) {
2141 if (i > 0) {
2142 ParseToken(TokKind::kComma, "exepcts ',' to separate tuple elements");
2143 }
2144 if (!ParseLiteral(&elements[i],
2145 ShapeUtil::GetTupleElementShape(shape, i))) {
2146 return TokenError(StrCat("expects the ", i, "th element"));
2147 }
2148 }
2149 }
2150 *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
2151 return ParseToken(TokKind::kRparen,
2152 StrCat("expects ')' at the end of the tuple with ",
2153 ShapeUtil::TupleElementCount(shape), "elements"));
2154 }
2155
2156 // non_tuple
2157 // ::= rank01
2158 // ::= rank2345
2159 // rank2345 ::= shape sparse_or_nested_array
ParseNonTupleLiteral(Literal * literal,const Shape & shape)2160 bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
2161 if (LayoutUtil::IsSparseArray(shape)) {
2162 return ParseSparseLiteral(literal, shape);
2163 }
2164
2165 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
2166 return ParseDenseLiteral(literal, shape);
2167 }
2168
ParseDenseLiteral(Literal * literal,const Shape & shape)2169 bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
2170 // Cast `rank` to int because we call shape.dimensions(int rank) below, and if
2171 // `rank` is an int64, that's an implicit narrowing conversion, which is
2172 // implementation-defined behavior.
2173 const int rank = static_cast<int>(shape.rank());
2174
2175 // Create a literal with the given shape in default layout.
2176 *literal = LiteralUtil::CreateFromDimensions(
2177 shape.element_type(), AsInt64Slice(shape.dimensions()));
2178 int64 nest_level = 0;
2179 int64 linear_index = 0;
2180 // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
2181 // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
2182 // when we are parsing the 2nd '{' (right before '1'), we are seeing a
2183 // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
2184 // the first '}' (right after '3'), it means the sub-array ends, and the
2185 // sub-array is supposed to contain exactly 3 elements, so check if
2186 // elems_seen_per_dim[1] is 3.
2187 std::vector<int64> elems_seen_per_dim(rank);
2188 auto get_index_str = [&elems_seen_per_dim](int dim) -> string {
2189 std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
2190 elems_seen_per_dim.begin() + dim);
2191 return StrCat("[",
2192 StrJoin(elems_seen_until_dim, ",",
2193 [](string* out, const int64& num_elems) {
2194 StrAppend(out, num_elems - 1);
2195 }),
2196 "]");
2197 };
2198
2199 auto add_one_elem_seen = [&] {
2200 if (rank > 0) {
2201 if (nest_level != rank) {
2202 return TokenError(absl::StrFormat(
2203 "expects nested array in rank %d, but sees %d", rank, nest_level));
2204 }
2205 elems_seen_per_dim[rank - 1]++;
2206 if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
2207 return TokenError(absl::StrFormat(
2208 "expects %d elements on the minor-most dimension, but "
2209 "sees more",
2210 shape.dimensions(rank - 1)));
2211 }
2212 }
2213 return true;
2214 };
2215
2216 do {
2217 switch (lexer_.GetKind()) {
2218 default:
2219 return TokenError("unexpected token type in a literal");
2220 case TokKind::kLbrace: {
2221 nest_level++;
2222 if (nest_level > rank) {
2223 return TokenError(absl::StrFormat(
2224 "expects nested array in rank %d, but sees larger", rank));
2225 }
2226 if (nest_level > 1) {
2227 elems_seen_per_dim[nest_level - 2]++;
2228 if (elems_seen_per_dim[nest_level - 2] >
2229 shape.dimensions(nest_level - 2)) {
2230 return TokenError(absl::StrFormat(
2231 "expects %d elements in the %sth element, but sees more",
2232 shape.dimensions(nest_level - 2),
2233 get_index_str(nest_level - 2)));
2234 }
2235 }
2236 lexer_.Lex();
2237 break;
2238 }
2239 case TokKind::kRbrace: {
2240 nest_level--;
2241 if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
2242 return TokenError(absl::StrFormat(
2243 "expects %d elements in the %sth element, but sees %d",
2244 shape.dimensions(nest_level), get_index_str(nest_level),
2245 elems_seen_per_dim[nest_level]));
2246 }
2247 elems_seen_per_dim[nest_level] = 0;
2248 lexer_.Lex();
2249 break;
2250 }
2251 case TokKind::kLparen: {
2252 if (!primitive_util::IsComplexType(shape.element_type())) {
2253 return TokenError(
2254 absl::StrFormat("unexpected '(' in literal. Parens are only "
2255 "valid for complex literals"));
2256 }
2257
2258 std::complex<double> value;
2259 LocTy loc = lexer_.GetLoc();
2260 if (!add_one_elem_seen() || !ParseComplex(&value) ||
2261 !SetValueInLiteral(loc, value, linear_index++, literal)) {
2262 return false;
2263 }
2264 break;
2265 }
2266 case TokKind::kDots: {
2267 if (nest_level != 1) {
2268 return TokenError(absl::StrFormat(
2269 "expects `...` at nest level 1, but sees it at nest level %d",
2270 nest_level));
2271 }
2272 elems_seen_per_dim[0] = shape.dimensions(0);
2273 lexer_.Lex();
2274 break;
2275 }
2276 case TokKind::kComma:
2277 // Skip.
2278 lexer_.Lex();
2279 break;
2280 case TokKind::kw_true:
2281 case TokKind::kw_false:
2282 case TokKind::kInt:
2283 case TokKind::kDecimal:
2284 case TokKind::kw_nan:
2285 case TokKind::kw_inf:
2286 case TokKind::kNegInf: {
2287 add_one_elem_seen();
2288 if (lexer_.GetKind() == TokKind::kw_true ||
2289 lexer_.GetKind() == TokKind::kw_false) {
2290 if (!SetValueInLiteral(lexer_.GetLoc(),
2291 lexer_.GetKind() == TokKind::kw_true,
2292 linear_index++, literal)) {
2293 return false;
2294 }
2295 lexer_.Lex();
2296 } else if (primitive_util::IsIntegralType(shape.element_type()) ||
2297 shape.element_type() == PRED) {
2298 LocTy loc = lexer_.GetLoc();
2299 int64 value;
2300 if (!ParseInt64(&value)) {
2301 return Error(loc, StrCat("expects integer for primitive type: ",
2302 PrimitiveType_Name(shape.element_type())));
2303 }
2304 if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
2305 return false;
2306 }
2307 } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
2308 LocTy loc = lexer_.GetLoc();
2309 double value;
2310 if (!ParseDouble(&value)) {
2311 return Error(
2312 loc, StrCat("expect floating point value for primitive type: ",
2313 PrimitiveType_Name(shape.element_type())));
2314 }
2315 if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
2316 return false;
2317 }
2318 } else {
2319 return TokenError(StrCat("unsupported primitive type ",
2320 PrimitiveType_Name(shape.element_type())));
2321 }
2322 break;
2323 }
2324 } // end of switch
2325 } while (nest_level > 0);
2326
2327 *literal = literal->Relayout(shape.layout());
2328 return true;
2329 }
2330
ParseSparseLiteral(Literal * literal,const Shape & shape)2331 bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
2332 *literal = Literal(shape);
2333 if (!ParseToken(TokKind::kLbrace,
2334 "expects '{' at the beginning of a sparse literal")) {
2335 return false;
2336 }
2337
2338 for (;;) {
2339 if (lexer_.GetKind() == TokKind::kRbrace) {
2340 lexer_.Lex();
2341 break;
2342 }
2343
2344 std::vector<int64> index;
2345 if (lexer_.GetKind() == TokKind::kInt) {
2346 int64 single_index = lexer_.GetInt64Val();
2347 lexer_.Lex();
2348 index.push_back(single_index);
2349 } else {
2350 if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
2351 &index)) {
2352 return false;
2353 }
2354 }
2355 if (!ParseToken(TokKind::kColon,
2356 "expects ':' after after the sparse array index and before "
2357 "the sparse array value")) {
2358 return false;
2359 }
2360
2361 LocTy value_loc = lexer_.GetLoc();
2362 if (lexer_.GetKind() == TokKind::kw_true ||
2363 lexer_.GetKind() == TokKind::kw_false) {
2364 bool value = lexer_.GetKind() == TokKind::kw_true;
2365 if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) {
2366 return false;
2367 }
2368 lexer_.Lex();
2369 } else if (primitive_util::IsIntegralType(shape.element_type())) {
2370 int64 value;
2371 if (!ParseInt64(&value)) {
2372 return Error(value_loc,
2373 StrCat("expects integer for primitive type: ",
2374 PrimitiveType_Name(shape.element_type())));
2375 }
2376 if (!SetValueInLiteral(value_loc, value, index, literal)) {
2377 return false;
2378 }
2379 } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
2380 double value;
2381 if (!ParseDouble(&value)) {
2382 return Error(value_loc,
2383 StrCat("expects floating point value for primitive type: ",
2384 PrimitiveType_Name(shape.element_type())));
2385 }
2386 if (!SetValueInLiteral(value_loc, value, index, literal)) {
2387 return false;
2388 }
2389 } else if (primitive_util::IsComplexType(shape.element_type())) {
2390 std::complex<double> value;
2391 if (!ParseComplex(&value)) {
2392 return Error(value_loc,
2393 StrCat("expects complex value for primitive type: ",
2394 PrimitiveType_Name(shape.element_type())));
2395 }
2396 if (!SetValueInLiteral(value_loc, value, index, literal)) {
2397 return false;
2398 }
2399 } else {
2400 LOG(FATAL) << "Unexpected element type: "
2401 << PrimitiveType_Name(shape.element_type());
2402 }
2403
2404 if (lexer_.GetKind() != TokKind::kRbrace &&
2405 !ParseToken(TokKind::kComma,
2406 "expects ',' separator between sparse array elements")) {
2407 return false;
2408 }
2409
2410 if (literal->sparse_element_count() + 1 ==
2411 LayoutUtil::MaxSparseElements(shape.layout())) {
2412 return Error(
2413 lexer_.GetLoc(),
2414 StrCat("number of sparse elements exceeds maximum for layout: ",
2415 ShapeUtil::HumanStringWithLayout(shape)));
2416 }
2417 }
2418
2419 literal->SortSparseElements();
2420 return true;
2421 }
2422
2423 // MaxFiniteValue is a type-traits helper used by
2424 // HloParser::CheckParsedValueIsInRange.
2425 template <typename T>
2426 struct MinMaxFiniteValue {
maxxla::__anonc071bf1f0111::MinMaxFiniteValue2427 static T max() { return std::numeric_limits<T>::max(); }
minxla::__anonc071bf1f0111::MinMaxFiniteValue2428 static T min() { return std::numeric_limits<T>::lowest(); }
2429 };
2430
2431 template <>
2432 struct MinMaxFiniteValue<Eigen::half> {
maxxla::__anonc071bf1f0111::MinMaxFiniteValue2433 static double max() {
2434 // Sadly this is not constexpr, so this forces `value` to be a method.
2435 return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
2436 }
minxla::__anonc071bf1f0111::MinMaxFiniteValue2437 static double min() { return -max(); }
2438 };
2439
2440 template <>
2441 struct MinMaxFiniteValue<bfloat16> {
maxxla::__anonc071bf1f0111::MinMaxFiniteValue2442 static double max() { return static_cast<double>(bfloat16::highest()); }
minxla::__anonc071bf1f0111::MinMaxFiniteValue2443 static double min() { return -max(); }
2444 };
2445
2446 template <typename LiteralNativeT, typename ParsedElemT>
CheckParsedValueIsInRange(LocTy loc,ParsedElemT value)2447 bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
2448 PrimitiveType literal_ty =
2449 primitive_util::NativeToPrimitiveType<LiteralNativeT>();
2450 if (std::isnan(value) ||
2451 (std::numeric_limits<ParsedElemT>::has_infinity &&
2452 (std::numeric_limits<ParsedElemT>::infinity() == value ||
2453 -std::numeric_limits<ParsedElemT>::infinity() == value))) {
2454 // Skip range checking for non-finite value.
2455 } else if (std::is_unsigned<LiteralNativeT>::value) {
2456 CHECK((std::is_same<ParsedElemT, int64>::value ||
2457 std::is_same<ParsedElemT, bool>::value))
2458 << "Unimplemented checking for ParsedElemT";
2459
2460 ParsedElemT upper_bound;
2461 if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) {
2462 upper_bound = std::numeric_limits<ParsedElemT>::max();
2463 } else {
2464 upper_bound =
2465 static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max());
2466 }
2467 if (value > upper_bound || value < 0) {
2468 // Value is out of range for LiteralNativeT.
2469 return Error(loc, StrCat("value ", value,
2470 " is out of range for literal's primitive type ",
2471 PrimitiveType_Name(literal_ty), " namely [0, ",
2472 upper_bound, "]."));
2473 }
2474 } else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
2475 value < MinMaxFiniteValue<LiteralNativeT>::min()) {
2476 // Value is out of range for LiteralNativeT.
2477 return Error(loc, StrCat("value ", value,
2478 " is out of range for literal's primitive type ",
2479 PrimitiveType_Name(literal_ty), " namely [",
2480 MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
2481 MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
2482 }
2483 return true;
2484 }
2485
2486 template <typename LiteralNativeT>
CheckParsedValueIsInRange(LocTy loc,std::complex<double> value)2487 bool HloParser::CheckParsedValueIsInRange(LocTy loc,
2488 std::complex<double> value) {
2489 // e.g. `float` for std::complex<float>
2490 using LiteralComplexComponentT =
2491 decltype(std::real(std::declval<LiteralNativeT>()));
2492
2493 // We could do simply
2494 //
2495 // return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
2496 // CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
2497 //
2498 // but this would give bad error messages on failure.
2499
2500 auto check_component = [&](absl::string_view name, double v) {
2501 if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
2502 v == -std::numeric_limits<double>::infinity()) {
2503 // Skip range-checking for non-finite values.
2504 return true;
2505 }
2506
2507 double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
2508 double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
2509 if (v < min || v > max) {
2510 // Value is out of range for LitearlComplexComponentT.
2511 return Error(
2512 loc,
2513 StrCat(name, " part ", v,
2514 " is out of range for literal's primitive type ",
2515 PrimitiveType_Name(
2516 primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
2517 ", namely [", min, ", ", max, "]."));
2518 }
2519 return true;
2520 };
2521 return check_component("real", std::real(value)) &&
2522 check_component("imaginary", std::imag(value));
2523 }
2524
2525 // operands ::= '(' operands1 ')'
2526 // operands1
2527 // ::= /*empty*/
2528 // ::= operand (, operand)*
2529 // operand ::= (shape)? name
ParseOperands(std::vector<HloInstruction * > * operands)2530 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
2531 CHECK(operands != nullptr);
2532 if (!ParseToken(TokKind::kLparen,
2533 "expects '(' at the beginning of operands")) {
2534 return false;
2535 }
2536 if (lexer_.GetKind() == TokKind::kRparen) {
2537 // empty
2538 } else {
2539 do {
2540 LocTy loc = lexer_.GetLoc();
2541 string name;
2542 optional<Shape> shape;
2543 if (CanBeShape()) {
2544 shape.emplace();
2545 if (!ParseShape(&shape.value())) {
2546 return false;
2547 }
2548 }
2549 if (!ParseName(&name)) {
2550 // When parsing a single instruction (as opposed to a whole module), an
2551 // HLO may have one or more operands with a shape but no name:
2552 //
2553 // foo = add(f32[10], f32[10])
2554 //
2555 // create_missing_instruction_ is always non-null when parsing a single
2556 // instruction, and is responsible for creating kParameter instructions
2557 // for these operands.
2558 if (shape.has_value() && create_missing_instruction_ != nullptr &&
2559 scoped_name_tables_.size() == 1) {
2560 name = "";
2561 } else {
2562 return false;
2563 }
2564 }
2565 std::pair<HloInstruction*, LocTy>* instruction =
2566 FindInstruction(name, shape);
2567 if (instruction == nullptr) {
2568 return Error(loc, StrCat("instruction does not exist: ", name));
2569 }
2570 operands->push_back(instruction->first);
2571 } while (EatIfPresent(TokKind::kComma));
2572 }
2573 return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
2574 }
2575
ParseOperands(std::vector<HloInstruction * > * operands,const int expected_size)2576 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
2577 const int expected_size) {
2578 CHECK(operands != nullptr);
2579 LocTy loc = lexer_.GetLoc();
2580 if (!ParseOperands(operands)) {
2581 return false;
2582 }
2583 if (expected_size != operands->size()) {
2584 return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
2585 operands->size(), " operands"));
2586 }
2587 return true;
2588 }
2589
2590 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const std::unordered_map<string,AttrConfig> & attrs)2591 bool HloParser::ParseSubAttributes(
2592 const std::unordered_map<string, AttrConfig>& attrs) {
2593 LocTy loc = lexer_.GetLoc();
2594 if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
2595 return false;
2596 }
2597 std::unordered_set<string> seen_attrs;
2598 if (lexer_.GetKind() == TokKind::kRbrace) {
2599 // empty
2600 } else {
2601 do {
2602 EatIfPresent(TokKind::kComma);
2603 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
2604 return false;
2605 }
2606 } while (lexer_.GetKind() != TokKind::kRbrace);
2607 }
2608 // Check that all required attrs were seen.
2609 for (const auto& attr_it : attrs) {
2610 if (attr_it.second.required &&
2611 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
2612 return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
2613 attr_it.first));
2614 }
2615 }
2616 return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
2617 }
2618
2619 // attributes ::= (',' attribute)*
ParseAttributes(const std::unordered_map<string,AttrConfig> & attrs)2620 bool HloParser::ParseAttributes(
2621 const std::unordered_map<string, AttrConfig>& attrs) {
2622 LocTy loc = lexer_.GetLoc();
2623 std::unordered_set<string> seen_attrs;
2624 while (EatIfPresent(TokKind::kComma)) {
2625 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
2626 return false;
2627 }
2628 }
2629 // Check that all required attrs were seen.
2630 for (const auto& attr_it : attrs) {
2631 if (attr_it.second.required &&
2632 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
2633 return Error(loc, StrFormat("attribute %s is expected but not seen",
2634 attr_it.first));
2635 }
2636 }
2637 return true;
2638 }
2639
ParseAttributeHelper(const std::unordered_map<string,AttrConfig> & attrs,std::unordered_set<string> * seen_attrs)2640 bool HloParser::ParseAttributeHelper(
2641 const std::unordered_map<string, AttrConfig>& attrs,
2642 std::unordered_set<string>* seen_attrs) {
2643 LocTy loc = lexer_.GetLoc();
2644 string name;
2645 if (!ParseAttributeName(&name)) {
2646 return Error(loc, "error parsing attributes");
2647 }
2648 VLOG(1) << "Parsing attribute " << name;
2649 if (!seen_attrs->insert(name).second) {
2650 return Error(loc, StrFormat("attribute %s already exists", name));
2651 }
2652 auto attr_it = attrs.find(name);
2653 if (attr_it == attrs.end()) {
2654 string allowed_attrs;
2655 if (attrs.empty()) {
2656 allowed_attrs = "No attributes are allowed here.";
2657 } else {
2658 allowed_attrs = StrCat(
2659 "Allowed attributes: ",
2660 StrJoin(attrs, ", ",
2661 [&](string* out, const std::pair<string, AttrConfig>& kv) {
2662 StrAppend(out, kv.first);
2663 }));
2664 }
2665 return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
2666 allowed_attrs));
2667 }
2668 AttrTy attr_type = attr_it->second.attr_type;
2669 void* attr_out_ptr = attr_it->second.result;
2670 bool success = [&] {
2671 LocTy attr_loc = lexer_.GetLoc();
2672 switch (attr_type) {
2673 case AttrTy::kBool: {
2674 bool result;
2675 if (!ParseBool(&result)) {
2676 return false;
2677 }
2678 static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
2679 return true;
2680 }
2681 case AttrTy::kInt64: {
2682 int64 result;
2683 if (!ParseInt64(&result)) {
2684 return false;
2685 }
2686 static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
2687 return true;
2688 }
2689 case AttrTy::kInt32: {
2690 int64 result;
2691 if (!ParseInt64(&result)) {
2692 return false;
2693 }
2694 if (result != static_cast<int32>(result)) {
2695 return Error(attr_loc, "value out of range for int32");
2696 }
2697 static_cast<optional<int32>*>(attr_out_ptr)
2698 ->emplace(static_cast<int32>(result));
2699 return true;
2700 }
2701 case AttrTy::kFloat: {
2702 double result;
2703 if (!ParseDouble(&result)) {
2704 return false;
2705 }
2706 if (result > std::numeric_limits<float>::max() ||
2707 result < std::numeric_limits<float>::lowest()) {
2708 return Error(attr_loc, "value out of range for float");
2709 }
2710 static_cast<optional<float>*>(attr_out_ptr)
2711 ->emplace(static_cast<float>(result));
2712 return true;
2713 }
2714 case AttrTy::kHloComputation: {
2715 HloComputation* result = nullptr;
2716 if (!ParseHloComputation(&result)) {
2717 return false;
2718 }
2719 static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
2720 return true;
2721 }
2722 case AttrTy::kBracedHloComputationList: {
2723 std::vector<HloComputation*> result;
2724 if (!ParseHloComputationList(&result)) {
2725 return false;
2726 }
2727 static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
2728 ->emplace(result);
2729 return true;
2730 }
2731 case AttrTy::kFftType: {
2732 FftType result;
2733 if (!ParseFftType(&result)) {
2734 return false;
2735 }
2736 static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
2737 return true;
2738 }
2739 case AttrTy::kComparisonDirection: {
2740 ComparisonDirection result;
2741 if (!ParseComparisonDirection(&result)) {
2742 return false;
2743 }
2744 static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
2745 ->emplace(result);
2746 return true;
2747 }
2748 case AttrTy::kWindow: {
2749 Window result;
2750 if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
2751 return false;
2752 }
2753 static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
2754 return true;
2755 }
2756 case AttrTy::kConvolutionDimensionNumbers: {
2757 ConvolutionDimensionNumbers result;
2758 if (!ParseConvolutionDimensionNumbers(&result)) {
2759 return false;
2760 }
2761 static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
2762 ->emplace(result);
2763 return true;
2764 }
2765 case AttrTy::kSharding: {
2766 OpSharding sharding;
2767 if (!ParseSharding(&sharding)) {
2768 return false;
2769 }
2770 static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
2771 return true;
2772 }
2773 case AttrTy::kParameterReplication: {
2774 ParameterReplication parameter_replication;
2775 if (!ParseParameterReplication(¶meter_replication)) {
2776 return false;
2777 }
2778 static_cast<optional<ParameterReplication>*>(attr_out_ptr)
2779 ->emplace(parameter_replication);
2780 return true;
2781 }
2782 case AttrTy::kInstructionList: {
2783 std::vector<HloInstruction*> result;
2784 if (!ParseInstructionNames(&result)) {
2785 return false;
2786 }
2787 static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
2788 ->emplace(result);
2789 return true;
2790 }
2791 case AttrTy::kFusionKind: {
2792 HloInstruction::FusionKind result;
2793 if (!ParseFusionKind(&result)) {
2794 return false;
2795 }
2796 static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
2797 ->emplace(result);
2798 return true;
2799 }
2800 case AttrTy::kBracedInt64List: {
2801 std::vector<int64> result;
2802 if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
2803 &result)) {
2804 return false;
2805 }
2806 static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
2807 ->emplace(result);
2808 return true;
2809 }
2810 case AttrTy::kBracedInt64ListList: {
2811 std::vector<std::vector<int64>> result;
2812 auto parse_and_add_item = [&]() {
2813 std::vector<int64> item;
2814 if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace,
2815 TokKind::kComma, &item)) {
2816 return false;
2817 }
2818 result.push_back(item);
2819 return true;
2820 };
2821 if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
2822 parse_and_add_item)) {
2823 return false;
2824 }
2825 static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
2826 ->emplace(result);
2827 return true;
2828 }
2829 case AttrTy::kSliceRanges: {
2830 SliceRanges result;
2831 if (!ParseSliceRanges(&result)) {
2832 return false;
2833 }
2834 static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
2835 return true;
2836 }
2837 case AttrTy::kPaddingConfig: {
2838 PaddingConfig result;
2839 if (!ParsePaddingConfig(&result)) {
2840 return false;
2841 }
2842 static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
2843 return true;
2844 }
2845 case AttrTy::kString: {
2846 string result;
2847 if (!ParseString(&result)) {
2848 return false;
2849 }
2850 static_cast<optional<string>*>(attr_out_ptr)->emplace(result);
2851 return true;
2852 }
2853 case AttrTy::kMetadata: {
2854 OpMetadata result;
2855 if (!ParseMetadata(&result)) {
2856 return false;
2857 }
2858 static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
2859 return true;
2860 }
2861 case AttrTy::kDistribution: {
2862 RandomDistribution result;
2863 if (!ParseRandomDistribution(&result)) {
2864 return false;
2865 }
2866 static_cast<optional<RandomDistribution>*>(attr_out_ptr)
2867 ->emplace(result);
2868 return true;
2869 }
2870 case AttrTy::kDomain: {
2871 return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
2872 }
2873 case AttrTy::kPrecisionList: {
2874 std::vector<PrecisionConfig::Precision> result;
2875 if (!ParsePrecisionList(&result)) {
2876 return false;
2877 }
2878 static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
2879 attr_out_ptr)
2880 ->emplace(result);
2881 return true;
2882 }
2883 case AttrTy::kShapeList: {
2884 std::vector<Shape> result;
2885 if (!ParseShapeList(&result)) {
2886 return false;
2887 }
2888 static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
2889 ->emplace(result);
2890 return true;
2891 }
2892 }
2893 }();
2894 if (!success) {
2895 return Error(loc, StrFormat("error parsing attribute %s", name));
2896 }
2897 return true;
2898 }
2899
2900 // attributes ::= (',' attribute)*
ParseAttributesAsProtoMessage(const std::unordered_set<string> & required_attrs,tensorflow::protobuf::Message * message)2901 bool HloParser::ParseAttributesAsProtoMessage(
2902 const std::unordered_set<string>& required_attrs,
2903 tensorflow::protobuf::Message* message) {
2904 LocTy loc = lexer_.GetLoc();
2905 std::unordered_set<string> seen_attrs;
2906 while (EatIfPresent(TokKind::kComma)) {
2907 if (!ParseAttributeAsProtoMessageHelper(message, &seen_attrs)) {
2908 return false;
2909 }
2910 }
2911 // Check that all required attrs were seen.
2912 for (const string& attr : required_attrs) {
2913 if (seen_attrs.find(attr) == seen_attrs.end()) {
2914 return Error(loc,
2915 StrFormat("attribute %s is expected but not seen", attr));
2916 }
2917 }
2918 return true;
2919 }
2920
ParseAttributeAsProtoMessageHelper(tensorflow::protobuf::Message * message,std::unordered_set<string> * seen_attrs)2921 bool HloParser::ParseAttributeAsProtoMessageHelper(
2922 tensorflow::protobuf::Message* message,
2923 std::unordered_set<string>* seen_attrs) {
2924 LocTy loc = lexer_.GetLoc();
2925 string name;
2926 if (!ParseAttributeName(&name)) {
2927 return Error(loc, "error parsing attributes");
2928 }
2929 VLOG(1) << "Parsing attribute " << name;
2930 if (!seen_attrs->insert(name).second) {
2931 return Error(loc, StrFormat("attribute %s already exists", name));
2932 }
2933 const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
2934 const tensorflow::protobuf::FieldDescriptor* fd =
2935 descriptor->FindFieldByName(name);
2936 if (!fd) {
2937 string allowed_attrs = "Allowed attributes: ";
2938
2939 for (int i = 0; i < descriptor->field_count(); ++i) {
2940 if (i == 0) {
2941 absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
2942 } else {
2943 absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
2944 }
2945 }
2946 return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
2947 allowed_attrs));
2948 }
2949 const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
2950 CHECK(!fd->is_repeated()); // Repeated fields not implemented.
2951 bool success = [&] {
2952 switch (fd->type()) {
2953 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
2954 bool result;
2955 if (!ParseBool(&result)) {
2956 return false;
2957 }
2958 reflection->SetBool(message, fd, result);
2959 return true;
2960 }
2961 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
2962 if (lexer_.GetKind() != TokKind::kIdent) {
2963 return TokenError(
2964 StrFormat("expects %s type", fd->enum_type()->name()));
2965 }
2966 string val = lexer_.GetStrVal();
2967 const tensorflow::protobuf::EnumValueDescriptor* evd =
2968 fd->enum_type()->FindValueByName(val);
2969 if (evd == nullptr) {
2970 return TokenError(StrFormat("expects %s type but sees: %s",
2971 fd->enum_type()->name(), val));
2972 }
2973 reflection->SetEnum(message, fd, evd);
2974 lexer_.Lex();
2975 return true;
2976 }
2977 default:
2978 LOG(ERROR) << "Unimplemented protocol buffer type "
2979 << fd->DebugString();
2980 return false;
2981 }
2982 }();
2983 if (!success) {
2984 return Error(loc, StrFormat("error parsing attribute %s", name));
2985 }
2986 return true;
2987 }
2988
ParseComputationName(HloComputation ** value)2989 bool HloParser::ParseComputationName(HloComputation** value) {
2990 string name;
2991 LocTy loc = lexer_.GetLoc();
2992 if (!ParseName(&name)) {
2993 return Error(loc, "expects computation name");
2994 }
2995 std::pair<HloComputation*, LocTy>* computation =
2996 tensorflow::gtl::FindOrNull(computation_pool_, name);
2997 if (computation == nullptr) {
2998 return Error(loc, StrCat("computation does not exist: ", name));
2999 }
3000 *value = computation->first;
3001 return true;
3002 }
3003
3004 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
3005 // The subattributes can appear in any order. 'size=' is required, others are
3006 // optional.
ParseWindow(Window * window,bool expect_outer_curlies)3007 bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) {
3008 LocTy loc = lexer_.GetLoc();
3009 if (expect_outer_curlies &&
3010 !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
3011 return false;
3012 }
3013
3014 std::vector<int64> size;
3015 std::vector<int64> stride;
3016 std::vector<std::vector<int64>> pad;
3017 std::vector<int64> lhs_dilate;
3018 std::vector<int64> rhs_dilate;
3019 std::vector<int64> rhs_reversal;
3020 const auto end_token =
3021 expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
3022 while (lexer_.GetKind() != end_token) {
3023 LocTy attr_loc = lexer_.GetLoc();
3024 string field_name;
3025 if (!ParseAttributeName(&field_name)) {
3026 return Error(attr_loc, "expects sub-attributes in window");
3027 }
3028 bool ok = [&] {
3029 if (field_name == "size") {
3030 return ParseDxD("size", &size);
3031 }
3032 if (field_name == "stride") {
3033 return ParseDxD("stride", &stride);
3034 }
3035 if (field_name == "lhs_dilate") {
3036 return ParseDxD("lhs_dilate", &lhs_dilate);
3037 }
3038 if (field_name == "rhs_dilate") {
3039 return ParseDxD("rls_dilate", &rhs_dilate);
3040 }
3041 if (field_name == "pad") {
3042 return ParseWindowPad(&pad);
3043 }
3044 if (field_name == "rhs_reversal") {
3045 return ParseDxD("rhs_reversal", &rhs_reversal);
3046 }
3047 return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
3048 }();
3049 if (!ok) {
3050 return false;
3051 }
3052 }
3053
3054 if (size.empty()) {
3055 return Error(loc,
3056 "sub-attribute 'size=' is required in the window attribute");
3057 }
3058 if (!stride.empty() && stride.size() != size.size()) {
3059 return Error(loc, "expects 'stride=' has the same size as 'size='");
3060 }
3061 if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
3062 return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
3063 }
3064 if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
3065 return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
3066 }
3067 if (!pad.empty() && pad.size() != size.size()) {
3068 return Error(loc, "expects 'pad=' has the same size as 'size='");
3069 }
3070
3071 for (int i = 0; i < size.size(); i++) {
3072 window->add_dimensions()->set_size(size[i]);
3073 if (!pad.empty()) {
3074 window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
3075 window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
3076 }
3077 // If some field is not present, it has the default value.
3078 window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
3079 window->mutable_dimensions(i)->set_base_dilation(
3080 lhs_dilate.empty() ? 1 : lhs_dilate[i]);
3081 window->mutable_dimensions(i)->set_window_dilation(
3082 rhs_dilate.empty() ? 1 : rhs_dilate[i]);
3083 window->mutable_dimensions(i)->set_window_reversal(
3084 rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
3085 }
3086 return !expect_outer_curlies ||
3087 ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
3088 }
3089
3090 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
3091 // The string looks like "dim_labels=0bf_0io->0bf".
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)3092 bool HloParser::ParseConvolutionDimensionNumbers(
3093 ConvolutionDimensionNumbers* dnums) {
3094 if (lexer_.GetKind() != TokKind::kDimLabels) {
3095 return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
3096 }
3097 string str = lexer_.GetStrVal();
3098
3099 // The str is expected to have 3 items, lhs, rhs, out, and it must look like
3100 // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
3101 std::vector<string> split1 = absl::StrSplit(str, '_');
3102 if (split1.size() != 2) {
3103 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
3104 << str;
3105 }
3106 std::vector<string> split2 = absl::StrSplit(split1[1], "->");
3107 if (split2.size() != 2) {
3108 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
3109 << str;
3110 }
3111 absl::string_view lhs = split1[0];
3112 absl::string_view rhs = split2[0];
3113 absl::string_view out = split2[1];
3114
3115 const int64 rank = lhs.length();
3116 if (rank != rhs.length() || rank != out.length()) {
3117 return TokenError(
3118 "convolution lhs, rhs, and output must have the same rank");
3119 }
3120 if (rank < 2) {
3121 return TokenError("convolution rank must >=2");
3122 }
3123
3124 auto is_unique = [](string str) -> bool {
3125 absl::c_sort(str);
3126 return std::unique(str.begin(), str.end()) == str.end();
3127 };
3128
3129 // lhs
3130 {
3131 if (!is_unique(string(lhs))) {
3132 return TokenError(
3133 StrCat("expects unique lhs dimension numbers, but sees ", lhs));
3134 }
3135 for (int i = 0; i < rank - 2; i++) {
3136 dnums->add_input_spatial_dimensions(-1);
3137 }
3138 for (int i = 0; i < rank; i++) {
3139 char c = lhs[i];
3140 if (c == 'b') {
3141 dnums->set_input_batch_dimension(i);
3142 } else if (c == 'f') {
3143 dnums->set_input_feature_dimension(i);
3144 } else if (c < '0' + rank && c >= '0') {
3145 dnums->set_input_spatial_dimensions(c - '0', i);
3146 } else {
3147 return TokenError(
3148 StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1));
3149 }
3150 }
3151 }
3152 // rhs
3153 {
3154 if (!is_unique(string(rhs))) {
3155 return TokenError(
3156 StrCat("expects unique rhs dimension numbers, but sees ", rhs));
3157 }
3158 for (int i = 0; i < rank - 2; i++) {
3159 dnums->add_kernel_spatial_dimensions(-1);
3160 }
3161 for (int i = 0; i < rank; i++) {
3162 char c = rhs[i];
3163 if (c == 'i') {
3164 dnums->set_kernel_input_feature_dimension(i);
3165 } else if (c == 'o') {
3166 dnums->set_kernel_output_feature_dimension(i);
3167 } else if (c < '0' + rank && c >= '0') {
3168 dnums->set_kernel_spatial_dimensions(c - '0', i);
3169 } else {
3170 return TokenError(
3171 StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1));
3172 }
3173 }
3174 }
3175 // output
3176 {
3177 if (!is_unique(string(out))) {
3178 return TokenError(
3179 StrCat("expects unique output dimension numbers, but sees ", out));
3180 }
3181 for (int i = 0; i < rank - 2; i++) {
3182 dnums->add_output_spatial_dimensions(-1);
3183 }
3184 for (int i = 0; i < rank; i++) {
3185 char c = out[i];
3186 if (c == 'b') {
3187 dnums->set_output_batch_dimension(i);
3188 } else if (c == 'f') {
3189 dnums->set_output_feature_dimension(i);
3190 } else if (c < '0' + rank && c >= '0') {
3191 dnums->set_output_spatial_dimensions(c - '0', i);
3192 } else {
3193 return TokenError(StrFormat(
3194 "expects [0-%dbf] in output dimension numbers", rank - 1));
3195 }
3196 }
3197 }
3198
3199 lexer_.Lex();
3200 return true;
3201 }
3202
3203 // ::= '{' ranges '}'
3204 // ::= /*empty*/
3205 // ::= range (',' range)*
3206 // range ::= '[' start ':' limit (':' stride)? ']'
3207 //
3208 // The slice ranges are printed as:
3209 //
3210 // {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
3211 //
3212 // This function extracts the starts, limits, and strides as 3 vectors to the
3213 // result. If stride is not present, stride is 1. For example, if the slice
3214 // ranges is printed as:
3215 //
3216 // {[2:3:4], [5:6:7], [8:9]}
3217 //
3218 // The parsed result will be:
3219 //
3220 // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
3221 //
ParseSliceRanges(SliceRanges * result)3222 bool HloParser::ParseSliceRanges(SliceRanges* result) {
3223 if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
3224 return false;
3225 }
3226 std::vector<std::vector<int64>> ranges;
3227 if (lexer_.GetKind() == TokKind::kRbrace) {
3228 // empty
3229 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
3230 }
3231 do {
3232 LocTy loc = lexer_.GetLoc();
3233 ranges.emplace_back();
3234 if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
3235 &ranges.back())) {
3236 return false;
3237 }
3238 const auto& range = ranges.back();
3239 if (range.size() != 2 && range.size() != 3) {
3240 return Error(loc,
3241 StrFormat("expects [start:limit:step] or [start:limit], "
3242 "but sees %d elements.",
3243 range.size()));
3244 }
3245 } while (EatIfPresent(TokKind::kComma));
3246
3247 for (const auto& range : ranges) {
3248 result->starts.push_back(range[0]);
3249 result->limits.push_back(range[1]);
3250 result->strides.push_back(range.size() == 3 ? range[2] : 1);
3251 }
3252 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
3253 }
3254
3255 // precisionlist ::= start precision_elements end
3256 // precision_elements
3257 // ::= /*empty*/
3258 // ::= precision_val (delim precision_val)*
ParsePrecisionList(std::vector<PrecisionConfig::Precision> * result)3259 bool HloParser::ParsePrecisionList(
3260 std::vector<PrecisionConfig::Precision>* result) {
3261 auto parse_and_add_item = [&]() {
3262 PrecisionConfig::Precision item;
3263 if (!ParsePrecision(&item)) {
3264 return false;
3265 }
3266 result->push_back(item);
3267 return true;
3268 };
3269 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3270 parse_and_add_item);
3271 }
3272
ParseHloComputation(HloComputation ** result)3273 bool HloParser::ParseHloComputation(HloComputation** result) {
3274 if (lexer_.GetKind() == TokKind::kLbrace) {
3275 // This means it is a nested computation.
3276 return ParseInstructionList(result, /*computation_name=*/"_");
3277 }
3278 // This means it is a computation name.
3279 return ParseComputationName(result);
3280 }
3281
ParseHloComputationList(std::vector<HloComputation * > * result)3282 bool HloParser::ParseHloComputationList(std::vector<HloComputation*>* result) {
3283 auto parse_and_add_item = [&]() {
3284 HloComputation* computation;
3285 if (!ParseHloComputation(&computation)) {
3286 return false;
3287 }
3288 LOG(INFO) << "parsed computation " << computation->name();
3289 result->push_back(computation);
3290 return true;
3291 };
3292 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3293 parse_and_add_item);
3294 }
3295
3296 // shapelist ::= '{' shapes '}'
3297 // precision_elements
3298 // ::= /*empty*/
3299 // ::= shape (',' shape)*
ParseShapeList(std::vector<Shape> * result)3300 bool HloParser::ParseShapeList(std::vector<Shape>* result) {
3301 auto parse_and_add_item = [&]() {
3302 Shape shape;
3303 if (!ParseShape(&shape)) {
3304 return false;
3305 }
3306 result->push_back(std::move(shape));
3307 return true;
3308 };
3309 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3310 parse_and_add_item);
3311 }
3312
3313 // int64list ::= start int64_elements end
3314 // int64_elements
3315 // ::= /*empty*/
3316 // ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64> * result)3317 bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
3318 const TokKind delim,
3319 std::vector<int64>* result) {
3320 auto parse_and_add_item = [&]() {
3321 int64 i;
3322 if (!ParseInt64(&i)) {
3323 return false;
3324 }
3325 result->push_back(i);
3326 return true;
3327 };
3328 return ParseList(start, end, delim, parse_and_add_item);
3329 }
3330
ParseList(const TokKind start,const TokKind end,const TokKind delim,const std::function<bool ()> & parse_and_add_item)3331 bool HloParser::ParseList(const TokKind start, const TokKind end,
3332 const TokKind delim,
3333 const std::function<bool()>& parse_and_add_item) {
3334 if (!ParseToken(start, StrCat("expects a list starting with ",
3335 TokKindToString(start)))) {
3336 return false;
3337 }
3338 if (lexer_.GetKind() == end) {
3339 // empty
3340 } else {
3341 do {
3342 if (!parse_and_add_item()) {
3343 return false;
3344 }
3345 } while (EatIfPresent(delim));
3346 }
3347 return ParseToken(
3348 end, StrCat("expects a list to end with ", TokKindToString(end)));
3349 }
3350
3351 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)3352 bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
3353 if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
3354 return false;
3355 }
3356 *shape_loc = lexer_.GetLoc();
3357 return ParseShape(shape);
3358 }
3359
CanBeParamListToShape()3360 bool HloParser::CanBeParamListToShape() {
3361 return lexer_.GetKind() == TokKind::kLparen;
3362 }
3363
3364 // param_list ::= '(' param_list1 ')'
3365 // param_list1
3366 // ::= /*empty*/
3367 // ::= param (',' param)*
3368 // param ::= name shape
ParseParamList()3369 bool HloParser::ParseParamList() {
3370 if (!ParseToken(TokKind::kLparen,
3371 "expects '(' at the beginning of param list")) {
3372 return false;
3373 }
3374
3375 if (lexer_.GetKind() == TokKind::kRparen) {
3376 // empty
3377 } else {
3378 do {
3379 Shape shape;
3380 string name;
3381 if (!ParseName(&name) || !ParseShape(&shape)) {
3382 return false;
3383 }
3384 } while (EatIfPresent(TokKind::kComma));
3385 }
3386 return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
3387 }
3388
3389 // dimension_sizes ::= '[' dimension_list ']'
3390 // dimension_list
3391 // ::= /*empty*/
3392 // ::= <=? int64 (',' param)*
3393 // param ::= name shape
ParseDimensionSizes(std::vector<int64> * dimension_sizes,std::vector<bool> * dynamic_dimensions)3394 bool HloParser::ParseDimensionSizes(std::vector<int64>* dimension_sizes,
3395 std::vector<bool>* dynamic_dimensions) {
3396 auto parse_and_add_item = [&]() {
3397 int64 i;
3398 bool is_dynamic = false;
3399 if (lexer_.GetKind() == TokKind::kLeq) {
3400 is_dynamic = true;
3401 lexer_.Lex();
3402 }
3403 if (!ParseInt64(&i)) {
3404 return false;
3405 }
3406 dimension_sizes->push_back(i);
3407 dynamic_dimensions->push_back(is_dynamic);
3408 return true;
3409 };
3410 return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
3411 parse_and_add_item);
3412 }
3413
3414 // tiles
3415 // ::= /*empty*/
3416 // ::= 'T' '(' dim_list ')'
3417 // dim_list
3418 // ::= /*empty*/
3419 // ::= (int64 | '*') (',' (int64 | '*'))*
ParseTiles(std::vector<Tile> * tiles)3420 bool HloParser::ParseTiles(std::vector<Tile>* tiles) {
3421 auto parse_and_add_tile_dimension = [&]() {
3422 tensorflow::int64 i;
3423 if (ParseInt64(&i)) {
3424 tiles->back().add_dimensions(i);
3425 return true;
3426 }
3427 if (lexer_.GetKind() == TokKind::kAsterisk) {
3428 tiles->back().add_dimensions(Tile::kCombineDimension);
3429 lexer_.Lex();
3430 return true;
3431 }
3432 return false;
3433 };
3434
3435 do {
3436 tiles->push_back(Tile());
3437 if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
3438 parse_and_add_tile_dimension)) {
3439 return false;
3440 }
3441 } while (lexer_.GetKind() == TokKind::kLparen);
3442 return true;
3443 }
3444
3445 // layout ::= '{' int64_list (':' tiles element_size_in_bits)? '}'
3446 // element_size_in_bits
3447 // ::= /*empty*/
3448 // ::= 'E' '(' int64 ')'
ParseLayout(Layout * layout)3449 bool HloParser::ParseLayout(Layout* layout) {
3450 std::vector<int64> minor_to_major;
3451 std::vector<Tile> tiles;
3452 tensorflow::int64 element_size_in_bits = 0;
3453
3454 auto parse_and_add_item = [&]() {
3455 int64 i;
3456 if (!ParseInt64(&i)) {
3457 return false;
3458 }
3459 minor_to_major.push_back(i);
3460 return true;
3461 };
3462
3463 if (!ParseToken(TokKind::kLbrace,
3464 StrCat("expects layout to start with ",
3465 TokKindToString(TokKind::kLbrace)))) {
3466 return false;
3467 }
3468 if (lexer_.GetKind() != TokKind::kRbrace) {
3469 if (lexer_.GetKind() == TokKind::kInt) {
3470 // Parse minor to major.
3471 do {
3472 if (!parse_and_add_item()) {
3473 return false;
3474 }
3475 } while (EatIfPresent(TokKind::kComma));
3476 }
3477
3478 if (lexer_.GetKind() == TokKind::kColon) {
3479 lexer_.Lex();
3480 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
3481 lexer_.Lex();
3482 ParseTiles(&tiles);
3483 }
3484
3485 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
3486 // Parse element size in bits.
3487 lexer_.Lex();
3488 if (!ParseToken(TokKind::kLparen,
3489 StrCat("expects element size in bits to start with ",
3490 TokKindToString(TokKind::kLparen)))) {
3491 return false;
3492 }
3493 if (!ParseInt64(&element_size_in_bits)) {
3494 return false;
3495 }
3496 if (!ParseToken(TokKind::kRparen,
3497 StrCat("expects element size in bits to end with ",
3498 TokKindToString(TokKind::kRparen)))) {
3499 return false;
3500 }
3501 }
3502 }
3503 }
3504 if (!ParseToken(TokKind::kRbrace,
3505 StrCat("expects layout to end with ",
3506 TokKindToString(TokKind::kRbrace)))) {
3507 return false;
3508 }
3509
3510 std::vector<Tile> vec_tiles(tiles.size());
3511 for (int i = 0; i < tiles.size(); i++) {
3512 vec_tiles[i] = Tile(tiles[i]);
3513 }
3514 *layout =
3515 LayoutUtil::MakeLayout(minor_to_major, vec_tiles, element_size_in_bits);
3516 return true;
3517 }
3518
3519 // shape ::= shape_val_
3520 // shape ::= '(' tuple_elements ')'
3521 // tuple_elements
3522 // ::= /*empty*/
3523 // ::= shape (',' shape)*
ParseShape(Shape * result)3524 bool HloParser::ParseShape(Shape* result) {
3525 if (EatIfPresent(TokKind::kLparen)) { // Tuple
3526 std::vector<Shape> shapes;
3527 if (lexer_.GetKind() == TokKind::kRparen) {
3528 /*empty*/
3529 } else {
3530 // shape (',' shape)*
3531 do {
3532 shapes.emplace_back();
3533 if (!ParseShape(&shapes.back())) {
3534 return false;
3535 }
3536 } while (EatIfPresent(TokKind::kComma));
3537 }
3538 *result = ShapeUtil::MakeTupleShape(shapes);
3539 return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
3540 }
3541
3542 if (lexer_.GetKind() != TokKind::kPrimitiveType) {
3543 return TokenError(absl::StrCat("expected primitive type, saw ",
3544 TokKindToString(lexer_.GetKind())));
3545 }
3546 PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
3547 lexer_.Lex();
3548
3549 // Each element contains a dimension size and a bool indicating whether this
3550 // is a dynamic dimension.
3551 std::vector<int64> dimension_sizes;
3552 std::vector<bool> dynamic_dimensions;
3553 if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
3554 return false;
3555 }
3556 result->set_element_type(primitive_type);
3557 for (int i = 0; i < dimension_sizes.size(); ++i) {
3558 result->add_dimensions(dimension_sizes[i]);
3559 result->set_dynamic_dimension(i, dynamic_dimensions[i]);
3560 }
3561 LayoutUtil::SetToDefaultLayout(result);
3562
3563 if (lexer_.GetKind() == TokKind::kw_sparse) {
3564 lexer_.Lex();
3565 const string message =
3566 "expects a brace-bracketed integer for sparse layout";
3567 int64 max_sparse_elements;
3568 if (!ParseToken(TokKind::kLbrace, message) ||
3569 !ParseInt64(&max_sparse_elements) ||
3570 !ParseToken(TokKind::kRbrace, message)) {
3571 return false;
3572 }
3573 *result->mutable_layout() =
3574 LayoutUtil::MakeSparseLayout(max_sparse_elements);
3575 return true;
3576 }
3577
3578 // We need to lookahead to see if a following open brace is the start of a
3579 // layout. The specific problematic case is:
3580 //
3581 // ENTRY %foo (x: f32[42]) -> f32[123] {
3582 // ...
3583 // }
3584 //
3585 // The open brace could either be the start of a computation or the start of a
3586 // layout for the f32[123] shape. We consider it the start of a layout if the
3587 // next token after the open brace is an integer or a colon.
3588 if (lexer_.GetKind() == TokKind::kLbrace &&
3589 (lexer_.LookAhead() == TokKind::kInt ||
3590 lexer_.LookAhead() == TokKind::kColon)) {
3591 Layout layout;
3592 if (!ParseLayout(&layout)) {
3593 return false;
3594 }
3595 if (layout.minor_to_major_size() != result->rank()) {
3596 return Error(
3597 lexer_.GetLoc(),
3598 StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
3599 result->rank(), layout.minor_to_major_size()));
3600 }
3601 *result->mutable_layout() = layout;
3602 }
3603 return true;
3604 }
3605
CanBeShape()3606 bool HloParser::CanBeShape() {
3607 // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
3608 // with '('.
3609 return lexer_.GetKind() == TokKind::kPrimitiveType ||
3610 lexer_.GetKind() == TokKind::kLparen;
3611 }
3612
ParseName(string * result)3613 bool HloParser::ParseName(string* result) {
3614 VLOG(1) << "ParseName";
3615 if (lexer_.GetKind() != TokKind::kIdent &&
3616 lexer_.GetKind() != TokKind::kName) {
3617 return TokenError("expects name");
3618 }
3619 *result = lexer_.GetStrVal();
3620 lexer_.Lex();
3621 return true;
3622 }
3623
ParseAttributeName(string * result)3624 bool HloParser::ParseAttributeName(string* result) {
3625 if (lexer_.GetKind() != TokKind::kAttributeName) {
3626 return TokenError("expects attribute name");
3627 }
3628 *result = lexer_.GetStrVal();
3629 lexer_.Lex();
3630 return true;
3631 }
3632
ParseString(string * result)3633 bool HloParser::ParseString(string* result) {
3634 VLOG(1) << "ParseString";
3635 if (lexer_.GetKind() != TokKind::kString) {
3636 return TokenError("expects string");
3637 }
3638 *result = lexer_.GetStrVal();
3639 lexer_.Lex();
3640 return true;
3641 }
3642
ParseDxD(const string & name,std::vector<int64> * result)3643 bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
3644 LocTy loc = lexer_.GetLoc();
3645 if (!result->empty()) {
3646 return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
3647 }
3648 // 1D
3649 if (lexer_.GetKind() == TokKind::kInt) {
3650 int64 number;
3651 if (!ParseInt64(&number)) {
3652 return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
3653 }
3654 result->push_back(number);
3655 return true;
3656 }
3657 // 2D or higher.
3658 if (lexer_.GetKind() == TokKind::kDxD) {
3659 string str = lexer_.GetStrVal();
3660 if (!SplitToInt64s(str, 'x', result)) {
3661 return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
3662 }
3663 lexer_.Lex();
3664 return true;
3665 }
3666 return TokenError("expects token type kInt or kDxD");
3667 }
3668
ParseWindowPad(std::vector<std::vector<int64>> * pad)3669 bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
3670 LocTy loc = lexer_.GetLoc();
3671 if (!pad->empty()) {
3672 return Error(loc, "sub-attribute 'pad=' already exists");
3673 }
3674 if (lexer_.GetKind() != TokKind::kPad) {
3675 return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
3676 }
3677 string str = lexer_.GetStrVal();
3678 for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
3679 std::vector<int64> low_high;
3680 if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
3681 low_high.size() != 2) {
3682 return Error(loc,
3683 "expects padding_low and padding_high separated by '_'");
3684 }
3685 pad->push_back(low_high);
3686 }
3687 lexer_.Lex();
3688 return true;
3689 }
3690
3691 // This is the inverse xla::ToString(PaddingConfig). The padding config string
3692 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
3693 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
3694 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)3695 bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
3696 if (lexer_.GetKind() != TokKind::kPad) {
3697 return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
3698 }
3699 LocTy loc = lexer_.GetLoc();
3700 string str = lexer_.GetStrVal();
3701 for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
3702 std::vector<int64> padding_dim;
3703 if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
3704 (padding_dim.size() != 2 && padding_dim.size() != 3)) {
3705 return Error(loc,
3706 "expects padding config pattern like 'low_high_interior' or "
3707 "'low_high'");
3708 }
3709 auto* dim = padding->add_dimensions();
3710 dim->set_edge_padding_low(padding_dim[0]);
3711 dim->set_edge_padding_high(padding_dim[1]);
3712 dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
3713 }
3714 lexer_.Lex();
3715 return true;
3716 }
3717
3718 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)3719 bool HloParser::ParseMetadata(OpMetadata* metadata) {
3720 std::unordered_map<string, AttrConfig> attrs;
3721 optional<string> op_type;
3722 optional<string> op_name;
3723 optional<string> source_file;
3724 optional<int32> source_line;
3725 attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
3726 attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
3727 attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
3728 attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
3729 if (!ParseSubAttributes(attrs)) {
3730 return false;
3731 }
3732 if (op_type) {
3733 metadata->set_op_type(*op_type);
3734 }
3735 if (op_name) {
3736 metadata->set_op_name(*op_name);
3737 }
3738 if (source_file) {
3739 metadata->set_source_file(*source_file);
3740 }
3741 if (source_line) {
3742 metadata->set_source_line(*source_line);
3743 }
3744 return true;
3745 }
3746
ParseOpcode(HloOpcode * result)3747 bool HloParser::ParseOpcode(HloOpcode* result) {
3748 VLOG(1) << "ParseOpcode";
3749 if (lexer_.GetKind() != TokKind::kIdent) {
3750 return TokenError("expects opcode");
3751 }
3752 string val = lexer_.GetStrVal();
3753 auto status_or_result = StringToHloOpcode(val);
3754 if (!status_or_result.ok()) {
3755 return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
3756 status_or_result.status().error_message()));
3757 }
3758 *result = status_or_result.ValueOrDie();
3759 lexer_.Lex();
3760 return true;
3761 }
3762
ParseFftType(FftType * result)3763 bool HloParser::ParseFftType(FftType* result) {
3764 VLOG(1) << "ParseFftType";
3765 if (lexer_.GetKind() != TokKind::kIdent) {
3766 return TokenError("expects fft type");
3767 }
3768 string val = lexer_.GetStrVal();
3769 if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
3770 return TokenError(StrFormat("expects fft type but sees: %s", val));
3771 }
3772 lexer_.Lex();
3773 return true;
3774 }
3775
ParseComparisonDirection(ComparisonDirection * result)3776 bool HloParser::ParseComparisonDirection(ComparisonDirection* result) {
3777 VLOG(1) << "ParseComparisonDirection";
3778 if (lexer_.GetKind() != TokKind::kIdent) {
3779 return TokenError("expects comparison direction");
3780 }
3781 string val = lexer_.GetStrVal();
3782 auto status_or_result = StringToComparisonDirection(val);
3783 if (!status_or_result.ok()) {
3784 return TokenError(
3785 StrFormat("expects comparison direction but sees: %s", val));
3786 }
3787 *result = status_or_result.ValueOrDie();
3788 lexer_.Lex();
3789 return true;
3790 }
3791
ParseFusionKind(HloInstruction::FusionKind * result)3792 bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
3793 VLOG(1) << "ParseFusionKind";
3794 if (lexer_.GetKind() != TokKind::kIdent) {
3795 return TokenError("expects fusion kind");
3796 }
3797 string val = lexer_.GetStrVal();
3798 auto status_or_result = StringToFusionKind(val);
3799 if (!status_or_result.ok()) {
3800 return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
3801 val,
3802 status_or_result.status().error_message()));
3803 }
3804 *result = status_or_result.ValueOrDie();
3805 lexer_.Lex();
3806 return true;
3807 }
3808
ParseRandomDistribution(RandomDistribution * result)3809 bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
3810 VLOG(1) << "ParseRandomDistribution";
3811 if (lexer_.GetKind() != TokKind::kIdent) {
3812 return TokenError("expects random distribution");
3813 }
3814 string val = lexer_.GetStrVal();
3815 auto status_or_result = StringToRandomDistribution(val);
3816 if (!status_or_result.ok()) {
3817 return TokenError(
3818 StrFormat("expects random distribution but sees: %s, error: %s", val,
3819 status_or_result.status().error_message()));
3820 }
3821 *result = status_or_result.ValueOrDie();
3822 lexer_.Lex();
3823 return true;
3824 }
3825
ParsePrecision(PrecisionConfig::Precision * result)3826 bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
3827 VLOG(1) << "ParsePrecision";
3828 if (lexer_.GetKind() != TokKind::kIdent) {
3829 return TokenError("expects random distribution");
3830 }
3831 string val = lexer_.GetStrVal();
3832 auto status_or_result = StringToPrecision(val);
3833 if (!status_or_result.ok()) {
3834 return TokenError(StrFormat("expects precision but sees: %s, error: %s",
3835 val,
3836 status_or_result.status().error_message()));
3837 }
3838 *result = status_or_result.ValueOrDie();
3839 lexer_.Lex();
3840 return true;
3841 }
3842
ParseInt64(int64 * result)3843 bool HloParser::ParseInt64(int64* result) {
3844 VLOG(1) << "ParseInt64";
3845 if (lexer_.GetKind() != TokKind::kInt) {
3846 return TokenError("expects integer");
3847 }
3848 *result = lexer_.GetInt64Val();
3849 lexer_.Lex();
3850 return true;
3851 }
3852
ParseDouble(double * result)3853 bool HloParser::ParseDouble(double* result) {
3854 switch (lexer_.GetKind()) {
3855 case TokKind::kDecimal: {
3856 double val = lexer_.GetDecimalVal();
3857 // If GetDecimalVal returns +/-inf, that means that we overflowed
3858 // `double`.
3859 if (std::isinf(val)) {
3860 return TokenError(StrCat("Constant is out of range for double (+/-",
3861 std::numeric_limits<double>::max(),
3862 ") and so is unparsable."));
3863 }
3864 *result = val;
3865 break;
3866 }
3867 case TokKind::kInt:
3868 *result = static_cast<double>(lexer_.GetInt64Val());
3869 break;
3870 case TokKind::kw_nan:
3871 *result = std::numeric_limits<double>::quiet_NaN();
3872 break;
3873 case TokKind::kw_inf:
3874 *result = std::numeric_limits<double>::infinity();
3875 break;
3876 case TokKind::kNegInf:
3877 *result = -std::numeric_limits<double>::infinity();
3878 break;
3879 default:
3880 return TokenError("expects decimal or integer");
3881 }
3882 lexer_.Lex();
3883 return true;
3884 }
3885
ParseComplex(std::complex<double> * result)3886 bool HloParser::ParseComplex(std::complex<double>* result) {
3887 if (lexer_.GetKind() != TokKind::kLparen) {
3888 return TokenError("expects '(' before complex number");
3889 }
3890 lexer_.Lex();
3891
3892 double real;
3893 LocTy loc = lexer_.GetLoc();
3894 if (!ParseDouble(&real)) {
3895 return Error(loc,
3896 "expect floating-point value for real part of complex number");
3897 }
3898
3899 if (lexer_.GetKind() != TokKind::kComma) {
3900 return TokenError(
3901 absl::StrFormat("expect comma after real part of complex literal"));
3902 }
3903 lexer_.Lex();
3904
3905 double imag;
3906 loc = lexer_.GetLoc();
3907 if (!ParseDouble(&imag)) {
3908 return Error(
3909 loc,
3910 "expect floating-point value for imaginary part of complex number");
3911 }
3912
3913 if (lexer_.GetKind() != TokKind::kRparen) {
3914 return TokenError(absl::StrFormat("expect ')' after complex number"));
3915 }
3916
3917 *result = std::complex<double>(real, imag);
3918 lexer_.Lex();
3919 return true;
3920 }
3921
ParseBool(bool * result)3922 bool HloParser::ParseBool(bool* result) {
3923 if (lexer_.GetKind() != TokKind::kw_true &&
3924 lexer_.GetKind() != TokKind::kw_false) {
3925 return TokenError("expects true or false");
3926 }
3927 *result = lexer_.GetKind() == TokKind::kw_true;
3928 lexer_.Lex();
3929 return true;
3930 }
3931
ParseToken(TokKind kind,const string & msg)3932 bool HloParser::ParseToken(TokKind kind, const string& msg) {
3933 VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg;
3934 if (lexer_.GetKind() != kind) {
3935 return TokenError(msg);
3936 }
3937 lexer_.Lex();
3938 return true;
3939 }
3940
EatIfPresent(TokKind kind)3941 bool HloParser::EatIfPresent(TokKind kind) {
3942 if (lexer_.GetKind() != kind) {
3943 return false;
3944 }
3945 lexer_.Lex();
3946 return true;
3947 }
3948
AddInstruction(const string & name,HloInstruction * instruction,LocTy name_loc)3949 bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
3950 LocTy name_loc) {
3951 auto result = current_name_table().insert({name, {instruction, name_loc}});
3952 if (!result.second) {
3953 Error(name_loc, StrCat("instruction already exists: ", name));
3954 return Error(/*loc=*/result.first->second.second,
3955 "instruction previously defined here");
3956 }
3957 return true;
3958 }
3959
AddComputation(const string & name,HloComputation * computation,LocTy name_loc)3960 bool HloParser::AddComputation(const string& name, HloComputation* computation,
3961 LocTy name_loc) {
3962 auto result = computation_pool_.insert({name, {computation, name_loc}});
3963 if (!result.second) {
3964 Error(name_loc, StrCat("computation already exists: ", name));
3965 return Error(/*loc=*/result.first->second.second,
3966 "computation previously defined here");
3967 }
3968 return true;
3969 }
3970
ParseShapeOnly()3971 StatusOr<Shape> HloParser::ParseShapeOnly() {
3972 lexer_.Lex();
3973 Shape shape;
3974 if (!ParseShape(&shape)) {
3975 return InvalidArgument("Syntax error:\n%s", GetError());
3976 }
3977 if (lexer_.GetKind() != TokKind::kEof) {
3978 return InvalidArgument("Syntax error:\nExtra content after shape");
3979 }
3980 return shape;
3981 }
3982
ParseShardingOnly()3983 StatusOr<HloSharding> HloParser::ParseShardingOnly() {
3984 lexer_.Lex();
3985 OpSharding op_sharding;
3986 if (!ParseSharding(&op_sharding)) {
3987 return InvalidArgument("Syntax error:\n%s", GetError());
3988 }
3989 if (lexer_.GetKind() != TokKind::kEof) {
3990 return InvalidArgument("Syntax error:\nExtra content after sharding");
3991 }
3992 return HloSharding::FromProto(op_sharding);
3993 }
3994
ParseParameterReplicationOnly()3995 StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() {
3996 lexer_.Lex();
3997 ParameterReplication parameter_replication;
3998 if (!ParseParameterReplication(¶meter_replication)) {
3999 return InvalidArgument("Syntax error:\n%s", GetError());
4000 }
4001 if (lexer_.GetKind() != TokKind::kEof) {
4002 return InvalidArgument(
4003 "Syntax error:\nExtra content after parameter replication");
4004 }
4005 return std::vector<bool>(
4006 parameter_replication.replicated_at_leaf_buffers().begin(),
4007 parameter_replication.replicated_at_leaf_buffers().end());
4008 }
4009
ParseWindowOnly()4010 StatusOr<Window> HloParser::ParseWindowOnly() {
4011 lexer_.Lex();
4012 Window window;
4013 if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
4014 return InvalidArgument("Syntax error:\n%s", GetError());
4015 }
4016 if (lexer_.GetKind() != TokKind::kEof) {
4017 return InvalidArgument("Syntax error:\nExtra content after window");
4018 }
4019 return window;
4020 }
4021
4022 StatusOr<ConvolutionDimensionNumbers>
ParseConvolutionDimensionNumbersOnly()4023 HloParser::ParseConvolutionDimensionNumbersOnly() {
4024 lexer_.Lex();
4025 ConvolutionDimensionNumbers dnums;
4026 if (!ParseConvolutionDimensionNumbers(&dnums)) {
4027 return InvalidArgument("Syntax error:\n%s", GetError());
4028 }
4029 if (lexer_.GetKind() != TokKind::kEof) {
4030 return InvalidArgument(
4031 "Syntax error:\nExtra content after convolution dnums");
4032 }
4033 return dnums;
4034 }
4035
ParsePaddingConfigOnly()4036 StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
4037 lexer_.Lex();
4038 PaddingConfig padding_config;
4039 if (!ParsePaddingConfig(&padding_config)) {
4040 return InvalidArgument("Syntax error:\n%s", GetError());
4041 }
4042 if (lexer_.GetKind() != TokKind::kEof) {
4043 return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
4044 }
4045 return padding_config;
4046 }
4047
ParseSingleInstruction(HloModule * module)4048 bool HloParser::ParseSingleInstruction(HloModule* module) {
4049 if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
4050 LOG(FATAL) << "Parser state is not clean. Please do not call any other "
4051 "methods before calling ParseSingleInstruction.";
4052 }
4053 HloComputation::Builder builder(module->name());
4054
4055 // The missing instruction hook we register creates the shaped instruction on
4056 // the fly as a parameter and returns it.
4057 int64 parameter_count = 0;
4058 create_missing_instruction_ =
4059 [this, &builder, ¶meter_count](
4060 const string& name,
4061 const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
4062 string new_name = name.empty() ? StrCat("_", parameter_count) : name;
4063 HloInstruction* parameter = builder.AddInstruction(
4064 HloInstruction::CreateParameter(parameter_count++, shape, new_name));
4065 current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
4066 return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
4067 };
4068
4069 // Parse the instruction with the registered hook.
4070 Scope scope(&scoped_name_tables_);
4071 if (CanBeShape()) {
4072 // This means that the instruction's left-hand side is probably omitted,
4073 // e.g.
4074 //
4075 // f32[10] fusion(...), calls={...}
4076 if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
4077 return false;
4078 }
4079 } else {
4080 // This means that the instruction's left-hand side might exist, e.g.
4081 //
4082 // foo = f32[10] fusion(...), calls={...}
4083 string root_name;
4084 if (!ParseInstruction(&builder, &root_name)) {
4085 return false;
4086 }
4087 }
4088
4089 module->AddEntryComputation(builder.Build());
4090 for (auto& comp : computations_) {
4091 module->AddEmbeddedComputation(std::move(comp));
4092 }
4093 return true;
4094 }
4095
4096 } // namespace
4097
ParseHloString(absl::string_view str,const HloModuleConfig & config)4098 StatusOr<std::unique_ptr<HloModule>> ParseHloString(
4099 absl::string_view str, const HloModuleConfig& config) {
4100 auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
4101 HloParser parser(str);
4102 TF_RETURN_IF_ERROR(parser.Run(module.get()));
4103 return std::move(module);
4104 }
4105
ParseHloString(absl::string_view str)4106 StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
4107 auto module = absl::make_unique<HloModule>(/*name=*/"_", HloModuleConfig());
4108 HloParser parser(str);
4109 TF_RETURN_IF_ERROR(parser.Run(module.get()));
4110 return std::move(module);
4111 }
4112
ParseHloString(absl::string_view str,HloModule * module)4113 Status ParseHloString(absl::string_view str, HloModule* module) {
4114 TF_RET_CHECK(module->computation_count() == 0);
4115 HloParser parser(str);
4116 TF_RETURN_IF_ERROR(parser.Run(module));
4117 return Status::OK();
4118 }
4119
ParseSharding(absl::string_view str)4120 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
4121 HloParser parser(str);
4122 return parser.ParseShardingOnly();
4123 }
4124
ParseParameterReplication(absl::string_view str)4125 StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
4126 HloParser parser(str);
4127 return parser.ParseParameterReplicationOnly();
4128 }
4129
ParseWindow(absl::string_view str)4130 StatusOr<Window> ParseWindow(absl::string_view str) {
4131 HloParser parser(str);
4132 return parser.ParseWindowOnly();
4133 }
4134
ParseConvolutionDimensionNumbers(absl::string_view str)4135 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
4136 absl::string_view str) {
4137 HloParser parser(str);
4138 return parser.ParseConvolutionDimensionNumbersOnly();
4139 }
4140
ParsePaddingConfig(absl::string_view str)4141 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
4142 HloParser parser(str);
4143 return parser.ParsePaddingConfigOnly();
4144 }
4145
ParseShape(absl::string_view str)4146 StatusOr<Shape> ParseShape(absl::string_view str) {
4147 HloParser parser(str);
4148 return parser.ParseShapeOnly();
4149 }
4150
4151 } // namespace xla
4152