• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17 
18 #include <memory>
19 #include <string>
20 #include <type_traits>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/numbers.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/span.h"
35 #include "absl/types/variant.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/literal_util.h"
38 #include "tensorflow/compiler/xla/primitive_util.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
41 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
48 #include "tensorflow/compiler/xla/service/shape_inference.h"
49 #include "tensorflow/compiler/xla/shape_util.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/platform/macros.h"
54 #include "tensorflow/core/platform/protobuf.h"
55 
56 namespace xla {
57 
58 namespace {
59 
60 using absl::nullopt;
61 using absl::optional;
62 using absl::StrAppend;
63 using absl::StrCat;
64 using absl::StrFormat;
65 using absl::StrJoin;
66 
67 // Creates and returns a schedule created using the order of the instructions in
68 // the HloComputation::instructions() vectors in the module.
ScheduleFromInstructionOrder(HloModule * module)69 HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
70   HloSchedule schedule(module);
71   for (HloComputation* computation : module->computations()) {
72     if (!computation->IsFusionComputation()) {
73       for (HloInstruction* instruction : computation->instructions()) {
74         schedule.GetOrCreateSequence(computation).push_back(instruction);
75       }
76     }
77   }
78   return schedule;
79 }
80 
CanInferShape(HloOpcode code)81 bool CanInferShape(HloOpcode code) {
82   switch (code) {
83     case HloOpcode::kAbs:
84     case HloOpcode::kAdd:
85     case HloOpcode::kAddDependency:
86     case HloOpcode::kAfterAll:
87     case HloOpcode::kAtan2:
88     case HloOpcode::kBatchNormGrad:
89     case HloOpcode::kBatchNormInference:
90     case HloOpcode::kBatchNormTraining:
91     case HloOpcode::kBroadcast:
92     case HloOpcode::kCall:
93     case HloOpcode::kCeil:
94     case HloOpcode::kCholesky:
95     case HloOpcode::kClamp:
96     case HloOpcode::kClz:
97     case HloOpcode::kCompare:
98     case HloOpcode::kComplex:
99     case HloOpcode::kConcatenate:
100     case HloOpcode::kConditional:
101     case HloOpcode::kConvolution:
102     case HloOpcode::kCopy:
103     case HloOpcode::kCos:
104     case HloOpcode::kDivide:
105     case HloOpcode::kDomain:
106     case HloOpcode::kDot:
107     case HloOpcode::kExp:
108     case HloOpcode::kExpm1:
109     case HloOpcode::kFft:
110     case HloOpcode::kFloor:
111     case HloOpcode::kGather:
112     case HloOpcode::kGetDimensionSize:
113     case HloOpcode::kSetDimensionSize:
114     case HloOpcode::kGetTupleElement:
115     case HloOpcode::kImag:
116     case HloOpcode::kIsFinite:
117     case HloOpcode::kLog:
118     case HloOpcode::kLog1p:
119     case HloOpcode::kLogistic:
120     case HloOpcode::kAnd:
121     case HloOpcode::kNot:
122     case HloOpcode::kOr:
123     case HloOpcode::kXor:
124     case HloOpcode::kMap:
125     case HloOpcode::kMaximum:
126     case HloOpcode::kMinimum:
127     case HloOpcode::kMultiply:
128     case HloOpcode::kNegate:
129     case HloOpcode::kPad:
130     case HloOpcode::kPartitionId:
131     case HloOpcode::kPopulationCount:
132     case HloOpcode::kPower:
133     case HloOpcode::kReal:
134     case HloOpcode::kReduce:
135     case HloOpcode::kRemainder:
136     case HloOpcode::kReplicaId:
137     case HloOpcode::kReverse:
138     case HloOpcode::kRoundNearestAfz:
139     case HloOpcode::kRsqrt:
140     case HloOpcode::kScatter:
141     case HloOpcode::kSelect:
142     case HloOpcode::kShiftLeft:
143     case HloOpcode::kShiftRightArithmetic:
144     case HloOpcode::kShiftRightLogical:
145     case HloOpcode::kSign:
146     case HloOpcode::kSin:
147     case HloOpcode::kSqrt:
148     case HloOpcode::kCbrt:
149     case HloOpcode::kReduceWindow:
150     case HloOpcode::kSelectAndScatter:
151     case HloOpcode::kSort:
152     case HloOpcode::kSubtract:
153     case HloOpcode::kTanh:
154     case HloOpcode::kTrace:
155     case HloOpcode::kTranspose:
156     case HloOpcode::kTriangularSolve:
157     case HloOpcode::kTuple:
158     case HloOpcode::kTupleSelect:
159     case HloOpcode::kWhile:
160       return true;
161     // Technically the following ops do not require an explicit result shape,
162     // but we made it so that we always write the shapes explicitly.
163     case HloOpcode::kAllGather:
164     case HloOpcode::kAllReduce:
165     case HloOpcode::kAllToAll:
166     case HloOpcode::kCollectivePermute:
167     case HloOpcode::kCollectivePermuteStart:
168     case HloOpcode::kCollectivePermuteDone:
169     case HloOpcode::kCopyDone:
170     case HloOpcode::kCopyStart:
171     case HloOpcode::kDynamicReshape:
172     case HloOpcode::kDynamicSlice:
173     case HloOpcode::kDynamicUpdateSlice:
174     case HloOpcode::kRecv:
175     case HloOpcode::kRecvDone:
176     case HloOpcode::kSend:
177     case HloOpcode::kSendDone:
178     case HloOpcode::kSlice:
179     // The following ops require an explicit result shape.
180     case HloOpcode::kBitcast:
181     case HloOpcode::kBitcastConvert:
182     case HloOpcode::kConstant:
183     case HloOpcode::kConvert:
184     case HloOpcode::kCustomCall:
185     case HloOpcode::kFusion:
186     case HloOpcode::kInfeed:
187     case HloOpcode::kIota:
188     case HloOpcode::kOutfeed:
189     case HloOpcode::kParameter:
190     case HloOpcode::kReducePrecision:
191     case HloOpcode::kReshape:
192     case HloOpcode::kRng:
193     case HloOpcode::kRngBitGenerator:
194     case HloOpcode::kRngGetAndUpdateState:
195       return false;
196   }
197 }
198 
199 // Parser for the HloModule::ToString() format text.
200 class HloParserImpl : public HloParser {
201  public:
202   using LocTy = HloLexer::LocTy;
203 
HloParserImpl(absl::string_view str)204   explicit HloParserImpl(absl::string_view str) : lexer_(str) {}
205 
206   // Runs the parser and constructs the resulting HLO in the given (empty)
207   // HloModule. Returns the error status in case an error occurred.
208   Status Run(HloModule* module) override;
209 
210   // Returns the error information.
GetError() const211   std::string GetError() const { return StrJoin(error_, "\n"); }
212 
213   // Stand alone parsing utils for various aggregate data types.
214   StatusOr<Shape> ParseShapeOnly();
215   StatusOr<HloSharding> ParseShardingOnly();
216   StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
217   StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
218   StatusOr<Window> ParseWindowOnly();
219   StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
220   StatusOr<PaddingConfig> ParsePaddingConfigOnly();
221   StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
222 
223  private:
224   using InstrNameTable =
225       absl::flat_hash_map<std::string, std::pair<HloInstruction*, LocTy>>;
226 
227   // Returns the map from the instruction name to the instruction itself and its
228   // location in the current scope.
current_name_table()229   InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
230 
231   // Locates an instruction with the given name in the current_name_table() or
232   // returns nullptr.
233   //
234   // When the name is not found or name is empty, if create_missing_instruction_
235   // hook is registered and a "shape" is provided, the hook will be called to
236   // create an instruction. This is useful when we reify parameters as they're
237   // resolved; i.e. for ParseSingleInstruction.
238   std::pair<HloInstruction*, LocTy>* FindInstruction(
239       const std::string& name, const optional<Shape>& shape = nullopt);
240 
241   // Parse a single instruction worth of text.
242   bool ParseSingleInstruction(HloModule* module);
243 
244   // Parses a module, returning false if an error occurred.
245   bool ParseHloModule(HloModule* module);
246 
247   bool ParseComputations(HloModule* module);
248   bool ParseComputation(HloComputation** entry_computation);
249   bool ParseInstructionList(HloComputation** computation,
250                             const std::string& computation_name);
251   bool ParseInstruction(HloComputation::Builder* builder,
252                         std::string* root_name);
253   bool ParseInstructionRhs(HloComputation::Builder* builder,
254                            const std::string& name, LocTy name_loc);
255   bool ParseControlPredecessors(HloInstruction* instruction);
256   bool ParseLiteral(Literal* literal);
257   bool ParseLiteral(Literal* literal, const Shape& shape);
258   bool ParseTupleLiteral(Literal* literal, const Shape& shape);
259   bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
260   bool ParseDenseLiteral(Literal* literal, const Shape& shape);
261 
262   // Sets the sub-value of literal at the given linear index to the
263   // given value. If the literal is dense, it must have the default layout.
264   //
265   // `loc` should be the source location of the value.
266   bool SetValueInLiteral(LocTy loc, int64 value, int64 index, Literal* literal);
267   bool SetValueInLiteral(LocTy loc, double value, int64 index,
268                          Literal* literal);
269   bool SetValueInLiteral(LocTy loc, bool value, int64 index, Literal* literal);
270   bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64 index,
271                          Literal* literal);
272   // `loc` should be the source location of the value.
273   template <typename LiteralNativeT, typename ParsedElemT>
274   bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64 index,
275                                Literal* literal);
276 
277   // Checks whether the given value is within the range of LiteralNativeT.
278   // `loc` should be the source location of the value.
279   template <typename LiteralNativeT, typename ParsedElemT>
280   bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
281   template <typename LiteralNativeT>
282   bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
283 
284   bool ParseOperands(std::vector<HloInstruction*>* operands);
285   // Fills parsed operands into 'operands' and expects a certain number of
286   // operands.
287   bool ParseOperands(std::vector<HloInstruction*>* operands,
288                      const int expected_size);
289 
290   // Describes the start, limit, and stride on every dimension of the operand
291   // being sliced.
292   struct SliceRanges {
293     std::vector<int64> starts;
294     std::vector<int64> limits;
295     std::vector<int64> strides;
296   };
297 
298   // The data parsed for the kDomain instruction.
299   struct DomainData {
300     std::unique_ptr<DomainMetadata> entry_metadata;
301     std::unique_ptr<DomainMetadata> exit_metadata;
302   };
303 
304   // Types of attributes.
305   enum class AttrTy {
306     kBool,
307     kInt64,
308     kInt32,
309     kFloat,
310     kString,
311     kLiteral,
312     kBracedInt64List,
313     kBracedInt64ListList,
314     kHloComputation,
315     kBracedHloComputationList,
316     kFftType,
317     kPaddingType,
318     kComparisonDirection,
319     kComparisonType,
320     kWindow,
321     kConvolutionDimensionNumbers,
322     kSharding,
323     kFrontendAttributes,
324     kParameterReplication,
325     kInstructionList,
326     kSliceRanges,
327     kPaddingConfig,
328     kMetadata,
329     kFusionKind,
330     kDistribution,
331     kDomain,
332     kPrecisionList,
333     kShape,
334     kShapeList,
335     kEnum,
336     kRandomAlgorithm,
337     kAliasing,
338     kInstructionAliasing,
339   };
340 
341   struct AttrConfig {
342     bool required;     // whether it's required or optional
343     AttrTy attr_type;  // what type it is
344     void* result;      // where to store the parsed result.
345   };
346 
347   // attributes ::= (',' attribute)*
348   //
349   // Parses attributes given names and configs of the attributes. Each parsed
350   // result is passed back through the result pointer in corresponding
351   // AttrConfig. Note that the result pointer must point to a optional<T> typed
352   // variable which outlives this function. Returns false on error. You should
353   // not use the any of the results if this function failed.
354   //
355   // Example usage:
356   //
357   //  absl::flat_hash_map<std::string, AttrConfig> attrs;
358   //  optional<int64> foo;
359   //  attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
360   //  optional<Window> bar;
361   //  attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
362   //  if (!ParseAttributes(attrs)) {
363   //    return false; // Do not use 'foo' 'bar' if failed.
364   //  }
365   //  // Do something with 'bar'.
366   //  if (foo) { // If attr foo is seen, do something with 'foo'. }
367   //
368   bool ParseAttributes(
369       const absl::flat_hash_map<std::string, AttrConfig>& attrs);
370 
371   // sub_attributes ::= '{' (','? attribute)* '}'
372   //
373   // Usage is the same as ParseAttributes. See immediately above.
374   bool ParseSubAttributes(
375       const absl::flat_hash_map<std::string, AttrConfig>& attrs);
376 
377   // Parses one attribute. If it has already been seen, return error. Returns
378   // true and adds to seen_attrs on success.
379   //
380   // Do not call this except in ParseAttributes or ParseSubAttributes.
381   bool ParseAttributeHelper(
382       const absl::flat_hash_map<std::string, AttrConfig>& attrs,
383       absl::flat_hash_set<std::string>* seen_attrs);
384 
385   // Copy attributes from `attrs` to `message`, unless the attribute name is in
386   // `non_proto_attrs`.
387   bool CopyAttributeToProtoMessage(
388       absl::flat_hash_set<std::string> non_proto_attrs,
389       const absl::flat_hash_map<std::string, AttrConfig>& attrs,
390       tensorflow::protobuf::Message* message);
391 
392   // Parses an attribute string into a protocol buffer `message`.
393   // Since proto3 has no notion of mandatory fields, `required_attrs` gives the
394   // set of mandatory attributes.
395   // `non_proto_attrs` specifies attributes that are not written to the proto,
396   // but added to the HloInstruction.
397   bool ParseAttributesAsProtoMessage(
398       const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
399       tensorflow::protobuf::Message* message);
400 
401   // Parses a name and finds the corresponding hlo computation.
402   bool ParseComputationName(HloComputation** value);
403   // Parses a list of names and finds the corresponding hlo instructions.
404   bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
405   // Pass expect_outer_curlies == true when parsing a Window in the context of a
406   // larger computation.  Pass false when parsing a stand-alone Window string.
407   bool ParseWindow(Window* window, bool expect_outer_curlies);
408   bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
409   bool ParsePaddingConfig(PaddingConfig* padding);
410   bool ParseMetadata(OpMetadata* metadata);
411   bool ParseSingleOrListMetadata(
412       tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata);
413   bool ParseSharding(OpSharding* sharding);
414   bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
415   bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
416   bool ParseParameterReplication(ParameterReplication* parameter_replication);
417   bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
418 
419   // Parses the metadata behind a kDOmain instruction.
420   bool ParseDomain(DomainData* domain);
421 
422   // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
423   bool ParseDxD(const std::string& name, std::vector<int64>* result);
424   // Parses window's pad sub-attribute, e.g., pad=0_0x3x3.
425   bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
426 
427   bool ParseSliceRanges(SliceRanges* result);
428   bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
429   bool ParseHloComputation(HloComputation** result);
430   bool ParseHloComputationList(std::vector<HloComputation*>* result);
431   bool ParseShapeList(std::vector<Shape>* result);
432   bool ParseInt64List(const TokKind start, const TokKind end,
433                       const TokKind delim, std::vector<int64>* result);
434   bool ParseInt64ListList(const TokKind start, const TokKind end,
435                           const TokKind delim,
436                           std::vector<std::vector<int64>>* result);
437   // 'parse_and_add_item' is an lambda to parse an element in the list and add
438   // the parsed element to the result. It's supposed to capture the result.
439   bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
440                  const std::function<bool()>& parse_and_add_item);
441 
442   bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
443   bool ParseParamList();
444   bool ParseName(std::string* result);
445   bool ParseAttributeName(std::string* result);
446   bool ParseString(std::string* result);
447   bool ParseDimensionSizes(std::vector<int64>* dimension_sizes,
448                            std::vector<bool>* dynamic_dimensions);
449   bool ParseShape(Shape* result);
450   bool ParseLayout(Layout* layout);
451   bool ParseLayoutIntAttribute(int64* attr_value,
452                                absl::string_view attr_description);
453   bool ParseTiles(std::vector<Tile>* tiles);
454   bool ParseOpcode(HloOpcode* result);
455   bool ParseFftType(FftType* result);
456   bool ParsePaddingType(PaddingType* result);
457   bool ParseComparisonDirection(ComparisonDirection* result);
458   bool ParseComparisonType(Comparison::Type* result);
459   bool ParseFusionKind(HloInstruction::FusionKind* result);
460   bool ParseRandomDistribution(RandomDistribution* result);
461   bool ParseRandomAlgorithm(RandomAlgorithm* result);
462   bool ParsePrecision(PrecisionConfig::Precision* result);
463   bool ParseInt64(int64* result);
464   bool ParseDouble(double* result);
465   bool ParseComplex(std::complex<double>* result);
466   bool ParseBool(bool* result);
467   bool ParseToken(TokKind kind, const std::string& msg);
468 
469   using AliasingData =
470       absl::flat_hash_map<ShapeIndex, HloInputOutputAliasConfig::Alias>;
471 
472   // Parses the aliasing information from string `s`, returns `false` if it
473   // fails.
474   bool ParseAliasing(AliasingData* data);
475 
476   // Parses the per-instruction aliasing information from string `s`, returns
477   // `false` if it fails.
478   bool ParseInstructionOutputOperandAliasing(
479       std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
480           aliasing_output_operand_pairs);
481 
482   bool ParseShapeIndex(ShapeIndex* out);
483 
484   // Returns true if the current token is the beginning of a shape.
485   bool CanBeShape();
486   // Returns true if the current token is the beginning of a
487   // param_list_to_shape.
488   bool CanBeParamListToShape();
489 
490   // Logs the current parsing line and the given message. Always returns false.
491   bool TokenError(absl::string_view msg);
492   bool Error(LocTy loc, absl::string_view msg);
493 
494   // If the current token is 'kind', eats it (i.e. lexes the next token) and
495   // returns true.
496   bool EatIfPresent(TokKind kind);
497 
498   // Adds the instruction to the pool. Returns false and emits an error if the
499   // instruction already exists.
500   bool AddInstruction(const std::string& name, HloInstruction* instruction,
501                       LocTy name_loc);
502   // Adds the computation to the pool. Returns false and emits an error if the
503   // computation already exists.
504   bool AddComputation(const std::string& name, HloComputation* computation,
505                       LocTy name_loc);
506 
507   HloLexer lexer_;
508 
509   // A stack for the instruction names. The top of the stack stores the
510   // instruction name table for the current scope.
511   //
512   // A instruction's name is unique among its scope (i.e. its parent
513   // computation), but it's not necessarily unique among all computations in the
514   // module. When there are multiple levels of nested computations, the same
515   // name could appear in both an outer computation and an inner computation. So
516   // we need a stack to make sure a name is only visible within its scope,
517   std::vector<InstrNameTable> scoped_name_tables_;
518 
519   // A helper class which pushes and pops to an InstrNameTable stack via RAII.
520   class Scope {
521    public:
Scope(std::vector<InstrNameTable> * scoped_name_tables)522     explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
523         : scoped_name_tables_(scoped_name_tables) {
524       scoped_name_tables_->emplace_back();
525     }
~Scope()526     ~Scope() { scoped_name_tables_->pop_back(); }
527 
528    private:
529     std::vector<InstrNameTable>* scoped_name_tables_;
530   };
531 
532   // Map from the computation name to the computation itself and its location.
533   absl::flat_hash_map<std::string, std::pair<HloComputation*, LocTy>>
534       computation_pool_;
535 
536   std::vector<std::unique_ptr<HloComputation>> computations_;
537   std::vector<std::string> error_;
538 
539   // When an operand name cannot be resolved, this function is called to create
540   // a parameter instruction with the given name and shape. It registers the
541   // name, instruction, and a placeholder location in the name table. It returns
542   // the newly-created instruction and the placeholder location. If `name` is
543   // empty, this should create the parameter with a generated name. This is
544   // supposed to be set and used only in ParseSingleInstruction.
545   std::function<std::pair<HloInstruction*, LocTy>*(const std::string& name,
546                                                    const Shape& shape)>
547       create_missing_instruction_;
548 };
549 
SplitToInt64s(absl::string_view s,char delim,std::vector<int64> * out)550 bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
551   for (const auto& split : absl::StrSplit(s, delim)) {
552     int64 val;
553     if (!absl::SimpleAtoi(split, &val)) {
554       return false;
555     }
556     out->push_back(val);
557   }
558   return true;
559 }
560 
561 // Creates replica groups from the provided nested array. groups[i] represents
562 // the replica ids for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64>> groups)563 std::vector<ReplicaGroup> CreateReplicaGroups(
564     absl::Span<const std::vector<int64>> groups) {
565   std::vector<ReplicaGroup> replica_groups;
566   absl::c_transform(groups, std::back_inserter(replica_groups),
567                     [](const std::vector<int64>& ids) {
568                       ReplicaGroup group;
569                       *group.mutable_replica_ids() = {ids.begin(), ids.end()};
570                       return group;
571                     });
572   return replica_groups;
573 }
574 
Error(LocTy loc,absl::string_view msg)575 bool HloParserImpl::Error(LocTy loc, absl::string_view msg) {
576   auto line_col = lexer_.GetLineAndColumn(loc);
577   const unsigned line = line_col.first;
578   const unsigned col = line_col.second;
579   std::vector<std::string> error_lines;
580   error_lines.push_back(
581       StrCat("was parsing ", line, ":", col, ": error: ", msg));
582   error_lines.emplace_back(lexer_.GetLine(loc));
583   error_lines.push_back(col == 0 ? "" : StrCat(std::string(col - 1, ' '), "^"));
584 
585   error_.push_back(StrJoin(error_lines, "\n"));
586   VLOG(1) << "Error: " << error_.back();
587   return false;
588 }
589 
TokenError(absl::string_view msg)590 bool HloParserImpl::TokenError(absl::string_view msg) {
591   return Error(lexer_.GetLoc(), msg);
592 }
593 
Run(HloModule * module)594 Status HloParserImpl::Run(HloModule* module) {
595   lexer_.Lex();
596   if (lexer_.GetKind() == TokKind::kw_HloModule) {
597     // This means that the text contains a full HLO module.
598     if (!ParseHloModule(module)) {
599       return InvalidArgument(
600           "Syntax error when trying to parse the text as a HloModule:\n%s",
601           GetError());
602     }
603     return Status::OK();
604   }
605   // This means that the text is a single HLO instruction.
606   if (!ParseSingleInstruction(module)) {
607     return InvalidArgument(
608         "Syntax error when trying to parse the text as a single "
609         "HloInstruction:\n%s",
610         GetError());
611   }
612   return Status::OK();
613 }
614 
615 std::pair<HloInstruction*, HloParserImpl::LocTy>*
FindInstruction(const std::string & name,const optional<Shape> & shape)616 HloParserImpl::FindInstruction(const std::string& name,
617                                const optional<Shape>& shape) {
618   std::pair<HloInstruction*, LocTy>* instr = nullptr;
619   if (!name.empty()) {
620     instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
621   }
622 
623   // Potentially call the missing instruction hook.
624   if (instr == nullptr && create_missing_instruction_ != nullptr &&
625       scoped_name_tables_.size() == 1) {
626     if (!shape.has_value()) {
627       Error(lexer_.GetLoc(),
628             "Operand had no shape in HLO text; cannot create parameter for "
629             "single-instruction module.");
630       return nullptr;
631     }
632     return create_missing_instruction_(name, *shape);
633   }
634 
635   if (instr != nullptr && shape.has_value() &&
636       !ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
637     Error(
638         lexer_.GetLoc(),
639         StrCat("The declared operand shape ",
640                ShapeUtil::HumanStringWithLayout(shape.value()),
641                " is not compatible with the shape of the operand instruction ",
642                ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
643     return nullptr;
644   }
645 
646   return instr;
647 }
648 
ParseShapeIndex(ShapeIndex * out)649 bool HloParserImpl::ParseShapeIndex(ShapeIndex* out) {
650   if (!ParseToken(TokKind::kLbrace, "Expects '{' at the start of ShapeIndex")) {
651     return false;
652   }
653 
654   std::vector<int64> idxs;
655   while (lexer_.GetKind() != TokKind::kRbrace) {
656     int64 idx;
657     if (!ParseInt64(&idx)) {
658       return false;
659     }
660     idxs.push_back(idx);
661     if (!EatIfPresent(TokKind::kComma)) {
662       break;
663     }
664   }
665   if (!ParseToken(TokKind::kRbrace, "Expects '}' at the end of ShapeIndex")) {
666     return false;
667   }
668   *out = ShapeIndex(idxs.begin(), idxs.end());
669   return true;
670 }
671 
ParseAliasing(AliasingData * data)672 bool HloParserImpl::ParseAliasing(AliasingData* data) {
673   if (!ParseToken(TokKind::kLbrace,
674                   "Expects '{' at the start of aliasing description")) {
675     return false;
676   }
677 
678   while (lexer_.GetKind() != TokKind::kRbrace) {
679     ShapeIndex out;
680     if (!ParseShapeIndex(&out)) {
681       return false;
682     }
683     std::string errmsg =
684         "Expected format: <output_shape_index>: (<input_param>, "
685         "<input_param_shape_index>) OR <output_shape_index>: <input_param>";
686     if (!ParseToken(TokKind::kColon, errmsg)) {
687       return false;
688     }
689 
690     if (!ParseToken(TokKind::kLparen, errmsg)) {
691       return false;
692     }
693     int64 param_num;
694     ParseInt64(&param_num);
695     if (!ParseToken(TokKind::kComma, errmsg)) {
696       return false;
697     }
698     ShapeIndex param_idx;
699     if (!ParseShapeIndex(&param_idx)) {
700       return false;
701     }
702 
703     HloInputOutputAliasConfig::AliasKind alias_kind =
704         HloInputOutputAliasConfig::kMayAlias;
705     if (EatIfPresent(TokKind::kComma)) {
706       std::string type;
707       ParseName(&type);
708       if (type == "must-alias") {
709         alias_kind = HloInputOutputAliasConfig::kMustAlias;
710       } else if (type == "may-alias") {
711         alias_kind = HloInputOutputAliasConfig::kMayAlias;
712       } else {
713         return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
714       }
715     }
716 
717     data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
718                   std::forward_as_tuple(param_num, param_idx, alias_kind));
719     if (!ParseToken(TokKind::kRparen, errmsg)) {
720       return false;
721     }
722 
723     if (!EatIfPresent(TokKind::kComma)) {
724       break;
725     }
726   }
727   if (!ParseToken(TokKind::kRbrace,
728                   "Expects '}' at the end of aliasing description")) {
729     return false;
730   }
731   return true;
732 }
733 
ParseInstructionOutputOperandAliasing(std::vector<std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> * aliasing_output_operand_pairs)734 bool HloParserImpl::ParseInstructionOutputOperandAliasing(
735     std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
736         aliasing_output_operand_pairs) {
737   if (!ParseToken(
738           TokKind::kLbrace,
739           "Expects '{' at the start of instruction aliasing description")) {
740     return false;
741   }
742 
743   while (lexer_.GetKind() != TokKind::kRbrace) {
744     ShapeIndex out;
745     if (!ParseShapeIndex(&out)) {
746       return false;
747     }
748     std::string errmsg =
749         "Expected format: <output_shape_index>: (<operand_index>, "
750         "<operand_shape_index>)";
751     if (!ParseToken(TokKind::kColon, errmsg)) {
752       return false;
753     }
754 
755     if (!ParseToken(TokKind::kLparen, errmsg)) {
756       return false;
757     }
758     int64 operand_index;
759     ParseInt64(&operand_index);
760     if (!ParseToken(TokKind::kComma, errmsg)) {
761       return false;
762     }
763     ShapeIndex operand_shape_index;
764     if (!ParseShapeIndex(&operand_shape_index)) {
765       return false;
766     }
767 
768     aliasing_output_operand_pairs->emplace_back(
769         out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
770     if (!ParseToken(TokKind::kRparen, errmsg)) {
771       return false;
772     }
773 
774     if (!EatIfPresent(TokKind::kComma)) {
775       break;
776     }
777   }
778   if (!ParseToken(
779           TokKind::kRbrace,
780           "Expects '}' at the end of instruction aliasing description")) {
781     return false;
782   }
783   return true;
784 }
785 
786 // ::= 'HloModule' name computations
ParseHloModule(HloModule * module)787 bool HloParserImpl::ParseHloModule(HloModule* module) {
788   if (lexer_.GetKind() != TokKind::kw_HloModule) {
789     return TokenError("expects HloModule");
790   }
791   // Eat 'HloModule'
792   lexer_.Lex();
793 
794   std::string name;
795   if (!ParseName(&name)) {
796     return false;
797   }
798 
799   absl::optional<bool> is_scheduled;
800   absl::optional<AliasingData> aliasing_data;
801   absl::flat_hash_map<std::string, AttrConfig> attrs;
802 
803   attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
804   attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing,
805                                  &aliasing_data};
806   if (!ParseAttributes(attrs)) {
807     return false;
808   }
809   module->set_name(name);
810   if (!ParseComputations(module)) {
811     return false;
812   }
813 
814   if (is_scheduled.has_value() && *is_scheduled) {
815     TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
816   }
817   if (aliasing_data) {
818     HloInputOutputAliasConfig alias_config(module->result_shape());
819     for (auto& p : *aliasing_data) {
820       Status st =
821           alias_config.SetUpAlias(p.first, p.second.parameter_number,
822                                   p.second.parameter_index, p.second.kind);
823       if (!st.ok()) {
824         return TokenError(st.error_message());
825       }
826     }
827     module->input_output_alias_config() = alias_config;
828   }
829 
830   return true;
831 }
832 
833 // computations ::= (computation)+
ParseComputations(HloModule * module)834 bool HloParserImpl::ParseComputations(HloModule* module) {
835   HloComputation* entry_computation = nullptr;
836   do {
837     if (!ParseComputation(&entry_computation)) {
838       return false;
839     }
840   } while (lexer_.GetKind() != TokKind::kEof);
841 
842   for (int i = 0; i < computations_.size(); i++) {
843     // If entry_computation is not nullptr, it means the computation it pointed
844     // to is marked with "ENTRY"; otherwise, no computation is marked with
845     // "ENTRY", and we use the last computation as the entry computation. We
846     // add the non-entry computations as embedded computations to the module.
847     if ((entry_computation != nullptr &&
848          computations_[i].get() != entry_computation) ||
849         (entry_computation == nullptr && i != computations_.size() - 1)) {
850       module->AddEmbeddedComputation(std::move(computations_[i]));
851       continue;
852     }
853     auto computation = module->AddEntryComputation(std::move(computations_[i]));
854     // The parameters and result layouts were set to default layout. Here we
855     // set the layouts to what the hlo text says.
856     for (int p = 0; p < computation->num_parameters(); p++) {
857       const Shape& param_shape = computation->parameter_instruction(p)->shape();
858       TF_CHECK_OK(module->mutable_entry_computation_layout()
859                       ->mutable_parameter_layout(p)
860                       ->CopyLayoutFromShape(param_shape));
861     }
862     const Shape& result_shape = computation->root_instruction()->shape();
863     TF_CHECK_OK(module->mutable_entry_computation_layout()
864                     ->mutable_result_layout()
865                     ->CopyLayoutFromShape(result_shape));
866   }
867   return true;
868 }
869 
870 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
ParseComputation(HloComputation ** entry_computation)871 bool HloParserImpl::ParseComputation(HloComputation** entry_computation) {
872   LocTy maybe_entry_loc = lexer_.GetLoc();
873   const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
874 
875   std::string name;
876   LocTy name_loc = lexer_.GetLoc();
877   if (!ParseName(&name)) {
878     return false;
879   }
880 
881   LocTy shape_loc = nullptr;
882   Shape shape;
883   if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
884     return false;
885   }
886 
887   HloComputation* computation = nullptr;
888   if (!ParseInstructionList(&computation, name)) {
889     return false;
890   }
891 
892   // If param_list_to_shape was present, check compatibility.
893   if (shape_loc != nullptr &&
894       !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
895     return Error(
896         shape_loc,
897         StrCat(
898             "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
899             ", is not compatible with that of its root instruction ",
900             computation->root_instruction()->name(), ", ",
901             ShapeUtil::HumanString(computation->root_instruction()->shape())));
902   }
903 
904   if (is_entry_computation) {
905     if (*entry_computation != nullptr) {
906       return Error(maybe_entry_loc, "expects only one ENTRY");
907     }
908     *entry_computation = computation;
909   }
910 
911   return AddComputation(name, computation, name_loc);
912 }
913 
914 // instruction_list ::= '{' instruction_list1 '}'
915 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation ** computation,const std::string & computation_name)916 bool HloParserImpl::ParseInstructionList(HloComputation** computation,
917                                          const std::string& computation_name) {
918   Scope scope(&scoped_name_tables_);
919   HloComputation::Builder builder(computation_name);
920   if (!ParseToken(TokKind::kLbrace,
921                   "expects '{' at the beginning of instruction list.")) {
922     return false;
923   }
924   std::string root_name;
925   do {
926     if (!ParseInstruction(&builder, &root_name)) {
927       return false;
928     }
929   } while (lexer_.GetKind() != TokKind::kRbrace);
930   if (!ParseToken(TokKind::kRbrace,
931                   "expects '}' at the end of instruction list.")) {
932     return false;
933   }
934   HloInstruction* root = nullptr;
935   if (!root_name.empty()) {
936     std::pair<HloInstruction*, LocTy>* root_node =
937         tensorflow::gtl::FindOrNull(current_name_table(), root_name);
938 
939     // This means some instruction was marked as ROOT but we didn't find it in
940     // the pool, which should not happen.
941     if (root_node == nullptr) {
942       // LOG(FATAL) crashes the program by calling abort().
943       LOG(FATAL) << "instruction " << root_name
944                  << " was marked as ROOT but the parser has not seen it before";
945     }
946     root = root_node->first;
947   }
948 
949   // Now root can be either an existing instruction or a nullptr. If it's a
950   // nullptr, the implementation of Builder will set the last instruction as
951   // the root instruction.
952   computations_.emplace_back(builder.Build(root));
953   *computation = computations_.back().get();
954   return true;
955 }
956 
957 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,std::string * root_name)958 bool HloParserImpl::ParseInstruction(HloComputation::Builder* builder,
959                                      std::string* root_name) {
960   std::string name;
961   LocTy maybe_root_loc = lexer_.GetLoc();
962   bool is_root = EatIfPresent(TokKind::kw_ROOT);
963 
964   const LocTy name_loc = lexer_.GetLoc();
965   if (!ParseName(&name) ||
966       !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
967     return false;
968   }
969 
970   if (is_root) {
971     if (!root_name->empty()) {
972       return Error(maybe_root_loc, "one computation should have only one ROOT");
973     }
974     *root_name = name;
975   }
976 
977   return ParseInstructionRhs(builder, name, name_loc);
978 }
979 
ParseInstructionRhs(HloComputation::Builder * builder,const std::string & name,LocTy name_loc)980 bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
981                                         const std::string& name,
982                                         LocTy name_loc) {
983   Shape shape;
984   HloOpcode opcode;
985   std::vector<HloInstruction*> operands;
986 
987   const bool parse_shape = CanBeShape();
988   if ((parse_shape && !ParseShape(&shape)) || !ParseOpcode(&opcode)) {
989     return false;
990   }
991   if (!parse_shape && !CanInferShape(opcode)) {
992     return TokenError(StrFormat("cannot infer shape for opcode: %s",
993                                 HloOpcodeString(opcode)));
994   }
995 
996   // Add optional attributes. These are added to any HloInstruction type if
997   // present.
998   absl::flat_hash_map<std::string, AttrConfig> attrs;
999   optional<OpSharding> sharding;
1000   optional<FrontendAttributes> frontend_attributes;
1001   attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
1002   attrs["frontend_attributes"] = {
1003       /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
1004   optional<ParameterReplication> parameter_replication;
1005   attrs["parameter_replication"] = {/*required=*/false,
1006                                     AttrTy::kParameterReplication,
1007                                     &parameter_replication};
1008   optional<std::vector<HloInstruction*>> predecessors;
1009   attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
1010                                    &predecessors};
1011   optional<OpMetadata> metadata;
1012   attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
1013 
1014   optional<std::string> backend_config;
1015   attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
1016                              &backend_config};
1017   optional<std::vector<int64>> outer_dimension_partitions;
1018   attrs["outer_dimension_partitions"] = {/*required=*/false,
1019                                          AttrTy::kBracedInt64List,
1020                                          &outer_dimension_partitions};
1021   const auto maybe_infer_shape =
1022       [&](const std::function<StatusOr<Shape>()>& infer, Shape* shape) {
1023         if (parse_shape) {
1024           return true;
1025         }
1026         auto inferred = infer();
1027         if (!inferred.ok()) {
1028           return TokenError(StrFormat(
1029               "failed to infer shape for opcode: %s, error: %s",
1030               HloOpcodeString(opcode), inferred.status().error_message()));
1031         }
1032         *shape = std::move(inferred).ValueOrDie();
1033         return true;
1034       };
1035   HloInstruction* instruction;
1036   switch (opcode) {
1037     case HloOpcode::kParameter: {
1038       int64 parameter_number;
1039       if (!ParseToken(TokKind::kLparen,
1040                       "expects '(' before parameter number") ||
1041           !ParseInt64(&parameter_number)) {
1042         return false;
1043       }
1044       if (parameter_number < 0) {
1045         Error(lexer_.GetLoc(), "parameter number must be >= 0");
1046         return false;
1047       }
1048       if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
1049           !ParseAttributes(attrs)) {
1050         return false;
1051       }
1052       instruction = builder->AddInstruction(
1053           HloInstruction::CreateParameter(parameter_number, shape, name));
1054       break;
1055     }
1056     case HloOpcode::kConstant: {
1057       Literal literal;
1058       if (!ParseToken(TokKind::kLparen,
1059                       "expects '(' before constant literal") ||
1060           !ParseLiteral(&literal, shape) ||
1061           !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
1062           !ParseAttributes(attrs)) {
1063         return false;
1064       }
1065       instruction = builder->AddInstruction(
1066           HloInstruction::CreateConstant(std::move(literal)));
1067       break;
1068     }
1069     case HloOpcode::kIota: {
1070       optional<int64> iota_dimension;
1071       attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
1072                                  &iota_dimension};
1073       if (!ParseOperands(&operands, /*expected_size=*/0) ||
1074           !ParseAttributes(attrs)) {
1075         return false;
1076       }
1077       instruction = builder->AddInstruction(
1078           HloInstruction::CreateIota(shape, *iota_dimension));
1079       break;
1080     }
1081     // Unary ops.
1082     case HloOpcode::kAbs:
1083     case HloOpcode::kRoundNearestAfz:
1084     case HloOpcode::kBitcast:
1085     case HloOpcode::kCeil:
1086     case HloOpcode::kClz:
1087     case HloOpcode::kCollectivePermuteDone:
1088     case HloOpcode::kCopy:
1089     case HloOpcode::kCopyDone:
1090     case HloOpcode::kCos:
1091     case HloOpcode::kExp:
1092     case HloOpcode::kExpm1:
1093     case HloOpcode::kImag:
1094     case HloOpcode::kIsFinite:
1095     case HloOpcode::kFloor:
1096     case HloOpcode::kLog:
1097     case HloOpcode::kLog1p:
1098     case HloOpcode::kLogistic:
1099     case HloOpcode::kNot:
1100     case HloOpcode::kNegate:
1101     case HloOpcode::kPopulationCount:
1102     case HloOpcode::kReal:
1103     case HloOpcode::kRsqrt:
1104     case HloOpcode::kSign:
1105     case HloOpcode::kSin:
1106     case HloOpcode::kSqrt:
1107     case HloOpcode::kCbrt:
1108     case HloOpcode::kTanh: {
1109       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1110           !ParseAttributes(attrs)) {
1111         return false;
1112       }
1113       if (!maybe_infer_shape(
1114               [&] {
1115                 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
1116               },
1117               &shape)) {
1118         return false;
1119       }
1120       instruction = builder->AddInstruction(
1121           HloInstruction::CreateUnary(shape, opcode, operands[0]));
1122       break;
1123     }
1124     // Binary ops.
1125     case HloOpcode::kAdd:
1126     case HloOpcode::kDivide:
1127     case HloOpcode::kMultiply:
1128     case HloOpcode::kSubtract:
1129     case HloOpcode::kAtan2:
1130     case HloOpcode::kComplex:
1131     case HloOpcode::kMaximum:
1132     case HloOpcode::kMinimum:
1133     case HloOpcode::kPower:
1134     case HloOpcode::kRemainder:
1135     case HloOpcode::kAnd:
1136     case HloOpcode::kOr:
1137     case HloOpcode::kXor:
1138     case HloOpcode::kShiftLeft:
1139     case HloOpcode::kShiftRightArithmetic:
1140     case HloOpcode::kShiftRightLogical: {
1141       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1142           !ParseAttributes(attrs)) {
1143         return false;
1144       }
1145       if (!maybe_infer_shape(
1146               [&] {
1147                 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1148                                                           operands[1]);
1149               },
1150               &shape)) {
1151         return false;
1152       }
1153       instruction = builder->AddInstruction(HloInstruction::CreateBinary(
1154           shape, opcode, operands[0], operands[1]));
1155       break;
1156     }
1157     // Ternary ops.
1158     case HloOpcode::kClamp:
1159     case HloOpcode::kSelect:
1160     case HloOpcode::kTupleSelect: {
1161       if (!ParseOperands(&operands, /*expected_size=*/3) ||
1162           !ParseAttributes(attrs)) {
1163         return false;
1164       }
1165       if (!maybe_infer_shape(
1166               [&] {
1167                 return ShapeInference::InferTernaryOpShape(
1168                     opcode, operands[0], operands[1], operands[2]);
1169               },
1170               &shape)) {
1171         return false;
1172       }
1173       instruction = builder->AddInstruction(HloInstruction::CreateTernary(
1174           shape, opcode, operands[0], operands[1], operands[2]));
1175       break;
1176     }
1177     // Other supported ops.
1178     case HloOpcode::kConvert: {
1179       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1180           !ParseAttributes(attrs)) {
1181         return false;
1182       }
1183       instruction = builder->AddInstruction(
1184           HloInstruction::CreateConvert(shape, operands[0]));
1185       break;
1186     }
1187     case HloOpcode::kBitcastConvert: {
1188       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1189           !ParseAttributes(attrs)) {
1190         return false;
1191       }
1192       instruction = builder->AddInstruction(
1193           HloInstruction::CreateBitcastConvert(shape, operands[0]));
1194       break;
1195     }
1196     case HloOpcode::kAllGather: {
1197       optional<std::vector<std::vector<int64>>> tmp_groups;
1198       optional<std::vector<int64>> replica_group_ids;
1199       optional<int64> channel_id;
1200       optional<std::vector<int64>> dimensions;
1201       optional<bool> constrain_layout;
1202       optional<bool> use_global_device_ids;
1203       attrs["replica_groups"] = {/*required=*/false,
1204                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1205       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1206       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1207                              &dimensions};
1208       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1209                                    &constrain_layout};
1210       attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1211                                         &use_global_device_ids};
1212       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1213         return false;
1214       }
1215       std::vector<ReplicaGroup> replica_groups;
1216       if (tmp_groups) {
1217         replica_groups = CreateReplicaGroups(*tmp_groups);
1218       }
1219       instruction = builder->AddInstruction(HloInstruction::CreateAllGather(
1220           shape, operands[0], dimensions->at(0), replica_groups,
1221           constrain_layout ? *constrain_layout : false, channel_id,
1222           use_global_device_ids ? *use_global_device_ids : false));
1223       break;
1224     }
1225     case HloOpcode::kAllReduce: {
1226       optional<std::vector<std::vector<int64>>> tmp_groups;
1227       optional<HloComputation*> to_apply;
1228       optional<std::vector<int64>> replica_group_ids;
1229       optional<int64> channel_id;
1230       optional<bool> constrain_layout;
1231       optional<bool> use_global_device_ids;
1232       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1233                            &to_apply};
1234       attrs["replica_groups"] = {/*required=*/false,
1235                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1236       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1237       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1238                                    &constrain_layout};
1239       attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1240                                         &use_global_device_ids};
1241       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1242         return false;
1243       }
1244       std::vector<ReplicaGroup> replica_groups;
1245       if (tmp_groups) {
1246         replica_groups = CreateReplicaGroups(*tmp_groups);
1247       }
1248       instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
1249           shape, operands, *to_apply, replica_groups,
1250           constrain_layout ? *constrain_layout : false, channel_id,
1251           use_global_device_ids ? *use_global_device_ids : false));
1252       break;
1253     }
1254     case HloOpcode::kAllToAll: {
1255       optional<std::vector<std::vector<int64>>> tmp_groups;
1256       attrs["replica_groups"] = {/*required=*/false,
1257                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1258       optional<int64> channel_id;
1259       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1260       optional<std::vector<int64>> dimensions;
1261       attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1262                              &dimensions};
1263       optional<bool> constrain_layout;
1264       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1265                                    &constrain_layout};
1266       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1267           (dimensions && dimensions->size() != 1)) {
1268         return false;
1269       }
1270       std::vector<ReplicaGroup> replica_groups;
1271       if (tmp_groups) {
1272         replica_groups = CreateReplicaGroups(*tmp_groups);
1273       }
1274       optional<int64> split_dimension;
1275       if (dimensions) {
1276         split_dimension = dimensions->at(0);
1277       }
1278       instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
1279           shape, operands, replica_groups,
1280           constrain_layout ? *constrain_layout : false, channel_id,
1281           split_dimension));
1282       break;
1283     }
1284     case HloOpcode::kCollectivePermute:
1285     case HloOpcode::kCollectivePermuteStart: {
1286       optional<std::vector<std::vector<int64>>> source_targets;
1287       attrs["source_target_pairs"] = {
1288           /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
1289       optional<int64> channel_id;
1290       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1291       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1292           !ParseAttributes(attrs)) {
1293         return false;
1294       }
1295       std::vector<std::pair<int64, int64>> pairs(source_targets->size());
1296       for (int i = 0; i < pairs.size(); i++) {
1297         if ((*source_targets)[i].size() != 2) {
1298           return TokenError(
1299               "expects 'source_target_pairs=' to be a list of pairs");
1300         }
1301         pairs[i].first = (*source_targets)[i][0];
1302         pairs[i].second = (*source_targets)[i][1];
1303       }
1304       if (opcode == HloOpcode::kCollectivePermute) {
1305         instruction =
1306             builder->AddInstruction(HloInstruction::CreateCollectivePermute(
1307                 shape, operands[0], pairs, channel_id));
1308       } else if (opcode == HloOpcode::kCollectivePermuteStart) {
1309         instruction = builder->AddInstruction(
1310             HloInstruction::CreateCollectivePermuteStart(shape, operands[0],
1311                                                          pairs, channel_id));
1312       } else {
1313         LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1314                       "CollectivePermuteStart, but got "
1315                    << HloOpcodeString(opcode);
1316       }
1317       break;
1318     }
1319     case HloOpcode::kCopyStart: {
1320       // If the is_cross_program_prefetch attribute is not present then default
1321       // to false.
1322       optional<bool> is_cross_program_prefetch = false;
1323       attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
1324                                             &is_cross_program_prefetch};
1325       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1326           !ParseAttributes(attrs)) {
1327         return false;
1328       }
1329       instruction = builder->AddInstruction(HloInstruction::CreateCopyStart(
1330           shape, operands[0], *is_cross_program_prefetch));
1331       break;
1332     }
1333     case HloOpcode::kReplicaId: {
1334       if (!ParseOperands(&operands, /*expected_size=*/0) ||
1335           !ParseAttributes(attrs)) {
1336         return false;
1337       }
1338       instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
1339       break;
1340     }
1341     case HloOpcode::kPartitionId: {
1342       if (!ParseOperands(&operands, /*expected_size=*/0) ||
1343           !ParseAttributes(attrs)) {
1344         return false;
1345       }
1346       instruction =
1347           builder->AddInstruction(HloInstruction::CreatePartitionId());
1348       break;
1349     }
1350     case HloOpcode::kDynamicReshape: {
1351       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1352         return false;
1353       }
1354       instruction =
1355           builder->AddInstruction(HloInstruction::CreateDynamicReshape(
1356               shape, operands[0],
1357               absl::Span<HloInstruction* const>(operands).subspan(1)));
1358       break;
1359     }
1360     case HloOpcode::kReshape: {
1361       optional<int64> inferred_dimension;
1362       attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
1363                                      &inferred_dimension};
1364       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1365           !ParseAttributes(attrs)) {
1366         return false;
1367       }
1368       instruction = builder->AddInstruction(HloInstruction::CreateReshape(
1369           shape, operands[0], inferred_dimension.value_or(-1)));
1370       break;
1371     }
1372     case HloOpcode::kAfterAll: {
1373       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1374         return false;
1375       }
1376       if (operands.empty()) {
1377         instruction = builder->AddInstruction(HloInstruction::CreateToken());
1378       } else {
1379         instruction =
1380             builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
1381       }
1382       break;
1383     }
1384     case HloOpcode::kAddDependency: {
1385       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1386           !ParseAttributes(attrs)) {
1387         return false;
1388       }
1389       instruction = builder->AddInstruction(
1390           HloInstruction::CreateAddDependency(operands[0], operands[1]));
1391       break;
1392     }
1393     case HloOpcode::kSort: {
1394       optional<std::vector<int64>> dimensions;
1395       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1396                              &dimensions};
1397       optional<bool> is_stable = false;
1398       attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
1399       optional<HloComputation*> to_apply;
1400       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1401                            &to_apply};
1402       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1403           dimensions->size() != 1) {
1404         return false;
1405       }
1406       if (!maybe_infer_shape(
1407               [&] {
1408                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1409                 arg_shapes.reserve(operands.size());
1410                 for (auto* operand : operands) {
1411                   arg_shapes.push_back(&operand->shape());
1412                 }
1413                 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1414               },
1415               &shape)) {
1416         return false;
1417       }
1418       instruction = builder->AddInstruction(
1419           HloInstruction::CreateSort(shape, dimensions->at(0), operands,
1420                                      to_apply.value(), is_stable.value()));
1421       break;
1422     }
1423     case HloOpcode::kTuple: {
1424       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1425         return false;
1426       }
1427       instruction =
1428           builder->AddInstruction(HloInstruction::CreateTuple(operands));
1429       break;
1430     }
1431     case HloOpcode::kWhile: {
1432       optional<HloComputation*> condition;
1433       optional<HloComputation*> body;
1434       attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
1435                             &condition};
1436       attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
1437       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1438           !ParseAttributes(attrs)) {
1439         return false;
1440       }
1441       if (!maybe_infer_shape(
1442               [&] {
1443                 return ShapeInference::InferWhileShape(
1444                     condition.value()->ComputeProgramShape(),
1445                     body.value()->ComputeProgramShape(), operands[0]->shape());
1446               },
1447               &shape)) {
1448         return false;
1449       }
1450       instruction = builder->AddInstruction(HloInstruction::CreateWhile(
1451           shape, *condition, *body, /*init=*/operands[0]));
1452       break;
1453     }
1454     case HloOpcode::kRecv: {
1455       optional<int64> channel_id;
1456       // If the is_host_transfer attribute is not present then default to false.
1457       optional<bool> is_host_transfer = false;
1458       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1459       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1460                                    &is_host_transfer};
1461       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1462           !ParseAttributes(attrs)) {
1463         return false;
1464       }
1465       // If the is_host_transfer attribute is not present then default to false.
1466       instruction = builder->AddInstruction(HloInstruction::CreateRecv(
1467           shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
1468       break;
1469     }
1470     case HloOpcode::kRecvDone: {
1471       optional<int64> channel_id;
1472       // If the is_host_transfer attribute is not present then default to false.
1473       optional<bool> is_host_transfer = false;
1474       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1475       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1476                                    &is_host_transfer};
1477       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1478           !ParseAttributes(attrs)) {
1479         return false;
1480       }
1481       if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1482         return false;
1483       }
1484       if (channel_id != operands[0]->channel_id()) {
1485         return false;
1486       }
1487       instruction = builder->AddInstruction(
1488           HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
1489       break;
1490     }
1491     case HloOpcode::kSend: {
1492       optional<int64> channel_id;
1493       // If the is_host_transfer attribute is not present then default to false.
1494       optional<bool> is_host_transfer = false;
1495       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1496       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1497                                    &is_host_transfer};
1498       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1499           !ParseAttributes(attrs)) {
1500         return false;
1501       }
1502       instruction = builder->AddInstruction(HloInstruction::CreateSend(
1503           operands[0], operands[1], *channel_id, *is_host_transfer));
1504       break;
1505     }
1506     case HloOpcode::kSendDone: {
1507       optional<int64> channel_id;
1508       // If the is_host_transfer attribute is not present then default to false.
1509       optional<bool> is_host_transfer = false;
1510       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1511       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1512                                    &is_host_transfer};
1513       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1514           !ParseAttributes(attrs)) {
1515         return false;
1516       }
1517       if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1518         return false;
1519       }
1520       if (channel_id != operands[0]->channel_id()) {
1521         return false;
1522       }
1523       instruction = builder->AddInstruction(
1524           HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
1525       break;
1526     }
1527     case HloOpcode::kGetTupleElement: {
1528       optional<int64> index;
1529       attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
1530       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1531           !ParseAttributes(attrs)) {
1532         return false;
1533       }
1534       if (!maybe_infer_shape(
1535               [&] {
1536                 return ShapeUtil::GetTupleElementShape(operands[0]->shape(),
1537                                                        *index);
1538               },
1539               &shape)) {
1540         return false;
1541       }
1542       instruction = builder->AddInstruction(
1543           HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
1544       break;
1545     }
1546     case HloOpcode::kCall: {
1547       optional<HloComputation*> to_apply;
1548       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1549                            &to_apply};
1550       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1551         return false;
1552       }
1553       if (!maybe_infer_shape(
1554               [&] {
1555                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1556                 arg_shapes.reserve(operands.size());
1557                 for (auto* operand : operands) {
1558                   arg_shapes.push_back(&operand->shape());
1559                 }
1560                 return ShapeInference::InferCallShape(
1561                     arg_shapes, to_apply.value()->ComputeProgramShape());
1562               },
1563               &shape)) {
1564         return false;
1565       }
1566       instruction = builder->AddInstruction(
1567           HloInstruction::CreateCall(shape, operands, *to_apply));
1568       break;
1569     }
1570     case HloOpcode::kReduceWindow: {
1571       optional<HloComputation*> reduce_computation;
1572       optional<Window> window;
1573       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1574       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1575                            &reduce_computation};
1576       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1577         return false;
1578       }
1579       if (!window) {
1580         window.emplace();
1581       }
1582       if (operands.size() % 2) {
1583         auto loc = lexer_.GetLoc();
1584         return Error(loc, StrCat("expects an even number of operands, but has ",
1585                                  operands.size(), " operands"));
1586       }
1587       if (!maybe_infer_shape(
1588               [&] {
1589                 return ShapeInference::InferReduceWindowShape(
1590                     operands[0]->shape(), operands[1]->shape(), *window,
1591                     reduce_computation.value()->ComputeProgramShape());
1592               },
1593               &shape)) {
1594         return false;
1595       }
1596       instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
1597           shape, /*operands=*/
1598           absl::Span<HloInstruction* const>(operands).subspan(
1599               0, operands.size() / 2),
1600           /*init_values=*/
1601           absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1602                                                               2),
1603           *window, *reduce_computation));
1604       break;
1605     }
1606     case HloOpcode::kConvolution: {
1607       optional<Window> window;
1608       optional<ConvolutionDimensionNumbers> dnums;
1609       optional<int64> feature_group_count;
1610       optional<int64> batch_group_count;
1611       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1612       attrs["dim_labels"] = {/*required=*/true,
1613                              AttrTy::kConvolutionDimensionNumbers, &dnums};
1614       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1615                                       &feature_group_count};
1616       attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1617                                     &batch_group_count};
1618       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1619       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1620                                     &operand_precision};
1621       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1622           !ParseAttributes(attrs)) {
1623         return false;
1624       }
1625       if (!window) {
1626         window.emplace();
1627       }
1628       if (!feature_group_count) {
1629         feature_group_count = 1;
1630       }
1631       if (!batch_group_count) {
1632         batch_group_count = 1;
1633       }
1634       PrecisionConfig precision_config;
1635       if (operand_precision) {
1636         *precision_config.mutable_operand_precision() = {
1637             operand_precision->begin(), operand_precision->end()};
1638       } else {
1639         precision_config.mutable_operand_precision()->Resize(
1640             operands.size(), PrecisionConfig::DEFAULT);
1641       }
1642       if (!maybe_infer_shape(
1643               [&] {
1644                 return ShapeInference::InferConvolveShape(
1645                     operands[0]->shape(), operands[1]->shape(),
1646                     *feature_group_count, *batch_group_count, *window, *dnums,
1647                     /*preferred_element_type=*/absl::nullopt);
1648               },
1649               &shape)) {
1650         return false;
1651       }
1652       instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
1653           shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
1654           feature_group_count.value(), batch_group_count.value(), *window,
1655           *dnums, precision_config));
1656       break;
1657     }
1658     case HloOpcode::kFft: {
1659       optional<FftType> fft_type;
1660       optional<std::vector<int64>> fft_length;
1661       attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
1662       attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
1663                              &fft_length};
1664       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1665           !ParseAttributes(attrs)) {
1666         return false;
1667       }
1668       if (!maybe_infer_shape(
1669               [&] {
1670                 return ShapeInference::InferFftShape(operands[0]->shape(),
1671                                                      *fft_type, *fft_length);
1672               },
1673               &shape)) {
1674         return false;
1675       }
1676       instruction = builder->AddInstruction(HloInstruction::CreateFft(
1677           shape, operands[0], *fft_type, *fft_length));
1678       break;
1679     }
1680     case HloOpcode::kTriangularSolve: {
1681       TriangularSolveOptions options;
1682       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1683           !ParseAttributesAsProtoMessage(
1684               /*non_proto_attrs=*/attrs, &options)) {
1685         return false;
1686       }
1687       if (!maybe_infer_shape(
1688               [&] {
1689                 return ShapeInference::InferTriangularSolveShape(
1690                     operands[0]->shape(), operands[1]->shape(), options);
1691               },
1692               &shape)) {
1693         return false;
1694       }
1695       instruction =
1696           builder->AddInstruction(HloInstruction::CreateTriangularSolve(
1697               shape, operands[0], operands[1], options));
1698       break;
1699     }
1700     case HloOpcode::kCompare: {
1701       optional<ComparisonDirection> direction;
1702       optional<Comparison::Type> type;
1703       attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
1704                             &direction};
1705       attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
1706       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1707           !ParseAttributes(attrs)) {
1708         return false;
1709       }
1710       if (!maybe_infer_shape(
1711               [&] {
1712                 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1713                                                           operands[1]);
1714               },
1715               &shape)) {
1716         return false;
1717       }
1718       instruction = builder->AddInstruction(HloInstruction::CreateCompare(
1719           shape, operands[0], operands[1], *direction, type));
1720       break;
1721     }
1722     case HloOpcode::kCholesky: {
1723       CholeskyOptions options;
1724       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1725           !ParseAttributesAsProtoMessage(
1726               /*non_proto_attrs=*/attrs, &options)) {
1727         return false;
1728       }
1729       if (!maybe_infer_shape(
1730               [&] {
1731                 return ShapeInference::InferCholeskyShape(operands[0]->shape());
1732               },
1733               &shape)) {
1734         return false;
1735       }
1736       instruction = builder->AddInstruction(
1737           HloInstruction::CreateCholesky(shape, operands[0], options));
1738       break;
1739     }
1740     case HloOpcode::kBroadcast: {
1741       optional<std::vector<int64>> broadcast_dimensions;
1742       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1743                              &broadcast_dimensions};
1744       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1745           !ParseAttributes(attrs)) {
1746         return false;
1747       }
1748       if (!maybe_infer_shape(
1749               [&] {
1750                 return ShapeInference::InferBroadcastShape(
1751                     operands[0]->shape(), *broadcast_dimensions);
1752               },
1753               &shape)) {
1754         return false;
1755       }
1756       instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
1757           shape, operands[0], *broadcast_dimensions));
1758       break;
1759     }
1760     case HloOpcode::kConcatenate: {
1761       optional<std::vector<int64>> dimensions;
1762       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1763                              &dimensions};
1764       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1765           dimensions->size() != 1) {
1766         return false;
1767       }
1768       if (!maybe_infer_shape(
1769               [&] {
1770                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1771                 arg_shapes.reserve(operands.size());
1772                 for (auto* operand : operands) {
1773                   arg_shapes.push_back(&operand->shape());
1774                 }
1775                 return ShapeInference::InferConcatOpShape(arg_shapes,
1776                                                           dimensions->at(0));
1777               },
1778               &shape)) {
1779         return false;
1780       }
1781       instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
1782           shape, operands, dimensions->at(0)));
1783       break;
1784     }
1785     case HloOpcode::kMap: {
1786       optional<HloComputation*> to_apply;
1787       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1788                            &to_apply};
1789       optional<std::vector<int64>> dimensions;
1790       attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1791                              &dimensions};
1792       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1793         return false;
1794       }
1795       if (!maybe_infer_shape(
1796               [&] {
1797                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1798                 arg_shapes.reserve(operands.size());
1799                 for (auto* operand : operands) {
1800                   arg_shapes.push_back(&operand->shape());
1801                 }
1802                 return ShapeInference::InferMapShape(
1803                     arg_shapes, to_apply.value()->ComputeProgramShape(),
1804                     *dimensions);
1805               },
1806               &shape)) {
1807         return false;
1808       }
1809       instruction = builder->AddInstruction(
1810           HloInstruction::CreateMap(shape, operands, *to_apply));
1811       break;
1812     }
1813     case HloOpcode::kReduce: {
1814       auto loc = lexer_.GetLoc();
1815 
1816       optional<HloComputation*> reduce_computation;
1817       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1818                            &reduce_computation};
1819       optional<std::vector<int64>> dimensions_to_reduce;
1820       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1821                              &dimensions_to_reduce};
1822       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1823         return false;
1824       }
1825       if (operands.size() % 2) {
1826         return Error(loc, StrCat("expects an even number of operands, but has ",
1827                                  operands.size(), " operands"));
1828       }
1829       if (!maybe_infer_shape(
1830               [&] {
1831                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1832                 arg_shapes.reserve(operands.size());
1833                 for (auto* operand : operands) {
1834                   arg_shapes.push_back(&operand->shape());
1835                 }
1836                 return ShapeInference::InferReduceShape(
1837                     arg_shapes, *dimensions_to_reduce,
1838                     reduce_computation.value()->ComputeProgramShape());
1839               },
1840               &shape)) {
1841         return false;
1842       }
1843       instruction = builder->AddInstruction(HloInstruction::CreateReduce(
1844           shape, /*operands=*/
1845           absl::Span<HloInstruction* const>(operands).subspan(
1846               0, operands.size() / 2),
1847           /*init_values=*/
1848           absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1849                                                               2),
1850           *dimensions_to_reduce, *reduce_computation));
1851       break;
1852     }
1853     case HloOpcode::kReverse: {
1854       optional<std::vector<int64>> dimensions;
1855       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1856                              &dimensions};
1857       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1858           !ParseAttributes(attrs)) {
1859         return false;
1860       }
1861       if (!maybe_infer_shape(
1862               [&] {
1863                 return ShapeInference::InferReverseShape(operands[0]->shape(),
1864                                                          *dimensions);
1865               },
1866               &shape)) {
1867         return false;
1868       }
1869       instruction = builder->AddInstruction(
1870           HloInstruction::CreateReverse(shape, operands[0], *dimensions));
1871       break;
1872     }
1873     case HloOpcode::kSelectAndScatter: {
1874       optional<HloComputation*> select;
1875       attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
1876       optional<HloComputation*> scatter;
1877       attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
1878       optional<Window> window;
1879       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1880       if (!ParseOperands(&operands, /*expected_size=*/3) ||
1881           !ParseAttributes(attrs)) {
1882         return false;
1883       }
1884       if (!window) {
1885         window.emplace();
1886       }
1887       if (!maybe_infer_shape(
1888               [&] {
1889                 return ShapeInference::InferSelectAndScatterShape(
1890                     operands[0]->shape(), select.value()->ComputeProgramShape(),
1891                     *window, operands[1]->shape(), operands[2]->shape(),
1892                     scatter.value()->ComputeProgramShape());
1893               },
1894               &shape)) {
1895         return false;
1896       }
1897       instruction =
1898           builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
1899               shape, /*operand=*/operands[0], *select, *window,
1900               /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
1901       break;
1902     }
1903     case HloOpcode::kSlice: {
1904       optional<SliceRanges> slice_ranges;
1905       attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
1906       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1907           !ParseAttributes(attrs)) {
1908         return false;
1909       }
1910       instruction = builder->AddInstruction(HloInstruction::CreateSlice(
1911           shape, operands[0], slice_ranges->starts, slice_ranges->limits,
1912           slice_ranges->strides));
1913       break;
1914     }
1915     case HloOpcode::kDynamicSlice: {
1916       optional<std::vector<int64>> dynamic_slice_sizes;
1917       attrs["dynamic_slice_sizes"] = {
1918           /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
1919       LocTy loc = lexer_.GetLoc();
1920       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1921         return false;
1922       }
1923       if (operands.empty()) {
1924         return Error(loc, "Expected at least one operand.");
1925       }
1926       if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
1927           operands.size() != 1 + operands[0]->shape().rank()) {
1928         return Error(loc, "Wrong number of operands.");
1929       }
1930       instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
1931           shape, /*operand=*/operands[0],
1932           /*start_indices=*/absl::MakeSpan(operands).subspan(1),
1933           *dynamic_slice_sizes));
1934       break;
1935     }
1936     case HloOpcode::kDynamicUpdateSlice: {
1937       LocTy loc = lexer_.GetLoc();
1938       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1939         return false;
1940       }
1941       if (operands.size() < 2) {
1942         return Error(loc, "Expected at least two operands.");
1943       }
1944       if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
1945           operands.size() != 2 + operands[0]->shape().rank()) {
1946         return Error(loc, "Wrong number of operands.");
1947       }
1948       instruction =
1949           builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1950               shape, /*operand=*/operands[0], /*update=*/operands[1],
1951               /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
1952       break;
1953     }
1954     case HloOpcode::kTranspose: {
1955       optional<std::vector<int64>> dimensions;
1956       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1957                              &dimensions};
1958       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1959           !ParseAttributes(attrs)) {
1960         return false;
1961       }
1962       if (!maybe_infer_shape(
1963               [&] {
1964                 return ShapeInference::InferTransposeShape(operands[0]->shape(),
1965                                                            *dimensions);
1966               },
1967               &shape)) {
1968         return false;
1969       }
1970       instruction = builder->AddInstruction(
1971           HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
1972       break;
1973     }
1974     case HloOpcode::kBatchNormTraining: {
1975       optional<float> epsilon;
1976       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
1977       optional<int64> feature_index;
1978       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
1979                                 &feature_index};
1980       if (!ParseOperands(&operands, /*expected_size=*/3) ||
1981           !ParseAttributes(attrs)) {
1982         return false;
1983       }
1984       if (!maybe_infer_shape(
1985               [&] {
1986                 return ShapeInference::InferBatchNormTrainingShape(
1987                     operands[0]->shape(), operands[1]->shape(),
1988                     operands[2]->shape(), *feature_index);
1989               },
1990               &shape)) {
1991         return false;
1992       }
1993       instruction =
1994           builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
1995               shape, /*operand=*/operands[0], /*scale=*/operands[1],
1996               /*offset=*/operands[2], *epsilon, *feature_index));
1997       break;
1998     }
1999     case HloOpcode::kBatchNormInference: {
2000       optional<float> epsilon;
2001       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2002       optional<int64> feature_index;
2003       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2004                                 &feature_index};
2005       if (!ParseOperands(&operands, /*expected_size=*/5) ||
2006           !ParseAttributes(attrs)) {
2007         return false;
2008       }
2009       if (!maybe_infer_shape(
2010               [&] {
2011                 return ShapeInference::InferBatchNormInferenceShape(
2012                     operands[0]->shape(), operands[1]->shape(),
2013                     operands[2]->shape(), operands[3]->shape(),
2014                     operands[4]->shape(), *feature_index);
2015               },
2016               &shape)) {
2017         return false;
2018       }
2019       instruction =
2020           builder->AddInstruction(HloInstruction::CreateBatchNormInference(
2021               shape, /*operand=*/operands[0], /*scale=*/operands[1],
2022               /*offset=*/operands[2], /*mean=*/operands[3],
2023               /*variance=*/operands[4], *epsilon, *feature_index));
2024       break;
2025     }
2026     case HloOpcode::kBatchNormGrad: {
2027       optional<float> epsilon;
2028       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2029       optional<int64> feature_index;
2030       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2031                                 &feature_index};
2032       if (!ParseOperands(&operands, /*expected_size=*/5) ||
2033           !ParseAttributes(attrs)) {
2034         return false;
2035       }
2036       if (!maybe_infer_shape(
2037               [&] {
2038                 return ShapeInference::InferBatchNormGradShape(
2039                     operands[0]->shape(), operands[1]->shape(),
2040                     operands[2]->shape(), operands[3]->shape(),
2041                     operands[4]->shape(), *feature_index);
2042               },
2043               &shape)) {
2044         return false;
2045       }
2046       instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
2047           shape, /*operand=*/operands[0], /*scale=*/operands[1],
2048           /*mean=*/operands[2], /*variance=*/operands[3],
2049           /*grad_output=*/operands[4], *epsilon, *feature_index));
2050       break;
2051     }
2052     case HloOpcode::kPad: {
2053       optional<PaddingConfig> padding;
2054       attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
2055       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2056           !ParseAttributes(attrs)) {
2057         return false;
2058       }
2059       if (!maybe_infer_shape(
2060               [&] {
2061                 return ShapeInference::InferPadShape(
2062                     operands[0]->shape(), operands[1]->shape(), *padding);
2063               },
2064               &shape)) {
2065         return false;
2066       }
2067       instruction = builder->AddInstruction(HloInstruction::CreatePad(
2068           shape, operands[0], /*padding_value=*/operands[1], *padding));
2069       break;
2070     }
2071     case HloOpcode::kFusion: {
2072       optional<HloComputation*> fusion_computation;
2073       attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
2074                         &fusion_computation};
2075       optional<HloInstruction::FusionKind> fusion_kind;
2076       attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
2077       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2078         return false;
2079       }
2080       instruction = builder->AddInstruction(HloInstruction::CreateFusion(
2081           shape, *fusion_kind, operands, *fusion_computation));
2082       break;
2083     }
2084     case HloOpcode::kInfeed: {
2085       optional<std::string> config;
2086       attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2087       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2088           !ParseAttributes(attrs)) {
2089         return false;
2090       }
2091       // We need to know the infeed data shape to construct the infeed
2092       // instruction. This is the zero-th element of the tuple-shaped output of
2093       // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
2094       // if the shape is not a non-empty tuple, so add guard so an error message
2095       // can be emitted instead of a check fail
2096       if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
2097         return Error(lexer_.GetLoc(),
2098                      "infeed must have a non-empty tuple shape");
2099       }
2100       instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
2101           ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
2102           config ? *config : ""));
2103       break;
2104     }
2105     case HloOpcode::kOutfeed: {
2106       optional<std::string> config;
2107       optional<Shape> outfeed_shape;
2108       attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2109       attrs["outfeed_shape"] = {/*required=*/false, AttrTy::kShape,
2110                                 &outfeed_shape};
2111       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2112           !ParseAttributes(attrs)) {
2113         return false;
2114       }
2115       HloInstruction* const outfeed_input = operands[0];
2116       HloInstruction* const outfeed_token = operands[1];
2117       const Shape shape =
2118           outfeed_shape.has_value() ? *outfeed_shape : outfeed_input->shape();
2119       instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
2120           shape, outfeed_input, outfeed_token, config ? *config : ""));
2121       break;
2122     }
2123     case HloOpcode::kRng: {
2124       optional<RandomDistribution> distribution;
2125       attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
2126                                &distribution};
2127       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2128         return false;
2129       }
2130       instruction = builder->AddInstruction(
2131           HloInstruction::CreateRng(shape, *distribution, operands));
2132       break;
2133     }
2134     case HloOpcode::kRngGetAndUpdateState: {
2135       optional<int64> delta;
2136       attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
2137       if (!ParseOperands(&operands, /*expected_size=*/0) ||
2138           !ParseAttributes(attrs)) {
2139         return false;
2140       }
2141       instruction = builder->AddInstruction(
2142           HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
2143       break;
2144     }
2145     case HloOpcode::kRngBitGenerator: {
2146       optional<RandomAlgorithm> algorithm;
2147       attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
2148                             &algorithm};
2149       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2150         return false;
2151       }
2152       instruction =
2153           builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
2154               shape, operands[0], *algorithm));
2155       break;
2156     }
2157     case HloOpcode::kReducePrecision: {
2158       optional<int64> exponent_bits;
2159       optional<int64> mantissa_bits;
2160       attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
2161                                 &exponent_bits};
2162       attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
2163                                 &mantissa_bits};
2164       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2165           !ParseAttributes(attrs)) {
2166         return false;
2167       }
2168       instruction =
2169           builder->AddInstruction(HloInstruction::CreateReducePrecision(
2170               shape, operands[0], static_cast<int>(*exponent_bits),
2171               static_cast<int>(*mantissa_bits)));
2172       break;
2173     }
2174     case HloOpcode::kConditional: {
2175       optional<HloComputation*> true_computation;
2176       optional<HloComputation*> false_computation;
2177       optional<std::vector<HloComputation*>> branch_computations;
2178       if (!ParseOperands(&operands)) {
2179         return false;
2180       }
2181       if (!ShapeUtil::IsScalar(operands[0]->shape())) {
2182         return Error(lexer_.GetLoc(), "The first operand must be a scalar");
2183       }
2184       const bool branch_index_is_bool =
2185           operands[0]->shape().element_type() == PRED;
2186       if (branch_index_is_bool) {
2187         attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
2188                                      &true_computation};
2189         attrs["false_computation"] = {
2190             /*required=*/true, AttrTy::kHloComputation, &false_computation};
2191       } else {
2192         if (operands[0]->shape().element_type() != S32) {
2193           return Error(lexer_.GetLoc(),
2194                        "The first operand must be a scalar of PRED or S32");
2195         }
2196         attrs["branch_computations"] = {/*required=*/true,
2197                                         AttrTy::kBracedHloComputationList,
2198                                         &branch_computations};
2199       }
2200       if (!ParseAttributes(attrs)) {
2201         return false;
2202       }
2203       if (branch_index_is_bool) {
2204         branch_computations.emplace({*true_computation, *false_computation});
2205       }
2206       if (branch_computations->empty() ||
2207           operands.size() != branch_computations->size() + 1) {
2208         return false;
2209       }
2210       if (!maybe_infer_shape(
2211               [&] {
2212                 absl::InlinedVector<ProgramShape, 2> branch_computation_shapes;
2213                 branch_computation_shapes.reserve(branch_computations->size());
2214                 for (auto* computation : *branch_computations) {
2215                   branch_computation_shapes.push_back(
2216                       computation->ComputeProgramShape());
2217                 }
2218                 absl::InlinedVector<Shape, 2> branch_operand_shapes;
2219                 branch_operand_shapes.reserve(operands.size() - 1);
2220                 for (int i = 1; i < operands.size(); ++i) {
2221                   branch_operand_shapes.push_back(operands[i]->shape());
2222                 }
2223                 return ShapeInference::InferConditionalShape(
2224                     operands[0]->shape(), branch_computation_shapes,
2225                     branch_operand_shapes);
2226               },
2227               &shape)) {
2228         return false;
2229       }
2230       instruction = builder->AddInstruction(HloInstruction::CreateConditional(
2231           shape, /*branch_index=*/operands[0],
2232           absl::MakeSpan(*branch_computations),
2233           absl::MakeSpan(operands).subspan(1)));
2234       break;
2235     }
2236     case HloOpcode::kCustomCall: {
2237       optional<std::string> custom_call_target;
2238       optional<Window> window;
2239       optional<ConvolutionDimensionNumbers> dnums;
2240       optional<int64> feature_group_count;
2241       optional<int64> batch_group_count;
2242       optional<std::vector<Shape>> operand_layout_constraints;
2243       optional<bool> custom_call_has_side_effect;
2244       optional<HloComputation*> to_apply;
2245       optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
2246           output_to_operand_aliasing;
2247       optional<PaddingType> padding_type;
2248       optional<std::vector<HloComputation*>> called_computations;
2249       attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
2250                                      &custom_call_target};
2251       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2252       attrs["dim_labels"] = {/*required=*/false,
2253                              AttrTy::kConvolutionDimensionNumbers, &dnums};
2254       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
2255                                       &feature_group_count};
2256       attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
2257                                     &batch_group_count};
2258       attrs["operand_layout_constraints"] = {
2259           /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
2260       attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
2261                                               &custom_call_has_side_effect};
2262       attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
2263                            &to_apply};
2264       attrs["called_computations"] = {/*required=*/false,
2265                                       AttrTy::kBracedHloComputationList,
2266                                       &called_computations};
2267       attrs["output_to_operand_aliasing"] = {/*required=*/false,
2268                                              AttrTy::kInstructionAliasing,
2269                                              &output_to_operand_aliasing};
2270 
2271       attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
2272                                &padding_type};
2273 
2274       optional<Literal> literal;
2275       attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
2276       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2277       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2278                                     &operand_precision};
2279       if (called_computations.has_value() && to_apply.has_value()) {
2280         return Error(lexer_.GetLoc(),
2281                      "A single instruction can't have both to_apply and "
2282                      "calls field");
2283       }
2284       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2285         return false;
2286       }
2287       if (operand_layout_constraints.has_value()) {
2288         if (!LayoutUtil::HasLayout(shape)) {
2289           return Error(lexer_.GetLoc(),
2290                        "Layout must be set on layout-constrained custom call");
2291         }
2292         if (operands.size() != operand_layout_constraints->size()) {
2293           return Error(lexer_.GetLoc(),
2294                        StrCat("Expected ", operands.size(),
2295                               " operand layout constraints, ",
2296                               operand_layout_constraints->size(), " given"));
2297         }
2298         for (int64 i = 0; i < operands.size(); ++i) {
2299           const Shape& operand_shape_with_layout =
2300               (*operand_layout_constraints)[i];
2301           if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
2302             return Error(lexer_.GetLoc(),
2303                          StrCat("Operand layout constraint shape ",
2304                                 ShapeUtil::HumanStringWithLayout(
2305                                     operand_shape_with_layout),
2306                                 " for operand ", i, " does not have a layout"));
2307           }
2308           if (!ShapeUtil::Compatible(operand_shape_with_layout,
2309                                      operands[i]->shape())) {
2310             return Error(
2311                 lexer_.GetLoc(),
2312                 StrCat(
2313                     "Operand layout constraint shape ",
2314                     ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
2315                     " for operand ", i,
2316                     " is not compatible with operand shape ",
2317                     ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
2318           }
2319         }
2320         instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
2321             shape, operands, *custom_call_target, *operand_layout_constraints,
2322             backend_config ? *backend_config : ""));
2323       } else {
2324         if (to_apply.has_value()) {
2325           instruction =
2326               builder->AddInstruction(HloInstruction::CreateCustomCall(
2327                   shape, operands, *to_apply, *custom_call_target,
2328                   backend_config ? *backend_config : ""));
2329         } else if (called_computations.has_value()) {
2330           instruction =
2331               builder->AddInstruction(HloInstruction::CreateCustomCall(
2332                   shape, operands, *called_computations, *custom_call_target,
2333                   backend_config ? *backend_config : ""));
2334         } else {
2335           instruction =
2336               builder->AddInstruction(HloInstruction::CreateCustomCall(
2337                   shape, operands, *custom_call_target,
2338                   backend_config ? *backend_config : ""));
2339         }
2340       }
2341       auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
2342       if (window.has_value()) {
2343         custom_call_instr->set_window(*window);
2344       }
2345       if (dnums.has_value()) {
2346         custom_call_instr->set_convolution_dimension_numbers(*dnums);
2347       }
2348       if (feature_group_count.has_value()) {
2349         custom_call_instr->set_feature_group_count(*feature_group_count);
2350       }
2351       if (batch_group_count.has_value()) {
2352         custom_call_instr->set_batch_group_count(*batch_group_count);
2353       }
2354       if (padding_type.has_value()) {
2355         custom_call_instr->set_padding_type(*padding_type);
2356       }
2357       if (custom_call_has_side_effect.has_value()) {
2358         custom_call_instr->set_custom_call_has_side_effect(
2359             *custom_call_has_side_effect);
2360       }
2361       if (output_to_operand_aliasing.has_value()) {
2362         custom_call_instr->set_output_to_operand_aliasing(
2363             std::move(*output_to_operand_aliasing));
2364       }
2365       if (literal.has_value()) {
2366         custom_call_instr->set_literal(std::move(*literal));
2367       }
2368       PrecisionConfig precision_config;
2369       if (operand_precision) {
2370         *precision_config.mutable_operand_precision() = {
2371             operand_precision->begin(), operand_precision->end()};
2372       } else {
2373         precision_config.mutable_operand_precision()->Resize(
2374             operands.size(), PrecisionConfig::DEFAULT);
2375       }
2376       *custom_call_instr->mutable_precision_config() = precision_config;
2377       break;
2378     }
2379     case HloOpcode::kDot: {
2380       optional<std::vector<int64>> lhs_contracting_dims;
2381       attrs["lhs_contracting_dims"] = {
2382           /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
2383       optional<std::vector<int64>> rhs_contracting_dims;
2384       attrs["rhs_contracting_dims"] = {
2385           /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
2386       optional<std::vector<int64>> lhs_batch_dims;
2387       attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2388                                  &lhs_batch_dims};
2389       optional<std::vector<int64>> rhs_batch_dims;
2390       attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2391                                  &rhs_batch_dims};
2392       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2393       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2394                                     &operand_precision};
2395 
2396       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2397           !ParseAttributes(attrs)) {
2398         return false;
2399       }
2400 
2401       DotDimensionNumbers dnum;
2402       if (lhs_contracting_dims) {
2403         *dnum.mutable_lhs_contracting_dimensions() = {
2404             lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
2405       }
2406       if (rhs_contracting_dims) {
2407         *dnum.mutable_rhs_contracting_dimensions() = {
2408             rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
2409       }
2410       if (lhs_batch_dims) {
2411         *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
2412                                                 lhs_batch_dims->end()};
2413       }
2414       if (rhs_batch_dims) {
2415         *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
2416                                                 rhs_batch_dims->end()};
2417       }
2418 
2419       PrecisionConfig precision_config;
2420       if (operand_precision) {
2421         *precision_config.mutable_operand_precision() = {
2422             operand_precision->begin(), operand_precision->end()};
2423       } else {
2424         precision_config.mutable_operand_precision()->Resize(
2425             operands.size(), PrecisionConfig::DEFAULT);
2426       }
2427       if (!maybe_infer_shape(
2428               [&] {
2429                 return ShapeInference::InferDotOpShape(
2430                     operands[0]->shape(), operands[1]->shape(), dnum,
2431                     /*preferred_element_type=*/absl::nullopt);
2432               },
2433               &shape)) {
2434         return false;
2435       }
2436       instruction = builder->AddInstruction(HloInstruction::CreateDot(
2437           shape, operands[0], operands[1], dnum, precision_config));
2438       break;
2439     }
2440     case HloOpcode::kGather: {
2441       optional<std::vector<int64>> offset_dims;
2442       attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
2443                               &offset_dims};
2444       optional<std::vector<int64>> collapsed_slice_dims;
2445       attrs["collapsed_slice_dims"] = {
2446           /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
2447       optional<std::vector<int64>> start_index_map;
2448       attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
2449                                   &start_index_map};
2450       optional<int64> index_vector_dim;
2451       attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2452                                    &index_vector_dim};
2453       optional<std::vector<int64>> slice_sizes;
2454       attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
2455                               &slice_sizes};
2456       optional<bool> indices_are_sorted = false;
2457       attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2458                                      &indices_are_sorted};
2459 
2460       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2461           !ParseAttributes(attrs)) {
2462         return false;
2463       }
2464 
2465       GatherDimensionNumbers dim_numbers =
2466           HloGatherInstruction::MakeGatherDimNumbers(
2467               /*offset_dims=*/*offset_dims,
2468               /*collapsed_slice_dims=*/*collapsed_slice_dims,
2469               /*start_index_map=*/*start_index_map,
2470               /*index_vector_dim=*/*index_vector_dim);
2471       if (!maybe_infer_shape(
2472               [&] {
2473                 return ShapeInference::InferGatherShape(
2474                     operands[0]->shape(), operands[1]->shape(), dim_numbers,
2475                     *slice_sizes);
2476               },
2477               &shape)) {
2478         return false;
2479       }
2480       instruction = builder->AddInstruction(HloInstruction::CreateGather(
2481           shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
2482           dim_numbers, *slice_sizes, indices_are_sorted.value()));
2483       break;
2484     }
2485     case HloOpcode::kScatter: {
2486       optional<std::vector<int64>> update_window_dims;
2487       attrs["update_window_dims"] = {
2488           /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
2489       optional<std::vector<int64>> inserted_window_dims;
2490       attrs["inserted_window_dims"] = {
2491           /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
2492       optional<std::vector<int64>> scatter_dims_to_operand_dims;
2493       attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
2494                                                AttrTy::kBracedInt64List,
2495                                                &scatter_dims_to_operand_dims};
2496       optional<int64> index_vector_dim;
2497       attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2498                                    &index_vector_dim};
2499 
2500       optional<HloComputation*> update_computation;
2501       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2502                            &update_computation};
2503       optional<bool> indices_are_sorted = false;
2504       attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2505                                      &indices_are_sorted};
2506       optional<bool> unique_indices = false;
2507       attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
2508                                  &unique_indices};
2509 
2510       if (!ParseOperands(&operands, /*expected_size=*/3) ||
2511           !ParseAttributes(attrs)) {
2512         return false;
2513       }
2514 
2515       ScatterDimensionNumbers dim_numbers =
2516           HloScatterInstruction::MakeScatterDimNumbers(
2517               /*update_window_dims=*/*update_window_dims,
2518               /*inserted_window_dims=*/*inserted_window_dims,
2519               /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
2520               /*index_vector_dim=*/*index_vector_dim);
2521 
2522       if (!maybe_infer_shape(
2523               [&] {
2524                 return ShapeInference::InferScatterShape(
2525                     operands[0]->shape(), operands[1]->shape(),
2526                     operands[2]->shape(),
2527                     update_computation.value()->ComputeProgramShape(),
2528                     dim_numbers);
2529               },
2530               &shape)) {
2531         return false;
2532       }
2533       instruction = builder->AddInstruction(HloInstruction::CreateScatter(
2534           shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
2535           /*updates=*/operands[2], *update_computation, dim_numbers,
2536           indices_are_sorted.value(), unique_indices.value()));
2537       break;
2538     }
2539     case HloOpcode::kDomain: {
2540       DomainData domain;
2541       attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
2542       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2543           !ParseAttributes(attrs)) {
2544         return false;
2545       }
2546       if (!maybe_infer_shape(
2547               [&] {
2548                 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
2549               },
2550               &shape)) {
2551         return false;
2552       }
2553       instruction = builder->AddInstruction(HloInstruction::CreateDomain(
2554           shape, operands[0], std::move(domain.exit_metadata),
2555           std::move(domain.entry_metadata)));
2556       break;
2557     }
2558     case HloOpcode::kTrace:
2559       return TokenError(StrCat("parsing not yet implemented for op: ",
2560                                HloOpcodeString(opcode)));
2561     case HloOpcode::kGetDimensionSize: {
2562       optional<std::vector<int64>> dimensions;
2563       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2564                              &dimensions};
2565       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2566           !ParseAttributes(attrs)) {
2567         return false;
2568       }
2569       if (!maybe_infer_shape(
2570               [&] {
2571                 return ShapeInference::InferGetDimensionSizeShape(
2572                     operands[0]->shape(), dimensions->at(0));
2573               },
2574               &shape)) {
2575         return false;
2576       }
2577       instruction =
2578           builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
2579               shape, operands[0], (*dimensions)[0]));
2580       break;
2581     }
2582     case HloOpcode::kSetDimensionSize: {
2583       optional<std::vector<int64>> dimensions;
2584       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2585                              &dimensions};
2586       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2587           !ParseAttributes(attrs)) {
2588         return false;
2589       }
2590       if (!maybe_infer_shape(
2591               [&] {
2592                 return ShapeInference::InferSetDimensionSizeShape(
2593                     operands[0]->shape(), operands[1]->shape(),
2594                     dimensions->at(0));
2595               },
2596               &shape)) {
2597         return false;
2598       }
2599       instruction =
2600           builder->AddInstruction(HloInstruction::CreateSetDimensionSize(
2601               shape, operands[0], operands[1], (*dimensions)[0]));
2602       break;
2603     }
2604   }
2605 
2606   instruction->SetAndSanitizeName(name);
2607   if (instruction->name() != name) {
2608     return Error(name_loc,
2609                  StrCat("illegal instruction name: ", name,
2610                         "; suggest renaming to: ", instruction->name()));
2611   }
2612 
2613   // Add shared attributes like metadata to the instruction, if they were seen.
2614   if (sharding) {
2615     instruction->set_sharding(
2616         HloSharding::FromProto(sharding.value()).ValueOrDie());
2617   }
2618   if (parameter_replication) {
2619     int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
2620     const auto& replicated =
2621         parameter_replication->replicated_at_leaf_buffers();
2622     if (leaf_count != replicated.size()) {
2623       return Error(lexer_.GetLoc(),
2624                    StrCat("parameter has ", leaf_count,
2625                           " leaf buffers, but parameter_replication has ",
2626                           replicated.size(), " elements."));
2627     }
2628     instruction->set_parameter_replicated_at_leaf_buffers(replicated);
2629   }
2630   if (predecessors) {
2631     for (auto* pre : *predecessors) {
2632       Status status = pre->AddControlDependencyTo(instruction);
2633       if (!status.ok()) {
2634         return Error(name_loc, StrCat("error adding control dependency for: ",
2635                                       name, " status: ", status.ToString()));
2636       }
2637     }
2638   }
2639   if (metadata) {
2640     instruction->set_metadata(*metadata);
2641   }
2642   if (backend_config) {
2643     instruction->set_raw_backend_config_string(std::move(*backend_config));
2644   }
2645   if (outer_dimension_partitions) {
2646     instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
2647   }
2648   if (frontend_attributes) {
2649     instruction->set_frontend_attributes(*frontend_attributes);
2650   }
2651   return AddInstruction(name, instruction, name_loc);
2652 }  // NOLINT(readability/fn_size)
2653 
2654 // ::= '{' (single_sharding | tuple_sharding) '}'
2655 //
2656 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)2657 bool HloParserImpl::ParseSharding(OpSharding* sharding) {
2658   // A single sharding starts with '{' and is not followed by '{'.
2659   // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
2660   // an empty tuple.
2661   if (!ParseToken(TokKind::kLbrace,
2662                   "expected '{' to start sharding attribute")) {
2663     return false;
2664   }
2665 
2666   if (lexer_.GetKind() != TokKind::kLbrace &&
2667       lexer_.GetKind() != TokKind::kRbrace) {
2668     return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
2669   }
2670 
2671   // Tuple sharding.
2672   // Allow empty tuple shardings.
2673   if (lexer_.GetKind() != TokKind::kRbrace) {
2674     do {
2675       if (!ParseSingleSharding(sharding->add_tuple_shardings(),
2676                                /*lbrace_pre_lexed=*/false)) {
2677         return false;
2678       }
2679     } while (EatIfPresent(TokKind::kComma));
2680   }
2681   sharding->set_type(OpSharding::TUPLE);
2682 
2683   return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
2684 }
2685 
2686 // frontend_attributes ::= '{' attributes '}'
2687 // attributes
2688 //   ::= /*empty*/
2689 //   ::= attribute '=' value (',' attribute '=' value)*
ParseFrontendAttributes(FrontendAttributes * frontend_attributes)2690 bool HloParserImpl::ParseFrontendAttributes(
2691     FrontendAttributes* frontend_attributes) {
2692   CHECK(frontend_attributes != nullptr);
2693   if (!ParseToken(TokKind::kLbrace,
2694                   "expected '{' to start frontend attributes")) {
2695     return false;
2696   }
2697   if (lexer_.GetKind() == TokKind::kRbrace) {
2698     // empty
2699   } else {
2700     do {
2701       std::string attribute;
2702       if (!ParseAttributeName(&attribute)) {
2703         return false;
2704       }
2705       if (lexer_.GetKind() != TokKind::kString) {
2706         return false;
2707       }
2708       (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
2709       lexer_.Lex();
2710     } while (EatIfPresent(TokKind::kComma));
2711   }
2712   return ParseToken(TokKind::kRbrace,
2713                     "expects '}' at the end of frontend attributes");
2714 }
2715 
2716 // ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
2717 //         ('devices=' ('[' dims ']')* device_list)?
2718 //         ('metadata=' metadata)* '}'
2719 //
2720 // dims ::= int_list device_list ::= int_list
2721 // metadata ::= single_metadata |
2722 //              ('{' [single_metadata (',' single_metadata)*] '}')
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)2723 bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
2724                                         bool lbrace_pre_lexed) {
2725   if (!lbrace_pre_lexed &&
2726       !ParseToken(TokKind::kLbrace,
2727                   "expected '{' to start sharding attribute")) {
2728     return false;
2729   }
2730 
2731   LocTy loc = lexer_.GetLoc();
2732   bool maximal = false;
2733   bool replicated = false;
2734   bool manual = false;
2735   bool last_tile_dim_replicate = false;
2736   std::vector<int64> devices;
2737   std::vector<int64> tile_assignment_dimensions;
2738   while (lexer_.GetKind() != TokKind::kRbrace) {
2739     switch (lexer_.GetKind()) {
2740       case TokKind::kw_maximal:
2741         maximal = true;
2742         lexer_.Lex();
2743         break;
2744       case TokKind::kw_replicated:
2745         replicated = true;
2746         lexer_.Lex();
2747         break;
2748       case TokKind::kw_manual:
2749         manual = true;
2750         lexer_.Lex();
2751         break;
2752       case TokKind::kAttributeName: {
2753         if (lexer_.GetStrVal() == "device") {
2754           if (lexer_.Lex() != TokKind::kInt) {
2755             return TokenError("device= attribute must be an integer");
2756           }
2757           devices = {lexer_.GetInt64Val()};
2758           lexer_.Lex();
2759         } else if (lexer_.GetStrVal() == "devices") {
2760           lexer_.Lex();
2761           if (!ParseToken(TokKind::kLsquare,
2762                           "expected '[' to start sharding devices shape")) {
2763             return false;
2764           }
2765 
2766           do {
2767             int64 dim;
2768             if (!ParseInt64(&dim)) {
2769               return false;
2770             }
2771             tile_assignment_dimensions.push_back(dim);
2772           } while (EatIfPresent(TokKind::kComma));
2773 
2774           if (!ParseToken(TokKind::kRsquare,
2775                           "expected ']' to start sharding devices shape")) {
2776             return false;
2777           }
2778           do {
2779             int64 device;
2780             if (!ParseInt64(&device)) {
2781               return false;
2782             }
2783             devices.push_back(device);
2784           } while (EatIfPresent(TokKind::kComma));
2785         } else if (lexer_.GetStrVal() == "metadata") {
2786           lexer_.Lex();
2787           if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) {
2788             return false;
2789           }
2790         } else {
2791           return TokenError(
2792               "unknown attribute in sharding: expected device=, devices= or "
2793               "metadata=");
2794         }
2795         break;
2796       }
2797       case TokKind::kw_last_tile_dim_replicate:
2798         last_tile_dim_replicate = true;
2799         lexer_.Lex();
2800         break;
2801       case TokKind::kRbrace:
2802         break;
2803       default:
2804         return TokenError("unexpected token");
2805     }
2806   }
2807 
2808   if (replicated) {
2809     if (!devices.empty()) {
2810       return Error(loc,
2811                    "replicated shardings should not have any devices assigned");
2812     }
2813     sharding->set_type(OpSharding::REPLICATED);
2814   } else if (maximal) {
2815     if (devices.size() != 1) {
2816       return Error(loc,
2817                    "maximal shardings should have exactly one device assigned");
2818     }
2819     sharding->set_type(OpSharding::MAXIMAL);
2820     sharding->add_tile_assignment_devices(devices[0]);
2821   } else if (manual) {
2822     if (!devices.empty()) {
2823       return Error(loc,
2824                    "manual shardings should not have any devices assigned");
2825     }
2826     sharding->set_type(OpSharding::MANUAL);
2827   } else {
2828     if (devices.size() <= 1) {
2829       return Error(
2830           loc, "non-maximal shardings must have more than one device assigned");
2831     }
2832     if (tile_assignment_dimensions.empty()) {
2833       return Error(
2834           loc,
2835           "non-maximal shardings must have a tile assignment list including "
2836           "dimensions");
2837     }
2838     sharding->set_type(OpSharding::OTHER);
2839     for (int64 dim : tile_assignment_dimensions) {
2840       sharding->add_tile_assignment_dimensions(dim);
2841     }
2842     for (int64 device : devices) {
2843       sharding->add_tile_assignment_devices(device);
2844     }
2845     sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate);
2846   }
2847 
2848   lexer_.Lex();
2849   return true;
2850 }
2851 
2852 // parameter_replication ::=
2853 //   '{' ('true' | 'false')* (',' ('true' | 'false'))*  '}'
ParseParameterReplication(ParameterReplication * parameter_replication)2854 bool HloParserImpl::ParseParameterReplication(
2855     ParameterReplication* parameter_replication) {
2856   if (!ParseToken(TokKind::kLbrace,
2857                   "expected '{' to start parameter_replication attribute")) {
2858     return false;
2859   }
2860 
2861   if (lexer_.GetKind() != TokKind::kRbrace) {
2862     do {
2863       if (lexer_.GetKind() == TokKind::kw_true) {
2864         parameter_replication->add_replicated_at_leaf_buffers(true);
2865       } else if (lexer_.GetKind() == TokKind::kw_false) {
2866         parameter_replication->add_replicated_at_leaf_buffers(false);
2867       } else {
2868         return false;
2869       }
2870       lexer_.Lex();
2871     } while (EatIfPresent(TokKind::kComma));
2872   }
2873 
2874   return ParseToken(TokKind::kRbrace,
2875                     "expected '}' to end parameter_replication attribute");
2876 }
2877 
2878 // replica_groups ::='{' int64list_elements '}'
2879 // int64list_elements
2880 //   ::= /*empty*/
2881 //   ::= int64list (',' int64list)*
2882 // int64list ::= '{' int64_elements '}'
2883 // int64_elements
2884 //   ::= /*empty*/
2885 //   ::= int64_val (',' int64_val)*
ParseReplicaGroupsOnly(std::vector<ReplicaGroup> * replica_groups)2886 bool HloParserImpl::ParseReplicaGroupsOnly(
2887     std::vector<ReplicaGroup>* replica_groups) {
2888   std::vector<std::vector<int64>> result;
2889   if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
2890                           &result)) {
2891     return false;
2892   }
2893   *replica_groups = CreateReplicaGroups(result);
2894   return true;
2895 }
2896 
2897 // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
2898 //            'exit=' exit_sharding '}'
ParseDomain(DomainData * domain)2899 bool HloParserImpl::ParseDomain(DomainData* domain) {
2900   absl::flat_hash_map<std::string, AttrConfig> attrs;
2901   optional<std::string> kind;
2902   optional<OpSharding> entry_sharding;
2903   optional<OpSharding> exit_sharding;
2904   attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
2905   attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
2906   attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
2907   if (!ParseSubAttributes(attrs)) {
2908     return false;
2909   }
2910   if (*kind == ShardingMetadata::KindName()) {
2911     auto entry_sharding_ptr = absl::make_unique<HloSharding>(
2912         HloSharding::FromProto(*entry_sharding).ValueOrDie());
2913     auto exit_sharding_ptr = absl::make_unique<HloSharding>(
2914         HloSharding::FromProto(*exit_sharding).ValueOrDie());
2915     domain->entry_metadata =
2916         absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
2917     domain->exit_metadata =
2918         absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
2919   } else {
2920     return TokenError(StrCat("unsupported domain kind: ", *kind));
2921   }
2922   return true;
2923 }
2924 
2925 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)2926 bool HloParserImpl::ParseInstructionNames(
2927     std::vector<HloInstruction*>* instructions) {
2928   if (!ParseToken(TokKind::kLbrace,
2929                   "expects '{' at the beginning of instruction name list")) {
2930     return false;
2931   }
2932   LocTy loc = lexer_.GetLoc();
2933   do {
2934     std::string name;
2935     if (!ParseName(&name)) {
2936       return Error(loc, "expects a instruction name");
2937     }
2938     std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
2939     if (!instr) {
2940       return TokenError(StrFormat("instruction '%s' is not defined", name));
2941     }
2942     instructions->push_back(instr->first);
2943   } while (EatIfPresent(TokKind::kComma));
2944 
2945   return ParseToken(TokKind::kRbrace,
2946                     "expects '}' at the end of instruction name list");
2947 }
2948 
SetValueInLiteral(LocTy loc,int64 value,int64 index,Literal * literal)2949 bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value, int64 index,
2950                                       Literal* literal) {
2951   const Shape& shape = literal->shape();
2952   switch (shape.element_type()) {
2953     case S8:
2954       return SetValueInLiteralHelper<int8>(loc, value, index, literal);
2955     case S16:
2956       return SetValueInLiteralHelper<int16>(loc, value, index, literal);
2957     case S32:
2958       return SetValueInLiteralHelper<int32>(loc, value, index, literal);
2959     case S64:
2960       return SetValueInLiteralHelper<int64>(loc, value, index, literal);
2961     case U8:
2962       return SetValueInLiteralHelper<tensorflow::uint8>(loc, value, index,
2963                                                         literal);
2964     case U16:
2965       return SetValueInLiteralHelper<tensorflow::uint16>(loc, value, index,
2966                                                          literal);
2967     case U32:
2968       return SetValueInLiteralHelper<tensorflow::uint32>(loc, value, index,
2969                                                          literal);
2970     case U64:
2971       return SetValueInLiteralHelper<tensorflow::uint64>(loc, value, index,
2972                                                          literal);
2973     case PRED:
2974       // Bool type literals with rank >= 1 are printed in 0s and 1s.
2975       return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
2976                                            literal);
2977     default:
2978       LOG(FATAL) << "unknown integral primitive type "
2979                  << PrimitiveType_Name(shape.element_type());
2980   }
2981 }
2982 
SetValueInLiteral(LocTy loc,double value,int64 index,Literal * literal)2983 bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64 index,
2984                                       Literal* literal) {
2985   const Shape& shape = literal->shape();
2986   switch (shape.element_type()) {
2987     case F16:
2988       return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
2989     case BF16:
2990       return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
2991                                                            literal);
2992     case F32:
2993       return SetValueInLiteralHelper<float>(loc, value, index, literal);
2994     case F64:
2995       return SetValueInLiteralHelper<double>(loc, value, index, literal);
2996     default:
2997       LOG(FATAL) << "unknown floating point primitive type "
2998                  << PrimitiveType_Name(shape.element_type());
2999   }
3000 }
3001 
SetValueInLiteral(LocTy loc,bool value,int64 index,Literal * literal)3002 bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64 index,
3003                                       Literal* literal) {
3004   const Shape& shape = literal->shape();
3005   switch (shape.element_type()) {
3006     case PRED:
3007       return SetValueInLiteralHelper<bool>(loc, value, index, literal);
3008     default:
3009       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3010                  << " is not PRED type";
3011   }
3012 }
3013 
SetValueInLiteral(LocTy loc,std::complex<double> value,int64 index,Literal * literal)3014 bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
3015                                       int64 index, Literal* literal) {
3016   const Shape& shape = literal->shape();
3017   switch (shape.element_type()) {
3018     case C64:
3019       return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
3020                                                           literal);
3021     case C128:
3022       return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
3023                                                            literal);
3024     default:
3025       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3026                  << " is not a complex type";
3027   }
3028 }
3029 
3030 template <typename T>
StringifyValue(T val)3031 std::string StringifyValue(T val) {
3032   return StrCat(val);
3033 }
3034 template <>
StringifyValue(std::complex<double> val)3035 std::string StringifyValue(std::complex<double> val) {
3036   return StrFormat("(%f, %f)", std::real(val), std::imag(val));
3037 }
3038 
3039 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(LocTy loc,ParsedElemT value,int64 index,Literal * literal)3040 bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
3041                                             int64 index, Literal* literal) {
3042   if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
3043     return false;
3044   }
3045 
3046   // Check that the index is in range and assign into the literal
3047   if (index >= ShapeUtil::ElementsIn(literal->shape())) {
3048     return Error(loc, StrCat("tries to set value ", StringifyValue(value),
3049                              " to a literal in shape ",
3050                              ShapeUtil::HumanString(literal->shape()),
3051                              " at linear index ", index,
3052                              ", but the index is out of range"));
3053   }
3054   literal->data<LiteralNativeT>().at(index) =
3055       static_cast<LiteralNativeT>(value);
3056   return true;
3057 }
3058 
ParseLiteral(Literal * literal)3059 bool HloParserImpl::ParseLiteral(Literal* literal) {
3060   Shape literal_shape;
3061   if (!ParseShape(&literal_shape)) {
3062     return false;
3063   }
3064   return ParseLiteral(literal, literal_shape);
3065 }
3066 
3067 // literal
3068 //  ::= tuple
3069 //  ::= non_tuple
ParseLiteral(Literal * literal,const Shape & shape)3070 bool HloParserImpl::ParseLiteral(Literal* literal, const Shape& shape) {
3071   return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
3072                          : ParseNonTupleLiteral(literal, shape);
3073 }
3074 
3075 // tuple
3076 //  ::= shape '(' literal_list ')'
3077 // literal_list
3078 //  ::= /*empty*/
3079 //  ::= literal (',' literal)*
ParseTupleLiteral(Literal * literal,const Shape & shape)3080 bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
3081   if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
3082     return false;
3083   }
3084   std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
3085 
3086   if (lexer_.GetKind() == TokKind::kRparen) {
3087     // empty
3088   } else {
3089     // literal, (',' literal)*
3090     for (int i = 0; i < elements.size(); i++) {
3091       if (i > 0) {
3092         ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3093       }
3094       if (!ParseLiteral(&elements[i],
3095                         ShapeUtil::GetTupleElementShape(shape, i))) {
3096         return TokenError(StrCat("expects the ", i, "th element"));
3097       }
3098     }
3099   }
3100   *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3101   return ParseToken(TokKind::kRparen,
3102                     StrCat("expects ')' at the end of the tuple with ",
3103                            ShapeUtil::TupleElementCount(shape), "elements"));
3104 }
3105 
3106 // non_tuple
3107 //   ::= rank01
3108 //   ::= rank2345
3109 // rank2345 ::= shape nested_array
ParseNonTupleLiteral(Literal * literal,const Shape & shape)3110 bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
3111   CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
3112   return ParseDenseLiteral(literal, shape);
3113 }
3114 
ParseDenseLiteral(Literal * literal,const Shape & shape)3115 bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
3116   // Cast `rank` to int because we call shape.dimensions(int rank) below, and if
3117   // `rank` is an int64, that's an implicit narrowing conversion, which is
3118   // implementation-defined behavior.
3119   const int rank = static_cast<int>(shape.rank());
3120 
3121   // Create a literal with the given shape in default layout.
3122   *literal = LiteralUtil::CreateFromDimensions(
3123       shape.element_type(), AsInt64Slice(shape.dimensions()));
3124   int64 nest_level = 0;
3125   int64 linear_index = 0;
3126   // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
3127   // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
3128   // when we are parsing the 2nd '{' (right before '1'), we are seeing a
3129   // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
3130   // the first '}' (right after '3'), it means the sub-array ends, and the
3131   // sub-array is supposed to contain exactly 3 elements, so check if
3132   // elems_seen_per_dim[1] is 3.
3133   std::vector<int64> elems_seen_per_dim(rank);
3134   auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string {
3135     std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
3136                                             elems_seen_per_dim.begin() + dim);
3137     return StrCat("[",
3138                   StrJoin(elems_seen_until_dim, ",",
3139                           [](std::string* out, const int64 num_elems) {
3140                             StrAppend(out, num_elems - 1);
3141                           }),
3142                   "]");
3143   };
3144 
3145   auto add_one_elem_seen = [&] {
3146     if (rank > 0) {
3147       if (nest_level != rank) {
3148         return TokenError(absl::StrFormat(
3149             "expects nested array in rank %d, but sees %d", rank, nest_level));
3150       }
3151       elems_seen_per_dim[rank - 1]++;
3152       if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
3153         return TokenError(absl::StrFormat(
3154             "expects %d elements on the minor-most dimension, but "
3155             "sees more",
3156             shape.dimensions(rank - 1)));
3157       }
3158     }
3159     return true;
3160   };
3161 
3162   do {
3163     switch (lexer_.GetKind()) {
3164       default:
3165         return TokenError("unexpected token type in a literal");
3166       case TokKind::kLbrace: {
3167         nest_level++;
3168         if (nest_level > rank) {
3169           return TokenError(absl::StrFormat(
3170               "expects nested array in rank %d, but sees larger", rank));
3171         }
3172         if (nest_level > 1) {
3173           elems_seen_per_dim[nest_level - 2]++;
3174           if (elems_seen_per_dim[nest_level - 2] >
3175               shape.dimensions(nest_level - 2)) {
3176             return TokenError(absl::StrFormat(
3177                 "expects %d elements in the %sth element, but sees more",
3178                 shape.dimensions(nest_level - 2),
3179                 get_index_str(nest_level - 2)));
3180           }
3181         }
3182         lexer_.Lex();
3183         break;
3184       }
3185       case TokKind::kRbrace: {
3186         nest_level--;
3187         if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
3188           return TokenError(absl::StrFormat(
3189               "expects %d elements in the %sth element, but sees %d",
3190               shape.dimensions(nest_level), get_index_str(nest_level),
3191               elems_seen_per_dim[nest_level]));
3192         }
3193         elems_seen_per_dim[nest_level] = 0;
3194         lexer_.Lex();
3195         break;
3196       }
3197       case TokKind::kLparen: {
3198         if (!primitive_util::IsComplexType(shape.element_type())) {
3199           return TokenError(
3200               absl::StrFormat("unexpected '(' in literal.  Parens are only "
3201                               "valid for complex literals"));
3202         }
3203 
3204         std::complex<double> value;
3205         LocTy loc = lexer_.GetLoc();
3206         if (!add_one_elem_seen() || !ParseComplex(&value) ||
3207             !SetValueInLiteral(loc, value, linear_index++, literal)) {
3208           return false;
3209         }
3210         break;
3211       }
3212       case TokKind::kDots: {
3213         if (nest_level != 1) {
3214           return TokenError(absl::StrFormat(
3215               "expects `...` at nest level 1, but sees it at nest level %d",
3216               nest_level));
3217         }
3218         elems_seen_per_dim[0] = shape.dimensions(0);
3219         lexer_.Lex();
3220         // Fill data with deterministic (garbage) values. Use static to avoid
3221         // creating identical constants which could potentially got CSE'ed
3222         // away. This is a best-effort approach to make sure replaying a HLO
3223         // gives us same optimized HLO graph.
3224         static uint32 data = 0;
3225         uint32* raw_data = static_cast<uint32*>(literal->untyped_data());
3226         for (int64 i = 0; i < literal->size_bytes() / 4; ++i) {
3227           raw_data[i] = data++;
3228         }
3229         uint8* raw_data_int8 = static_cast<uint8*>(literal->untyped_data());
3230         static uint8 data_int8 = 0;
3231         for (int64 i = 0; i < literal->size_bytes() % 4; ++i) {
3232           raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++;
3233         }
3234         break;
3235       }
3236       case TokKind::kComma:
3237         // Skip.
3238         lexer_.Lex();
3239         break;
3240       case TokKind::kw_true:
3241       case TokKind::kw_false:
3242       case TokKind::kInt:
3243       case TokKind::kDecimal:
3244       case TokKind::kw_nan:
3245       case TokKind::kNegNan:
3246       case TokKind::kw_inf:
3247       case TokKind::kNegInf: {
3248         add_one_elem_seen();
3249         if (lexer_.GetKind() == TokKind::kw_true ||
3250             lexer_.GetKind() == TokKind::kw_false) {
3251           if (!SetValueInLiteral(lexer_.GetLoc(),
3252                                  lexer_.GetKind() == TokKind::kw_true,
3253                                  linear_index++, literal)) {
3254             return false;
3255           }
3256           lexer_.Lex();
3257         } else if (primitive_util::IsIntegralType(shape.element_type()) ||
3258                    shape.element_type() == PRED) {
3259           LocTy loc = lexer_.GetLoc();
3260           int64 value;
3261           if (!ParseInt64(&value)) {
3262             return Error(loc, StrCat("expects integer for primitive type: ",
3263                                      PrimitiveType_Name(shape.element_type())));
3264           }
3265           if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3266             return false;
3267           }
3268         } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
3269           LocTy loc = lexer_.GetLoc();
3270           double value;
3271           if (!ParseDouble(&value)) {
3272             return Error(
3273                 loc, StrCat("expect floating point value for primitive type: ",
3274                             PrimitiveType_Name(shape.element_type())));
3275           }
3276           if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3277             return false;
3278           }
3279         } else {
3280           return TokenError(StrCat("unsupported primitive type ",
3281                                    PrimitiveType_Name(shape.element_type())));
3282         }
3283         break;
3284       }
3285     }  // end of switch
3286   } while (nest_level > 0);
3287 
3288   *literal = literal->Relayout(shape.layout());
3289   return true;
3290 }
3291 
3292 // MaxFiniteValue is a type-traits helper used by
3293 // HloParserImpl::CheckParsedValueIsInRange.
3294 template <typename T>
3295 struct MinMaxFiniteValue {
maxxla::__anonf806807d0111::MinMaxFiniteValue3296   static T max() { return std::numeric_limits<T>::max(); }
minxla::__anonf806807d0111::MinMaxFiniteValue3297   static T min() { return std::numeric_limits<T>::lowest(); }
3298 };
3299 
3300 template <>
3301 struct MinMaxFiniteValue<Eigen::half> {
maxxla::__anonf806807d0111::MinMaxFiniteValue3302   static double max() {
3303     // Sadly this is not constexpr, so this forces `value` to be a method.
3304     return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
3305   }
minxla::__anonf806807d0111::MinMaxFiniteValue3306   static double min() { return -max(); }
3307 };
3308 
3309 template <>
3310 struct MinMaxFiniteValue<bfloat16> {
maxxla::__anonf806807d0111::MinMaxFiniteValue3311   static double max() {
3312     return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
3313   }
minxla::__anonf806807d0111::MinMaxFiniteValue3314   static double min() { return -max(); }
3315 };
3316 
3317 // MSVC's standard C++ library does not define isnan/isfinite for integer types.
3318 // To work around that we will need to provide our own.
3319 template <typename T>
IsFinite(T val)3320 std::enable_if_t<std::is_floating_point<T>::value, bool> IsFinite(T val) {
3321   return std::isfinite(val);
3322 }
3323 template <typename T>
IsNaN(T val)3324 std::enable_if_t<std::is_floating_point<T>::value, bool> IsNaN(T val) {
3325   return std::isnan(val);
3326 }
3327 template <typename T>
IsFinite(T val)3328 std::enable_if_t<std::is_integral<T>::value, bool> IsFinite(T val) {
3329   return std::isfinite(static_cast<double>(val));
3330 }
3331 template <typename T>
IsNaN(T val)3332 std::enable_if_t<std::is_integral<T>::value, bool> IsNaN(T val) {
3333   return std::isnan(static_cast<double>(val));
3334 }
3335 
3336 template <typename LiteralNativeT, typename ParsedElemT>
CheckParsedValueIsInRange(LocTy loc,ParsedElemT value)3337 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
3338   if (std::is_floating_point<ParsedElemT>::value) {
3339     auto value_as_native_t = static_cast<LiteralNativeT>(value);
3340     auto value_double_converted = static_cast<ParsedElemT>(value_as_native_t);
3341     if (!IsFinite(value) || IsFinite(value_double_converted)) {
3342       value = value_double_converted;
3343     }
3344   }
3345   PrimitiveType literal_ty =
3346       primitive_util::NativeToPrimitiveType<LiteralNativeT>();
3347   if (IsNaN(value) ||
3348       (std::numeric_limits<ParsedElemT>::has_infinity &&
3349        (std::numeric_limits<ParsedElemT>::infinity() == value ||
3350         -std::numeric_limits<ParsedElemT>::infinity() == value))) {
3351     // Skip range checking for non-finite value.
3352   } else if (std::is_unsigned<LiteralNativeT>::value) {
3353     CHECK((std::is_same<ParsedElemT, int64>::value ||
3354            std::is_same<ParsedElemT, bool>::value))
3355         << "Unimplemented checking for ParsedElemT";
3356 
3357     const uint64 unsigned_value = value;
3358     const uint64 upper_bound =
3359         static_cast<uint64>(std::numeric_limits<LiteralNativeT>::max());
3360     if (unsigned_value > upper_bound) {
3361       // Value is out of range for LiteralNativeT.
3362       return Error(loc, StrCat("value ", value,
3363                                " is out of range for literal's primitive type ",
3364                                PrimitiveType_Name(literal_ty), " namely [0, ",
3365                                upper_bound, "]."));
3366     }
3367   } else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
3368              value < MinMaxFiniteValue<LiteralNativeT>::min()) {
3369     // Value is out of range for LiteralNativeT.
3370     return Error(loc, StrCat("value ", value,
3371                              " is out of range for literal's primitive type ",
3372                              PrimitiveType_Name(literal_ty), " namely [",
3373                              MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
3374                              MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
3375   }
3376   return true;
3377 }
3378 
3379 template <typename LiteralNativeT>
CheckParsedValueIsInRange(LocTy loc,std::complex<double> value)3380 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc,
3381                                               std::complex<double> value) {
3382   // e.g. `float` for std::complex<float>
3383   using LiteralComplexComponentT =
3384       decltype(std::real(std::declval<LiteralNativeT>()));
3385 
3386   // We could do simply
3387   //
3388   //   return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
3389   //          CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
3390   //
3391   // but this would give bad error messages on failure.
3392 
3393   auto check_component = [&](absl::string_view name, double v) {
3394     if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
3395         v == -std::numeric_limits<double>::infinity()) {
3396       // Skip range-checking for non-finite values.
3397       return true;
3398     }
3399 
3400     double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
3401     double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
3402     if (v < min || v > max) {
3403       // Value is out of range for LitearlComplexComponentT.
3404       return Error(
3405           loc,
3406           StrCat(name, " part ", v,
3407                  " is out of range for literal's primitive type ",
3408                  PrimitiveType_Name(
3409                      primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
3410                  ", namely [", min, ", ", max, "]."));
3411     }
3412     return true;
3413   };
3414   return check_component("real", std::real(value)) &&
3415          check_component("imaginary", std::imag(value));
3416 }
3417 
3418 // operands ::= '(' operands1 ')'
3419 // operands1
3420 //   ::= /*empty*/
3421 //   ::= operand (, operand)*
3422 // operand ::= (shape)? name
ParseOperands(std::vector<HloInstruction * > * operands)3423 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands) {
3424   CHECK(operands != nullptr);
3425   if (!ParseToken(TokKind::kLparen,
3426                   "expects '(' at the beginning of operands")) {
3427     return false;
3428   }
3429   if (lexer_.GetKind() == TokKind::kRparen) {
3430     // empty
3431   } else {
3432     do {
3433       LocTy loc = lexer_.GetLoc();
3434       std::string name;
3435       optional<Shape> shape;
3436       if (CanBeShape()) {
3437         shape.emplace();
3438         if (!ParseShape(&shape.value())) {
3439           return false;
3440         }
3441       }
3442       if (!ParseName(&name)) {
3443         // When parsing a single instruction (as opposed to a whole module), an
3444         // HLO may have one or more operands with a shape but no name:
3445         //
3446         //  foo = add(f32[10], f32[10])
3447         //
3448         // create_missing_instruction_ is always non-null when parsing a single
3449         // instruction, and is responsible for creating kParameter instructions
3450         // for these operands.
3451         if (shape.has_value() && create_missing_instruction_ != nullptr &&
3452             scoped_name_tables_.size() == 1) {
3453           name = "";
3454         } else {
3455           return false;
3456         }
3457       }
3458       std::pair<HloInstruction*, LocTy>* instruction =
3459           FindInstruction(name, shape);
3460       if (instruction == nullptr) {
3461         return Error(loc, StrCat("instruction does not exist: ", name));
3462       }
3463       operands->push_back(instruction->first);
3464     } while (EatIfPresent(TokKind::kComma));
3465   }
3466   return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
3467 }
3468 
ParseOperands(std::vector<HloInstruction * > * operands,const int expected_size)3469 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
3470                                   const int expected_size) {
3471   CHECK(operands != nullptr);
3472   LocTy loc = lexer_.GetLoc();
3473   if (!ParseOperands(operands)) {
3474     return false;
3475   }
3476   if (expected_size != operands->size()) {
3477     return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
3478                              operands->size(), " operands"));
3479   }
3480   return true;
3481 }
3482 
3483 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)3484 bool HloParserImpl::ParseSubAttributes(
3485     const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
3486   LocTy loc = lexer_.GetLoc();
3487   if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
3488     return false;
3489   }
3490   absl::flat_hash_set<std::string> seen_attrs;
3491   if (lexer_.GetKind() == TokKind::kRbrace) {
3492     // empty
3493   } else {
3494     do {
3495       EatIfPresent(TokKind::kComma);
3496       if (!ParseAttributeHelper(attrs, &seen_attrs)) {
3497         return false;
3498       }
3499     } while (lexer_.GetKind() != TokKind::kRbrace);
3500   }
3501   // Check that all required attrs were seen.
3502   for (const auto& attr_it : attrs) {
3503     if (attr_it.second.required &&
3504         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
3505       return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
3506                                   attr_it.first));
3507     }
3508   }
3509   return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
3510 }
3511 
3512 // attributes ::= (',' attribute)*
ParseAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)3513 bool HloParserImpl::ParseAttributes(
3514     const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
3515   LocTy loc = lexer_.GetLoc();
3516   absl::flat_hash_set<std::string> seen_attrs;
3517   while (EatIfPresent(TokKind::kComma)) {
3518     if (!ParseAttributeHelper(attrs, &seen_attrs)) {
3519       return false;
3520     }
3521   }
3522   // Check that all required attrs were seen.
3523   for (const auto& attr_it : attrs) {
3524     if (attr_it.second.required &&
3525         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
3526       return Error(loc, StrFormat("attribute %s is expected but not seen",
3527                                   attr_it.first));
3528     }
3529   }
3530   return true;
3531 }
3532 
ParseAttributeHelper(const absl::flat_hash_map<std::string,AttrConfig> & attrs,absl::flat_hash_set<std::string> * seen_attrs)3533 bool HloParserImpl::ParseAttributeHelper(
3534     const absl::flat_hash_map<std::string, AttrConfig>& attrs,
3535     absl::flat_hash_set<std::string>* seen_attrs) {
3536   LocTy loc = lexer_.GetLoc();
3537   std::string name;
3538   if (!ParseAttributeName(&name)) {
3539     return Error(loc, "error parsing attributes");
3540   }
3541   VLOG(3) << "Parsing attribute " << name;
3542   if (!seen_attrs->insert(name).second) {
3543     return Error(loc, StrFormat("attribute %s already exists", name));
3544   }
3545   auto attr_it = attrs.find(name);
3546   if (attr_it == attrs.end()) {
3547     std::string allowed_attrs;
3548     if (attrs.empty()) {
3549       allowed_attrs = "No attributes are allowed here.";
3550     } else {
3551       allowed_attrs =
3552           StrCat("Allowed attributes: ",
3553                  StrJoin(attrs, ", ",
3554                          [&](std::string* out,
3555                              const std::pair<std::string, AttrConfig>& kv) {
3556                            StrAppend(out, kv.first);
3557                          }));
3558     }
3559     return Error(loc, StrFormat("unexpected attribute \"%s\".  %s", name,
3560                                 allowed_attrs));
3561   }
3562   AttrTy attr_type = attr_it->second.attr_type;
3563   void* attr_out_ptr = attr_it->second.result;
3564   bool success = [&] {
3565     LocTy attr_loc = lexer_.GetLoc();
3566     switch (attr_type) {
3567       case AttrTy::kBool: {
3568         bool result;
3569         if (!ParseBool(&result)) {
3570           return false;
3571         }
3572         static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
3573         return true;
3574       }
3575       case AttrTy::kInt64: {
3576         int64 result;
3577         if (!ParseInt64(&result)) {
3578           return false;
3579         }
3580         static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
3581         return true;
3582       }
3583       case AttrTy::kInt32: {
3584         int64 result;
3585         if (!ParseInt64(&result)) {
3586           return false;
3587         }
3588         if (result != static_cast<int32>(result)) {
3589           return Error(attr_loc, "value out of range for int32");
3590         }
3591         static_cast<optional<int32>*>(attr_out_ptr)
3592             ->emplace(static_cast<int32>(result));
3593         return true;
3594       }
3595       case AttrTy::kFloat: {
3596         double result;
3597         if (!ParseDouble(&result)) {
3598           return false;
3599         }
3600         if (result > std::numeric_limits<float>::max() ||
3601             result < std::numeric_limits<float>::lowest()) {
3602           return Error(attr_loc, "value out of range for float");
3603         }
3604         static_cast<optional<float>*>(attr_out_ptr)
3605             ->emplace(static_cast<float>(result));
3606         return true;
3607       }
3608       case AttrTy::kHloComputation: {
3609         HloComputation* result = nullptr;
3610         if (!ParseHloComputation(&result)) {
3611           return false;
3612         }
3613         static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
3614         return true;
3615       }
3616       case AttrTy::kBracedHloComputationList: {
3617         std::vector<HloComputation*> result;
3618         if (!ParseHloComputationList(&result)) {
3619           return false;
3620         }
3621         static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
3622             ->emplace(result);
3623         return true;
3624       }
3625       case AttrTy::kFftType: {
3626         FftType result;
3627         if (!ParseFftType(&result)) {
3628           return false;
3629         }
3630         static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
3631         return true;
3632       }
3633       case AttrTy::kPaddingType: {
3634         PaddingType result;
3635         if (!ParsePaddingType(&result)) {
3636           return false;
3637         }
3638         static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
3639         return true;
3640       }
3641       case AttrTy::kComparisonDirection: {
3642         ComparisonDirection result;
3643         if (!ParseComparisonDirection(&result)) {
3644           return false;
3645         }
3646         static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
3647             ->emplace(result);
3648         return true;
3649       }
3650       case AttrTy::kComparisonType: {
3651         Comparison::Type result;
3652         if (!ParseComparisonType(&result)) {
3653           return false;
3654         }
3655         static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
3656         return true;
3657       }
3658       case AttrTy::kEnum: {
3659         if (lexer_.GetKind() != TokKind::kIdent) {
3660           return TokenError("expects an enumeration value");
3661         }
3662         std::string result = lexer_.GetStrVal();
3663         lexer_.Lex();
3664         static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
3665         return true;
3666       }
3667       case AttrTy::kWindow: {
3668         Window result;
3669         if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
3670           return false;
3671         }
3672         static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
3673         return true;
3674       }
3675       case AttrTy::kConvolutionDimensionNumbers: {
3676         ConvolutionDimensionNumbers result;
3677         if (!ParseConvolutionDimensionNumbers(&result)) {
3678           return false;
3679         }
3680         static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
3681             ->emplace(result);
3682         return true;
3683       }
3684       case AttrTy::kSharding: {
3685         OpSharding sharding;
3686         if (!ParseSharding(&sharding)) {
3687           return false;
3688         }
3689         static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
3690         return true;
3691       }
3692       case AttrTy::kFrontendAttributes: {
3693         FrontendAttributes frontend_attributes;
3694         if (!ParseFrontendAttributes(&frontend_attributes)) {
3695           return false;
3696         }
3697         static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
3698             ->emplace(frontend_attributes);
3699         return true;
3700       }
3701       case AttrTy::kParameterReplication: {
3702         ParameterReplication parameter_replication;
3703         if (!ParseParameterReplication(&parameter_replication)) {
3704           return false;
3705         }
3706         static_cast<optional<ParameterReplication>*>(attr_out_ptr)
3707             ->emplace(parameter_replication);
3708         return true;
3709       }
3710       case AttrTy::kInstructionList: {
3711         std::vector<HloInstruction*> result;
3712         if (!ParseInstructionNames(&result)) {
3713           return false;
3714         }
3715         static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
3716             ->emplace(result);
3717         return true;
3718       }
3719       case AttrTy::kFusionKind: {
3720         HloInstruction::FusionKind result;
3721         if (!ParseFusionKind(&result)) {
3722           return false;
3723         }
3724         static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
3725             ->emplace(result);
3726         return true;
3727       }
3728       case AttrTy::kBracedInt64List: {
3729         std::vector<int64> result;
3730         if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3731                             &result)) {
3732           return false;
3733         }
3734         static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
3735             ->emplace(result);
3736         return true;
3737       }
3738       case AttrTy::kBracedInt64ListList: {
3739         std::vector<std::vector<int64>> result;
3740         if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
3741                                 TokKind::kComma, &result)) {
3742           return false;
3743         }
3744         static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
3745             ->emplace(result);
3746         return true;
3747       }
3748       case AttrTy::kSliceRanges: {
3749         SliceRanges result;
3750         if (!ParseSliceRanges(&result)) {
3751           return false;
3752         }
3753         static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
3754         return true;
3755       }
3756       case AttrTy::kPaddingConfig: {
3757         PaddingConfig result;
3758         if (!ParsePaddingConfig(&result)) {
3759           return false;
3760         }
3761         static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
3762         return true;
3763       }
3764       case AttrTy::kString: {
3765         std::string result;
3766         if (!ParseString(&result)) {
3767           return false;
3768         }
3769         static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
3770         return true;
3771       }
3772       case AttrTy::kMetadata: {
3773         OpMetadata result;
3774         if (!ParseMetadata(&result)) {
3775           return false;
3776         }
3777         static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
3778         return true;
3779       }
3780       case AttrTy::kDistribution: {
3781         RandomDistribution result;
3782         if (!ParseRandomDistribution(&result)) {
3783           return false;
3784         }
3785         static_cast<optional<RandomDistribution>*>(attr_out_ptr)
3786             ->emplace(result);
3787         return true;
3788       }
3789       case AttrTy::kDomain: {
3790         return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
3791       }
3792       case AttrTy::kPrecisionList: {
3793         std::vector<PrecisionConfig::Precision> result;
3794         if (!ParsePrecisionList(&result)) {
3795           return false;
3796         }
3797         static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
3798             attr_out_ptr)
3799             ->emplace(result);
3800         return true;
3801       }
3802       case AttrTy::kShape: {
3803         Shape result;
3804         if (!ParseShape(&result)) {
3805           return false;
3806         }
3807         static_cast<optional<Shape>*>(attr_out_ptr)->emplace(result);
3808         return true;
3809       }
3810       case AttrTy::kShapeList: {
3811         std::vector<Shape> result;
3812         if (!ParseShapeList(&result)) {
3813           return false;
3814         }
3815         static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
3816             ->emplace(result);
3817         return true;
3818       }
3819       case AttrTy::kRandomAlgorithm: {
3820         RandomAlgorithm result;
3821         if (!ParseRandomAlgorithm(&result)) {
3822           return false;
3823         }
3824         static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
3825         return true;
3826       }
3827       case AttrTy::kAliasing: {
3828         AliasingData aliasing_data;
3829         if (!ParseAliasing(&aliasing_data)) {
3830           return false;
3831         }
3832         static_cast<optional<AliasingData>*>(attr_out_ptr)
3833             ->emplace(aliasing_data);
3834         return true;
3835       }
3836       case AttrTy::kInstructionAliasing: {
3837         std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
3838             aliasing_output_operand_pairs;
3839         if (!ParseInstructionOutputOperandAliasing(
3840                 &aliasing_output_operand_pairs)) {
3841           return false;
3842         }
3843         static_cast<optional<
3844             std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
3845             attr_out_ptr)
3846             ->emplace(std::move(aliasing_output_operand_pairs));
3847         return true;
3848       }
3849       case AttrTy::kLiteral: {
3850         if (!ParseToken(TokKind::kLparen, "expects '(' before literal")) {
3851           return false;
3852         }
3853         Literal result;
3854         if (!ParseLiteral(&result)) {
3855           return false;
3856         }
3857         if (!ParseToken(TokKind::kRparen, "expects ')' after literal")) {
3858           return false;
3859         }
3860         static_cast<optional<Literal>*>(attr_out_ptr)
3861             ->emplace(std::move(result));
3862         return true;
3863       }
3864     }
3865   }();
3866   if (!success) {
3867     return Error(loc, StrFormat("error parsing attribute %s", name));
3868   }
3869   return true;
3870 }
3871 
CopyAttributeToProtoMessage(absl::flat_hash_set<std::string> non_proto_attrs,const absl::flat_hash_map<std::string,AttrConfig> & attrs,tensorflow::protobuf::Message * message)3872 bool HloParserImpl::CopyAttributeToProtoMessage(
3873     absl::flat_hash_set<std::string> non_proto_attrs,
3874     const absl::flat_hash_map<std::string, AttrConfig>& attrs,
3875     tensorflow::protobuf::Message* message) {
3876   const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
3877   const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
3878 
3879   for (const auto& p : attrs) {
3880     const std::string& name = p.first;
3881     if (non_proto_attrs.find(name) != non_proto_attrs.end()) {
3882       continue;
3883     }
3884     const tensorflow::protobuf::FieldDescriptor* fd =
3885         descriptor->FindFieldByName(name);
3886     if (!fd) {
3887       std::string allowed_attrs = "Allowed attributes: ";
3888 
3889       for (int i = 0; i < descriptor->field_count(); ++i) {
3890         if (i == 0) {
3891           absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
3892         } else {
3893           absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
3894         }
3895       }
3896       return TokenError(
3897           StrFormat("unexpected attribute \"%s\".  %s", name, allowed_attrs));
3898     }
3899 
3900     CHECK(!fd->is_repeated());  // Repeated fields not implemented.
3901     bool success = [&] {
3902       switch (fd->type()) {
3903         case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
3904           auto attr_value = static_cast<optional<bool>*>(p.second.result);
3905           if (attr_value->has_value()) {
3906             reflection->SetBool(message, fd, **attr_value);
3907           }
3908           return true;
3909         }
3910         case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
3911           auto attr_value =
3912               static_cast<optional<std::string>*>(p.second.result);
3913           if (attr_value->has_value()) {
3914             const tensorflow::protobuf::EnumValueDescriptor* evd =
3915                 fd->enum_type()->FindValueByName(**attr_value);
3916             reflection->SetEnum(message, fd, evd);
3917           }
3918           return true;
3919         }
3920         default:
3921           return false;
3922       }
3923     }();
3924 
3925     if (!success) {
3926       return TokenError(StrFormat("error parsing attribute %s", name));
3927     }
3928   }
3929 
3930   return true;
3931 }
3932 
3933 // attributes ::= (',' attribute)*
ParseAttributesAsProtoMessage(const absl::flat_hash_map<std::string,AttrConfig> & non_proto_attrs,tensorflow::protobuf::Message * message)3934 bool HloParserImpl::ParseAttributesAsProtoMessage(
3935     const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
3936     tensorflow::protobuf::Message* message) {
3937   const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
3938   absl::flat_hash_map<std::string, AttrConfig> attrs;
3939 
3940   // Storage for attributes.
3941   std::vector<optional<bool>> bool_params;
3942   std::vector<optional<std::string>> string_params;
3943   // Reserve enough capacity to make sure that the vector is not growing, so we
3944   // can rely on the pointers to stay valid.
3945   bool_params.reserve(descriptor->field_count());
3946   string_params.reserve(descriptor->field_count());
3947 
3948   // Populate the storage of expected attributes from the protobuf description.
3949   for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
3950     const tensorflow::protobuf::FieldDescriptor* fd =
3951         descriptor->field(field_idx);
3952     const std::string& field_name = fd->name();
3953     switch (fd->type()) {
3954       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
3955         bool_params.emplace_back(absl::nullopt);
3956         attrs[field_name] = {/*is_required*/ false, AttrTy::kBool,
3957                              &bool_params.back()};
3958         break;
3959       }
3960       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
3961         string_params.emplace_back(absl::nullopt);
3962         attrs[field_name] = {/*is_required*/ false, AttrTy::kEnum,
3963                              &string_params.back()};
3964         break;
3965       }
3966       default:
3967         return TokenError(absl::StrFormat(
3968             "Unexpected protocol buffer type: %s ", fd->DebugString()));
3969     }
3970   }
3971 
3972   absl::flat_hash_set<std::string> non_proto_attrs_names;
3973   non_proto_attrs_names.reserve(non_proto_attrs.size());
3974   for (const auto& p : non_proto_attrs) {
3975     const std::string& attr_name = p.first;
3976     // If an attribute is both specified within 'non_proto_attrs' and an
3977     // attribute of the proto message, we prefer the attribute of the proto
3978     // message.
3979     if (attrs.find(attr_name) == attrs.end()) {
3980       non_proto_attrs_names.insert(attr_name);
3981       attrs[attr_name] = p.second;
3982     }
3983   }
3984 
3985   if (!ParseAttributes(attrs)) {
3986     return false;
3987   }
3988 
3989   return CopyAttributeToProtoMessage(non_proto_attrs_names, attrs, message);
3990 }
3991 
ParseComputationName(HloComputation ** value)3992 bool HloParserImpl::ParseComputationName(HloComputation** value) {
3993   std::string name;
3994   LocTy loc = lexer_.GetLoc();
3995   if (!ParseName(&name)) {
3996     return Error(loc, "expects computation name");
3997   }
3998   std::pair<HloComputation*, LocTy>* computation =
3999       tensorflow::gtl::FindOrNull(computation_pool_, name);
4000   if (computation == nullptr) {
4001     return Error(loc, StrCat("computation does not exist: ", name));
4002   }
4003   *value = computation->first;
4004   return true;
4005 }
4006 
4007 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
4008 // The subattributes can appear in any order. 'size=' is required, others are
4009 // optional.
ParseWindow(Window * window,bool expect_outer_curlies)4010 bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
4011   LocTy loc = lexer_.GetLoc();
4012   if (expect_outer_curlies &&
4013       !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
4014     return false;
4015   }
4016 
4017   std::vector<int64> size;
4018   std::vector<int64> stride;
4019   std::vector<std::vector<int64>> pad;
4020   std::vector<int64> lhs_dilate;
4021   std::vector<int64> rhs_dilate;
4022   std::vector<int64> rhs_reversal;
4023   const auto end_token =
4024       expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
4025   while (lexer_.GetKind() != end_token) {
4026     LocTy attr_loc = lexer_.GetLoc();
4027     std::string field_name;
4028     if (!ParseAttributeName(&field_name)) {
4029       return Error(attr_loc, "expects sub-attributes in window");
4030     }
4031     bool ok = [&] {
4032       if (field_name == "size") {
4033         return ParseDxD("size", &size);
4034       }
4035       if (field_name == "stride") {
4036         return ParseDxD("stride", &stride);
4037       }
4038       if (field_name == "lhs_dilate") {
4039         return ParseDxD("lhs_dilate", &lhs_dilate);
4040       }
4041       if (field_name == "rhs_dilate") {
4042         return ParseDxD("rls_dilate", &rhs_dilate);
4043       }
4044       if (field_name == "pad") {
4045         return ParseWindowPad(&pad);
4046       }
4047       if (field_name == "rhs_reversal") {
4048         return ParseDxD("rhs_reversal", &rhs_reversal);
4049       }
4050       return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
4051     }();
4052     if (!ok) {
4053       return false;
4054     }
4055   }
4056 
4057   if (!stride.empty() && stride.size() != size.size()) {
4058     return Error(loc, "expects 'stride=' has the same size as 'size='");
4059   }
4060   if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
4061     return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
4062   }
4063   if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
4064     return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
4065   }
4066   if (!pad.empty() && pad.size() != size.size()) {
4067     return Error(loc, "expects 'pad=' has the same size as 'size='");
4068   }
4069 
4070   for (int i = 0; i < size.size(); i++) {
4071     window->add_dimensions()->set_size(size[i]);
4072     if (!pad.empty()) {
4073       window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
4074       window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
4075     }
4076     // If some field is not present, it has the default value.
4077     window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
4078     window->mutable_dimensions(i)->set_base_dilation(
4079         lhs_dilate.empty() ? 1 : lhs_dilate[i]);
4080     window->mutable_dimensions(i)->set_window_dilation(
4081         rhs_dilate.empty() ? 1 : rhs_dilate[i]);
4082     window->mutable_dimensions(i)->set_window_reversal(
4083         rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
4084   }
4085   return !expect_outer_curlies ||
4086          ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
4087 }
4088 
4089 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
4090 // Thestring looks like "dim_labels=0bf_0io->0bf".
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)4091 bool HloParserImpl::ParseConvolutionDimensionNumbers(
4092     ConvolutionDimensionNumbers* dnums) {
4093   if (lexer_.GetKind() != TokKind::kDimLabels) {
4094     return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
4095   }
4096   std::string str = lexer_.GetStrVal();
4097 
4098   // The str is expected to have 3 items, lhs, rhs, out, and it must look like
4099   // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
4100   std::vector<std::string> split1 = absl::StrSplit(str, '_');
4101   if (split1.size() != 2) {
4102     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4103                << str;
4104   }
4105   std::vector<std::string> split2 = absl::StrSplit(split1[1], "->");
4106   if (split2.size() != 2) {
4107     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4108                << str;
4109   }
4110   absl::string_view lhs = split1[0];
4111   absl::string_view rhs = split2[0];
4112   absl::string_view out = split2[1];
4113 
4114   const int64 rank = lhs.length();
4115   if (rank != rhs.length() || rank != out.length()) {
4116     return TokenError(
4117         "convolution lhs, rhs, and output must have the same rank");
4118   }
4119   if (rank < 2) {
4120     return TokenError("convolution rank must >=2");
4121   }
4122 
4123   auto is_unique = [](std::string str) -> bool {
4124     absl::c_sort(str);
4125     return std::unique(str.begin(), str.end()) == str.end();
4126   };
4127 
4128   // lhs
4129   {
4130     if (!is_unique(std::string(lhs))) {
4131       return TokenError(
4132           StrCat("expects unique lhs dimension numbers, but sees ", lhs));
4133     }
4134     for (int i = 0; i < rank - 2; i++) {
4135       dnums->add_input_spatial_dimensions(-1);
4136     }
4137     for (int i = 0; i < rank; i++) {
4138       char c = lhs[i];
4139       if (c == 'b') {
4140         dnums->set_input_batch_dimension(i);
4141       } else if (c == 'f') {
4142         dnums->set_input_feature_dimension(i);
4143       } else if (c < '0' + rank && c >= '0') {
4144         dnums->set_input_spatial_dimensions(c - '0', i);
4145       } else {
4146         return TokenError(
4147             StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1));
4148       }
4149     }
4150   }
4151   // rhs
4152   {
4153     if (!is_unique(std::string(rhs))) {
4154       return TokenError(
4155           StrCat("expects unique rhs dimension numbers, but sees ", rhs));
4156     }
4157     for (int i = 0; i < rank - 2; i++) {
4158       dnums->add_kernel_spatial_dimensions(-1);
4159     }
4160     for (int i = 0; i < rank; i++) {
4161       char c = rhs[i];
4162       if (c == 'i') {
4163         dnums->set_kernel_input_feature_dimension(i);
4164       } else if (c == 'o') {
4165         dnums->set_kernel_output_feature_dimension(i);
4166       } else if (c < '0' + rank && c >= '0') {
4167         dnums->set_kernel_spatial_dimensions(c - '0', i);
4168       } else {
4169         return TokenError(
4170             StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1));
4171       }
4172     }
4173   }
4174   // output
4175   {
4176     if (!is_unique(std::string(out))) {
4177       return TokenError(
4178           StrCat("expects unique output dimension numbers, but sees ", out));
4179     }
4180     for (int i = 0; i < rank - 2; i++) {
4181       dnums->add_output_spatial_dimensions(-1);
4182     }
4183     for (int i = 0; i < rank; i++) {
4184       char c = out[i];
4185       if (c == 'b') {
4186         dnums->set_output_batch_dimension(i);
4187       } else if (c == 'f') {
4188         dnums->set_output_feature_dimension(i);
4189       } else if (c < '0' + rank && c >= '0') {
4190         dnums->set_output_spatial_dimensions(c - '0', i);
4191       } else {
4192         return TokenError(StrFormat(
4193             "expects [0-%dbf] in output dimension numbers", rank - 1));
4194       }
4195     }
4196   }
4197 
4198   lexer_.Lex();
4199   return true;
4200 }
4201 
4202 // ::= '{' ranges '}'
4203 //   ::= /*empty*/
4204 //   ::= range (',' range)*
4205 // range ::= '[' start ':' limit (':' stride)? ']'
4206 //
4207 // The slice ranges are printed as:
4208 //
4209 //  {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
4210 //
4211 // This function extracts the starts, limits, and strides as 3 vectors to the
4212 // result. If stride is not present, stride is 1. For example, if the slice
4213 // ranges is printed as:
4214 //
4215 //  {[2:3:4], [5:6:7], [8:9]}
4216 //
4217 // The parsed result will be:
4218 //
4219 //  {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
4220 //
ParseSliceRanges(SliceRanges * result)4221 bool HloParserImpl::ParseSliceRanges(SliceRanges* result) {
4222   if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
4223     return false;
4224   }
4225   std::vector<std::vector<int64>> ranges;
4226   if (lexer_.GetKind() == TokKind::kRbrace) {
4227     // empty
4228     return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4229   }
4230   do {
4231     LocTy loc = lexer_.GetLoc();
4232     ranges.emplace_back();
4233     if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
4234                         &ranges.back())) {
4235       return false;
4236     }
4237     const auto& range = ranges.back();
4238     if (range.size() != 2 && range.size() != 3) {
4239       return Error(loc,
4240                    StrFormat("expects [start:limit:step] or [start:limit], "
4241                              "but sees %d elements.",
4242                              range.size()));
4243     }
4244   } while (EatIfPresent(TokKind::kComma));
4245 
4246   for (const auto& range : ranges) {
4247     result->starts.push_back(range[0]);
4248     result->limits.push_back(range[1]);
4249     result->strides.push_back(range.size() == 3 ? range[2] : 1);
4250   }
4251   return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4252 }
4253 
4254 // precisionlist ::= start precision_elements end
4255 // precision_elements
4256 //   ::= /*empty*/
4257 //   ::= precision_val (delim precision_val)*
ParsePrecisionList(std::vector<PrecisionConfig::Precision> * result)4258 bool HloParserImpl::ParsePrecisionList(
4259     std::vector<PrecisionConfig::Precision>* result) {
4260   auto parse_and_add_item = [&]() {
4261     PrecisionConfig::Precision item;
4262     if (!ParsePrecision(&item)) {
4263       return false;
4264     }
4265     result->push_back(item);
4266     return true;
4267   };
4268   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4269                    parse_and_add_item);
4270 }
4271 
ParseHloComputation(HloComputation ** result)4272 bool HloParserImpl::ParseHloComputation(HloComputation** result) {
4273   if (lexer_.GetKind() == TokKind::kLbrace) {
4274     // This means it is a nested computation.
4275     return ParseInstructionList(result, /*computation_name=*/"_");
4276   }
4277   // This means it is a computation name.
4278   return ParseComputationName(result);
4279 }
4280 
ParseHloComputationList(std::vector<HloComputation * > * result)4281 bool HloParserImpl::ParseHloComputationList(
4282     std::vector<HloComputation*>* result) {
4283   auto parse_and_add_item = [&]() {
4284     HloComputation* computation;
4285     if (!ParseHloComputation(&computation)) {
4286       return false;
4287     }
4288     VLOG(3) << "parsed computation " << computation->name();
4289     result->push_back(computation);
4290     return true;
4291   };
4292   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4293                    parse_and_add_item);
4294 }
4295 
4296 // shapelist ::= '{' shapes '}'
4297 // precision_elements
4298 //   ::= /*empty*/
4299 //   ::= shape (',' shape)*
ParseShapeList(std::vector<Shape> * result)4300 bool HloParserImpl::ParseShapeList(std::vector<Shape>* result) {
4301   auto parse_and_add_item = [&]() {
4302     Shape shape;
4303     if (!ParseShape(&shape)) {
4304       return false;
4305     }
4306     result->push_back(std::move(shape));
4307     return true;
4308   };
4309   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4310                    parse_and_add_item);
4311 }
4312 
4313 // int64list ::= start int64_elements end
4314 // int64_elements
4315 //   ::= /*empty*/
4316 //   ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64> * result)4317 bool HloParserImpl::ParseInt64List(const TokKind start, const TokKind end,
4318                                    const TokKind delim,
4319                                    std::vector<int64>* result) {
4320   auto parse_and_add_item = [&]() {
4321     int64 i;
4322     if (!ParseInt64(&i)) {
4323       return false;
4324     }
4325     result->push_back(i);
4326     return true;
4327   };
4328   return ParseList(start, end, delim, parse_and_add_item);
4329 }
4330 
4331 // int64listlist ::= start int64list_elements end
4332 // int64list_elements
4333 //   ::= /*empty*/
4334 //   ::= int64list (delim int64list)*
4335 // int64list ::= start int64_elements end
4336 // int64_elements
4337 //   ::= /*empty*/
4338 //   ::= int64_val (delim int64_val)*
ParseInt64ListList(const TokKind start,const TokKind end,const TokKind delim,std::vector<std::vector<int64>> * result)4339 bool HloParserImpl::ParseInt64ListList(
4340     const TokKind start, const TokKind end, const TokKind delim,
4341     std::vector<std::vector<int64>>* result) {
4342   auto parse_and_add_item = [&]() {
4343     std::vector<int64> item;
4344     if (!ParseInt64List(start, end, delim, &item)) {
4345       return false;
4346     }
4347     result->push_back(item);
4348     return true;
4349   };
4350   return ParseList(start, end, delim, parse_and_add_item);
4351 }
4352 
ParseList(const TokKind start,const TokKind end,const TokKind delim,const std::function<bool ()> & parse_and_add_item)4353 bool HloParserImpl::ParseList(const TokKind start, const TokKind end,
4354                               const TokKind delim,
4355                               const std::function<bool()>& parse_and_add_item) {
4356   if (!ParseToken(start, StrCat("expects a list starting with ",
4357                                 TokKindToString(start)))) {
4358     return false;
4359   }
4360   if (lexer_.GetKind() == end) {
4361     // empty
4362   } else {
4363     do {
4364       if (!parse_and_add_item()) {
4365         return false;
4366       }
4367     } while (EatIfPresent(delim));
4368   }
4369   return ParseToken(
4370       end, StrCat("expects a list to end with ", TokKindToString(end)));
4371 }
4372 
4373 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)4374 bool HloParserImpl::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
4375   if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
4376     return false;
4377   }
4378   *shape_loc = lexer_.GetLoc();
4379   return ParseShape(shape);
4380 }
4381 
CanBeParamListToShape()4382 bool HloParserImpl::CanBeParamListToShape() {
4383   return lexer_.GetKind() == TokKind::kLparen;
4384 }
4385 
4386 // param_list ::= '(' param_list1 ')'
4387 // param_list1
4388 //   ::= /*empty*/
4389 //   ::= param (',' param)*
4390 // param ::= name shape
ParseParamList()4391 bool HloParserImpl::ParseParamList() {
4392   if (!ParseToken(TokKind::kLparen,
4393                   "expects '(' at the beginning of param list")) {
4394     return false;
4395   }
4396 
4397   if (lexer_.GetKind() == TokKind::kRparen) {
4398     // empty
4399   } else {
4400     do {
4401       Shape shape;
4402       std::string name;
4403       if (!ParseName(&name) || !ParseShape(&shape)) {
4404         return false;
4405       }
4406     } while (EatIfPresent(TokKind::kComma));
4407   }
4408   return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
4409 }
4410 
4411 // dimension_sizes ::= '[' dimension_list ']'
4412 // dimension_list
4413 //   ::= /*empty*/
4414 //   ::= <=? int64 (',' param)*
4415 // param ::= name shape
ParseDimensionSizes(std::vector<int64> * dimension_sizes,std::vector<bool> * dynamic_dimensions)4416 bool HloParserImpl::ParseDimensionSizes(std::vector<int64>* dimension_sizes,
4417                                         std::vector<bool>* dynamic_dimensions) {
4418   auto parse_and_add_item = [&]() {
4419     int64 i;
4420     bool is_dynamic = false;
4421     if (lexer_.GetKind() == TokKind::kLeq) {
4422       is_dynamic = true;
4423       lexer_.Lex();
4424     }
4425     if (!ParseInt64(&i)) {
4426       return false;
4427     }
4428     dimension_sizes->push_back(i);
4429     dynamic_dimensions->push_back(is_dynamic);
4430     return true;
4431   };
4432   return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
4433                    parse_and_add_item);
4434 }
4435 
4436 // tiles
4437 //   ::= /*empty*/
4438 //   ::= 'T' '(' dim_list ')'
4439 // dim_list
4440 //   ::= /*empty*/
4441 //   ::= (int64 | '*') (',' (int64 | '*'))*
ParseTiles(std::vector<Tile> * tiles)4442 bool HloParserImpl::ParseTiles(std::vector<Tile>* tiles) {
4443   auto parse_and_add_tile_dimension = [&]() {
4444     tensorflow::int64 i;
4445     if (ParseInt64(&i)) {
4446       tiles->back().add_dimensions(i);
4447       return true;
4448     }
4449     if (lexer_.GetKind() == TokKind::kAsterisk) {
4450       tiles->back().add_dimensions(Tile::kCombineDimension);
4451       lexer_.Lex();
4452       return true;
4453     }
4454     return false;
4455   };
4456 
4457   do {
4458     tiles->push_back(Tile());
4459     if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
4460                    parse_and_add_tile_dimension)) {
4461       return false;
4462     }
4463   } while (lexer_.GetKind() == TokKind::kLparen);
4464   return true;
4465 }
4466 
4467 // int_attribute
4468 //   ::= /*empty*/
4469 //   ::= attr_token '(' attr_value ')'
4470 // attr_token
4471 //   ::= 'E' | 'S'
4472 // attr_value
4473 //   ::= int64
ParseLayoutIntAttribute(int64 * attr_value,absl::string_view attr_description)4474 bool HloParserImpl::ParseLayoutIntAttribute(
4475     int64* attr_value, absl::string_view attr_description) {
4476   if (!ParseToken(TokKind::kLparen,
4477                   StrCat("expects ", attr_description, " to start with ",
4478                          TokKindToString(TokKind::kLparen)))) {
4479     return false;
4480   }
4481   if (!ParseInt64(attr_value)) {
4482     return false;
4483   }
4484   if (!ParseToken(TokKind::kRparen,
4485                   StrCat("expects ", attr_description, " to end with ",
4486                          TokKindToString(TokKind::kRparen)))) {
4487     return false;
4488   }
4489   return true;
4490 }
4491 
4492 // layout ::= '{' int64_list (':' tiles element_size_in_bits memory_space)? '}'
4493 // element_size_in_bits
4494 //   ::= /*empty*/
4495 //   ::= 'E' '(' int64 ')'
4496 // memory_space
4497 //   ::= /*empty*/
4498 //   ::= 'S' '(' int64 ')'
ParseLayout(Layout * layout)4499 bool HloParserImpl::ParseLayout(Layout* layout) {
4500   std::vector<int64> minor_to_major;
4501   std::vector<Tile> tiles;
4502   tensorflow::int64 element_size_in_bits = 0;
4503   tensorflow::int64 memory_space = 0;
4504 
4505   auto parse_and_add_item = [&]() {
4506     int64 i;
4507     if (!ParseInt64(&i)) {
4508       return false;
4509     }
4510     minor_to_major.push_back(i);
4511     return true;
4512   };
4513 
4514   if (!ParseToken(TokKind::kLbrace,
4515                   StrCat("expects layout to start with ",
4516                          TokKindToString(TokKind::kLbrace)))) {
4517     return false;
4518   }
4519   if (lexer_.GetKind() != TokKind::kRbrace) {
4520     if (lexer_.GetKind() == TokKind::kInt) {
4521       // Parse minor to major.
4522       do {
4523         if (!parse_and_add_item()) {
4524           return false;
4525         }
4526       } while (EatIfPresent(TokKind::kComma));
4527     }
4528 
4529     if (lexer_.GetKind() == TokKind::kColon) {
4530       lexer_.Lex();
4531       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
4532         lexer_.Lex();
4533         ParseTiles(&tiles);
4534       }
4535 
4536       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
4537         lexer_.Lex();
4538         ParseLayoutIntAttribute(&element_size_in_bits, "element size in bits");
4539       }
4540 
4541       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "S") {
4542         lexer_.Lex();
4543         ParseLayoutIntAttribute(&memory_space, "memory space");
4544       }
4545     }
4546   }
4547   if (!ParseToken(TokKind::kRbrace,
4548                   StrCat("expects layout to end with ",
4549                          TokKindToString(TokKind::kRbrace)))) {
4550     return false;
4551   }
4552 
4553   std::vector<Tile> vec_tiles(tiles.size());
4554   for (int i = 0; i < tiles.size(); i++) {
4555     vec_tiles[i] = Tile(tiles[i]);
4556   }
4557   *layout = LayoutUtil::MakeLayout(minor_to_major, vec_tiles,
4558                                    element_size_in_bits, memory_space);
4559   return true;
4560 }
4561 
4562 // shape ::= shape_val_
4563 // shape ::= '(' tuple_elements ')'
4564 // tuple_elements
4565 //   ::= /*empty*/
4566 //   ::= shape (',' shape)*
ParseShape(Shape * result)4567 bool HloParserImpl::ParseShape(Shape* result) {
4568   if (EatIfPresent(TokKind::kLparen)) {  // Tuple
4569     std::vector<Shape> shapes;
4570     if (lexer_.GetKind() == TokKind::kRparen) {
4571       /*empty*/
4572     } else {
4573       // shape (',' shape)*
4574       do {
4575         shapes.emplace_back();
4576         if (!ParseShape(&shapes.back())) {
4577           return false;
4578         }
4579       } while (EatIfPresent(TokKind::kComma));
4580     }
4581     *result = ShapeUtil::MakeTupleShape(shapes);
4582     return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
4583   }
4584 
4585   if (lexer_.GetKind() != TokKind::kPrimitiveType) {
4586     return TokenError(absl::StrCat("expected primitive type, saw ",
4587                                    TokKindToString(lexer_.GetKind())));
4588   }
4589   PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
4590   lexer_.Lex();
4591 
4592   // Each element contains a dimension size and a bool indicating whether this
4593   // is a dynamic dimension.
4594   std::vector<int64> dimension_sizes;
4595   std::vector<bool> dynamic_dimensions;
4596   if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
4597     return false;
4598   }
4599   result->set_element_type(primitive_type);
4600   for (int i = 0; i < dimension_sizes.size(); ++i) {
4601     result->add_dimensions(dimension_sizes[i]);
4602     result->set_dynamic_dimension(i, dynamic_dimensions[i]);
4603   }
4604   LayoutUtil::SetToDefaultLayout(result);
4605 
4606   // We need to lookahead to see if a following open brace is the start of a
4607   // layout. The specific problematic case is:
4608   //
4609   // ENTRY %foo (x: f32[42]) -> f32[123] {
4610   //  ...
4611   // }
4612   //
4613   // The open brace could either be the start of a computation or the start of a
4614   // layout for the f32[123] shape. We consider it the start of a layout if the
4615   // next token after the open brace is an integer or a colon.
4616   if (lexer_.GetKind() == TokKind::kLbrace &&
4617       (lexer_.LookAhead() == TokKind::kInt ||
4618        lexer_.LookAhead() == TokKind::kColon)) {
4619     Layout layout;
4620     if (!ParseLayout(&layout)) {
4621       return false;
4622     }
4623     if (layout.minor_to_major_size() != result->rank()) {
4624       return Error(
4625           lexer_.GetLoc(),
4626           StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
4627                     result->rank(), layout.minor_to_major_size()));
4628     }
4629     *result->mutable_layout() = layout;
4630   }
4631   return true;
4632 }
4633 
CanBeShape()4634 bool HloParserImpl::CanBeShape() {
4635   // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
4636   // with '('.
4637   return lexer_.GetKind() == TokKind::kPrimitiveType ||
4638          lexer_.GetKind() == TokKind::kLparen;
4639 }
4640 
ParseName(std::string * result)4641 bool HloParserImpl::ParseName(std::string* result) {
4642   VLOG(3) << "ParseName";
4643   if (lexer_.GetKind() != TokKind::kIdent &&
4644       lexer_.GetKind() != TokKind::kName) {
4645     return TokenError("expects name");
4646   }
4647   *result = lexer_.GetStrVal();
4648   lexer_.Lex();
4649   return true;
4650 }
4651 
ParseAttributeName(std::string * result)4652 bool HloParserImpl::ParseAttributeName(std::string* result) {
4653   if (lexer_.GetKind() != TokKind::kAttributeName) {
4654     return TokenError("expects attribute name");
4655   }
4656   *result = lexer_.GetStrVal();
4657   lexer_.Lex();
4658   return true;
4659 }
4660 
ParseString(std::string * result)4661 bool HloParserImpl::ParseString(std::string* result) {
4662   VLOG(3) << "ParseString";
4663   if (lexer_.GetKind() != TokKind::kString) {
4664     return TokenError("expects string");
4665   }
4666   *result = lexer_.GetStrVal();
4667   lexer_.Lex();
4668   return true;
4669 }
4670 
ParseDxD(const std::string & name,std::vector<int64> * result)4671 bool HloParserImpl::ParseDxD(const std::string& name,
4672                              std::vector<int64>* result) {
4673   LocTy loc = lexer_.GetLoc();
4674   if (!result->empty()) {
4675     return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
4676   }
4677   // 1D
4678   if (lexer_.GetKind() == TokKind::kInt) {
4679     int64 number;
4680     if (!ParseInt64(&number)) {
4681       return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
4682     }
4683     result->push_back(number);
4684     return true;
4685   }
4686   // 2D or higher.
4687   if (lexer_.GetKind() == TokKind::kDxD) {
4688     std::string str = lexer_.GetStrVal();
4689     if (!SplitToInt64s(str, 'x', result)) {
4690       return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
4691     }
4692     lexer_.Lex();
4693     return true;
4694   }
4695   return TokenError("expects token type kInt or kDxD");
4696 }
4697 
ParseWindowPad(std::vector<std::vector<int64>> * pad)4698 bool HloParserImpl::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
4699   LocTy loc = lexer_.GetLoc();
4700   if (!pad->empty()) {
4701     return Error(loc, "sub-attribute 'pad=' already exists");
4702   }
4703   if (lexer_.GetKind() != TokKind::kPad) {
4704     return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
4705   }
4706   std::string str = lexer_.GetStrVal();
4707   for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
4708     std::vector<int64> low_high;
4709     if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
4710         low_high.size() != 2) {
4711       return Error(loc,
4712                    "expects padding_low and padding_high separated by '_'");
4713     }
4714     pad->push_back(low_high);
4715   }
4716   lexer_.Lex();
4717   return true;
4718 }
4719 
4720 // This is the inverse xla::ToString(PaddingConfig). The padding config string
4721 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
4722 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
4723 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)4724 bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) {
4725   if (lexer_.GetKind() != TokKind::kPad) {
4726     return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
4727   }
4728   LocTy loc = lexer_.GetLoc();
4729   std::string str = lexer_.GetStrVal();
4730   for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
4731     std::vector<int64> padding_dim;
4732     if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
4733         (padding_dim.size() != 2 && padding_dim.size() != 3)) {
4734       return Error(loc,
4735                    "expects padding config pattern like 'low_high_interior' or "
4736                    "'low_high'");
4737     }
4738     auto* dim = padding->add_dimensions();
4739     dim->set_edge_padding_low(padding_dim[0]);
4740     dim->set_edge_padding_high(padding_dim[1]);
4741     dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
4742   }
4743   lexer_.Lex();
4744   return true;
4745 }
4746 
4747 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)4748 bool HloParserImpl::ParseMetadata(OpMetadata* metadata) {
4749   absl::flat_hash_map<std::string, AttrConfig> attrs;
4750   optional<std::string> op_type;
4751   optional<std::string> op_name;
4752   optional<std::string> source_file;
4753   optional<int32> source_line;
4754   optional<std::vector<int64>> profile_type;
4755   attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
4756   attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
4757   attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
4758   attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
4759   attrs["profile_type"] = {/*required=*/false, AttrTy::kBracedInt64List,
4760                            &profile_type};
4761   if (!ParseSubAttributes(attrs)) {
4762     return false;
4763   }
4764   if (op_type) {
4765     metadata->set_op_type(*op_type);
4766   }
4767   if (op_name) {
4768     metadata->set_op_name(*op_name);
4769   }
4770   if (source_file) {
4771     metadata->set_source_file(*source_file);
4772   }
4773   if (source_line) {
4774     metadata->set_source_line(*source_line);
4775   }
4776   if (profile_type) {
4777     for (const auto& type : *profile_type) {
4778       if (!ProfileType_IsValid(type)) {
4779         return false;
4780       }
4781       metadata->add_profile_type(static_cast<ProfileType>(type));
4782     }
4783   }
4784   return true;
4785 }
4786 
4787 // ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}')
ParseSingleOrListMetadata(tensorflow::protobuf::RepeatedPtrField<OpMetadata> * metadata)4788 bool HloParserImpl::ParseSingleOrListMetadata(
4789     tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata) {
4790   if (lexer_.GetKind() == TokKind::kLbrace &&
4791       lexer_.LookAhead() == TokKind::kLbrace) {
4792     if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) {
4793       return false;
4794     }
4795 
4796     if (lexer_.GetKind() != TokKind::kRbrace) {
4797       do {
4798         if (!ParseMetadata(metadata->Add())) {
4799           return false;
4800         }
4801       } while (EatIfPresent(TokKind::kComma));
4802     }
4803 
4804     return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list");
4805   }
4806 
4807   return ParseMetadata(metadata->Add());
4808 }
4809 
ParseOpcode(HloOpcode * result)4810 bool HloParserImpl::ParseOpcode(HloOpcode* result) {
4811   VLOG(3) << "ParseOpcode";
4812   if (lexer_.GetKind() != TokKind::kIdent) {
4813     return TokenError("expects opcode");
4814   }
4815   std::string val = lexer_.GetStrVal();
4816   auto status_or_result = StringToHloOpcode(val);
4817   if (!status_or_result.ok()) {
4818     return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
4819                                 status_or_result.status().error_message()));
4820   }
4821   *result = status_or_result.ValueOrDie();
4822   lexer_.Lex();
4823   return true;
4824 }
4825 
ParseFftType(FftType * result)4826 bool HloParserImpl::ParseFftType(FftType* result) {
4827   VLOG(3) << "ParseFftType";
4828   if (lexer_.GetKind() != TokKind::kIdent) {
4829     return TokenError("expects fft type");
4830   }
4831   std::string val = lexer_.GetStrVal();
4832   if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
4833     return TokenError(StrFormat("expects fft type but sees: %s", val));
4834   }
4835   lexer_.Lex();
4836   return true;
4837 }
4838 
ParsePaddingType(PaddingType * result)4839 bool HloParserImpl::ParsePaddingType(PaddingType* result) {
4840   VLOG(3) << "ParsePaddingType";
4841   if (lexer_.GetKind() != TokKind::kIdent) {
4842     return TokenError("expects padding type");
4843   }
4844   std::string val = lexer_.GetStrVal();
4845   if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
4846     return TokenError(StrFormat("expects padding type but sees: %s", val));
4847   }
4848   lexer_.Lex();
4849   return true;
4850 }
4851 
ParseComparisonDirection(ComparisonDirection * result)4852 bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
4853   VLOG(3) << "ParseComparisonDirection";
4854   if (lexer_.GetKind() != TokKind::kIdent) {
4855     return TokenError("expects comparison direction");
4856   }
4857   std::string val = lexer_.GetStrVal();
4858   auto status_or_result = StringToComparisonDirection(val);
4859   if (!status_or_result.ok()) {
4860     return TokenError(
4861         StrFormat("expects comparison direction but sees: %s", val));
4862   }
4863   *result = status_or_result.ValueOrDie();
4864   lexer_.Lex();
4865   return true;
4866 }
4867 
ParseComparisonType(Comparison::Type * result)4868 bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
4869   VLOG(1) << "ParseComparisonType";
4870   if (lexer_.GetKind() != TokKind::kIdent) {
4871     return TokenError("expects comparison type");
4872   }
4873   std::string val = lexer_.GetStrVal();
4874   auto status_or_result = StringToComparisonType(val);
4875   if (!status_or_result.ok()) {
4876     return TokenError(StrFormat("expects comparison type but sees: %s", val));
4877   }
4878   *result = status_or_result.ValueOrDie();
4879   lexer_.Lex();
4880   return true;
4881 }
4882 
ParseFusionKind(HloInstruction::FusionKind * result)4883 bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
4884   VLOG(3) << "ParseFusionKind";
4885   if (lexer_.GetKind() != TokKind::kIdent) {
4886     return TokenError("expects fusion kind");
4887   }
4888   std::string val = lexer_.GetStrVal();
4889   auto status_or_result = StringToFusionKind(val);
4890   if (!status_or_result.ok()) {
4891     return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
4892                                 val,
4893                                 status_or_result.status().error_message()));
4894   }
4895   *result = status_or_result.ValueOrDie();
4896   lexer_.Lex();
4897   return true;
4898 }
4899 
ParseRandomDistribution(RandomDistribution * result)4900 bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
4901   VLOG(3) << "ParseRandomDistribution";
4902   if (lexer_.GetKind() != TokKind::kIdent) {
4903     return TokenError("expects random distribution");
4904   }
4905   std::string val = lexer_.GetStrVal();
4906   auto status_or_result = StringToRandomDistribution(val);
4907   if (!status_or_result.ok()) {
4908     return TokenError(
4909         StrFormat("expects random distribution but sees: %s, error: %s", val,
4910                   status_or_result.status().error_message()));
4911   }
4912   *result = status_or_result.ValueOrDie();
4913   lexer_.Lex();
4914   return true;
4915 }
4916 
ParseRandomAlgorithm(RandomAlgorithm * result)4917 bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
4918   VLOG(3) << "ParseRandomAlgorithm";
4919   if (lexer_.GetKind() != TokKind::kIdent) {
4920     return TokenError("expects random algorithm");
4921   }
4922   std::string val = lexer_.GetStrVal();
4923   auto status_or_result = StringToRandomAlgorithm(val);
4924   if (!status_or_result.ok()) {
4925     return TokenError(
4926         StrFormat("expects random algorithm but sees: %s, error: %s", val,
4927                   status_or_result.status().error_message()));
4928   }
4929   *result = status_or_result.ValueOrDie();
4930   lexer_.Lex();
4931   return true;
4932 }
4933 
ParsePrecision(PrecisionConfig::Precision * result)4934 bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
4935   VLOG(3) << "ParsePrecision";
4936   if (lexer_.GetKind() != TokKind::kIdent) {
4937     return TokenError("expects random distribution");
4938   }
4939   std::string val = lexer_.GetStrVal();
4940   auto status_or_result = StringToPrecision(val);
4941   if (!status_or_result.ok()) {
4942     return TokenError(StrFormat("expects precision but sees: %s, error: %s",
4943                                 val,
4944                                 status_or_result.status().error_message()));
4945   }
4946   *result = status_or_result.ValueOrDie();
4947   lexer_.Lex();
4948   return true;
4949 }
4950 
ParseInt64(int64 * result)4951 bool HloParserImpl::ParseInt64(int64* result) {
4952   VLOG(3) << "ParseInt64";
4953   if (lexer_.GetKind() != TokKind::kInt) {
4954     return TokenError("expects integer");
4955   }
4956   *result = lexer_.GetInt64Val();
4957   lexer_.Lex();
4958   return true;
4959 }
4960 
ParseDouble(double * result)4961 bool HloParserImpl::ParseDouble(double* result) {
4962   switch (lexer_.GetKind()) {
4963     case TokKind::kDecimal: {
4964       double val = lexer_.GetDecimalVal();
4965       // If GetDecimalVal returns +/-inf, that means that we overflowed
4966       // `double`.
4967       if (std::isinf(val)) {
4968         return TokenError(StrCat("Constant is out of range for double (+/-",
4969                                  std::numeric_limits<double>::max(),
4970                                  ") and so is unparsable."));
4971       }
4972       *result = val;
4973       break;
4974     }
4975     case TokKind::kInt:
4976       *result = static_cast<double>(lexer_.GetInt64Val());
4977       break;
4978     case TokKind::kw_nan:
4979       *result = std::numeric_limits<double>::quiet_NaN();
4980       break;
4981     case TokKind::kNegNan:
4982       *result = -std::numeric_limits<double>::quiet_NaN();
4983       break;
4984     case TokKind::kw_inf:
4985       *result = std::numeric_limits<double>::infinity();
4986       break;
4987     case TokKind::kNegInf:
4988       *result = -std::numeric_limits<double>::infinity();
4989       break;
4990     default:
4991       return TokenError("expects decimal or integer");
4992   }
4993   lexer_.Lex();
4994   return true;
4995 }
4996 
ParseComplex(std::complex<double> * result)4997 bool HloParserImpl::ParseComplex(std::complex<double>* result) {
4998   if (lexer_.GetKind() != TokKind::kLparen) {
4999     return TokenError("expects '(' before complex number");
5000   }
5001   lexer_.Lex();
5002 
5003   double real;
5004   LocTy loc = lexer_.GetLoc();
5005   if (!ParseDouble(&real)) {
5006     return Error(loc,
5007                  "expect floating-point value for real part of complex number");
5008   }
5009 
5010   if (lexer_.GetKind() != TokKind::kComma) {
5011     return TokenError(
5012         absl::StrFormat("expect comma after real part of complex literal"));
5013   }
5014   lexer_.Lex();
5015 
5016   double imag;
5017   loc = lexer_.GetLoc();
5018   if (!ParseDouble(&imag)) {
5019     return Error(
5020         loc,
5021         "expect floating-point value for imaginary part of complex number");
5022   }
5023 
5024   if (lexer_.GetKind() != TokKind::kRparen) {
5025     return TokenError(absl::StrFormat("expect ')' after complex number"));
5026   }
5027 
5028   *result = std::complex<double>(real, imag);
5029   lexer_.Lex();
5030   return true;
5031 }
5032 
ParseBool(bool * result)5033 bool HloParserImpl::ParseBool(bool* result) {
5034   if (lexer_.GetKind() != TokKind::kw_true &&
5035       lexer_.GetKind() != TokKind::kw_false) {
5036     return TokenError("expects true or false");
5037   }
5038   *result = lexer_.GetKind() == TokKind::kw_true;
5039   lexer_.Lex();
5040   return true;
5041 }
5042 
ParseToken(TokKind kind,const std::string & msg)5043 bool HloParserImpl::ParseToken(TokKind kind, const std::string& msg) {
5044   VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
5045   if (lexer_.GetKind() != kind) {
5046     return TokenError(msg);
5047   }
5048   lexer_.Lex();
5049   return true;
5050 }
5051 
EatIfPresent(TokKind kind)5052 bool HloParserImpl::EatIfPresent(TokKind kind) {
5053   if (lexer_.GetKind() != kind) {
5054     return false;
5055   }
5056   lexer_.Lex();
5057   return true;
5058 }
5059 
AddInstruction(const std::string & name,HloInstruction * instruction,LocTy name_loc)5060 bool HloParserImpl::AddInstruction(const std::string& name,
5061                                    HloInstruction* instruction,
5062                                    LocTy name_loc) {
5063   auto result = current_name_table().insert({name, {instruction, name_loc}});
5064   if (!result.second) {
5065     Error(name_loc, StrCat("instruction already exists: ", name));
5066     return Error(/*loc=*/result.first->second.second,
5067                  "instruction previously defined here");
5068   }
5069   return true;
5070 }
5071 
AddComputation(const std::string & name,HloComputation * computation,LocTy name_loc)5072 bool HloParserImpl::AddComputation(const std::string& name,
5073                                    HloComputation* computation,
5074                                    LocTy name_loc) {
5075   auto result = computation_pool_.insert({name, {computation, name_loc}});
5076   if (!result.second) {
5077     Error(name_loc, StrCat("computation already exists: ", name));
5078     return Error(/*loc=*/result.first->second.second,
5079                  "computation previously defined here");
5080   }
5081   return true;
5082 }
5083 
ParseShapeOnly()5084 StatusOr<Shape> HloParserImpl::ParseShapeOnly() {
5085   lexer_.Lex();
5086   Shape shape;
5087   if (!ParseShape(&shape)) {
5088     return InvalidArgument("Syntax error:\n%s", GetError());
5089   }
5090   if (lexer_.GetKind() != TokKind::kEof) {
5091     return InvalidArgument("Syntax error:\nExtra content after shape");
5092   }
5093   return shape;
5094 }
5095 
ParseShardingOnly()5096 StatusOr<HloSharding> HloParserImpl::ParseShardingOnly() {
5097   lexer_.Lex();
5098   OpSharding op_sharding;
5099   if (!ParseSharding(&op_sharding)) {
5100     return InvalidArgument("Syntax error:\n%s", GetError());
5101   }
5102   if (lexer_.GetKind() != TokKind::kEof) {
5103     return InvalidArgument("Syntax error:\nExtra content after sharding");
5104   }
5105   return HloSharding::FromProto(op_sharding);
5106 }
5107 
ParseFrontendAttributesOnly()5108 StatusOr<FrontendAttributes> HloParserImpl::ParseFrontendAttributesOnly() {
5109   lexer_.Lex();
5110   FrontendAttributes attributes;
5111   if (!ParseFrontendAttributes(&attributes)) {
5112     return InvalidArgument("Syntax error:\n%s", GetError());
5113   }
5114   if (lexer_.GetKind() != TokKind::kEof) {
5115     return InvalidArgument(
5116         "Syntax error:\nExtra content after frontend attributes");
5117   }
5118   return attributes;
5119 }
5120 
ParseParameterReplicationOnly()5121 StatusOr<std::vector<bool>> HloParserImpl::ParseParameterReplicationOnly() {
5122   lexer_.Lex();
5123   ParameterReplication parameter_replication;
5124   if (!ParseParameterReplication(&parameter_replication)) {
5125     return InvalidArgument("Syntax error:\n%s", GetError());
5126   }
5127   if (lexer_.GetKind() != TokKind::kEof) {
5128     return InvalidArgument(
5129         "Syntax error:\nExtra content after parameter replication");
5130   }
5131   return std::vector<bool>(
5132       parameter_replication.replicated_at_leaf_buffers().begin(),
5133       parameter_replication.replicated_at_leaf_buffers().end());
5134 }
5135 
ParseReplicaGroupsOnly()5136 StatusOr<std::vector<ReplicaGroup>> HloParserImpl::ParseReplicaGroupsOnly() {
5137   lexer_.Lex();
5138   std::vector<ReplicaGroup> replica_groups;
5139   if (!ParseReplicaGroupsOnly(&replica_groups)) {
5140     return InvalidArgument("Syntax error:\n%s", GetError());
5141   }
5142   if (lexer_.GetKind() != TokKind::kEof) {
5143     return InvalidArgument("Syntax error:\nExtra content after replica groups");
5144   }
5145   return replica_groups;
5146 }
5147 
ParseWindowOnly()5148 StatusOr<Window> HloParserImpl::ParseWindowOnly() {
5149   lexer_.Lex();
5150   Window window;
5151   if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
5152     return InvalidArgument("Syntax error:\n%s", GetError());
5153   }
5154   if (lexer_.GetKind() != TokKind::kEof) {
5155     return InvalidArgument("Syntax error:\nExtra content after window");
5156   }
5157   return window;
5158 }
5159 
5160 StatusOr<ConvolutionDimensionNumbers>
ParseConvolutionDimensionNumbersOnly()5161 HloParserImpl::ParseConvolutionDimensionNumbersOnly() {
5162   lexer_.Lex();
5163   ConvolutionDimensionNumbers dnums;
5164   if (!ParseConvolutionDimensionNumbers(&dnums)) {
5165     return InvalidArgument("Syntax error:\n%s", GetError());
5166   }
5167   if (lexer_.GetKind() != TokKind::kEof) {
5168     return InvalidArgument(
5169         "Syntax error:\nExtra content after convolution dnums");
5170   }
5171   return dnums;
5172 }
5173 
ParsePaddingConfigOnly()5174 StatusOr<PaddingConfig> HloParserImpl::ParsePaddingConfigOnly() {
5175   lexer_.Lex();
5176   PaddingConfig padding_config;
5177   if (!ParsePaddingConfig(&padding_config)) {
5178     return InvalidArgument("Syntax error:\n%s", GetError());
5179   }
5180   if (lexer_.GetKind() != TokKind::kEof) {
5181     return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
5182   }
5183   return padding_config;
5184 }
5185 
ParseSingleInstruction(HloModule * module)5186 bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
5187   if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
5188     LOG(FATAL) << "Parser state is not clean. Please do not call any other "
5189                   "methods before calling ParseSingleInstruction.";
5190   }
5191   HloComputation::Builder builder(module->name());
5192 
5193   // The missing instruction hook we register creates the shaped instruction on
5194   // the fly as a parameter and returns it.
5195   int64 parameter_count = 0;
5196   create_missing_instruction_ =
5197       [this, &builder, &parameter_count](
5198           const std::string& name,
5199           const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
5200     std::string new_name = name.empty() ? StrCat("_", parameter_count) : name;
5201     HloInstruction* parameter = builder.AddInstruction(
5202         HloInstruction::CreateParameter(parameter_count++, shape, new_name));
5203     current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
5204     return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
5205   };
5206 
5207   // Parse the instruction with the registered hook.
5208   Scope scope(&scoped_name_tables_);
5209   if (CanBeShape()) {
5210     // This means that the instruction's left-hand side is probably omitted,
5211     // e.g.
5212     //
5213     //  f32[10] fusion(...), calls={...}
5214     if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
5215       return false;
5216     }
5217   } else {
5218     // This means that the instruction's left-hand side might exist, e.g.
5219     //
5220     //  foo = f32[10] fusion(...), calls={...}
5221     std::string root_name;
5222     if (!ParseInstruction(&builder, &root_name)) {
5223       return false;
5224     }
5225   }
5226 
5227   if (lexer_.GetKind() != TokKind::kEof) {
5228     Error(
5229         lexer_.GetLoc(),
5230         "Syntax error:\nExpected eof after parsing single instruction.  Did "
5231         "you mean to write an HLO module and forget the \"HloModule\" header?");
5232     return false;
5233   }
5234 
5235   module->AddEntryComputation(builder.Build());
5236   for (auto& comp : computations_) {
5237     module->AddEmbeddedComputation(std::move(comp));
5238   }
5239   TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
5240   return true;
5241 }
5242 
5243 }  // namespace
5244 
ParseAndReturnUnverifiedModule(absl::string_view str,const HloModuleConfig & config)5245 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5246     absl::string_view str, const HloModuleConfig& config) {
5247   auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
5248   HloParserImpl parser(str);
5249   TF_RETURN_IF_ERROR(parser.Run(module.get()));
5250   return std::move(module);
5251 }
5252 
ParseAndReturnUnverifiedModule(absl::string_view str)5253 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5254     absl::string_view str) {
5255   return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
5256 }
5257 
ParseSharding(absl::string_view str)5258 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
5259   HloParserImpl parser(str);
5260   return parser.ParseShardingOnly();
5261 }
5262 
ParseFrontendAttributes(absl::string_view str)5263 StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
5264   HloParserImpl parser(str);
5265   return parser.ParseFrontendAttributesOnly();
5266 }
5267 
ParseParameterReplication(absl::string_view str)5268 StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
5269   HloParserImpl parser(str);
5270   return parser.ParseParameterReplicationOnly();
5271 }
5272 
ParseReplicaGroupsOnly(absl::string_view str)5273 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
5274     absl::string_view str) {
5275   HloParserImpl parser(str);
5276   return parser.ParseReplicaGroupsOnly();
5277 }
5278 
ParseWindow(absl::string_view str)5279 StatusOr<Window> ParseWindow(absl::string_view str) {
5280   HloParserImpl parser(str);
5281   return parser.ParseWindowOnly();
5282 }
5283 
ParseConvolutionDimensionNumbers(absl::string_view str)5284 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
5285     absl::string_view str) {
5286   HloParserImpl parser(str);
5287   return parser.ParseConvolutionDimensionNumbersOnly();
5288 }
5289 
ParsePaddingConfig(absl::string_view str)5290 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
5291   HloParserImpl parser(str);
5292   return parser.ParsePaddingConfigOnly();
5293 }
5294 
ParseShape(absl::string_view str)5295 StatusOr<Shape> ParseShape(absl::string_view str) {
5296   HloParserImpl parser(str);
5297   return parser.ParseShapeOnly();
5298 }
5299 
CreateHloParserForTests(absl::string_view str)5300 std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
5301     absl::string_view str) {
5302   return absl::make_unique<HloParserImpl>(str);
5303 }
5304 
5305 }  // namespace xla
5306