• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17 #include <type_traits>
18 
19 #include "absl/algorithm/container.h"
20 #include "absl/memory/memory.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/types/span.h"
26 #include "absl/types/variant.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 
41 namespace xla {
42 
43 namespace {
44 
45 using absl::nullopt;
46 using absl::optional;
47 using absl::StrAppend;
48 using absl::StrCat;
49 using absl::StrFormat;
50 using absl::StrJoin;
51 
52 // Creates and returns a schedule created using the order of the instructions in
53 // the HloComputation::instructions() vectors in the module.
ScheduleFromInstructionOrder(HloModule * module)54 HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
55   HloSchedule schedule(module);
56   for (HloComputation* computation : module->computations()) {
57     if (!computation->IsFusionComputation()) {
58       for (HloInstruction* instruction : computation->instructions()) {
59         schedule.GetOrCreateSequence(computation).push_back(instruction);
60       }
61     }
62   }
63   return schedule;
64 }
65 
66 // Some functions accept either a linear index or a multi-dimensional index
67 // (used for indexing into sparse literals).
68 using LinearOrMultiIndex = absl::variant<int64, absl::Span<const int64>>;
69 
70 // Parser for the HloModule::ToString() format text.
71 class HloParser {
72  public:
73   using LocTy = HloLexer::LocTy;
74 
HloParser(absl::string_view str)75   explicit HloParser(absl::string_view str) : lexer_(str) {}
76 
77   // Runs the parser and constructs the resulting HLO in the given (empty)
78   // HloModule. Returns false if an error occurred.
79   Status Run(HloModule* module);
80 
81   // Returns the error information.
GetError() const82   string GetError() const { return StrJoin(error_, "\n"); }
83 
84   // Stand alone parsing utils for various aggregate data types.
85   StatusOr<Shape> ParseShapeOnly();
86   StatusOr<HloSharding> ParseShardingOnly();
87   StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
88   StatusOr<Window> ParseWindowOnly();
89   StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
90   StatusOr<PaddingConfig> ParsePaddingConfigOnly();
91 
92  private:
93   using InstrNameTable =
94       std::unordered_map<string, std::pair<HloInstruction*, LocTy>>;
95 
96   // Returns the map from the instruction name to the instruction itself and its
97   // location in the current scope.
current_name_table()98   InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
99 
100   // Locates an instruction with the given name in the current_name_table() or
101   // returns nullptr.
102   //
103   // When the name is not found or name is empty, if create_missing_instruction_
104   // hook is registered and a "shape" is provided, the hook will be called to
105   // create an instruction. This is useful when we reify parameters as they're
106   // resolved; i.e. for ParseSingleInstruction.
107   std::pair<HloInstruction*, LocTy>* FindInstruction(
108       const string& name, const optional<Shape>& shape = nullopt);
109 
110   // Parse a single instruction worth of text.
111   bool ParseSingleInstruction(HloModule* module);
112 
113   // Parses a module, returning false if an error occurred.
114   bool ParseHloModule(HloModule* module);
115 
116   bool ParseComputations(HloModule* module);
117   bool ParseComputation(HloComputation** entry_computation);
118   bool ParseInstructionList(HloComputation** computation,
119                             const string& computation_name);
120   bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
121   bool ParseInstructionRhs(HloComputation::Builder* builder, const string& name,
122                            LocTy name_loc);
123   bool ParseControlPredecessors(HloInstruction* instruction);
124   bool ParseLiteral(Literal* literal, const Shape& shape);
125   bool ParseTupleLiteral(Literal* literal, const Shape& shape);
126   bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
127   bool ParseDenseLiteral(Literal* literal, const Shape& shape);
128   bool ParseSparseLiteral(Literal* literal, const Shape& shape);
129 
130   // Sets the sub-value of literal at the given linear or sparse index to the
131   // given value. If the literal is dense, it myst have the default layout.
132   //
133   // `loc` should be the source location of the value.
134   bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index,
135                          Literal* literal);
136   bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index,
137                          Literal* literal);
138   bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index,
139                          Literal* literal);
140   bool SetValueInLiteral(LocTy loc, std::complex<double> value,
141                          LinearOrMultiIndex index, Literal* literal);
142   // `loc` should be the source location of the value.
143   template <typename LiteralNativeT, typename ParsedElemT>
144   bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
145                                LinearOrMultiIndex index, Literal* literal);
146 
147   // Checks whether the given value is within the range of LiteralNativeT.
148   // `loc` should be the source location of the value.
149   template <typename LiteralNativeT, typename ParsedElemT>
150   bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
151   template <typename LiteralNativeT>
152   bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
153 
154   bool ParseOperands(std::vector<HloInstruction*>* operands);
155   // Fills parsed operands into 'operands' and expects a certain number of
156   // operands.
157   bool ParseOperands(std::vector<HloInstruction*>* operands,
158                      const int expected_size);
159 
160   // Describes the start, limit, and stride on every dimension of the operand
161   // being sliced.
162   struct SliceRanges {
163     std::vector<int64> starts;
164     std::vector<int64> limits;
165     std::vector<int64> strides;
166   };
167 
168   // The data parsed for the kDomain instruction.
169   struct DomainData {
170     std::unique_ptr<DomainMetadata> entry_metadata;
171     std::unique_ptr<DomainMetadata> exit_metadata;
172   };
173 
174   // Types of attributes.
175   enum class AttrTy {
176     kBool,
177     kInt64,
178     kInt32,
179     kFloat,
180     kString,
181     kBracedInt64List,
182     kBracedInt64ListList,
183     kHloComputation,
184     kBracedHloComputationList,
185     kFftType,
186     kComparisonDirection,
187     kWindow,
188     kConvolutionDimensionNumbers,
189     kSharding,
190     kParameterReplication,
191     kInstructionList,
192     kSliceRanges,
193     kPaddingConfig,
194     kMetadata,
195     kFusionKind,
196     kDistribution,
197     kDomain,
198     kPrecisionList,
199     kShapeList
200   };
201 
202   struct AttrConfig {
203     bool required;     // whether it's required or optional
204     AttrTy attr_type;  // what type it is
205     void* result;      // where to store the parsed result.
206   };
207 
208   // attributes ::= (',' attribute)*
209   //
210   // Parses attributes given names and configs of the attributes. Each parsed
211   // result is passed back through the result pointer in corresponding
212   // AttrConfig. Note that the result pointer must point to a optional<T> typed
213   // variable which outlives this function. Returns false on error. You should
214   // not use the any of the results if this function failed.
215   //
216   // Example usage:
217   //
218   //  std::unordered_map<string, AttrConfig> attrs;
219   //  optional<int64> foo;
220   //  attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
221   //  optional<Window> bar;
222   //  attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
223   //  if (!ParseAttributes(attrs)) {
224   //    return false; // Do not use 'foo' 'bar' if failed.
225   //  }
226   //  // Do something with 'bar'.
227   //  if (foo) { // If attr foo is seen, do something with 'foo'. }
228   //
229   bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
230 
231   // sub_attributes ::= '{' (','? attribute)* '}'
232   //
233   // Usage is the same as ParseAttributes. See immediately above.
234   bool ParseSubAttributes(const std::unordered_map<string, AttrConfig>& attrs);
235 
236   // Parses one attribute. If it has already been seen, return error. Returns
237   // true and adds to seen_attrs on success.
238   //
239   // Do not call this except in ParseAttributes or ParseSubAttributes.
240   bool ParseAttributeHelper(const std::unordered_map<string, AttrConfig>& attrs,
241                             std::unordered_set<string>* seen_attrs);
242 
243   // Parses an attribute string into a protocol buffer `message`.
244   // Since proto3 has no notion of mandatory fields, `required_attrs` gives the
245   // set of mandatory attributes.
246   bool ParseAttributesAsProtoMessage(
247       const std::unordered_set<string>& required_attrs,
248       tensorflow::protobuf::Message* message);
249 
250   // Parses one attribute. If it has already been seen, return error. Returns
251   // true and adds to seen_attrs on success.
252   //
253   // Do not call this except in ParseAttributesAsProtoMessage.
254   bool ParseAttributeAsProtoMessageHelper(
255       tensorflow::protobuf::Message* message,
256       std::unordered_set<string>* seen_attrs);
257 
258   // Parses a name and finds the corresponding hlo computation.
259   bool ParseComputationName(HloComputation** value);
260   // Parses a list of names and finds the corresponding hlo instructions.
261   bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
262   // Pass expect_outer_curlies == true when parsing a Window in the context of a
263   // larger computation.  Pass false when parsing a stand-alone Window string.
264   bool ParseWindow(Window* window, bool expect_outer_curlies);
265   bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
266   bool ParsePaddingConfig(PaddingConfig* padding);
267   bool ParseMetadata(OpMetadata* metadata);
268   bool ParseSharding(OpSharding* sharding);
269   bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
270   bool ParseParameterReplication(ParameterReplication* parameter_replication);
271 
272   // Parses the metadata behind a kDOmain instruction.
273   bool ParseDomain(DomainData* domain);
274 
275   // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
276   bool ParseDxD(const string& name, std::vector<int64>* result);
277   // Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
278   bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
279 
280   bool ParseSliceRanges(SliceRanges* result);
281   bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
282   bool ParseHloComputation(HloComputation** result);
283   bool ParseHloComputationList(std::vector<HloComputation*>* result);
284   bool ParseShapeList(std::vector<Shape>* result);
285   bool ParseInt64List(const TokKind start, const TokKind end,
286                       const TokKind delim, std::vector<int64>* result);
287   // 'parse_and_add_item' is an lambda to parse an element in the list and add
288   // the parsed element to the result. It's supposed to capture the result.
289   bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
290                  const std::function<bool()>& parse_and_add_item);
291 
292   bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
293   bool ParseParamList();
294   bool ParseName(string* result);
295   bool ParseAttributeName(string* result);
296   bool ParseString(string* result);
297   bool ParseDimensionSizes(std::vector<int64>* dimension_sizes,
298                            std::vector<bool>* dynamic_dimensions);
299   bool ParseShape(Shape* result);
300   bool ParseLayout(Layout* layout);
301   bool ParseTiles(std::vector<Tile>* tiles);
302   bool ParseOpcode(HloOpcode* result);
303   bool ParseFftType(FftType* result);
304   bool ParseComparisonDirection(ComparisonDirection* result);
305   bool ParseFusionKind(HloInstruction::FusionKind* result);
306   bool ParseRandomDistribution(RandomDistribution* result);
307   bool ParsePrecision(PrecisionConfig::Precision* result);
308   bool ParseInt64(int64* result);
309   bool ParseDouble(double* result);
310   bool ParseComplex(std::complex<double>* result);
311   bool ParseBool(bool* result);
312   bool ParseToken(TokKind kind, const string& msg);
313 
314   // Returns true if the current token is the beginning of a shape.
315   bool CanBeShape();
316   // Returns true if the current token is the beginning of a
317   // param_list_to_shape.
318   bool CanBeParamListToShape();
319 
320   // Logs the current parsing line and the given message. Always returns false.
321   bool TokenError(absl::string_view msg);
322   bool Error(LocTy loc, absl::string_view msg);
323 
324   // If the current token is 'kind', eats it (i.e. lexes the next token) and
325   // returns true.
326   bool EatIfPresent(TokKind kind);
327 
328   // Adds the instruction to the pool. Returns false and emits an error if the
329   // instruction already exists.
330   bool AddInstruction(const string& name, HloInstruction* instruction,
331                       LocTy name_loc);
332   // Adds the computation to the pool. Returns false and emits an error if the
333   // computation already exists.
334   bool AddComputation(const string& name, HloComputation* computation,
335                       LocTy name_loc);
336 
337   HloLexer lexer_;
338 
339   // A stack for the instruction names. The top of the stack stores the
340   // instruction name table for the current scope.
341   //
342   // A instruction's name is unique among its scope (i.e. its parent
343   // computation), but it's not necessarily unique among all computations in the
344   // module. When there are multiple levels of nested computations, the same
345   // name could appear in both an outer computation and an inner computation. So
346   // we need a stack to make sure a name is only visible within its scope,
347   std::vector<InstrNameTable> scoped_name_tables_;
348 
349   // A helper class which pushes and pops to an InstrNameTable stack via RAII.
350   class Scope {
351    public:
Scope(std::vector<InstrNameTable> * scoped_name_tables)352     explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
353         : scoped_name_tables_(scoped_name_tables) {
354       scoped_name_tables_->emplace_back();
355     }
~Scope()356     ~Scope() { scoped_name_tables_->pop_back(); }
357 
358    private:
359     std::vector<InstrNameTable>* scoped_name_tables_;
360   };
361 
362   // Map from the computation name to the computation itself and its location.
363   std::unordered_map<string, std::pair<HloComputation*, LocTy>>
364       computation_pool_;
365 
366   std::vector<std::unique_ptr<HloComputation>> computations_;
367   std::vector<string> error_;
368 
369   // When an operand name cannot be resolved, this function is called to create
370   // a parameter instruction with the given name and shape. It registers the
371   // name, instruction, and a placeholder location in the name table. It returns
372   // the newly-created instruction and the placeholder location. If `name` is
373   // empty, this should create the parameter with a generated name. This is
374   // supposed to be set and used only in ParseSingleInstruction.
375   std::function<std::pair<HloInstruction*, LocTy>*(const string& name,
376                                                    const Shape& shape)>
377       create_missing_instruction_;
378 };
379 
SplitToInt64s(absl::string_view s,char delim,std::vector<int64> * out)380 bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
381   for (const auto& split : absl::StrSplit(s, delim)) {
382     int64 val;
383     if (!absl::SimpleAtoi(split, &val)) {
384       return false;
385     }
386     out->push_back(val);
387   }
388   return true;
389 }
390 
391 // Creates replica groups from the provided nested array. groups[i] represents
392 // the replica ids for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64>> groups)393 std::vector<ReplicaGroup> CreateReplicaGroups(
394     absl::Span<const std::vector<int64>> groups) {
395   std::vector<ReplicaGroup> replica_groups;
396   absl::c_transform(groups, std::back_inserter(replica_groups),
397                     [](const std::vector<int64>& ids) {
398                       ReplicaGroup group;
399                       *group.mutable_replica_ids() = {ids.begin(), ids.end()};
400                       return group;
401                     });
402   return replica_groups;
403 }
404 
Error(LocTy loc,absl::string_view msg)405 bool HloParser::Error(LocTy loc, absl::string_view msg) {
406   auto line_col = lexer_.GetLineAndColumn(loc);
407   const unsigned line = line_col.first;
408   const unsigned col = line_col.second;
409   std::vector<string> error_lines;
410   error_lines.push_back(
411       StrCat("was parsing ", line, ":", col, ": error: ", msg));
412   error_lines.emplace_back(lexer_.GetLine(loc));
413   error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
414 
415   error_.push_back(StrJoin(error_lines, "\n"));
416   VLOG(1) << "Error: " << error_.back();
417   return false;
418 }
419 
TokenError(absl::string_view msg)420 bool HloParser::TokenError(absl::string_view msg) {
421   return Error(lexer_.GetLoc(), msg);
422 }
423 
Run(HloModule * module)424 Status HloParser::Run(HloModule* module) {
425   lexer_.Lex();
426   if (lexer_.GetKind() == TokKind::kw_HloModule) {
427     // This means that the text contains a full HLO module.
428     if (!ParseHloModule(module)) {
429       return InvalidArgument(
430           "Syntax error when trying to parse the text as a HloModule:\n%s",
431           GetError());
432     }
433     return Status::OK();
434   }
435   // This means that the text is a single HLO instruction.
436   if (!ParseSingleInstruction(module)) {
437     return InvalidArgument(
438         "Syntax error when trying to parse the text as a single "
439         "HloInstruction:\n%s",
440         GetError());
441   }
442   return Status::OK();
443 }
444 
FindInstruction(const string & name,const optional<Shape> & shape)445 std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
446     const string& name, const optional<Shape>& shape) {
447   std::pair<HloInstruction*, LocTy>* instr = nullptr;
448   if (!name.empty()) {
449     instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
450   }
451 
452   // Potentially call the missing instruction hook.
453   if (instr == nullptr && create_missing_instruction_ != nullptr &&
454       scoped_name_tables_.size() == 1) {
455     if (!shape.has_value()) {
456       Error(lexer_.GetLoc(),
457             "Operand had no shape in HLO text; cannot create parameter for "
458             "single-instruction module.");
459       return nullptr;
460     }
461     return create_missing_instruction_(name, *shape);
462   }
463 
464   if (instr != nullptr && shape.has_value() &&
465       !ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
466     Error(
467         lexer_.GetLoc(),
468         StrCat("The declared operand shape ",
469                ShapeUtil::HumanStringWithLayout(shape.value()),
470                " is not compatible with the shape of the operand instruction ",
471                ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
472     return nullptr;
473   }
474 
475   return instr;
476 }
477 
478 // ::= 'HloModule' name computations
ParseHloModule(HloModule * module)479 bool HloParser::ParseHloModule(HloModule* module) {
480   if (lexer_.GetKind() != TokKind::kw_HloModule) {
481     return TokenError("expects HloModule");
482   }
483   // Eat 'HloModule'
484   lexer_.Lex();
485 
486   string name;
487   if (!ParseName(&name)) {
488     return false;
489   }
490 
491   absl::optional<bool> is_scheduled;
492   std::unordered_map<string, AttrConfig> attrs;
493   attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
494   if (!ParseAttributes(attrs)) {
495     return false;
496   }
497 
498   module->set_name(name);
499   if (!ParseComputations(module)) {
500     return false;
501   }
502 
503   if (is_scheduled.has_value() && *is_scheduled) {
504     TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
505   }
506 
507   return true;
508 }
509 
510 // computations ::= (computation)+
ParseComputations(HloModule * module)511 bool HloParser::ParseComputations(HloModule* module) {
512   HloComputation* entry_computation = nullptr;
513   do {
514     if (!ParseComputation(&entry_computation)) {
515       return false;
516     }
517   } while (lexer_.GetKind() != TokKind::kEof);
518 
519   for (int i = 0; i < computations_.size(); i++) {
520     // If entry_computation is not nullptr, it means the computation it pointed
521     // to is marked with "ENTRY"; otherwise, no computation is marked with
522     // "ENTRY", and we use the last computation as the entry computation. We
523     // add the non-entry computations as embedded computations to the module.
524     if ((entry_computation != nullptr &&
525          computations_[i].get() != entry_computation) ||
526         (entry_computation == nullptr && i != computations_.size() - 1)) {
527       module->AddEmbeddedComputation(std::move(computations_[i]));
528       continue;
529     }
530     auto computation = module->AddEntryComputation(std::move(computations_[i]));
531     // The parameters and result layouts were set to default layout. Here we
532     // set the layouts to what the hlo text says.
533     for (int p = 0; p < computation->num_parameters(); p++) {
534       const Shape& param_shape = computation->parameter_instruction(p)->shape();
535       TF_CHECK_OK(module->mutable_entry_computation_layout()
536                       ->mutable_parameter_layout(p)
537                       ->CopyLayoutFromShape(param_shape));
538     }
539     const Shape& result_shape = computation->root_instruction()->shape();
540     TF_CHECK_OK(module->mutable_entry_computation_layout()
541                     ->mutable_result_layout()
542                     ->CopyLayoutFromShape(result_shape));
543   }
544   return true;
545 }
546 
547 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
ParseComputation(HloComputation ** entry_computation)548 bool HloParser::ParseComputation(HloComputation** entry_computation) {
549   LocTy maybe_entry_loc = lexer_.GetLoc();
550   const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
551 
552   string name;
553   LocTy name_loc = lexer_.GetLoc();
554   if (!ParseName(&name)) {
555     return false;
556   }
557 
558   LocTy shape_loc = nullptr;
559   Shape shape;
560   if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
561     return false;
562   }
563 
564   HloComputation* computation = nullptr;
565   if (!ParseInstructionList(&computation, name)) {
566     return false;
567   }
568 
569   // If param_list_to_shape was present, check compatibility.
570   if (shape_loc != nullptr &&
571       !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
572     return Error(
573         shape_loc,
574         StrCat(
575             "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
576             ", is not compatible with that of its root instruction ",
577             computation->root_instruction()->name(), ", ",
578             ShapeUtil::HumanString(computation->root_instruction()->shape())));
579   }
580 
581   if (is_entry_computation) {
582     if (*entry_computation != nullptr) {
583       return Error(maybe_entry_loc, "expects only one ENTRY");
584     }
585     *entry_computation = computation;
586   }
587 
588   return AddComputation(name, computation, name_loc);
589 }
590 
591 // instruction_list ::= '{' instruction_list1 '}'
592 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation ** computation,const string & computation_name)593 bool HloParser::ParseInstructionList(HloComputation** computation,
594                                      const string& computation_name) {
595   Scope scope(&scoped_name_tables_);
596   HloComputation::Builder builder(computation_name);
597   if (!ParseToken(TokKind::kLbrace,
598                   "expects '{' at the beginning of instruction list.")) {
599     return false;
600   }
601   string root_name;
602   do {
603     if (!ParseInstruction(&builder, &root_name)) {
604       return false;
605     }
606   } while (lexer_.GetKind() != TokKind::kRbrace);
607   if (!ParseToken(TokKind::kRbrace,
608                   "expects '}' at the end of instruction list.")) {
609     return false;
610   }
611   HloInstruction* root = nullptr;
612   if (!root_name.empty()) {
613     std::pair<HloInstruction*, LocTy>* root_node =
614         tensorflow::gtl::FindOrNull(current_name_table(), root_name);
615 
616     // This means some instruction was marked as ROOT but we didn't find it in
617     // the pool, which should not happen.
618     if (root_node == nullptr) {
619       LOG(FATAL) << "instruction " << root_name
620                  << " was marked as ROOT but the parser has not seen it before";
621     }
622     root = root_node->first;
623   }
624 
625   // Now root can be either an existing instruction or a nullptr. If it's a
626   // nullptr, the implementation of Builder will set the last instruction as
627   // the root instruction.
628   computations_.emplace_back(builder.Build(root));
629   *computation = computations_.back().get();
630   return true;
631 }
632 
633 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,string * root_name)634 bool HloParser::ParseInstruction(HloComputation::Builder* builder,
635                                  string* root_name) {
636   string name;
637   LocTy maybe_root_loc = lexer_.GetLoc();
638   bool is_root = EatIfPresent(TokKind::kw_ROOT);
639 
640   const LocTy name_loc = lexer_.GetLoc();
641   if (!ParseName(&name) ||
642       !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
643     return false;
644   }
645 
646   if (is_root) {
647     if (!root_name->empty()) {
648       return Error(maybe_root_loc, "one computation should have only one ROOT");
649     }
650     *root_name = name;
651   }
652 
653   return ParseInstructionRhs(builder, name, name_loc);
654 }
655 
ParseInstructionRhs(HloComputation::Builder * builder,const string & name,LocTy name_loc)656 bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
657                                     const string& name, LocTy name_loc) {
658   Shape shape;
659   HloOpcode opcode;
660   std::vector<HloInstruction*> operands;
661 
662   if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
663     return false;
664   }
665 
666   // Add optional attributes.
667   std::unordered_map<string, AttrConfig> attrs;
668   optional<OpSharding> sharding;
669   attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
670   optional<ParameterReplication> parameter_replication;
671   attrs["parameter_replication"] = {/*required=*/false,
672                                     AttrTy::kParameterReplication,
673                                     &parameter_replication};
674   optional<std::vector<HloInstruction*>> predecessors;
675   attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
676                                    &predecessors};
677   optional<OpMetadata> metadata;
678   attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
679 
680   optional<string> backend_config;
681   attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
682                              &backend_config};
683 
684   HloInstruction* instruction;
685   switch (opcode) {
686     case HloOpcode::kParameter: {
687       int64 parameter_number;
688       if (!ParseToken(TokKind::kLparen,
689                       "expects '(' before parameter number") ||
690           !ParseInt64(&parameter_number)) {
691         return false;
692       }
693       if (parameter_number < 0) {
694         Error(lexer_.GetLoc(), "parameter number must be >= 0");
695         return false;
696       }
697       if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
698           !ParseAttributes(attrs)) {
699         return false;
700       }
701       instruction = builder->AddInstruction(
702           HloInstruction::CreateParameter(parameter_number, shape, name));
703       break;
704     }
705     case HloOpcode::kConstant: {
706       Literal literal;
707       if (!ParseToken(TokKind::kLparen,
708                       "expects '(' before constant literal") ||
709           !ParseLiteral(&literal, shape) ||
710           !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
711           !ParseAttributes(attrs)) {
712         return false;
713       }
714       instruction = builder->AddInstruction(
715           HloInstruction::CreateConstant(std::move(literal)));
716       break;
717     }
718     case HloOpcode::kIota: {
719       optional<int64> iota_dimension;
720       attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
721                                  &iota_dimension};
722       if (!ParseOperands(&operands, /*expected_size=*/0) ||
723           !ParseAttributes(attrs)) {
724         return false;
725       }
726       instruction = builder->AddInstruction(
727           HloInstruction::CreateIota(shape, *iota_dimension));
728       break;
729     }
730     // Unary ops.
731     case HloOpcode::kAbs:
732     case HloOpcode::kRoundNearestAfz:
733     case HloOpcode::kBitcast:
734     case HloOpcode::kCeil:
735     case HloOpcode::kClz:
736     case HloOpcode::kCopy:
737     case HloOpcode::kCos:
738     case HloOpcode::kExp:
739     case HloOpcode::kExpm1:
740     case HloOpcode::kImag:
741     case HloOpcode::kIsFinite:
742     case HloOpcode::kFloor:
743     case HloOpcode::kLog:
744     case HloOpcode::kLog1p:
745     case HloOpcode::kNot:
746     case HloOpcode::kNegate:
747     case HloOpcode::kReal:
748     case HloOpcode::kRsqrt:
749     case HloOpcode::kSign:
750     case HloOpcode::kSin:
751     case HloOpcode::kSqrt:
752     case HloOpcode::kTanh: {
753       if (!ParseOperands(&operands, /*expected_size=*/1) ||
754           !ParseAttributes(attrs)) {
755         return false;
756       }
757       instruction = builder->AddInstruction(
758           HloInstruction::CreateUnary(shape, opcode, operands[0]));
759       break;
760     }
761     // Binary ops.
762     case HloOpcode::kAdd:
763     case HloOpcode::kDivide:
764     case HloOpcode::kMultiply:
765     case HloOpcode::kSubtract:
766     case HloOpcode::kAtan2:
767     case HloOpcode::kComplex:
768     case HloOpcode::kMaximum:
769     case HloOpcode::kMinimum:
770     case HloOpcode::kPower:
771     case HloOpcode::kRemainder:
772     case HloOpcode::kAnd:
773     case HloOpcode::kOr:
774     case HloOpcode::kXor:
775     case HloOpcode::kShiftLeft:
776     case HloOpcode::kShiftRightArithmetic:
777     case HloOpcode::kShiftRightLogical: {
778       if (!ParseOperands(&operands, /*expected_size=*/2) ||
779           !ParseAttributes(attrs)) {
780         return false;
781       }
782       instruction = builder->AddInstruction(HloInstruction::CreateBinary(
783           shape, opcode, operands[0], operands[1]));
784       break;
785     }
786     // Ternary ops.
787     case HloOpcode::kClamp:
788     case HloOpcode::kSelect:
789     case HloOpcode::kTupleSelect: {
790       if (!ParseOperands(&operands, /*expected_size=*/3) ||
791           !ParseAttributes(attrs)) {
792         return false;
793       }
794       instruction = builder->AddInstruction(HloInstruction::CreateTernary(
795           shape, opcode, operands[0], operands[1], operands[2]));
796       break;
797     }
798     // Other supported ops.
799     case HloOpcode::kConvert: {
800       if (!ParseOperands(&operands, /*expected_size=*/1) ||
801           !ParseAttributes(attrs)) {
802         return false;
803       }
804       instruction = builder->AddInstruction(
805           HloInstruction::CreateConvert(shape, operands[0]));
806       break;
807     }
808     case HloOpcode::kBitcastConvert: {
809       if (!ParseOperands(&operands, /*expected_size=*/1) ||
810           !ParseAttributes(attrs)) {
811         return false;
812       }
813       instruction = builder->AddInstruction(
814           HloInstruction::CreateBitcastConvert(shape, operands[0]));
815       break;
816     }
817     case HloOpcode::kAllReduce: {
818       optional<std::vector<std::vector<int64>>> tmp_groups;
819       optional<HloComputation*> to_apply;
820       optional<std::vector<int64>> replica_group_ids;
821       optional<string> barrier;
822       optional<int64> all_reduce_id;
823       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
824                            &to_apply};
825       attrs["replica_groups"] = {/*required=*/false,
826                                  AttrTy::kBracedInt64ListList, &tmp_groups};
827       attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
828       attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
829                                 &all_reduce_id};
830       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
831         return false;
832       }
833       std::vector<ReplicaGroup> replica_groups;
834       if (tmp_groups) {
835         replica_groups = CreateReplicaGroups(*tmp_groups);
836       }
837       instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
838           shape, operands, *to_apply, replica_groups, barrier ? *barrier : "",
839           all_reduce_id));
840       break;
841     }
842     case HloOpcode::kAllToAll: {
843       optional<std::vector<std::vector<int64>>> tmp_groups;
844       optional<string> barrier;
845       attrs["replica_groups"] = {/*required=*/false,
846                                  AttrTy::kBracedInt64ListList, &tmp_groups};
847       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
848         return false;
849       }
850       std::vector<ReplicaGroup> replica_groups;
851       if (tmp_groups) {
852         replica_groups = CreateReplicaGroups(*tmp_groups);
853       }
854       instruction = builder->AddInstruction(
855           HloInstruction::CreateAllToAll(shape, operands, replica_groups));
856       break;
857     }
858     case HloOpcode::kCollectivePermute: {
859       optional<std::vector<std::vector<int64>>> source_targets;
860       attrs["source_target_pairs"] = {
861           /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
862       if (!ParseOperands(&operands, /*expected_size=*/1) ||
863           !ParseAttributes(attrs)) {
864         return false;
865       }
866       std::vector<std::pair<int64, int64>> pairs(source_targets->size());
867       for (int i = 0; i < pairs.size(); i++) {
868         if ((*source_targets)[i].size() != 2) {
869           return TokenError(
870               "expects 'source_target_pairs=' to be a list of pairs");
871         }
872         pairs[i].first = (*source_targets)[i][0];
873         pairs[i].second = (*source_targets)[i][1];
874       }
875       instruction = builder->AddInstruction(
876           HloInstruction::CreateCollectivePermute(shape, operands[0], pairs));
877       break;
878     }
879     case HloOpcode::kReplicaId: {
880       if (!ParseOperands(&operands, /*expected_size=*/0) ||
881           !ParseAttributes(attrs)) {
882         return false;
883       }
884       instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
885       break;
886     }
887     case HloOpcode::kReshape: {
888       if (!ParseOperands(&operands, /*expected_size=*/1) ||
889           !ParseAttributes(attrs)) {
890         return false;
891       }
892       instruction = builder->AddInstruction(
893           HloInstruction::CreateReshape(shape, operands[0]));
894       break;
895     }
896     case HloOpcode::kAfterAll: {
897       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
898         return false;
899       }
900       if (operands.empty()) {
901         instruction = builder->AddInstruction(HloInstruction::CreateToken());
902       } else {
903         instruction =
904             builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
905       }
906       break;
907     }
908     case HloOpcode::kAddDependency: {
909       if (!ParseOperands(&operands, /*expected_size=*/2) ||
910           !ParseAttributes(attrs)) {
911         return false;
912       }
913       instruction = builder->AddInstruction(
914           HloInstruction::CreateAddDependency(operands[0], operands[1]));
915       break;
916     }
917     case HloOpcode::kSort: {
918       optional<std::vector<int64>> dimensions;
919       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
920                              &dimensions};
921       optional<bool> is_stable = false;
922       attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
923       optional<HloComputation*> to_apply;
924       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
925                            &to_apply};
926       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
927           dimensions->size() != 1) {
928         return false;
929       }
930       instruction = builder->AddInstruction(
931           HloInstruction::CreateSort(shape, dimensions->at(0), operands,
932                                      to_apply.value(), is_stable.value()));
933       break;
934     }
935     case HloOpcode::kTuple: {
936       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
937         return false;
938       }
939       instruction =
940           builder->AddInstruction(HloInstruction::CreateTuple(operands));
941       break;
942     }
943     case HloOpcode::kWhile: {
944       optional<HloComputation*> condition;
945       optional<HloComputation*> body;
946       attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
947                             &condition};
948       attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
949       if (!ParseOperands(&operands, /*expected_size=*/1) ||
950           !ParseAttributes(attrs)) {
951         return false;
952       }
953       instruction = builder->AddInstruction(HloInstruction::CreateWhile(
954           shape, *condition, *body, /*init=*/operands[0]));
955       break;
956     }
957     case HloOpcode::kRecv: {
958       optional<int64> channel_id;
959       // If the is_host_transfer attribute is not present then default to false.
960       optional<bool> is_host_transfer = false;
961       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
962       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
963                                    &is_host_transfer};
964       if (!ParseOperands(&operands, /*expected_size=*/1) ||
965           !ParseAttributes(attrs)) {
966         return false;
967       }
968       // If the is_host_transfer attribute is not present then default to false.
969       instruction = builder->AddInstruction(HloInstruction::CreateRecv(
970           shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
971       break;
972     }
973     case HloOpcode::kRecvDone: {
974       optional<int64> channel_id;
975       // If the is_host_transfer attribute is not present then default to false.
976       optional<bool> is_host_transfer = false;
977       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
978       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
979                                    &is_host_transfer};
980       if (!ParseOperands(&operands, /*expected_size=*/1) ||
981           !ParseAttributes(attrs)) {
982         return false;
983       }
984       if (channel_id != operands[0]->channel_id()) {
985         return false;
986       }
987       instruction = builder->AddInstruction(
988           HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
989       break;
990     }
991     case HloOpcode::kSend: {
992       optional<int64> channel_id;
993       // If the is_host_transfer attribute is not present then default to false.
994       optional<bool> is_host_transfer = false;
995       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
996       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
997                                    &is_host_transfer};
998       if (!ParseOperands(&operands, /*expected_size=*/2) ||
999           !ParseAttributes(attrs)) {
1000         return false;
1001       }
1002       instruction = builder->AddInstruction(HloInstruction::CreateSend(
1003           operands[0], operands[1], *channel_id, *is_host_transfer));
1004       break;
1005     }
1006     case HloOpcode::kSendDone: {
1007       optional<int64> channel_id;
1008       // If the is_host_transfer attribute is not present then default to false.
1009       optional<bool> is_host_transfer = false;
1010       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1011       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1012                                    &is_host_transfer};
1013       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1014           !ParseAttributes(attrs)) {
1015         return false;
1016       }
1017       if (channel_id != operands[0]->channel_id()) {
1018         return false;
1019       }
1020       instruction = builder->AddInstruction(
1021           HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
1022       break;
1023     }
1024     case HloOpcode::kGetTupleElement: {
1025       optional<int64> index;
1026       attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
1027       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1028           !ParseAttributes(attrs)) {
1029         return false;
1030       }
1031       instruction = builder->AddInstruction(
1032           HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
1033       break;
1034     }
1035     case HloOpcode::kCall: {
1036       optional<HloComputation*> to_apply;
1037       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1038                            &to_apply};
1039       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1040         return false;
1041       }
1042       instruction = builder->AddInstruction(
1043           HloInstruction::CreateCall(shape, operands, *to_apply));
1044       break;
1045     }
1046     case HloOpcode::kReduceWindow: {
1047       optional<HloComputation*> reduce_computation;
1048       optional<Window> window;
1049       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1050       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1051                            &reduce_computation};
1052       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1053           !ParseAttributes(attrs)) {
1054         return false;
1055       }
1056       if (!window) {
1057         window.emplace();
1058       }
1059       instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
1060           shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
1061           *reduce_computation));
1062       break;
1063     }
1064     case HloOpcode::kConvolution: {
1065       optional<Window> window;
1066       optional<ConvolutionDimensionNumbers> dnums;
1067       optional<int64> feature_group_count;
1068       optional<int64> batch_group_count;
1069       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1070       attrs["dim_labels"] = {/*required=*/true,
1071                              AttrTy::kConvolutionDimensionNumbers, &dnums};
1072       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1073                                       &feature_group_count};
1074       attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1075                                     &batch_group_count};
1076       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1077       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1078                                     &operand_precision};
1079       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1080           !ParseAttributes(attrs)) {
1081         return false;
1082       }
1083       if (!window) {
1084         window.emplace();
1085       }
1086       if (!feature_group_count) {
1087         feature_group_count = 1;
1088       }
1089       if (!batch_group_count) {
1090         batch_group_count = 1;
1091       }
1092       PrecisionConfig precision_config;
1093       if (operand_precision) {
1094         *precision_config.mutable_operand_precision() = {
1095             operand_precision->begin(), operand_precision->end()};
1096       } else {
1097         precision_config.mutable_operand_precision()->Resize(
1098             operands.size(), PrecisionConfig::DEFAULT);
1099       }
1100       instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
1101           shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
1102           feature_group_count.value(), batch_group_count.value(), *window,
1103           *dnums, precision_config));
1104       break;
1105     }
1106     case HloOpcode::kFft: {
1107       optional<FftType> fft_type;
1108       optional<std::vector<int64>> fft_length;
1109       attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
1110       attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
1111                              &fft_length};
1112       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1113           !ParseAttributes(attrs)) {
1114         return false;
1115       }
1116       instruction = builder->AddInstruction(HloInstruction::CreateFft(
1117           shape, operands[0], *fft_type, *fft_length));
1118       break;
1119     }
1120     case HloOpcode::kTriangularSolve: {
1121       TriangularSolveOptions options;
1122       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1123           !ParseAttributesAsProtoMessage(
1124               /*required_attrs=*/std::unordered_set<string>(), &options)) {
1125         return false;
1126       }
1127       instruction =
1128           builder->AddInstruction(HloInstruction::CreateTriangularSolve(
1129               shape, operands[0], operands[1], options));
1130       break;
1131     }
1132     case HloOpcode::kCompare: {
1133       optional<ComparisonDirection> direction;
1134       attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
1135                             &direction};
1136       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1137           !ParseAttributes(attrs)) {
1138         return false;
1139       }
1140       instruction = builder->AddInstruction(HloInstruction::CreateCompare(
1141           shape, operands[0], operands[1], *direction));
1142       break;
1143     }
1144     case HloOpcode::kCholesky: {
1145       CholeskyOptions options;
1146       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1147           !ParseAttributesAsProtoMessage(
1148               /*required_attrs=*/std::unordered_set<string>(), &options)) {
1149         return false;
1150       }
1151       instruction = builder->AddInstruction(
1152           HloInstruction::CreateCholesky(shape, operands[0], options));
1153       break;
1154     }
1155     case HloOpcode::kBroadcast: {
1156       optional<std::vector<int64>> broadcast_dimensions;
1157       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1158                              &broadcast_dimensions};
1159       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1160           !ParseAttributes(attrs)) {
1161         return false;
1162       }
1163       instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
1164           shape, operands[0], *broadcast_dimensions));
1165       break;
1166     }
1167     case HloOpcode::kConcatenate: {
1168       optional<std::vector<int64>> dimensions;
1169       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1170                              &dimensions};
1171       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1172           dimensions->size() != 1) {
1173         return false;
1174       }
1175       instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
1176           shape, operands, dimensions->at(0)));
1177       break;
1178     }
1179     case HloOpcode::kMap: {
1180       optional<HloComputation*> to_apply;
1181       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1182                            &to_apply};
1183       optional<std::vector<int64>> dimensions;
1184       attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1185                              &dimensions};
1186       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1187         return false;
1188       }
1189       instruction = builder->AddInstruction(
1190           HloInstruction::CreateMap(shape, operands, *to_apply));
1191       break;
1192     }
1193     case HloOpcode::kReduce: {
1194       auto loc = lexer_.GetLoc();
1195 
1196       optional<HloComputation*> reduce_computation;
1197       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1198                            &reduce_computation};
1199       optional<std::vector<int64>> dimensions_to_reduce;
1200       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1201                              &dimensions_to_reduce};
1202       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1203         return false;
1204       }
1205       if (operands.size() % 2) {
1206         return Error(loc, StrCat("expects an even number of operands, but has ",
1207                                  operands.size(), " operands"));
1208       }
1209       instruction = builder->AddInstruction(HloInstruction::CreateReduce(
1210           shape, /*operands=*/
1211           absl::Span<HloInstruction* const>(operands).subspan(
1212               0, operands.size() / 2),
1213           /*init_values=*/
1214           absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1215                                                               2),
1216           *dimensions_to_reduce, *reduce_computation));
1217       break;
1218     }
1219     case HloOpcode::kReverse: {
1220       optional<std::vector<int64>> dimensions;
1221       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1222                              &dimensions};
1223       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1224           !ParseAttributes(attrs)) {
1225         return false;
1226       }
1227       instruction = builder->AddInstruction(
1228           HloInstruction::CreateReverse(shape, operands[0], *dimensions));
1229       break;
1230     }
1231     case HloOpcode::kSelectAndScatter: {
1232       optional<HloComputation*> select;
1233       attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
1234       optional<HloComputation*> scatter;
1235       attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
1236       optional<Window> window;
1237       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1238       if (!ParseOperands(&operands, /*expected_size=*/3) ||
1239           !ParseAttributes(attrs)) {
1240         return false;
1241       }
1242       if (!window) {
1243         window.emplace();
1244       }
1245       instruction =
1246           builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
1247               shape, /*operand=*/operands[0], *select, *window,
1248               /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
1249       break;
1250     }
1251     case HloOpcode::kSlice: {
1252       optional<SliceRanges> slice_ranges;
1253       attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
1254       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1255           !ParseAttributes(attrs)) {
1256         return false;
1257       }
1258       instruction = builder->AddInstruction(HloInstruction::CreateSlice(
1259           shape, operands[0], slice_ranges->starts, slice_ranges->limits,
1260           slice_ranges->strides));
1261       break;
1262     }
1263     case HloOpcode::kDynamicSlice: {
1264       optional<std::vector<int64>> dynamic_slice_sizes;
1265       attrs["dynamic_slice_sizes"] = {
1266           /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
1267       LocTy loc = lexer_.GetLoc();
1268       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1269         return false;
1270       }
1271       if (operands.empty()) {
1272         return Error(loc, "Expected at least one operand.");
1273       }
1274       if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
1275           operands.size() != 1 + operands[0]->shape().rank()) {
1276         return Error(loc, "Wrong number of operands.");
1277       }
1278       instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
1279           shape, /*operand=*/operands[0],
1280           /*start_indices=*/absl::MakeSpan(operands).subspan(1),
1281           *dynamic_slice_sizes));
1282       break;
1283     }
1284     case HloOpcode::kDynamicUpdateSlice: {
1285       LocTy loc = lexer_.GetLoc();
1286       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1287         return false;
1288       }
1289       if (operands.size() < 2) {
1290         return Error(loc, "Expected at least two operands.");
1291       }
1292       if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
1293           operands.size() != 2 + operands[0]->shape().rank()) {
1294         return Error(loc, "Wrong number of operands.");
1295       }
1296       instruction =
1297           builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1298               shape, /*operand=*/operands[0], /*update=*/operands[1],
1299               /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
1300       break;
1301     }
1302     case HloOpcode::kTranspose: {
1303       optional<std::vector<int64>> dimensions;
1304       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1305                              &dimensions};
1306       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1307           !ParseAttributes(attrs)) {
1308         return false;
1309       }
1310       instruction = builder->AddInstruction(
1311           HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
1312       break;
1313     }
1314     case HloOpcode::kBatchNormTraining: {
1315       optional<float> epsilon;
1316       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
1317       optional<int64> feature_index;
1318       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
1319                                 &feature_index};
1320       if (!ParseOperands(&operands, /*expected_size=*/3) ||
1321           !ParseAttributes(attrs)) {
1322         return false;
1323       }
1324       instruction =
1325           builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
1326               shape, /*operand=*/operands[0], /*scale=*/operands[1],
1327               /*offset=*/operands[2], *epsilon, *feature_index));
1328       break;
1329     }
1330     case HloOpcode::kBatchNormInference: {
1331       optional<float> epsilon;
1332       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
1333       optional<int64> feature_index;
1334       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
1335                                 &feature_index};
1336       if (!ParseOperands(&operands, /*expected_size=*/5) ||
1337           !ParseAttributes(attrs)) {
1338         return false;
1339       }
1340       instruction =
1341           builder->AddInstruction(HloInstruction::CreateBatchNormInference(
1342               shape, /*operand=*/operands[0], /*scale=*/operands[1],
1343               /*offset=*/operands[2], /*mean=*/operands[3],
1344               /*variance=*/operands[4], *epsilon, *feature_index));
1345       break;
1346     }
1347     case HloOpcode::kBatchNormGrad: {
1348       optional<float> epsilon;
1349       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
1350       optional<int64> feature_index;
1351       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
1352                                 &feature_index};
1353       if (!ParseOperands(&operands, /*expected_size=*/5) ||
1354           !ParseAttributes(attrs)) {
1355         return false;
1356       }
1357       instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
1358           shape, /*operand=*/operands[0], /*scale=*/operands[1],
1359           /*mean=*/operands[2], /*variance=*/operands[3],
1360           /*grad_output=*/operands[4], *epsilon, *feature_index));
1361       break;
1362     }
1363     case HloOpcode::kPad: {
1364       optional<PaddingConfig> padding;
1365       attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
1366       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1367           !ParseAttributes(attrs)) {
1368         return false;
1369       }
1370       instruction = builder->AddInstruction(HloInstruction::CreatePad(
1371           shape, operands[0], /*padding_value=*/operands[1], *padding));
1372       break;
1373     }
1374     case HloOpcode::kFusion: {
1375       optional<HloComputation*> fusion_computation;
1376       attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
1377                         &fusion_computation};
1378       optional<HloInstruction::FusionKind> fusion_kind;
1379       attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
1380       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1381         return false;
1382       }
1383       instruction = builder->AddInstruction(HloInstruction::CreateFusion(
1384           shape, *fusion_kind, operands, *fusion_computation));
1385       break;
1386     }
1387     case HloOpcode::kInfeed: {
1388       optional<string> config;
1389       attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
1390       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1391           !ParseAttributes(attrs)) {
1392         return false;
1393       }
1394       // We need to know the infeed data shape to construct the infeed
1395       // instruction. This is the zero-th element of the tuple-shaped output of
1396       // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
1397       // if the shape is not a non-empty tuple, so add guard so an error message
1398       // can be emitted instead of a check fail
1399       if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
1400         return Error(lexer_.GetLoc(),
1401                      "infeed must have a non-empty tuple shape");
1402       }
1403       instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
1404           ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
1405           config ? *config : ""));
1406       break;
1407     }
1408     case HloOpcode::kOutfeed: {
1409       optional<string> config;
1410       attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
1411       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1412           !ParseAttributes(attrs)) {
1413         return false;
1414       }
1415       instruction = builder->AddInstruction(
1416           HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
1417                                         operands[1], config ? *config : ""));
1418       break;
1419     }
1420     case HloOpcode::kRng: {
1421       optional<RandomDistribution> distribution;
1422       attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
1423                                &distribution};
1424       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1425         return false;
1426       }
1427       instruction = builder->AddInstruction(
1428           HloInstruction::CreateRng(shape, *distribution, operands));
1429       break;
1430     }
1431     case HloOpcode::kReducePrecision: {
1432       optional<int64> exponent_bits;
1433       optional<int64> mantissa_bits;
1434       attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
1435                                 &exponent_bits};
1436       attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
1437                                 &mantissa_bits};
1438       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1439           !ParseAttributes(attrs)) {
1440         return false;
1441       }
1442       instruction =
1443           builder->AddInstruction(HloInstruction::CreateReducePrecision(
1444               shape, operands[0], static_cast<int>(*exponent_bits),
1445               static_cast<int>(*mantissa_bits)));
1446       break;
1447     }
1448     case HloOpcode::kConditional: {
1449       optional<HloComputation*> true_computation;
1450       optional<HloComputation*> false_computation;
1451       optional<std::vector<HloComputation*>> branch_computations;
1452       if (!ParseOperands(&operands)) {
1453         return false;
1454       }
1455       const bool branch_index_is_bool =
1456           operands[0]->shape().element_type() == PRED;
1457       if (branch_index_is_bool) {
1458         attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
1459                                      &true_computation};
1460         attrs["false_computation"] = {
1461             /*required=*/true, AttrTy::kHloComputation, &false_computation};
1462       } else {
1463         attrs["branch_computations"] = {/*required=*/true,
1464                                         AttrTy::kBracedHloComputationList,
1465                                         &branch_computations};
1466       }
1467       if (!ParseAttributes(attrs)) {
1468         return false;
1469       }
1470       if (branch_index_is_bool) {
1471         branch_computations.emplace({*true_computation, *false_computation});
1472       }
1473       if (branch_computations->empty() ||
1474           operands.size() != branch_computations->size() + 1) {
1475         return false;
1476       }
1477       instruction = builder->AddInstruction(HloInstruction::CreateConditional(
1478           shape, /*branch_index=*/operands[0],
1479           absl::MakeSpan(*branch_computations),
1480           absl::MakeSpan(operands).subspan(1)));
1481       break;
1482     }
1483     case HloOpcode::kCustomCall: {
1484       optional<string> custom_call_target;
1485       optional<string> opaque;
1486       optional<Window> window;
1487       optional<ConvolutionDimensionNumbers> dnums;
1488       optional<int64> feature_group_count;
1489       optional<int64> batch_group_count;
1490       optional<std::vector<Shape>> operand_layout_constraints;
1491       attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
1492                                      &custom_call_target};
1493       attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
1494       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1495       attrs["dim_labels"] = {/*required=*/false,
1496                              AttrTy::kConvolutionDimensionNumbers, &dnums};
1497       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1498                                       &feature_group_count};
1499       attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1500                                     &batch_group_count};
1501       attrs["operand_layout_constraints"] = {
1502           /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
1503       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1504         return false;
1505       }
1506       if (operand_layout_constraints.has_value()) {
1507         if (!LayoutUtil::HasLayout(shape)) {
1508           return Error(lexer_.GetLoc(),
1509                        "Layout must be set on layout-constrained custom call");
1510         }
1511         if (operands.size() != operand_layout_constraints->size()) {
1512           return Error(lexer_.GetLoc(),
1513                        StrCat("Expected ", operands.size(),
1514                               " operand layout constraints, ",
1515                               operand_layout_constraints->size(), " given"));
1516         }
1517         for (int64 i = 0; i < operands.size(); ++i) {
1518           const Shape& operand_shape_with_layout =
1519               (*operand_layout_constraints)[i];
1520           if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
1521             return Error(lexer_.GetLoc(),
1522                          StrCat("Operand layout constraint shape ",
1523                                 ShapeUtil::HumanStringWithLayout(
1524                                     operand_shape_with_layout),
1525                                 " for operand ", i, " does not have a layout"));
1526           }
1527           if (!ShapeUtil::Compatible(operand_shape_with_layout,
1528                                      operands[i]->shape())) {
1529             return Error(
1530                 lexer_.GetLoc(),
1531                 StrCat(
1532                     "Operand layout constraint shape ",
1533                     ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
1534                     " for operand ", i,
1535                     " is not compatible with operand shape ",
1536                     ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
1537           }
1538         }
1539         instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
1540             shape, operands, *custom_call_target, *operand_layout_constraints,
1541             opaque.has_value() ? *opaque : ""));
1542       } else {
1543         instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
1544             shape, operands, *custom_call_target,
1545             opaque.has_value() ? *opaque : ""));
1546       }
1547       if (window.has_value()) {
1548         instruction->set_window(*window);
1549       }
1550       if (dnums.has_value()) {
1551         instruction->set_convolution_dimension_numbers(*dnums);
1552       }
1553       if (feature_group_count.has_value()) {
1554         instruction->set_feature_group_count(*feature_group_count);
1555       }
1556       if (batch_group_count.has_value()) {
1557         instruction->set_batch_group_count(*batch_group_count);
1558       }
1559       break;
1560     }
1561     case HloOpcode::kDot: {
1562       optional<std::vector<int64>> lhs_contracting_dims;
1563       attrs["lhs_contracting_dims"] = {
1564           /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
1565       optional<std::vector<int64>> rhs_contracting_dims;
1566       attrs["rhs_contracting_dims"] = {
1567           /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
1568       optional<std::vector<int64>> lhs_batch_dims;
1569       attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
1570                                  &lhs_batch_dims};
1571       optional<std::vector<int64>> rhs_batch_dims;
1572       attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
1573                                  &rhs_batch_dims};
1574       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1575       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1576                                     &operand_precision};
1577 
1578       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1579           !ParseAttributes(attrs)) {
1580         return false;
1581       }
1582 
1583       DotDimensionNumbers dnum;
1584       if (lhs_contracting_dims) {
1585         *dnum.mutable_lhs_contracting_dimensions() = {
1586             lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
1587       }
1588       if (rhs_contracting_dims) {
1589         *dnum.mutable_rhs_contracting_dimensions() = {
1590             rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
1591       }
1592       if (lhs_batch_dims) {
1593         *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
1594                                                 lhs_batch_dims->end()};
1595       }
1596       if (rhs_batch_dims) {
1597         *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
1598                                                 rhs_batch_dims->end()};
1599       }
1600 
1601       PrecisionConfig precision_config;
1602       if (operand_precision) {
1603         *precision_config.mutable_operand_precision() = {
1604             operand_precision->begin(), operand_precision->end()};
1605       } else {
1606         precision_config.mutable_operand_precision()->Resize(
1607             operands.size(), PrecisionConfig::DEFAULT);
1608       }
1609 
1610       instruction = builder->AddInstruction(HloInstruction::CreateDot(
1611           shape, operands[0], operands[1], dnum, precision_config));
1612       break;
1613     }
1614     case HloOpcode::kGather: {
1615       optional<std::vector<int64>> offset_dims;
1616       attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
1617                               &offset_dims};
1618       optional<std::vector<int64>> collapsed_slice_dims;
1619       attrs["collapsed_slice_dims"] = {
1620           /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
1621       optional<std::vector<int64>> start_index_map;
1622       attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
1623                                   &start_index_map};
1624       optional<int64> index_vector_dim;
1625       attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
1626                                    &index_vector_dim};
1627       optional<std::vector<int64>> slice_sizes;
1628       attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
1629                               &slice_sizes};
1630 
1631       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1632           !ParseAttributes(attrs)) {
1633         return false;
1634       }
1635 
1636       GatherDimensionNumbers dim_numbers =
1637           HloGatherInstruction::MakeGatherDimNumbers(
1638               /*offset_dims=*/*offset_dims,
1639               /*collapsed_slice_dims=*/*collapsed_slice_dims,
1640               /*start_index_map=*/*start_index_map,
1641               /*index_vector_dim=*/*index_vector_dim);
1642 
1643       instruction = builder->AddInstruction(HloInstruction::CreateGather(
1644           shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
1645           dim_numbers, *slice_sizes));
1646       break;
1647     }
1648     case HloOpcode::kScatter: {
1649       optional<std::vector<int64>> update_window_dims;
1650       attrs["update_window_dims"] = {
1651           /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
1652       optional<std::vector<int64>> inserted_window_dims;
1653       attrs["inserted_window_dims"] = {
1654           /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
1655       optional<std::vector<int64>> scatter_dims_to_operand_dims;
1656       attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
1657                                                AttrTy::kBracedInt64List,
1658                                                &scatter_dims_to_operand_dims};
1659       optional<int64> index_vector_dim;
1660       attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
1661                                    &index_vector_dim};
1662 
1663       optional<HloComputation*> update_computation;
1664       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1665                            &update_computation};
1666 
1667       if (!ParseOperands(&operands, /*expected_size=*/3) ||
1668           !ParseAttributes(attrs)) {
1669         return false;
1670       }
1671 
1672       ScatterDimensionNumbers dim_numbers =
1673           HloScatterInstruction::MakeScatterDimNumbers(
1674               /*update_window_dims=*/*update_window_dims,
1675               /*inserted_window_dims=*/*inserted_window_dims,
1676               /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
1677               /*index_vector_dim=*/*index_vector_dim);
1678 
1679       instruction = builder->AddInstruction(HloInstruction::CreateScatter(
1680           shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
1681           /*updates=*/operands[2], *update_computation, dim_numbers));
1682       break;
1683     }
1684     case HloOpcode::kDomain: {
1685       DomainData domain;
1686       attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
1687       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1688           !ParseAttributes(attrs)) {
1689         return false;
1690       }
1691       instruction = builder->AddInstruction(HloInstruction::CreateDomain(
1692           shape, operands[0], std::move(domain.exit_metadata),
1693           std::move(domain.entry_metadata)));
1694       break;
1695     }
1696     case HloOpcode::kTrace:
1697       return TokenError(StrCat("parsing not yet implemented for op: ",
1698                                HloOpcodeString(opcode)));
1699     case HloOpcode::kGetDimensionSize:
1700       optional<std::vector<int64>> dimensions;
1701       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1702                              &dimensions};
1703       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1704           !ParseAttributes(attrs)) {
1705         return false;
1706       }
1707       instruction =
1708           builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
1709               shape, operands[0], (*dimensions)[0]));
1710       break;
1711   }
1712 
1713   instruction->SetAndSanitizeName(name);
1714   if (instruction->name() != name) {
1715     return Error(name_loc,
1716                  StrCat("illegal instruction name: ", name,
1717                         "; suggest renaming to: ", instruction->name()));
1718   }
1719 
1720   // Add shared attributes like metadata to the instruction, if they were seen.
1721   if (sharding) {
1722     instruction->set_sharding(
1723         HloSharding::FromProto(sharding.value()).ValueOrDie());
1724   }
1725   if (parameter_replication) {
1726     int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
1727     const auto& replicated =
1728         parameter_replication->replicated_at_leaf_buffers();
1729     if (leaf_count != replicated.size()) {
1730       return Error(lexer_.GetLoc(),
1731                    StrCat("parameter has ", leaf_count,
1732                           " leaf buffers, but parameter_replication has ",
1733                           replicated.size(), " elements."));
1734     }
1735     instruction->set_parameter_replicated_at_leaf_buffers(replicated);
1736   }
1737   if (predecessors) {
1738     for (auto* pre : *predecessors) {
1739       Status status = pre->AddControlDependencyTo(instruction);
1740       if (!status.ok()) {
1741         return Error(name_loc, StrCat("error adding control dependency for: ",
1742                                       name, " status: ", status.ToString()));
1743       }
1744     }
1745   }
1746   if (metadata) {
1747     instruction->set_metadata(*metadata);
1748   }
1749   if (backend_config) {
1750     instruction->set_raw_backend_config_string(std::move(*backend_config));
1751   }
1752   return AddInstruction(name, instruction, name_loc);
1753 }  // NOLINT(readability/fn_size)
1754 
1755 // ::= '{' (single_sharding | tuple_sharding) '}'
1756 //
1757 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)1758 bool HloParser::ParseSharding(OpSharding* sharding) {
1759   // A single sharding starts with '{' and is not followed by '{'.
1760   // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
1761   // an empty tuple.
1762   if (!ParseToken(TokKind::kLbrace,
1763                   "expected '{' to start sharding attribute")) {
1764     return false;
1765   }
1766 
1767   if (lexer_.GetKind() != TokKind::kLbrace &&
1768       lexer_.GetKind() != TokKind::kRbrace) {
1769     return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
1770   }
1771 
1772   // Tuple sharding.
1773   // Allow empty tuple shardings.
1774   if (lexer_.GetKind() != TokKind::kRbrace) {
1775     do {
1776       if (!ParseSingleSharding(sharding->add_tuple_shardings(),
1777                                /*lbrace_pre_lexed=*/false)) {
1778         return false;
1779       }
1780     } while (EatIfPresent(TokKind::kComma));
1781   }
1782   sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE);
1783 
1784   return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
1785 }
1786 
1787 //  ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
1788 //          ('devices=' ('[' dims ']')* device_list)? '}'
1789 // dims ::= int_list device_list ::= int_list
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)1790 bool HloParser::ParseSingleSharding(OpSharding* sharding,
1791                                     bool lbrace_pre_lexed) {
1792   if (!lbrace_pre_lexed &&
1793       !ParseToken(TokKind::kLbrace,
1794                   "expected '{' to start sharding attribute")) {
1795     return false;
1796   }
1797 
1798   LocTy loc = lexer_.GetLoc();
1799   bool maximal = false;
1800   bool replicated = false;
1801   std::vector<int64> devices;
1802   std::vector<int64> tile_assignment_dimensions;
1803   while (lexer_.GetKind() != TokKind::kRbrace) {
1804     switch (lexer_.GetKind()) {
1805       case TokKind::kw_maximal:
1806         maximal = true;
1807         lexer_.Lex();
1808         break;
1809       case TokKind::kw_replicated:
1810         replicated = true;
1811         lexer_.Lex();
1812         break;
1813       case TokKind::kAttributeName: {
1814         if (lexer_.GetStrVal() == "device") {
1815           if (lexer_.Lex() != TokKind::kInt) {
1816             return TokenError("device= attribute must be an integer");
1817           }
1818           devices = {lexer_.GetInt64Val()};
1819           lexer_.Lex();
1820         } else if (lexer_.GetStrVal() == "devices") {
1821           lexer_.Lex();
1822           if (!ParseToken(TokKind::kLsquare,
1823                           "expected '[' to start sharding devices shape")) {
1824             return false;
1825           }
1826 
1827           do {
1828             int64 dim;
1829             if (!ParseInt64(&dim)) {
1830               return false;
1831             }
1832             tile_assignment_dimensions.push_back(dim);
1833           } while (EatIfPresent(TokKind::kComma));
1834 
1835           if (!ParseToken(TokKind::kRsquare,
1836                           "expected ']' to start sharding devices shape")) {
1837             return false;
1838           }
1839           do {
1840             int64 device;
1841             if (!ParseInt64(&device)) {
1842               return false;
1843             }
1844             devices.push_back(device);
1845           } while (EatIfPresent(TokKind::kComma));
1846         } else {
1847           return TokenError(
1848               "unknown attribute in sharding: expected device= or devices=");
1849         }
1850         break;
1851       }
1852       case TokKind::kRbrace:
1853         break;
1854       default:
1855         return TokenError("unexpected token");
1856     }
1857   }
1858 
1859   if (replicated) {
1860     if (!devices.empty()) {
1861       return Error(loc,
1862                    "replicated shardings should not have any devices assigned");
1863     }
1864     sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
1865   } else if (maximal) {
1866     if (devices.size() != 1) {
1867       return Error(loc,
1868                    "maximal shardings should have exactly one device assigned");
1869     }
1870     sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
1871     sharding->add_tile_assignment_devices(devices[0]);
1872   } else {
1873     if (devices.size() <= 1) {
1874       return Error(
1875           loc, "non-maximal shardings must have more than one device assigned");
1876     }
1877     if (tile_assignment_dimensions.empty()) {
1878       return Error(
1879           loc,
1880           "non-maximal shardings must have a tile assignment list including "
1881           "dimensions");
1882     }
1883     sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
1884     for (int64 dim : tile_assignment_dimensions) {
1885       sharding->add_tile_assignment_dimensions(dim);
1886     }
1887     for (int64 device : devices) {
1888       sharding->add_tile_assignment_devices(device);
1889     }
1890   }
1891 
1892   lexer_.Lex();
1893   return true;
1894 }
1895 
1896 // parameter_replication ::=
1897 //   '{' ('true' | 'false')* (',' ('true' | 'false'))*  '}'
ParseParameterReplication(ParameterReplication * parameter_replication)1898 bool HloParser::ParseParameterReplication(
1899     ParameterReplication* parameter_replication) {
1900   if (!ParseToken(TokKind::kLbrace,
1901                   "expected '{' to start parameter_replication attribute")) {
1902     return false;
1903   }
1904 
1905   if (lexer_.GetKind() != TokKind::kRbrace) {
1906     do {
1907       if (lexer_.GetKind() == TokKind::kw_true) {
1908         parameter_replication->add_replicated_at_leaf_buffers(true);
1909       } else if (lexer_.GetKind() == TokKind::kw_false) {
1910         parameter_replication->add_replicated_at_leaf_buffers(false);
1911       } else {
1912         return false;
1913       }
1914       lexer_.Lex();
1915     } while (EatIfPresent(TokKind::kComma));
1916   }
1917 
1918   return ParseToken(TokKind::kRbrace,
1919                     "expected '}' to end parameter_replication attribute");
1920 }
1921 
1922 // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
1923 //            'exit=' exit_sharding '}'
ParseDomain(DomainData * domain)1924 bool HloParser::ParseDomain(DomainData* domain) {
1925   std::unordered_map<string, AttrConfig> attrs;
1926   optional<string> kind;
1927   optional<OpSharding> entry_sharding;
1928   optional<OpSharding> exit_sharding;
1929   attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
1930   attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
1931   attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
1932   if (!ParseSubAttributes(attrs)) {
1933     return false;
1934   }
1935   if (*kind == ShardingMetadata::KindName()) {
1936     auto entry_sharding_ptr = absl::make_unique<HloSharding>(
1937         HloSharding::FromProto(*entry_sharding).ValueOrDie());
1938     auto exit_sharding_ptr = absl::make_unique<HloSharding>(
1939         HloSharding::FromProto(*exit_sharding).ValueOrDie());
1940     domain->entry_metadata =
1941         absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
1942     domain->exit_metadata =
1943         absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
1944   } else {
1945     return TokenError(StrCat("unsupported domain kind: ", *kind));
1946   }
1947   return true;
1948 }
1949 
1950 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)1951 bool HloParser::ParseInstructionNames(
1952     std::vector<HloInstruction*>* instructions) {
1953   if (!ParseToken(TokKind::kLbrace,
1954                   "expects '{' at the beginning of instruction name list")) {
1955     return false;
1956   }
1957   LocTy loc = lexer_.GetLoc();
1958   do {
1959     string name;
1960     if (!ParseName(&name)) {
1961       return Error(loc, "expects a instruction name");
1962     }
1963     std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
1964     if (!instr) {
1965       return TokenError(StrFormat("instruction '%s' is not defined", name));
1966     }
1967     instructions->push_back(instr->first);
1968   } while (EatIfPresent(TokKind::kComma));
1969 
1970   return ParseToken(TokKind::kRbrace,
1971                     "expects '}' at the end of instruction name list");
1972 }
1973 
SetValueInLiteral(LocTy loc,int64 value,LinearOrMultiIndex index,Literal * literal)1974 bool HloParser::SetValueInLiteral(LocTy loc, int64 value,
1975                                   LinearOrMultiIndex index, Literal* literal) {
1976   const Shape& shape = literal->shape();
1977   switch (shape.element_type()) {
1978     case S8:
1979       return SetValueInLiteralHelper<int8>(loc, value, index, literal);
1980     case S16:
1981       return SetValueInLiteralHelper<int16>(loc, value, index, literal);
1982     case S32:
1983       return SetValueInLiteralHelper<int32>(loc, value, index, literal);
1984     case S64:
1985       return SetValueInLiteralHelper<int64>(loc, value, index, literal);
1986     case U8:
1987       return SetValueInLiteralHelper<tensorflow::uint8>(loc, value, index,
1988                                                         literal);
1989     case U16:
1990       return SetValueInLiteralHelper<tensorflow::uint16>(loc, value, index,
1991                                                          literal);
1992     case U32:
1993       return SetValueInLiteralHelper<tensorflow::uint32>(loc, value, index,
1994                                                          literal);
1995     case U64:
1996       return SetValueInLiteralHelper<tensorflow::uint64>(loc, value, index,
1997                                                          literal);
1998     case PRED:
1999       // Bool type literals with rank >= 1 are printed in 0s and 1s.
2000       return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
2001                                            literal);
2002     default:
2003       LOG(FATAL) << "unknown integral primitive type "
2004                  << PrimitiveType_Name(shape.element_type());
2005   }
2006 }
2007 
SetValueInLiteral(LocTy loc,double value,LinearOrMultiIndex index,Literal * literal)2008 bool HloParser::SetValueInLiteral(LocTy loc, double value,
2009                                   LinearOrMultiIndex index, Literal* literal) {
2010   const Shape& shape = literal->shape();
2011   switch (shape.element_type()) {
2012     case F16:
2013       return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
2014     case BF16:
2015       return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
2016                                                            literal);
2017     case F32:
2018       return SetValueInLiteralHelper<float>(loc, value, index, literal);
2019     case F64:
2020       return SetValueInLiteralHelper<double>(loc, value, index, literal);
2021     default:
2022       LOG(FATAL) << "unknown floating point primitive type "
2023                  << PrimitiveType_Name(shape.element_type());
2024   }
2025 }
2026 
SetValueInLiteral(LocTy loc,bool value,LinearOrMultiIndex index,Literal * literal)2027 bool HloParser::SetValueInLiteral(LocTy loc, bool value,
2028                                   LinearOrMultiIndex index, Literal* literal) {
2029   const Shape& shape = literal->shape();
2030   switch (shape.element_type()) {
2031     case PRED:
2032       return SetValueInLiteralHelper<bool>(loc, value, index, literal);
2033     default:
2034       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
2035                  << " is not PRED type";
2036   }
2037 }
2038 
SetValueInLiteral(LocTy loc,std::complex<double> value,LinearOrMultiIndex index,Literal * literal)2039 bool HloParser::SetValueInLiteral(LocTy loc, std::complex<double> value,
2040                                   LinearOrMultiIndex index, Literal* literal) {
2041   const Shape& shape = literal->shape();
2042   switch (shape.element_type()) {
2043     case C64:
2044       return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
2045                                                           literal);
2046     case C128:
2047       return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
2048                                                            literal);
2049     default:
2050       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
2051                  << " is not a complex type type";
2052   }
2053 }
2054 
2055 template <typename T>
StringifyValue(T val)2056 string StringifyValue(T val) {
2057   return StrCat(val);
2058 }
2059 template <>
StringifyValue(std::complex<double> val)2060 string StringifyValue(std::complex<double> val) {
2061   return StrFormat("(%f, %f)", std::real(val), std::imag(val));
2062 }
2063 
2064 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(LocTy loc,ParsedElemT value,LinearOrMultiIndex index,Literal * literal)2065 bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
2066                                         LinearOrMultiIndex index,
2067                                         Literal* literal) {
2068   if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
2069     return false;
2070   }
2071 
2072   // Check that the index is in range and assign into the literal
2073   if (auto* linear_index = absl::get_if<int64>(&index)) {
2074     if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
2075       return Error(loc, StrCat("trys to set value ", StringifyValue(value),
2076                                " to a literal in shape ",
2077                                ShapeUtil::HumanString(literal->shape()),
2078                                " at linear index ", *linear_index,
2079                                ", but the index is out of range"));
2080     }
2081     literal->data<LiteralNativeT>().at(*linear_index) =
2082         static_cast<LiteralNativeT>(value);
2083   } else {
2084     auto* multi_index = absl::get_if<absl::Span<const int64>>(&index);
2085     CHECK(multi_index != nullptr);
2086 
2087     auto invalid_idx = [&](string msg) {
2088       return Error(loc, StrFormat("Invalid sparse index [%s]. %s",
2089                                   absl::StrJoin(*multi_index, ", "), msg));
2090     };
2091 
2092     const auto& shape = literal->shape();
2093     if (shape.rank() != multi_index->size()) {
2094       return invalid_idx(
2095           StrFormat("Has rank %d, but constant has shape %s, which has rank %d",
2096                     multi_index->size(), shape.ToString(), shape.rank()));
2097     }
2098     for (int64 i = 0; i < shape.rank(); ++i) {
2099       auto idx = (*multi_index)[i];
2100       if (idx < 0) {
2101         return invalid_idx(StrFormat(
2102             "Sub-index value at %d, namely %d, cannot be negative.", i, idx));
2103       }
2104       if (idx >= shape.dimensions(i)) {
2105         return invalid_idx(
2106             StrFormat("Sub-index at %d, namely %d, doesn't fit within shape "
2107                       "dimension %d in %s",
2108                       i, idx, shape.dimensions(i), shape.ToString()));
2109       }
2110     }
2111     literal->AppendSparseElement(*multi_index,
2112                                  static_cast<LiteralNativeT>(value));
2113   }
2114   return true;
2115 }
2116 
2117 // literal
2118 //  ::= tuple
2119 //  ::= non_tuple
ParseLiteral(Literal * literal,const Shape & shape)2120 bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
2121   return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
2122                          : ParseNonTupleLiteral(literal, shape);
2123 }
2124 
2125 // tuple
2126 //  ::= shape '(' literal_list ')'
2127 // literal_list
2128 //  ::= /*empty*/
2129 //  ::= literal (',' literal)*
ParseTupleLiteral(Literal * literal,const Shape & shape)2130 bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
2131   if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
2132     return false;
2133   }
2134   std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
2135 
2136   if (lexer_.GetKind() == TokKind::kRparen) {
2137     // empty
2138   } else {
2139     // literal, (',' literal)*
2140     for (int i = 0; i < elements.size(); i++) {
2141       if (i > 0) {
2142         ParseToken(TokKind::kComma, "exepcts ',' to separate tuple elements");
2143       }
2144       if (!ParseLiteral(&elements[i],
2145                         ShapeUtil::GetTupleElementShape(shape, i))) {
2146         return TokenError(StrCat("expects the ", i, "th element"));
2147       }
2148     }
2149   }
2150   *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
2151   return ParseToken(TokKind::kRparen,
2152                     StrCat("expects ')' at the end of the tuple with ",
2153                            ShapeUtil::TupleElementCount(shape), "elements"));
2154 }
2155 
2156 // non_tuple
2157 //   ::= rank01
2158 //   ::= rank2345
2159 // rank2345 ::= shape sparse_or_nested_array
ParseNonTupleLiteral(Literal * literal,const Shape & shape)2160 bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
2161   if (LayoutUtil::IsSparseArray(shape)) {
2162     return ParseSparseLiteral(literal, shape);
2163   }
2164 
2165   CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
2166   return ParseDenseLiteral(literal, shape);
2167 }
2168 
ParseDenseLiteral(Literal * literal,const Shape & shape)2169 bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
2170   // Cast `rank` to int because we call shape.dimensions(int rank) below, and if
2171   // `rank` is an int64, that's an implicit narrowing conversion, which is
2172   // implementation-defined behavior.
2173   const int rank = static_cast<int>(shape.rank());
2174 
2175   // Create a literal with the given shape in default layout.
2176   *literal = LiteralUtil::CreateFromDimensions(
2177       shape.element_type(), AsInt64Slice(shape.dimensions()));
2178   int64 nest_level = 0;
2179   int64 linear_index = 0;
2180   // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
2181   // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
2182   // when we are parsing the 2nd '{' (right before '1'), we are seeing a
2183   // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
2184   // the first '}' (right after '3'), it means the sub-array ends, and the
2185   // sub-array is supposed to contain exactly 3 elements, so check if
2186   // elems_seen_per_dim[1] is 3.
2187   std::vector<int64> elems_seen_per_dim(rank);
2188   auto get_index_str = [&elems_seen_per_dim](int dim) -> string {
2189     std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
2190                                             elems_seen_per_dim.begin() + dim);
2191     return StrCat("[",
2192                   StrJoin(elems_seen_until_dim, ",",
2193                           [](string* out, const int64& num_elems) {
2194                             StrAppend(out, num_elems - 1);
2195                           }),
2196                   "]");
2197   };
2198 
2199   auto add_one_elem_seen = [&] {
2200     if (rank > 0) {
2201       if (nest_level != rank) {
2202         return TokenError(absl::StrFormat(
2203             "expects nested array in rank %d, but sees %d", rank, nest_level));
2204       }
2205       elems_seen_per_dim[rank - 1]++;
2206       if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
2207         return TokenError(absl::StrFormat(
2208             "expects %d elements on the minor-most dimension, but "
2209             "sees more",
2210             shape.dimensions(rank - 1)));
2211       }
2212     }
2213     return true;
2214   };
2215 
2216   do {
2217     switch (lexer_.GetKind()) {
2218       default:
2219         return TokenError("unexpected token type in a literal");
2220       case TokKind::kLbrace: {
2221         nest_level++;
2222         if (nest_level > rank) {
2223           return TokenError(absl::StrFormat(
2224               "expects nested array in rank %d, but sees larger", rank));
2225         }
2226         if (nest_level > 1) {
2227           elems_seen_per_dim[nest_level - 2]++;
2228           if (elems_seen_per_dim[nest_level - 2] >
2229               shape.dimensions(nest_level - 2)) {
2230             return TokenError(absl::StrFormat(
2231                 "expects %d elements in the %sth element, but sees more",
2232                 shape.dimensions(nest_level - 2),
2233                 get_index_str(nest_level - 2)));
2234           }
2235         }
2236         lexer_.Lex();
2237         break;
2238       }
2239       case TokKind::kRbrace: {
2240         nest_level--;
2241         if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
2242           return TokenError(absl::StrFormat(
2243               "expects %d elements in the %sth element, but sees %d",
2244               shape.dimensions(nest_level), get_index_str(nest_level),
2245               elems_seen_per_dim[nest_level]));
2246         }
2247         elems_seen_per_dim[nest_level] = 0;
2248         lexer_.Lex();
2249         break;
2250       }
2251       case TokKind::kLparen: {
2252         if (!primitive_util::IsComplexType(shape.element_type())) {
2253           return TokenError(
2254               absl::StrFormat("unexpected '(' in literal.  Parens are only "
2255                               "valid for complex literals"));
2256         }
2257 
2258         std::complex<double> value;
2259         LocTy loc = lexer_.GetLoc();
2260         if (!add_one_elem_seen() || !ParseComplex(&value) ||
2261             !SetValueInLiteral(loc, value, linear_index++, literal)) {
2262           return false;
2263         }
2264         break;
2265       }
2266       case TokKind::kDots: {
2267         if (nest_level != 1) {
2268           return TokenError(absl::StrFormat(
2269               "expects `...` at nest level 1, but sees it at nest level %d",
2270               nest_level));
2271         }
2272         elems_seen_per_dim[0] = shape.dimensions(0);
2273         lexer_.Lex();
2274         break;
2275       }
2276       case TokKind::kComma:
2277         // Skip.
2278         lexer_.Lex();
2279         break;
2280       case TokKind::kw_true:
2281       case TokKind::kw_false:
2282       case TokKind::kInt:
2283       case TokKind::kDecimal:
2284       case TokKind::kw_nan:
2285       case TokKind::kw_inf:
2286       case TokKind::kNegInf: {
2287         add_one_elem_seen();
2288         if (lexer_.GetKind() == TokKind::kw_true ||
2289             lexer_.GetKind() == TokKind::kw_false) {
2290           if (!SetValueInLiteral(lexer_.GetLoc(),
2291                                  lexer_.GetKind() == TokKind::kw_true,
2292                                  linear_index++, literal)) {
2293             return false;
2294           }
2295           lexer_.Lex();
2296         } else if (primitive_util::IsIntegralType(shape.element_type()) ||
2297                    shape.element_type() == PRED) {
2298           LocTy loc = lexer_.GetLoc();
2299           int64 value;
2300           if (!ParseInt64(&value)) {
2301             return Error(loc, StrCat("expects integer for primitive type: ",
2302                                      PrimitiveType_Name(shape.element_type())));
2303           }
2304           if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
2305             return false;
2306           }
2307         } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
2308           LocTy loc = lexer_.GetLoc();
2309           double value;
2310           if (!ParseDouble(&value)) {
2311             return Error(
2312                 loc, StrCat("expect floating point value for primitive type: ",
2313                             PrimitiveType_Name(shape.element_type())));
2314           }
2315           if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
2316             return false;
2317           }
2318         } else {
2319           return TokenError(StrCat("unsupported primitive type ",
2320                                    PrimitiveType_Name(shape.element_type())));
2321         }
2322         break;
2323       }
2324     }  // end of switch
2325   } while (nest_level > 0);
2326 
2327   *literal = literal->Relayout(shape.layout());
2328   return true;
2329 }
2330 
ParseSparseLiteral(Literal * literal,const Shape & shape)2331 bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
2332   *literal = Literal(shape);
2333   if (!ParseToken(TokKind::kLbrace,
2334                   "expects '{' at the beginning of a sparse literal")) {
2335     return false;
2336   }
2337 
2338   for (;;) {
2339     if (lexer_.GetKind() == TokKind::kRbrace) {
2340       lexer_.Lex();
2341       break;
2342     }
2343 
2344     std::vector<int64> index;
2345     if (lexer_.GetKind() == TokKind::kInt) {
2346       int64 single_index = lexer_.GetInt64Val();
2347       lexer_.Lex();
2348       index.push_back(single_index);
2349     } else {
2350       if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
2351                           &index)) {
2352         return false;
2353       }
2354     }
2355     if (!ParseToken(TokKind::kColon,
2356                     "expects ':' after after the sparse array index and before "
2357                     "the sparse array value")) {
2358       return false;
2359     }
2360 
2361     LocTy value_loc = lexer_.GetLoc();
2362     if (lexer_.GetKind() == TokKind::kw_true ||
2363         lexer_.GetKind() == TokKind::kw_false) {
2364       bool value = lexer_.GetKind() == TokKind::kw_true;
2365       if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) {
2366         return false;
2367       }
2368       lexer_.Lex();
2369     } else if (primitive_util::IsIntegralType(shape.element_type())) {
2370       int64 value;
2371       if (!ParseInt64(&value)) {
2372         return Error(value_loc,
2373                      StrCat("expects integer for primitive type: ",
2374                             PrimitiveType_Name(shape.element_type())));
2375       }
2376       if (!SetValueInLiteral(value_loc, value, index, literal)) {
2377         return false;
2378       }
2379     } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
2380       double value;
2381       if (!ParseDouble(&value)) {
2382         return Error(value_loc,
2383                      StrCat("expects floating point value for primitive type: ",
2384                             PrimitiveType_Name(shape.element_type())));
2385       }
2386       if (!SetValueInLiteral(value_loc, value, index, literal)) {
2387         return false;
2388       }
2389     } else if (primitive_util::IsComplexType(shape.element_type())) {
2390       std::complex<double> value;
2391       if (!ParseComplex(&value)) {
2392         return Error(value_loc,
2393                      StrCat("expects complex value for primitive type: ",
2394                             PrimitiveType_Name(shape.element_type())));
2395       }
2396       if (!SetValueInLiteral(value_loc, value, index, literal)) {
2397         return false;
2398       }
2399     } else {
2400       LOG(FATAL) << "Unexpected element type: "
2401                  << PrimitiveType_Name(shape.element_type());
2402     }
2403 
2404     if (lexer_.GetKind() != TokKind::kRbrace &&
2405         !ParseToken(TokKind::kComma,
2406                     "expects ',' separator between sparse array elements")) {
2407       return false;
2408     }
2409 
2410     if (literal->sparse_element_count() + 1 ==
2411         LayoutUtil::MaxSparseElements(shape.layout())) {
2412       return Error(
2413           lexer_.GetLoc(),
2414           StrCat("number of sparse elements exceeds maximum for layout: ",
2415                  ShapeUtil::HumanStringWithLayout(shape)));
2416     }
2417   }
2418 
2419   literal->SortSparseElements();
2420   return true;
2421 }
2422 
2423 // MaxFiniteValue is a type-traits helper used by
2424 // HloParser::CheckParsedValueIsInRange.
2425 template <typename T>
2426 struct MinMaxFiniteValue {
maxxla::__anonc071bf1f0111::MinMaxFiniteValue2427   static T max() { return std::numeric_limits<T>::max(); }
minxla::__anonc071bf1f0111::MinMaxFiniteValue2428   static T min() { return std::numeric_limits<T>::lowest(); }
2429 };
2430 
2431 template <>
2432 struct MinMaxFiniteValue<Eigen::half> {
maxxla::__anonc071bf1f0111::MinMaxFiniteValue2433   static double max() {
2434     // Sadly this is not constexpr, so this forces `value` to be a method.
2435     return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
2436   }
minxla::__anonc071bf1f0111::MinMaxFiniteValue2437   static double min() { return -max(); }
2438 };
2439 
2440 template <>
2441 struct MinMaxFiniteValue<bfloat16> {
maxxla::__anonc071bf1f0111::MinMaxFiniteValue2442   static double max() { return static_cast<double>(bfloat16::highest()); }
minxla::__anonc071bf1f0111::MinMaxFiniteValue2443   static double min() { return -max(); }
2444 };
2445 
2446 template <typename LiteralNativeT, typename ParsedElemT>
CheckParsedValueIsInRange(LocTy loc,ParsedElemT value)2447 bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
2448   PrimitiveType literal_ty =
2449       primitive_util::NativeToPrimitiveType<LiteralNativeT>();
2450   if (std::isnan(value) ||
2451       (std::numeric_limits<ParsedElemT>::has_infinity &&
2452        (std::numeric_limits<ParsedElemT>::infinity() == value ||
2453         -std::numeric_limits<ParsedElemT>::infinity() == value))) {
2454     // Skip range checking for non-finite value.
2455   } else if (std::is_unsigned<LiteralNativeT>::value) {
2456     CHECK((std::is_same<ParsedElemT, int64>::value ||
2457            std::is_same<ParsedElemT, bool>::value))
2458         << "Unimplemented checking for ParsedElemT";
2459 
2460     ParsedElemT upper_bound;
2461     if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) {
2462       upper_bound = std::numeric_limits<ParsedElemT>::max();
2463     } else {
2464       upper_bound =
2465           static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max());
2466     }
2467     if (value > upper_bound || value < 0) {
2468       // Value is out of range for LiteralNativeT.
2469       return Error(loc, StrCat("value ", value,
2470                                " is out of range for literal's primitive type ",
2471                                PrimitiveType_Name(literal_ty), " namely [0, ",
2472                                upper_bound, "]."));
2473     }
2474   } else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
2475              value < MinMaxFiniteValue<LiteralNativeT>::min()) {
2476     // Value is out of range for LiteralNativeT.
2477     return Error(loc, StrCat("value ", value,
2478                              " is out of range for literal's primitive type ",
2479                              PrimitiveType_Name(literal_ty), " namely [",
2480                              MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
2481                              MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
2482   }
2483   return true;
2484 }
2485 
2486 template <typename LiteralNativeT>
CheckParsedValueIsInRange(LocTy loc,std::complex<double> value)2487 bool HloParser::CheckParsedValueIsInRange(LocTy loc,
2488                                           std::complex<double> value) {
2489   // e.g. `float` for std::complex<float>
2490   using LiteralComplexComponentT =
2491       decltype(std::real(std::declval<LiteralNativeT>()));
2492 
2493   // We could do simply
2494   //
2495   //   return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
2496   //          CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
2497   //
2498   // but this would give bad error messages on failure.
2499 
2500   auto check_component = [&](absl::string_view name, double v) {
2501     if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
2502         v == -std::numeric_limits<double>::infinity()) {
2503       // Skip range-checking for non-finite values.
2504       return true;
2505     }
2506 
2507     double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
2508     double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
2509     if (v < min || v > max) {
2510       // Value is out of range for LitearlComplexComponentT.
2511       return Error(
2512           loc,
2513           StrCat(name, " part ", v,
2514                  " is out of range for literal's primitive type ",
2515                  PrimitiveType_Name(
2516                      primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
2517                  ", namely [", min, ", ", max, "]."));
2518     }
2519     return true;
2520   };
2521   return check_component("real", std::real(value)) &&
2522          check_component("imaginary", std::imag(value));
2523 }
2524 
2525 // operands ::= '(' operands1 ')'
2526 // operands1
2527 //   ::= /*empty*/
2528 //   ::= operand (, operand)*
2529 // operand ::= (shape)? name
ParseOperands(std::vector<HloInstruction * > * operands)2530 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
2531   CHECK(operands != nullptr);
2532   if (!ParseToken(TokKind::kLparen,
2533                   "expects '(' at the beginning of operands")) {
2534     return false;
2535   }
2536   if (lexer_.GetKind() == TokKind::kRparen) {
2537     // empty
2538   } else {
2539     do {
2540       LocTy loc = lexer_.GetLoc();
2541       string name;
2542       optional<Shape> shape;
2543       if (CanBeShape()) {
2544         shape.emplace();
2545         if (!ParseShape(&shape.value())) {
2546           return false;
2547         }
2548       }
2549       if (!ParseName(&name)) {
2550         // When parsing a single instruction (as opposed to a whole module), an
2551         // HLO may have one or more operands with a shape but no name:
2552         //
2553         //  foo = add(f32[10], f32[10])
2554         //
2555         // create_missing_instruction_ is always non-null when parsing a single
2556         // instruction, and is responsible for creating kParameter instructions
2557         // for these operands.
2558         if (shape.has_value() && create_missing_instruction_ != nullptr &&
2559             scoped_name_tables_.size() == 1) {
2560           name = "";
2561         } else {
2562           return false;
2563         }
2564       }
2565       std::pair<HloInstruction*, LocTy>* instruction =
2566           FindInstruction(name, shape);
2567       if (instruction == nullptr) {
2568         return Error(loc, StrCat("instruction does not exist: ", name));
2569       }
2570       operands->push_back(instruction->first);
2571     } while (EatIfPresent(TokKind::kComma));
2572   }
2573   return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
2574 }
2575 
ParseOperands(std::vector<HloInstruction * > * operands,const int expected_size)2576 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
2577                               const int expected_size) {
2578   CHECK(operands != nullptr);
2579   LocTy loc = lexer_.GetLoc();
2580   if (!ParseOperands(operands)) {
2581     return false;
2582   }
2583   if (expected_size != operands->size()) {
2584     return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
2585                              operands->size(), " operands"));
2586   }
2587   return true;
2588 }
2589 
2590 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const std::unordered_map<string,AttrConfig> & attrs)2591 bool HloParser::ParseSubAttributes(
2592     const std::unordered_map<string, AttrConfig>& attrs) {
2593   LocTy loc = lexer_.GetLoc();
2594   if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
2595     return false;
2596   }
2597   std::unordered_set<string> seen_attrs;
2598   if (lexer_.GetKind() == TokKind::kRbrace) {
2599     // empty
2600   } else {
2601     do {
2602       EatIfPresent(TokKind::kComma);
2603       if (!ParseAttributeHelper(attrs, &seen_attrs)) {
2604         return false;
2605       }
2606     } while (lexer_.GetKind() != TokKind::kRbrace);
2607   }
2608   // Check that all required attrs were seen.
2609   for (const auto& attr_it : attrs) {
2610     if (attr_it.second.required &&
2611         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
2612       return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
2613                                   attr_it.first));
2614     }
2615   }
2616   return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
2617 }
2618 
2619 // attributes ::= (',' attribute)*
ParseAttributes(const std::unordered_map<string,AttrConfig> & attrs)2620 bool HloParser::ParseAttributes(
2621     const std::unordered_map<string, AttrConfig>& attrs) {
2622   LocTy loc = lexer_.GetLoc();
2623   std::unordered_set<string> seen_attrs;
2624   while (EatIfPresent(TokKind::kComma)) {
2625     if (!ParseAttributeHelper(attrs, &seen_attrs)) {
2626       return false;
2627     }
2628   }
2629   // Check that all required attrs were seen.
2630   for (const auto& attr_it : attrs) {
2631     if (attr_it.second.required &&
2632         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
2633       return Error(loc, StrFormat("attribute %s is expected but not seen",
2634                                   attr_it.first));
2635     }
2636   }
2637   return true;
2638 }
2639 
ParseAttributeHelper(const std::unordered_map<string,AttrConfig> & attrs,std::unordered_set<string> * seen_attrs)2640 bool HloParser::ParseAttributeHelper(
2641     const std::unordered_map<string, AttrConfig>& attrs,
2642     std::unordered_set<string>* seen_attrs) {
2643   LocTy loc = lexer_.GetLoc();
2644   string name;
2645   if (!ParseAttributeName(&name)) {
2646     return Error(loc, "error parsing attributes");
2647   }
2648   VLOG(1) << "Parsing attribute " << name;
2649   if (!seen_attrs->insert(name).second) {
2650     return Error(loc, StrFormat("attribute %s already exists", name));
2651   }
2652   auto attr_it = attrs.find(name);
2653   if (attr_it == attrs.end()) {
2654     string allowed_attrs;
2655     if (attrs.empty()) {
2656       allowed_attrs = "No attributes are allowed here.";
2657     } else {
2658       allowed_attrs = StrCat(
2659           "Allowed attributes: ",
2660           StrJoin(attrs, ", ",
2661                   [&](string* out, const std::pair<string, AttrConfig>& kv) {
2662                     StrAppend(out, kv.first);
2663                   }));
2664     }
2665     return Error(loc, StrFormat("unexpected attribute \"%s\".  %s", name,
2666                                 allowed_attrs));
2667   }
2668   AttrTy attr_type = attr_it->second.attr_type;
2669   void* attr_out_ptr = attr_it->second.result;
2670   bool success = [&] {
2671     LocTy attr_loc = lexer_.GetLoc();
2672     switch (attr_type) {
2673       case AttrTy::kBool: {
2674         bool result;
2675         if (!ParseBool(&result)) {
2676           return false;
2677         }
2678         static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
2679         return true;
2680       }
2681       case AttrTy::kInt64: {
2682         int64 result;
2683         if (!ParseInt64(&result)) {
2684           return false;
2685         }
2686         static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
2687         return true;
2688       }
2689       case AttrTy::kInt32: {
2690         int64 result;
2691         if (!ParseInt64(&result)) {
2692           return false;
2693         }
2694         if (result != static_cast<int32>(result)) {
2695           return Error(attr_loc, "value out of range for int32");
2696         }
2697         static_cast<optional<int32>*>(attr_out_ptr)
2698             ->emplace(static_cast<int32>(result));
2699         return true;
2700       }
2701       case AttrTy::kFloat: {
2702         double result;
2703         if (!ParseDouble(&result)) {
2704           return false;
2705         }
2706         if (result > std::numeric_limits<float>::max() ||
2707             result < std::numeric_limits<float>::lowest()) {
2708           return Error(attr_loc, "value out of range for float");
2709         }
2710         static_cast<optional<float>*>(attr_out_ptr)
2711             ->emplace(static_cast<float>(result));
2712         return true;
2713       }
2714       case AttrTy::kHloComputation: {
2715         HloComputation* result = nullptr;
2716         if (!ParseHloComputation(&result)) {
2717           return false;
2718         }
2719         static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
2720         return true;
2721       }
2722       case AttrTy::kBracedHloComputationList: {
2723         std::vector<HloComputation*> result;
2724         if (!ParseHloComputationList(&result)) {
2725           return false;
2726         }
2727         static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
2728             ->emplace(result);
2729         return true;
2730       }
2731       case AttrTy::kFftType: {
2732         FftType result;
2733         if (!ParseFftType(&result)) {
2734           return false;
2735         }
2736         static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
2737         return true;
2738       }
2739       case AttrTy::kComparisonDirection: {
2740         ComparisonDirection result;
2741         if (!ParseComparisonDirection(&result)) {
2742           return false;
2743         }
2744         static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
2745             ->emplace(result);
2746         return true;
2747       }
2748       case AttrTy::kWindow: {
2749         Window result;
2750         if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
2751           return false;
2752         }
2753         static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
2754         return true;
2755       }
2756       case AttrTy::kConvolutionDimensionNumbers: {
2757         ConvolutionDimensionNumbers result;
2758         if (!ParseConvolutionDimensionNumbers(&result)) {
2759           return false;
2760         }
2761         static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
2762             ->emplace(result);
2763         return true;
2764       }
2765       case AttrTy::kSharding: {
2766         OpSharding sharding;
2767         if (!ParseSharding(&sharding)) {
2768           return false;
2769         }
2770         static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
2771         return true;
2772       }
2773       case AttrTy::kParameterReplication: {
2774         ParameterReplication parameter_replication;
2775         if (!ParseParameterReplication(&parameter_replication)) {
2776           return false;
2777         }
2778         static_cast<optional<ParameterReplication>*>(attr_out_ptr)
2779             ->emplace(parameter_replication);
2780         return true;
2781       }
2782       case AttrTy::kInstructionList: {
2783         std::vector<HloInstruction*> result;
2784         if (!ParseInstructionNames(&result)) {
2785           return false;
2786         }
2787         static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
2788             ->emplace(result);
2789         return true;
2790       }
2791       case AttrTy::kFusionKind: {
2792         HloInstruction::FusionKind result;
2793         if (!ParseFusionKind(&result)) {
2794           return false;
2795         }
2796         static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
2797             ->emplace(result);
2798         return true;
2799       }
2800       case AttrTy::kBracedInt64List: {
2801         std::vector<int64> result;
2802         if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
2803                             &result)) {
2804           return false;
2805         }
2806         static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
2807             ->emplace(result);
2808         return true;
2809       }
2810       case AttrTy::kBracedInt64ListList: {
2811         std::vector<std::vector<int64>> result;
2812         auto parse_and_add_item = [&]() {
2813           std::vector<int64> item;
2814           if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace,
2815                               TokKind::kComma, &item)) {
2816             return false;
2817           }
2818           result.push_back(item);
2819           return true;
2820         };
2821         if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
2822                        parse_and_add_item)) {
2823           return false;
2824         }
2825         static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
2826             ->emplace(result);
2827         return true;
2828       }
2829       case AttrTy::kSliceRanges: {
2830         SliceRanges result;
2831         if (!ParseSliceRanges(&result)) {
2832           return false;
2833         }
2834         static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
2835         return true;
2836       }
2837       case AttrTy::kPaddingConfig: {
2838         PaddingConfig result;
2839         if (!ParsePaddingConfig(&result)) {
2840           return false;
2841         }
2842         static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
2843         return true;
2844       }
2845       case AttrTy::kString: {
2846         string result;
2847         if (!ParseString(&result)) {
2848           return false;
2849         }
2850         static_cast<optional<string>*>(attr_out_ptr)->emplace(result);
2851         return true;
2852       }
2853       case AttrTy::kMetadata: {
2854         OpMetadata result;
2855         if (!ParseMetadata(&result)) {
2856           return false;
2857         }
2858         static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
2859         return true;
2860       }
2861       case AttrTy::kDistribution: {
2862         RandomDistribution result;
2863         if (!ParseRandomDistribution(&result)) {
2864           return false;
2865         }
2866         static_cast<optional<RandomDistribution>*>(attr_out_ptr)
2867             ->emplace(result);
2868         return true;
2869       }
2870       case AttrTy::kDomain: {
2871         return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
2872       }
2873       case AttrTy::kPrecisionList: {
2874         std::vector<PrecisionConfig::Precision> result;
2875         if (!ParsePrecisionList(&result)) {
2876           return false;
2877         }
2878         static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
2879             attr_out_ptr)
2880             ->emplace(result);
2881         return true;
2882       }
2883       case AttrTy::kShapeList: {
2884         std::vector<Shape> result;
2885         if (!ParseShapeList(&result)) {
2886           return false;
2887         }
2888         static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
2889             ->emplace(result);
2890         return true;
2891       }
2892     }
2893   }();
2894   if (!success) {
2895     return Error(loc, StrFormat("error parsing attribute %s", name));
2896   }
2897   return true;
2898 }
2899 
2900 // attributes ::= (',' attribute)*
ParseAttributesAsProtoMessage(const std::unordered_set<string> & required_attrs,tensorflow::protobuf::Message * message)2901 bool HloParser::ParseAttributesAsProtoMessage(
2902     const std::unordered_set<string>& required_attrs,
2903     tensorflow::protobuf::Message* message) {
2904   LocTy loc = lexer_.GetLoc();
2905   std::unordered_set<string> seen_attrs;
2906   while (EatIfPresent(TokKind::kComma)) {
2907     if (!ParseAttributeAsProtoMessageHelper(message, &seen_attrs)) {
2908       return false;
2909     }
2910   }
2911   // Check that all required attrs were seen.
2912   for (const string& attr : required_attrs) {
2913     if (seen_attrs.find(attr) == seen_attrs.end()) {
2914       return Error(loc,
2915                    StrFormat("attribute %s is expected but not seen", attr));
2916     }
2917   }
2918   return true;
2919 }
2920 
ParseAttributeAsProtoMessageHelper(tensorflow::protobuf::Message * message,std::unordered_set<string> * seen_attrs)2921 bool HloParser::ParseAttributeAsProtoMessageHelper(
2922     tensorflow::protobuf::Message* message,
2923     std::unordered_set<string>* seen_attrs) {
2924   LocTy loc = lexer_.GetLoc();
2925   string name;
2926   if (!ParseAttributeName(&name)) {
2927     return Error(loc, "error parsing attributes");
2928   }
2929   VLOG(1) << "Parsing attribute " << name;
2930   if (!seen_attrs->insert(name).second) {
2931     return Error(loc, StrFormat("attribute %s already exists", name));
2932   }
2933   const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
2934   const tensorflow::protobuf::FieldDescriptor* fd =
2935       descriptor->FindFieldByName(name);
2936   if (!fd) {
2937     string allowed_attrs = "Allowed attributes: ";
2938 
2939     for (int i = 0; i < descriptor->field_count(); ++i) {
2940       if (i == 0) {
2941         absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
2942       } else {
2943         absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
2944       }
2945     }
2946     return Error(loc, StrFormat("unexpected attribute \"%s\".  %s", name,
2947                                 allowed_attrs));
2948   }
2949   const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
2950   CHECK(!fd->is_repeated());  // Repeated fields not implemented.
2951   bool success = [&] {
2952     switch (fd->type()) {
2953       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
2954         bool result;
2955         if (!ParseBool(&result)) {
2956           return false;
2957         }
2958         reflection->SetBool(message, fd, result);
2959         return true;
2960       }
2961       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
2962         if (lexer_.GetKind() != TokKind::kIdent) {
2963           return TokenError(
2964               StrFormat("expects %s type", fd->enum_type()->name()));
2965         }
2966         string val = lexer_.GetStrVal();
2967         const tensorflow::protobuf::EnumValueDescriptor* evd =
2968             fd->enum_type()->FindValueByName(val);
2969         if (evd == nullptr) {
2970           return TokenError(StrFormat("expects %s type but sees: %s",
2971                                       fd->enum_type()->name(), val));
2972         }
2973         reflection->SetEnum(message, fd, evd);
2974         lexer_.Lex();
2975         return true;
2976       }
2977       default:
2978         LOG(ERROR) << "Unimplemented protocol buffer type "
2979                    << fd->DebugString();
2980         return false;
2981     }
2982   }();
2983   if (!success) {
2984     return Error(loc, StrFormat("error parsing attribute %s", name));
2985   }
2986   return true;
2987 }
2988 
ParseComputationName(HloComputation ** value)2989 bool HloParser::ParseComputationName(HloComputation** value) {
2990   string name;
2991   LocTy loc = lexer_.GetLoc();
2992   if (!ParseName(&name)) {
2993     return Error(loc, "expects computation name");
2994   }
2995   std::pair<HloComputation*, LocTy>* computation =
2996       tensorflow::gtl::FindOrNull(computation_pool_, name);
2997   if (computation == nullptr) {
2998     return Error(loc, StrCat("computation does not exist: ", name));
2999   }
3000   *value = computation->first;
3001   return true;
3002 }
3003 
3004 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
3005 // The subattributes can appear in any order. 'size=' is required, others are
3006 // optional.
ParseWindow(Window * window,bool expect_outer_curlies)3007 bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) {
3008   LocTy loc = lexer_.GetLoc();
3009   if (expect_outer_curlies &&
3010       !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
3011     return false;
3012   }
3013 
3014   std::vector<int64> size;
3015   std::vector<int64> stride;
3016   std::vector<std::vector<int64>> pad;
3017   std::vector<int64> lhs_dilate;
3018   std::vector<int64> rhs_dilate;
3019   std::vector<int64> rhs_reversal;
3020   const auto end_token =
3021       expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
3022   while (lexer_.GetKind() != end_token) {
3023     LocTy attr_loc = lexer_.GetLoc();
3024     string field_name;
3025     if (!ParseAttributeName(&field_name)) {
3026       return Error(attr_loc, "expects sub-attributes in window");
3027     }
3028     bool ok = [&] {
3029       if (field_name == "size") {
3030         return ParseDxD("size", &size);
3031       }
3032       if (field_name == "stride") {
3033         return ParseDxD("stride", &stride);
3034       }
3035       if (field_name == "lhs_dilate") {
3036         return ParseDxD("lhs_dilate", &lhs_dilate);
3037       }
3038       if (field_name == "rhs_dilate") {
3039         return ParseDxD("rls_dilate", &rhs_dilate);
3040       }
3041       if (field_name == "pad") {
3042         return ParseWindowPad(&pad);
3043       }
3044       if (field_name == "rhs_reversal") {
3045         return ParseDxD("rhs_reversal", &rhs_reversal);
3046       }
3047       return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
3048     }();
3049     if (!ok) {
3050       return false;
3051     }
3052   }
3053 
3054   if (size.empty()) {
3055     return Error(loc,
3056                  "sub-attribute 'size=' is required in the window attribute");
3057   }
3058   if (!stride.empty() && stride.size() != size.size()) {
3059     return Error(loc, "expects 'stride=' has the same size as 'size='");
3060   }
3061   if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
3062     return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
3063   }
3064   if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
3065     return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
3066   }
3067   if (!pad.empty() && pad.size() != size.size()) {
3068     return Error(loc, "expects 'pad=' has the same size as 'size='");
3069   }
3070 
3071   for (int i = 0; i < size.size(); i++) {
3072     window->add_dimensions()->set_size(size[i]);
3073     if (!pad.empty()) {
3074       window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
3075       window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
3076     }
3077     // If some field is not present, it has the default value.
3078     window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
3079     window->mutable_dimensions(i)->set_base_dilation(
3080         lhs_dilate.empty() ? 1 : lhs_dilate[i]);
3081     window->mutable_dimensions(i)->set_window_dilation(
3082         rhs_dilate.empty() ? 1 : rhs_dilate[i]);
3083     window->mutable_dimensions(i)->set_window_reversal(
3084         rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
3085   }
3086   return !expect_outer_curlies ||
3087          ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
3088 }
3089 
3090 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
3091 // The string looks like "dim_labels=0bf_0io->0bf".
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)3092 bool HloParser::ParseConvolutionDimensionNumbers(
3093     ConvolutionDimensionNumbers* dnums) {
3094   if (lexer_.GetKind() != TokKind::kDimLabels) {
3095     return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
3096   }
3097   string str = lexer_.GetStrVal();
3098 
3099   // The str is expected to have 3 items, lhs, rhs, out, and it must look like
3100   // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
3101   std::vector<string> split1 = absl::StrSplit(str, '_');
3102   if (split1.size() != 2) {
3103     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
3104                << str;
3105   }
3106   std::vector<string> split2 = absl::StrSplit(split1[1], "->");
3107   if (split2.size() != 2) {
3108     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
3109                << str;
3110   }
3111   absl::string_view lhs = split1[0];
3112   absl::string_view rhs = split2[0];
3113   absl::string_view out = split2[1];
3114 
3115   const int64 rank = lhs.length();
3116   if (rank != rhs.length() || rank != out.length()) {
3117     return TokenError(
3118         "convolution lhs, rhs, and output must have the same rank");
3119   }
3120   if (rank < 2) {
3121     return TokenError("convolution rank must >=2");
3122   }
3123 
3124   auto is_unique = [](string str) -> bool {
3125     absl::c_sort(str);
3126     return std::unique(str.begin(), str.end()) == str.end();
3127   };
3128 
3129   // lhs
3130   {
3131     if (!is_unique(string(lhs))) {
3132       return TokenError(
3133           StrCat("expects unique lhs dimension numbers, but sees ", lhs));
3134     }
3135     for (int i = 0; i < rank - 2; i++) {
3136       dnums->add_input_spatial_dimensions(-1);
3137     }
3138     for (int i = 0; i < rank; i++) {
3139       char c = lhs[i];
3140       if (c == 'b') {
3141         dnums->set_input_batch_dimension(i);
3142       } else if (c == 'f') {
3143         dnums->set_input_feature_dimension(i);
3144       } else if (c < '0' + rank && c >= '0') {
3145         dnums->set_input_spatial_dimensions(c - '0', i);
3146       } else {
3147         return TokenError(
3148             StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1));
3149       }
3150     }
3151   }
3152   // rhs
3153   {
3154     if (!is_unique(string(rhs))) {
3155       return TokenError(
3156           StrCat("expects unique rhs dimension numbers, but sees ", rhs));
3157     }
3158     for (int i = 0; i < rank - 2; i++) {
3159       dnums->add_kernel_spatial_dimensions(-1);
3160     }
3161     for (int i = 0; i < rank; i++) {
3162       char c = rhs[i];
3163       if (c == 'i') {
3164         dnums->set_kernel_input_feature_dimension(i);
3165       } else if (c == 'o') {
3166         dnums->set_kernel_output_feature_dimension(i);
3167       } else if (c < '0' + rank && c >= '0') {
3168         dnums->set_kernel_spatial_dimensions(c - '0', i);
3169       } else {
3170         return TokenError(
3171             StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1));
3172       }
3173     }
3174   }
3175   // output
3176   {
3177     if (!is_unique(string(out))) {
3178       return TokenError(
3179           StrCat("expects unique output dimension numbers, but sees ", out));
3180     }
3181     for (int i = 0; i < rank - 2; i++) {
3182       dnums->add_output_spatial_dimensions(-1);
3183     }
3184     for (int i = 0; i < rank; i++) {
3185       char c = out[i];
3186       if (c == 'b') {
3187         dnums->set_output_batch_dimension(i);
3188       } else if (c == 'f') {
3189         dnums->set_output_feature_dimension(i);
3190       } else if (c < '0' + rank && c >= '0') {
3191         dnums->set_output_spatial_dimensions(c - '0', i);
3192       } else {
3193         return TokenError(StrFormat(
3194             "expects [0-%dbf] in output dimension numbers", rank - 1));
3195       }
3196     }
3197   }
3198 
3199   lexer_.Lex();
3200   return true;
3201 }
3202 
3203 // ::= '{' ranges '}'
3204 //   ::= /*empty*/
3205 //   ::= range (',' range)*
3206 // range ::= '[' start ':' limit (':' stride)? ']'
3207 //
3208 // The slice ranges are printed as:
3209 //
3210 //  {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
3211 //
3212 // This function extracts the starts, limits, and strides as 3 vectors to the
3213 // result. If stride is not present, stride is 1. For example, if the slice
3214 // ranges is printed as:
3215 //
3216 //  {[2:3:4], [5:6:7], [8:9]}
3217 //
3218 // The parsed result will be:
3219 //
3220 //  {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
3221 //
ParseSliceRanges(SliceRanges * result)3222 bool HloParser::ParseSliceRanges(SliceRanges* result) {
3223   if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
3224     return false;
3225   }
3226   std::vector<std::vector<int64>> ranges;
3227   if (lexer_.GetKind() == TokKind::kRbrace) {
3228     // empty
3229     return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
3230   }
3231   do {
3232     LocTy loc = lexer_.GetLoc();
3233     ranges.emplace_back();
3234     if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
3235                         &ranges.back())) {
3236       return false;
3237     }
3238     const auto& range = ranges.back();
3239     if (range.size() != 2 && range.size() != 3) {
3240       return Error(loc,
3241                    StrFormat("expects [start:limit:step] or [start:limit], "
3242                              "but sees %d elements.",
3243                              range.size()));
3244     }
3245   } while (EatIfPresent(TokKind::kComma));
3246 
3247   for (const auto& range : ranges) {
3248     result->starts.push_back(range[0]);
3249     result->limits.push_back(range[1]);
3250     result->strides.push_back(range.size() == 3 ? range[2] : 1);
3251   }
3252   return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
3253 }
3254 
3255 // precisionlist ::= start precision_elements end
3256 // precision_elements
3257 //   ::= /*empty*/
3258 //   ::= precision_val (delim precision_val)*
ParsePrecisionList(std::vector<PrecisionConfig::Precision> * result)3259 bool HloParser::ParsePrecisionList(
3260     std::vector<PrecisionConfig::Precision>* result) {
3261   auto parse_and_add_item = [&]() {
3262     PrecisionConfig::Precision item;
3263     if (!ParsePrecision(&item)) {
3264       return false;
3265     }
3266     result->push_back(item);
3267     return true;
3268   };
3269   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3270                    parse_and_add_item);
3271 }
3272 
ParseHloComputation(HloComputation ** result)3273 bool HloParser::ParseHloComputation(HloComputation** result) {
3274   if (lexer_.GetKind() == TokKind::kLbrace) {
3275     // This means it is a nested computation.
3276     return ParseInstructionList(result, /*computation_name=*/"_");
3277   }
3278   // This means it is a computation name.
3279   return ParseComputationName(result);
3280 }
3281 
ParseHloComputationList(std::vector<HloComputation * > * result)3282 bool HloParser::ParseHloComputationList(std::vector<HloComputation*>* result) {
3283   auto parse_and_add_item = [&]() {
3284     HloComputation* computation;
3285     if (!ParseHloComputation(&computation)) {
3286       return false;
3287     }
3288     LOG(INFO) << "parsed computation " << computation->name();
3289     result->push_back(computation);
3290     return true;
3291   };
3292   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3293                    parse_and_add_item);
3294 }
3295 
3296 // shapelist ::= '{' shapes '}'
3297 // precision_elements
3298 //   ::= /*empty*/
3299 //   ::= shape (',' shape)*
ParseShapeList(std::vector<Shape> * result)3300 bool HloParser::ParseShapeList(std::vector<Shape>* result) {
3301   auto parse_and_add_item = [&]() {
3302     Shape shape;
3303     if (!ParseShape(&shape)) {
3304       return false;
3305     }
3306     result->push_back(std::move(shape));
3307     return true;
3308   };
3309   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3310                    parse_and_add_item);
3311 }
3312 
3313 // int64list ::= start int64_elements end
3314 // int64_elements
3315 //   ::= /*empty*/
3316 //   ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64> * result)3317 bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
3318                                const TokKind delim,
3319                                std::vector<int64>* result) {
3320   auto parse_and_add_item = [&]() {
3321     int64 i;
3322     if (!ParseInt64(&i)) {
3323       return false;
3324     }
3325     result->push_back(i);
3326     return true;
3327   };
3328   return ParseList(start, end, delim, parse_and_add_item);
3329 }
3330 
ParseList(const TokKind start,const TokKind end,const TokKind delim,const std::function<bool ()> & parse_and_add_item)3331 bool HloParser::ParseList(const TokKind start, const TokKind end,
3332                           const TokKind delim,
3333                           const std::function<bool()>& parse_and_add_item) {
3334   if (!ParseToken(start, StrCat("expects a list starting with ",
3335                                 TokKindToString(start)))) {
3336     return false;
3337   }
3338   if (lexer_.GetKind() == end) {
3339     // empty
3340   } else {
3341     do {
3342       if (!parse_and_add_item()) {
3343         return false;
3344       }
3345     } while (EatIfPresent(delim));
3346   }
3347   return ParseToken(
3348       end, StrCat("expects a list to end with ", TokKindToString(end)));
3349 }
3350 
3351 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)3352 bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
3353   if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
3354     return false;
3355   }
3356   *shape_loc = lexer_.GetLoc();
3357   return ParseShape(shape);
3358 }
3359 
CanBeParamListToShape()3360 bool HloParser::CanBeParamListToShape() {
3361   return lexer_.GetKind() == TokKind::kLparen;
3362 }
3363 
3364 // param_list ::= '(' param_list1 ')'
3365 // param_list1
3366 //   ::= /*empty*/
3367 //   ::= param (',' param)*
3368 // param ::= name shape
ParseParamList()3369 bool HloParser::ParseParamList() {
3370   if (!ParseToken(TokKind::kLparen,
3371                   "expects '(' at the beginning of param list")) {
3372     return false;
3373   }
3374 
3375   if (lexer_.GetKind() == TokKind::kRparen) {
3376     // empty
3377   } else {
3378     do {
3379       Shape shape;
3380       string name;
3381       if (!ParseName(&name) || !ParseShape(&shape)) {
3382         return false;
3383       }
3384     } while (EatIfPresent(TokKind::kComma));
3385   }
3386   return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
3387 }
3388 
3389 // dimension_sizes ::= '[' dimension_list ']'
3390 // dimension_list
3391 //   ::= /*empty*/
3392 //   ::= <=? int64 (',' param)*
3393 // param ::= name shape
ParseDimensionSizes(std::vector<int64> * dimension_sizes,std::vector<bool> * dynamic_dimensions)3394 bool HloParser::ParseDimensionSizes(std::vector<int64>* dimension_sizes,
3395                                     std::vector<bool>* dynamic_dimensions) {
3396   auto parse_and_add_item = [&]() {
3397     int64 i;
3398     bool is_dynamic = false;
3399     if (lexer_.GetKind() == TokKind::kLeq) {
3400       is_dynamic = true;
3401       lexer_.Lex();
3402     }
3403     if (!ParseInt64(&i)) {
3404       return false;
3405     }
3406     dimension_sizes->push_back(i);
3407     dynamic_dimensions->push_back(is_dynamic);
3408     return true;
3409   };
3410   return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
3411                    parse_and_add_item);
3412 }
3413 
3414 // tiles
3415 //   ::= /*empty*/
3416 //   ::= 'T' '(' dim_list ')'
3417 // dim_list
3418 //   ::= /*empty*/
3419 //   ::= (int64 | '*') (',' (int64 | '*'))*
ParseTiles(std::vector<Tile> * tiles)3420 bool HloParser::ParseTiles(std::vector<Tile>* tiles) {
3421   auto parse_and_add_tile_dimension = [&]() {
3422     tensorflow::int64 i;
3423     if (ParseInt64(&i)) {
3424       tiles->back().add_dimensions(i);
3425       return true;
3426     }
3427     if (lexer_.GetKind() == TokKind::kAsterisk) {
3428       tiles->back().add_dimensions(Tile::kCombineDimension);
3429       lexer_.Lex();
3430       return true;
3431     }
3432     return false;
3433   };
3434 
3435   do {
3436     tiles->push_back(Tile());
3437     if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
3438                    parse_and_add_tile_dimension)) {
3439       return false;
3440     }
3441   } while (lexer_.GetKind() == TokKind::kLparen);
3442   return true;
3443 }
3444 
3445 // layout ::= '{' int64_list (':' tiles element_size_in_bits)? '}'
3446 // element_size_in_bits
3447 //   ::= /*empty*/
3448 //   ::= 'E' '(' int64 ')'
ParseLayout(Layout * layout)3449 bool HloParser::ParseLayout(Layout* layout) {
3450   std::vector<int64> minor_to_major;
3451   std::vector<Tile> tiles;
3452   tensorflow::int64 element_size_in_bits = 0;
3453 
3454   auto parse_and_add_item = [&]() {
3455     int64 i;
3456     if (!ParseInt64(&i)) {
3457       return false;
3458     }
3459     minor_to_major.push_back(i);
3460     return true;
3461   };
3462 
3463   if (!ParseToken(TokKind::kLbrace,
3464                   StrCat("expects layout to start with ",
3465                          TokKindToString(TokKind::kLbrace)))) {
3466     return false;
3467   }
3468   if (lexer_.GetKind() != TokKind::kRbrace) {
3469     if (lexer_.GetKind() == TokKind::kInt) {
3470       // Parse minor to major.
3471       do {
3472         if (!parse_and_add_item()) {
3473           return false;
3474         }
3475       } while (EatIfPresent(TokKind::kComma));
3476     }
3477 
3478     if (lexer_.GetKind() == TokKind::kColon) {
3479       lexer_.Lex();
3480       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
3481         lexer_.Lex();
3482         ParseTiles(&tiles);
3483       }
3484 
3485       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
3486         // Parse element size in bits.
3487         lexer_.Lex();
3488         if (!ParseToken(TokKind::kLparen,
3489                         StrCat("expects element size in bits to start with ",
3490                                TokKindToString(TokKind::kLparen)))) {
3491           return false;
3492         }
3493         if (!ParseInt64(&element_size_in_bits)) {
3494           return false;
3495         }
3496         if (!ParseToken(TokKind::kRparen,
3497                         StrCat("expects element size in bits to end with ",
3498                                TokKindToString(TokKind::kRparen)))) {
3499           return false;
3500         }
3501       }
3502     }
3503   }
3504   if (!ParseToken(TokKind::kRbrace,
3505                   StrCat("expects layout to end with ",
3506                          TokKindToString(TokKind::kRbrace)))) {
3507     return false;
3508   }
3509 
3510   std::vector<Tile> vec_tiles(tiles.size());
3511   for (int i = 0; i < tiles.size(); i++) {
3512     vec_tiles[i] = Tile(tiles[i]);
3513   }
3514   *layout =
3515       LayoutUtil::MakeLayout(minor_to_major, vec_tiles, element_size_in_bits);
3516   return true;
3517 }
3518 
3519 // shape ::= shape_val_
3520 // shape ::= '(' tuple_elements ')'
3521 // tuple_elements
3522 //   ::= /*empty*/
3523 //   ::= shape (',' shape)*
ParseShape(Shape * result)3524 bool HloParser::ParseShape(Shape* result) {
3525   if (EatIfPresent(TokKind::kLparen)) {  // Tuple
3526     std::vector<Shape> shapes;
3527     if (lexer_.GetKind() == TokKind::kRparen) {
3528       /*empty*/
3529     } else {
3530       // shape (',' shape)*
3531       do {
3532         shapes.emplace_back();
3533         if (!ParseShape(&shapes.back())) {
3534           return false;
3535         }
3536       } while (EatIfPresent(TokKind::kComma));
3537     }
3538     *result = ShapeUtil::MakeTupleShape(shapes);
3539     return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
3540   }
3541 
3542   if (lexer_.GetKind() != TokKind::kPrimitiveType) {
3543     return TokenError(absl::StrCat("expected primitive type, saw ",
3544                                    TokKindToString(lexer_.GetKind())));
3545   }
3546   PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
3547   lexer_.Lex();
3548 
3549   // Each element contains a dimension size and a bool indicating whether this
3550   // is a dynamic dimension.
3551   std::vector<int64> dimension_sizes;
3552   std::vector<bool> dynamic_dimensions;
3553   if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
3554     return false;
3555   }
3556   result->set_element_type(primitive_type);
3557   for (int i = 0; i < dimension_sizes.size(); ++i) {
3558     result->add_dimensions(dimension_sizes[i]);
3559     result->set_dynamic_dimension(i, dynamic_dimensions[i]);
3560   }
3561   LayoutUtil::SetToDefaultLayout(result);
3562 
3563   if (lexer_.GetKind() == TokKind::kw_sparse) {
3564     lexer_.Lex();
3565     const string message =
3566         "expects a brace-bracketed integer for sparse layout";
3567     int64 max_sparse_elements;
3568     if (!ParseToken(TokKind::kLbrace, message) ||
3569         !ParseInt64(&max_sparse_elements) ||
3570         !ParseToken(TokKind::kRbrace, message)) {
3571       return false;
3572     }
3573     *result->mutable_layout() =
3574         LayoutUtil::MakeSparseLayout(max_sparse_elements);
3575     return true;
3576   }
3577 
3578   // We need to lookahead to see if a following open brace is the start of a
3579   // layout. The specific problematic case is:
3580   //
3581   // ENTRY %foo (x: f32[42]) -> f32[123] {
3582   //  ...
3583   // }
3584   //
3585   // The open brace could either be the start of a computation or the start of a
3586   // layout for the f32[123] shape. We consider it the start of a layout if the
3587   // next token after the open brace is an integer or a colon.
3588   if (lexer_.GetKind() == TokKind::kLbrace &&
3589       (lexer_.LookAhead() == TokKind::kInt ||
3590        lexer_.LookAhead() == TokKind::kColon)) {
3591     Layout layout;
3592     if (!ParseLayout(&layout)) {
3593       return false;
3594     }
3595     if (layout.minor_to_major_size() != result->rank()) {
3596       return Error(
3597           lexer_.GetLoc(),
3598           StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
3599                     result->rank(), layout.minor_to_major_size()));
3600     }
3601     *result->mutable_layout() = layout;
3602   }
3603   return true;
3604 }
3605 
CanBeShape()3606 bool HloParser::CanBeShape() {
3607   // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
3608   // with '('.
3609   return lexer_.GetKind() == TokKind::kPrimitiveType ||
3610          lexer_.GetKind() == TokKind::kLparen;
3611 }
3612 
ParseName(string * result)3613 bool HloParser::ParseName(string* result) {
3614   VLOG(1) << "ParseName";
3615   if (lexer_.GetKind() != TokKind::kIdent &&
3616       lexer_.GetKind() != TokKind::kName) {
3617     return TokenError("expects name");
3618   }
3619   *result = lexer_.GetStrVal();
3620   lexer_.Lex();
3621   return true;
3622 }
3623 
ParseAttributeName(string * result)3624 bool HloParser::ParseAttributeName(string* result) {
3625   if (lexer_.GetKind() != TokKind::kAttributeName) {
3626     return TokenError("expects attribute name");
3627   }
3628   *result = lexer_.GetStrVal();
3629   lexer_.Lex();
3630   return true;
3631 }
3632 
ParseString(string * result)3633 bool HloParser::ParseString(string* result) {
3634   VLOG(1) << "ParseString";
3635   if (lexer_.GetKind() != TokKind::kString) {
3636     return TokenError("expects string");
3637   }
3638   *result = lexer_.GetStrVal();
3639   lexer_.Lex();
3640   return true;
3641 }
3642 
ParseDxD(const string & name,std::vector<int64> * result)3643 bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
3644   LocTy loc = lexer_.GetLoc();
3645   if (!result->empty()) {
3646     return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
3647   }
3648   // 1D
3649   if (lexer_.GetKind() == TokKind::kInt) {
3650     int64 number;
3651     if (!ParseInt64(&number)) {
3652       return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
3653     }
3654     result->push_back(number);
3655     return true;
3656   }
3657   // 2D or higher.
3658   if (lexer_.GetKind() == TokKind::kDxD) {
3659     string str = lexer_.GetStrVal();
3660     if (!SplitToInt64s(str, 'x', result)) {
3661       return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
3662     }
3663     lexer_.Lex();
3664     return true;
3665   }
3666   return TokenError("expects token type kInt or kDxD");
3667 }
3668 
ParseWindowPad(std::vector<std::vector<int64>> * pad)3669 bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
3670   LocTy loc = lexer_.GetLoc();
3671   if (!pad->empty()) {
3672     return Error(loc, "sub-attribute 'pad=' already exists");
3673   }
3674   if (lexer_.GetKind() != TokKind::kPad) {
3675     return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
3676   }
3677   string str = lexer_.GetStrVal();
3678   for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
3679     std::vector<int64> low_high;
3680     if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
3681         low_high.size() != 2) {
3682       return Error(loc,
3683                    "expects padding_low and padding_high separated by '_'");
3684     }
3685     pad->push_back(low_high);
3686   }
3687   lexer_.Lex();
3688   return true;
3689 }
3690 
3691 // This is the inverse xla::ToString(PaddingConfig). The padding config string
3692 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
3693 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
3694 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)3695 bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
3696   if (lexer_.GetKind() != TokKind::kPad) {
3697     return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
3698   }
3699   LocTy loc = lexer_.GetLoc();
3700   string str = lexer_.GetStrVal();
3701   for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
3702     std::vector<int64> padding_dim;
3703     if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
3704         (padding_dim.size() != 2 && padding_dim.size() != 3)) {
3705       return Error(loc,
3706                    "expects padding config pattern like 'low_high_interior' or "
3707                    "'low_high'");
3708     }
3709     auto* dim = padding->add_dimensions();
3710     dim->set_edge_padding_low(padding_dim[0]);
3711     dim->set_edge_padding_high(padding_dim[1]);
3712     dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
3713   }
3714   lexer_.Lex();
3715   return true;
3716 }
3717 
3718 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)3719 bool HloParser::ParseMetadata(OpMetadata* metadata) {
3720   std::unordered_map<string, AttrConfig> attrs;
3721   optional<string> op_type;
3722   optional<string> op_name;
3723   optional<string> source_file;
3724   optional<int32> source_line;
3725   attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
3726   attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
3727   attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
3728   attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
3729   if (!ParseSubAttributes(attrs)) {
3730     return false;
3731   }
3732   if (op_type) {
3733     metadata->set_op_type(*op_type);
3734   }
3735   if (op_name) {
3736     metadata->set_op_name(*op_name);
3737   }
3738   if (source_file) {
3739     metadata->set_source_file(*source_file);
3740   }
3741   if (source_line) {
3742     metadata->set_source_line(*source_line);
3743   }
3744   return true;
3745 }
3746 
ParseOpcode(HloOpcode * result)3747 bool HloParser::ParseOpcode(HloOpcode* result) {
3748   VLOG(1) << "ParseOpcode";
3749   if (lexer_.GetKind() != TokKind::kIdent) {
3750     return TokenError("expects opcode");
3751   }
3752   string val = lexer_.GetStrVal();
3753   auto status_or_result = StringToHloOpcode(val);
3754   if (!status_or_result.ok()) {
3755     return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
3756                                 status_or_result.status().error_message()));
3757   }
3758   *result = status_or_result.ValueOrDie();
3759   lexer_.Lex();
3760   return true;
3761 }
3762 
ParseFftType(FftType * result)3763 bool HloParser::ParseFftType(FftType* result) {
3764   VLOG(1) << "ParseFftType";
3765   if (lexer_.GetKind() != TokKind::kIdent) {
3766     return TokenError("expects fft type");
3767   }
3768   string val = lexer_.GetStrVal();
3769   if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
3770     return TokenError(StrFormat("expects fft type but sees: %s", val));
3771   }
3772   lexer_.Lex();
3773   return true;
3774 }
3775 
ParseComparisonDirection(ComparisonDirection * result)3776 bool HloParser::ParseComparisonDirection(ComparisonDirection* result) {
3777   VLOG(1) << "ParseComparisonDirection";
3778   if (lexer_.GetKind() != TokKind::kIdent) {
3779     return TokenError("expects comparison direction");
3780   }
3781   string val = lexer_.GetStrVal();
3782   auto status_or_result = StringToComparisonDirection(val);
3783   if (!status_or_result.ok()) {
3784     return TokenError(
3785         StrFormat("expects comparison direction but sees: %s", val));
3786   }
3787   *result = status_or_result.ValueOrDie();
3788   lexer_.Lex();
3789   return true;
3790 }
3791 
ParseFusionKind(HloInstruction::FusionKind * result)3792 bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
3793   VLOG(1) << "ParseFusionKind";
3794   if (lexer_.GetKind() != TokKind::kIdent) {
3795     return TokenError("expects fusion kind");
3796   }
3797   string val = lexer_.GetStrVal();
3798   auto status_or_result = StringToFusionKind(val);
3799   if (!status_or_result.ok()) {
3800     return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
3801                                 val,
3802                                 status_or_result.status().error_message()));
3803   }
3804   *result = status_or_result.ValueOrDie();
3805   lexer_.Lex();
3806   return true;
3807 }
3808 
ParseRandomDistribution(RandomDistribution * result)3809 bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
3810   VLOG(1) << "ParseRandomDistribution";
3811   if (lexer_.GetKind() != TokKind::kIdent) {
3812     return TokenError("expects random distribution");
3813   }
3814   string val = lexer_.GetStrVal();
3815   auto status_or_result = StringToRandomDistribution(val);
3816   if (!status_or_result.ok()) {
3817     return TokenError(
3818         StrFormat("expects random distribution but sees: %s, error: %s", val,
3819                   status_or_result.status().error_message()));
3820   }
3821   *result = status_or_result.ValueOrDie();
3822   lexer_.Lex();
3823   return true;
3824 }
3825 
ParsePrecision(PrecisionConfig::Precision * result)3826 bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
3827   VLOG(1) << "ParsePrecision";
3828   if (lexer_.GetKind() != TokKind::kIdent) {
3829     return TokenError("expects random distribution");
3830   }
3831   string val = lexer_.GetStrVal();
3832   auto status_or_result = StringToPrecision(val);
3833   if (!status_or_result.ok()) {
3834     return TokenError(StrFormat("expects precision but sees: %s, error: %s",
3835                                 val,
3836                                 status_or_result.status().error_message()));
3837   }
3838   *result = status_or_result.ValueOrDie();
3839   lexer_.Lex();
3840   return true;
3841 }
3842 
ParseInt64(int64 * result)3843 bool HloParser::ParseInt64(int64* result) {
3844   VLOG(1) << "ParseInt64";
3845   if (lexer_.GetKind() != TokKind::kInt) {
3846     return TokenError("expects integer");
3847   }
3848   *result = lexer_.GetInt64Val();
3849   lexer_.Lex();
3850   return true;
3851 }
3852 
ParseDouble(double * result)3853 bool HloParser::ParseDouble(double* result) {
3854   switch (lexer_.GetKind()) {
3855     case TokKind::kDecimal: {
3856       double val = lexer_.GetDecimalVal();
3857       // If GetDecimalVal returns +/-inf, that means that we overflowed
3858       // `double`.
3859       if (std::isinf(val)) {
3860         return TokenError(StrCat("Constant is out of range for double (+/-",
3861                                  std::numeric_limits<double>::max(),
3862                                  ") and so is unparsable."));
3863       }
3864       *result = val;
3865       break;
3866     }
3867     case TokKind::kInt:
3868       *result = static_cast<double>(lexer_.GetInt64Val());
3869       break;
3870     case TokKind::kw_nan:
3871       *result = std::numeric_limits<double>::quiet_NaN();
3872       break;
3873     case TokKind::kw_inf:
3874       *result = std::numeric_limits<double>::infinity();
3875       break;
3876     case TokKind::kNegInf:
3877       *result = -std::numeric_limits<double>::infinity();
3878       break;
3879     default:
3880       return TokenError("expects decimal or integer");
3881   }
3882   lexer_.Lex();
3883   return true;
3884 }
3885 
ParseComplex(std::complex<double> * result)3886 bool HloParser::ParseComplex(std::complex<double>* result) {
3887   if (lexer_.GetKind() != TokKind::kLparen) {
3888     return TokenError("expects '(' before complex number");
3889   }
3890   lexer_.Lex();
3891 
3892   double real;
3893   LocTy loc = lexer_.GetLoc();
3894   if (!ParseDouble(&real)) {
3895     return Error(loc,
3896                  "expect floating-point value for real part of complex number");
3897   }
3898 
3899   if (lexer_.GetKind() != TokKind::kComma) {
3900     return TokenError(
3901         absl::StrFormat("expect comma after real part of complex literal"));
3902   }
3903   lexer_.Lex();
3904 
3905   double imag;
3906   loc = lexer_.GetLoc();
3907   if (!ParseDouble(&imag)) {
3908     return Error(
3909         loc,
3910         "expect floating-point value for imaginary part of complex number");
3911   }
3912 
3913   if (lexer_.GetKind() != TokKind::kRparen) {
3914     return TokenError(absl::StrFormat("expect ')' after complex number"));
3915   }
3916 
3917   *result = std::complex<double>(real, imag);
3918   lexer_.Lex();
3919   return true;
3920 }
3921 
ParseBool(bool * result)3922 bool HloParser::ParseBool(bool* result) {
3923   if (lexer_.GetKind() != TokKind::kw_true &&
3924       lexer_.GetKind() != TokKind::kw_false) {
3925     return TokenError("expects true or false");
3926   }
3927   *result = lexer_.GetKind() == TokKind::kw_true;
3928   lexer_.Lex();
3929   return true;
3930 }
3931 
ParseToken(TokKind kind,const string & msg)3932 bool HloParser::ParseToken(TokKind kind, const string& msg) {
3933   VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg;
3934   if (lexer_.GetKind() != kind) {
3935     return TokenError(msg);
3936   }
3937   lexer_.Lex();
3938   return true;
3939 }
3940 
EatIfPresent(TokKind kind)3941 bool HloParser::EatIfPresent(TokKind kind) {
3942   if (lexer_.GetKind() != kind) {
3943     return false;
3944   }
3945   lexer_.Lex();
3946   return true;
3947 }
3948 
AddInstruction(const string & name,HloInstruction * instruction,LocTy name_loc)3949 bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
3950                                LocTy name_loc) {
3951   auto result = current_name_table().insert({name, {instruction, name_loc}});
3952   if (!result.second) {
3953     Error(name_loc, StrCat("instruction already exists: ", name));
3954     return Error(/*loc=*/result.first->second.second,
3955                  "instruction previously defined here");
3956   }
3957   return true;
3958 }
3959 
AddComputation(const string & name,HloComputation * computation,LocTy name_loc)3960 bool HloParser::AddComputation(const string& name, HloComputation* computation,
3961                                LocTy name_loc) {
3962   auto result = computation_pool_.insert({name, {computation, name_loc}});
3963   if (!result.second) {
3964     Error(name_loc, StrCat("computation already exists: ", name));
3965     return Error(/*loc=*/result.first->second.second,
3966                  "computation previously defined here");
3967   }
3968   return true;
3969 }
3970 
ParseShapeOnly()3971 StatusOr<Shape> HloParser::ParseShapeOnly() {
3972   lexer_.Lex();
3973   Shape shape;
3974   if (!ParseShape(&shape)) {
3975     return InvalidArgument("Syntax error:\n%s", GetError());
3976   }
3977   if (lexer_.GetKind() != TokKind::kEof) {
3978     return InvalidArgument("Syntax error:\nExtra content after shape");
3979   }
3980   return shape;
3981 }
3982 
ParseShardingOnly()3983 StatusOr<HloSharding> HloParser::ParseShardingOnly() {
3984   lexer_.Lex();
3985   OpSharding op_sharding;
3986   if (!ParseSharding(&op_sharding)) {
3987     return InvalidArgument("Syntax error:\n%s", GetError());
3988   }
3989   if (lexer_.GetKind() != TokKind::kEof) {
3990     return InvalidArgument("Syntax error:\nExtra content after sharding");
3991   }
3992   return HloSharding::FromProto(op_sharding);
3993 }
3994 
ParseParameterReplicationOnly()3995 StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() {
3996   lexer_.Lex();
3997   ParameterReplication parameter_replication;
3998   if (!ParseParameterReplication(&parameter_replication)) {
3999     return InvalidArgument("Syntax error:\n%s", GetError());
4000   }
4001   if (lexer_.GetKind() != TokKind::kEof) {
4002     return InvalidArgument(
4003         "Syntax error:\nExtra content after parameter replication");
4004   }
4005   return std::vector<bool>(
4006       parameter_replication.replicated_at_leaf_buffers().begin(),
4007       parameter_replication.replicated_at_leaf_buffers().end());
4008 }
4009 
ParseWindowOnly()4010 StatusOr<Window> HloParser::ParseWindowOnly() {
4011   lexer_.Lex();
4012   Window window;
4013   if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
4014     return InvalidArgument("Syntax error:\n%s", GetError());
4015   }
4016   if (lexer_.GetKind() != TokKind::kEof) {
4017     return InvalidArgument("Syntax error:\nExtra content after window");
4018   }
4019   return window;
4020 }
4021 
4022 StatusOr<ConvolutionDimensionNumbers>
ParseConvolutionDimensionNumbersOnly()4023 HloParser::ParseConvolutionDimensionNumbersOnly() {
4024   lexer_.Lex();
4025   ConvolutionDimensionNumbers dnums;
4026   if (!ParseConvolutionDimensionNumbers(&dnums)) {
4027     return InvalidArgument("Syntax error:\n%s", GetError());
4028   }
4029   if (lexer_.GetKind() != TokKind::kEof) {
4030     return InvalidArgument(
4031         "Syntax error:\nExtra content after convolution dnums");
4032   }
4033   return dnums;
4034 }
4035 
ParsePaddingConfigOnly()4036 StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
4037   lexer_.Lex();
4038   PaddingConfig padding_config;
4039   if (!ParsePaddingConfig(&padding_config)) {
4040     return InvalidArgument("Syntax error:\n%s", GetError());
4041   }
4042   if (lexer_.GetKind() != TokKind::kEof) {
4043     return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
4044   }
4045   return padding_config;
4046 }
4047 
ParseSingleInstruction(HloModule * module)4048 bool HloParser::ParseSingleInstruction(HloModule* module) {
4049   if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
4050     LOG(FATAL) << "Parser state is not clean. Please do not call any other "
4051                   "methods before calling ParseSingleInstruction.";
4052   }
4053   HloComputation::Builder builder(module->name());
4054 
4055   // The missing instruction hook we register creates the shaped instruction on
4056   // the fly as a parameter and returns it.
4057   int64 parameter_count = 0;
4058   create_missing_instruction_ =
4059       [this, &builder, &parameter_count](
4060           const string& name,
4061           const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
4062     string new_name = name.empty() ? StrCat("_", parameter_count) : name;
4063     HloInstruction* parameter = builder.AddInstruction(
4064         HloInstruction::CreateParameter(parameter_count++, shape, new_name));
4065     current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
4066     return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
4067   };
4068 
4069   // Parse the instruction with the registered hook.
4070   Scope scope(&scoped_name_tables_);
4071   if (CanBeShape()) {
4072     // This means that the instruction's left-hand side is probably omitted,
4073     // e.g.
4074     //
4075     //  f32[10] fusion(...), calls={...}
4076     if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
4077       return false;
4078     }
4079   } else {
4080     // This means that the instruction's left-hand side might exist, e.g.
4081     //
4082     //  foo = f32[10] fusion(...), calls={...}
4083     string root_name;
4084     if (!ParseInstruction(&builder, &root_name)) {
4085       return false;
4086     }
4087   }
4088 
4089   module->AddEntryComputation(builder.Build());
4090   for (auto& comp : computations_) {
4091     module->AddEmbeddedComputation(std::move(comp));
4092   }
4093   return true;
4094 }
4095 
4096 }  // namespace
4097 
ParseHloString(absl::string_view str,const HloModuleConfig & config)4098 StatusOr<std::unique_ptr<HloModule>> ParseHloString(
4099     absl::string_view str, const HloModuleConfig& config) {
4100   auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
4101   HloParser parser(str);
4102   TF_RETURN_IF_ERROR(parser.Run(module.get()));
4103   return std::move(module);
4104 }
4105 
ParseHloString(absl::string_view str)4106 StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
4107   auto module = absl::make_unique<HloModule>(/*name=*/"_", HloModuleConfig());
4108   HloParser parser(str);
4109   TF_RETURN_IF_ERROR(parser.Run(module.get()));
4110   return std::move(module);
4111 }
4112 
ParseHloString(absl::string_view str,HloModule * module)4113 Status ParseHloString(absl::string_view str, HloModule* module) {
4114   TF_RET_CHECK(module->computation_count() == 0);
4115   HloParser parser(str);
4116   TF_RETURN_IF_ERROR(parser.Run(module));
4117   return Status::OK();
4118 }
4119 
ParseSharding(absl::string_view str)4120 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
4121   HloParser parser(str);
4122   return parser.ParseShardingOnly();
4123 }
4124 
ParseParameterReplication(absl::string_view str)4125 StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
4126   HloParser parser(str);
4127   return parser.ParseParameterReplicationOnly();
4128 }
4129 
ParseWindow(absl::string_view str)4130 StatusOr<Window> ParseWindow(absl::string_view str) {
4131   HloParser parser(str);
4132   return parser.ParseWindowOnly();
4133 }
4134 
ParseConvolutionDimensionNumbers(absl::string_view str)4135 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
4136     absl::string_view str) {
4137   HloParser parser(str);
4138   return parser.ParseConvolutionDimensionNumbersOnly();
4139 }
4140 
ParsePaddingConfig(absl::string_view str)4141 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
4142   HloParser parser(str);
4143   return parser.ParsePaddingConfigOnly();
4144 }
4145 
ParseShape(absl::string_view str)4146 StatusOr<Shape> ParseShape(absl::string_view str) {
4147   HloParser parser(str);
4148   return parser.ParseShapeOnly();
4149 }
4150 
4151 }  // namespace xla
4152