• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/base/casts.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "absl/types/variant.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/literal_util.h"
41 #include "tensorflow/compiler/xla/primitive_util.h"
42 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
43 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
44 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
45 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
46 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
47 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
50 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
51 #include "tensorflow/compiler/xla/service/shape_inference.h"
52 #include "tensorflow/compiler/xla/shape_util.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/lib/gtl/map_util.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 
59 namespace xla {
60 
61 namespace {
62 
63 using absl::nullopt;
64 using absl::optional;
65 using absl::StrAppend;
66 using absl::StrCat;
67 using absl::StrFormat;
68 using absl::StrJoin;
69 
70 // Creates and returns a schedule created using the order of the instructions in
71 // the HloComputation::instructions() vectors in the module.
ScheduleFromInstructionOrder(HloModule * module)72 HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
73   HloSchedule schedule(module);
74   for (HloComputation* computation : module->computations()) {
75     if (!computation->IsFusionComputation()) {
76       for (HloInstruction* instruction : computation->instructions()) {
77         schedule.GetOrCreateSequence(computation).push_back(instruction);
78       }
79     }
80   }
81   return schedule;
82 }
83 
CanInferShape(HloOpcode code)84 bool CanInferShape(HloOpcode code) {
85   switch (code) {
86     case HloOpcode::kAbs:
87     case HloOpcode::kAdd:
88     case HloOpcode::kAddDependency:
89     case HloOpcode::kAfterAll:
90     case HloOpcode::kAtan2:
91     case HloOpcode::kBatchNormGrad:
92     case HloOpcode::kBatchNormInference:
93     case HloOpcode::kBatchNormTraining:
94     case HloOpcode::kBroadcast:
95     case HloOpcode::kCall:
96     case HloOpcode::kCeil:
97     case HloOpcode::kCholesky:
98     case HloOpcode::kClamp:
99     case HloOpcode::kClz:
100     case HloOpcode::kCompare:
101     case HloOpcode::kComplex:
102     case HloOpcode::kConcatenate:
103     case HloOpcode::kConditional:
104     case HloOpcode::kConvolution:
105     case HloOpcode::kCopy:
106     case HloOpcode::kCos:
107     case HloOpcode::kDivide:
108     case HloOpcode::kDomain:
109     case HloOpcode::kDot:
110     case HloOpcode::kExp:
111     case HloOpcode::kExpm1:
112     case HloOpcode::kFft:
113     case HloOpcode::kFloor:
114     case HloOpcode::kGather:
115     case HloOpcode::kGetDimensionSize:
116     case HloOpcode::kSetDimensionSize:
117     case HloOpcode::kGetTupleElement:
118     case HloOpcode::kImag:
119     case HloOpcode::kIsFinite:
120     case HloOpcode::kLog:
121     case HloOpcode::kLog1p:
122     case HloOpcode::kLogistic:
123     case HloOpcode::kAnd:
124     case HloOpcode::kNot:
125     case HloOpcode::kOr:
126     case HloOpcode::kXor:
127     case HloOpcode::kMap:
128     case HloOpcode::kMaximum:
129     case HloOpcode::kMinimum:
130     case HloOpcode::kMultiply:
131     case HloOpcode::kNegate:
132     case HloOpcode::kPad:
133     case HloOpcode::kPartitionId:
134     case HloOpcode::kPopulationCount:
135     case HloOpcode::kPower:
136     case HloOpcode::kReal:
137     case HloOpcode::kReduce:
138     case HloOpcode::kRemainder:
139     case HloOpcode::kReplicaId:
140     case HloOpcode::kReverse:
141     case HloOpcode::kRoundNearestAfz:
142     case HloOpcode::kRsqrt:
143     case HloOpcode::kScatter:
144     case HloOpcode::kSelect:
145     case HloOpcode::kShiftLeft:
146     case HloOpcode::kShiftRightArithmetic:
147     case HloOpcode::kShiftRightLogical:
148     case HloOpcode::kSign:
149     case HloOpcode::kSin:
150     case HloOpcode::kSqrt:
151     case HloOpcode::kCbrt:
152     case HloOpcode::kReduceWindow:
153     case HloOpcode::kSelectAndScatter:
154     case HloOpcode::kSort:
155     case HloOpcode::kSubtract:
156     case HloOpcode::kTanh:
157     case HloOpcode::kTrace:
158     case HloOpcode::kTranspose:
159     case HloOpcode::kTriangularSolve:
160     case HloOpcode::kTuple:
161     case HloOpcode::kTupleSelect:
162     case HloOpcode::kWhile:
163       return true;
164     // Technically the following ops do not require an explicit result shape,
165     // but we made it so that we always write the shapes explicitly.
166     case HloOpcode::kAllGather:
167     case HloOpcode::kAllGatherStart:
168     case HloOpcode::kAllGatherDone:
169     case HloOpcode::kAllReduce:
170     case HloOpcode::kAllReduceStart:
171     case HloOpcode::kAllReduceDone:
172     case HloOpcode::kAllToAll:
173     case HloOpcode::kCollectivePermute:
174     case HloOpcode::kCollectivePermuteStart:
175     case HloOpcode::kCollectivePermuteDone:
176     case HloOpcode::kCopyDone:
177     case HloOpcode::kCopyStart:
178     case HloOpcode::kDynamicReshape:
179     case HloOpcode::kDynamicSlice:
180     case HloOpcode::kDynamicUpdateSlice:
181     case HloOpcode::kRecv:
182     case HloOpcode::kRecvDone:
183     case HloOpcode::kReduceScatter:
184     case HloOpcode::kSend:
185     case HloOpcode::kSendDone:
186     case HloOpcode::kSlice:
187     // The following ops require an explicit result shape.
188     case HloOpcode::kBitcast:
189     case HloOpcode::kBitcastConvert:
190     case HloOpcode::kConstant:
191     case HloOpcode::kConvert:
192     case HloOpcode::kCustomCall:
193     case HloOpcode::kFusion:
194     case HloOpcode::kInfeed:
195     case HloOpcode::kIota:
196     case HloOpcode::kOutfeed:
197     case HloOpcode::kParameter:
198     case HloOpcode::kReducePrecision:
199     case HloOpcode::kReshape:
200     case HloOpcode::kRng:
201     case HloOpcode::kRngBitGenerator:
202     case HloOpcode::kRngGetAndUpdateState:
203       return false;
204   }
205 }
206 
207 // Parser for the HloModule::ToString() format text.
208 class HloParserImpl : public HloParser {
209  public:
210   using LocTy = HloLexer::LocTy;
211 
HloParserImpl(absl::string_view str)212   explicit HloParserImpl(absl::string_view str) : lexer_(str) {}
213 
214   // Runs the parser and constructs the resulting HLO in the given (empty)
215   // HloModule. Returns the error status in case an error occurred.
216   Status Run(HloModule* module) override;
217 
218   // Returns the error information.
GetError() const219   std::string GetError() const { return StrJoin(error_, "\n"); }
220 
221   // Stand alone parsing utils for various aggregate data types.
222   StatusOr<Shape> ParseShapeOnly();
223   StatusOr<HloSharding> ParseShardingOnly();
224   StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
225   StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
226   StatusOr<Window> ParseWindowOnly();
227   StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
228   StatusOr<PaddingConfig> ParsePaddingConfigOnly();
229   StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
230 
231  private:
232   using InstrNameTable =
233       absl::flat_hash_map<std::string, std::pair<HloInstruction*, LocTy>>;
234 
235   // Returns the map from the instruction name to the instruction itself and its
236   // location in the current scope.
current_name_table()237   InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
238 
239   // Locates an instruction with the given name in the current_name_table() or
240   // returns nullptr.
241   //
242   // When the name is not found or name is empty, if create_missing_instruction_
243   // hook is registered and a "shape" is provided, the hook will be called to
244   // create an instruction. This is useful when we reify parameters as they're
245   // resolved; i.e. for ParseSingleInstruction.
246   std::pair<HloInstruction*, LocTy>* FindInstruction(
247       const std::string& name, const optional<Shape>& shape = nullopt);
248 
249   // Parse a single instruction worth of text.
250   bool ParseSingleInstruction(HloModule* module);
251 
252   // Parses a module, returning false if an error occurred.
253   bool ParseHloModule(HloModule* module);
254 
255   bool ParseComputations(HloModule* module);
256   bool ParseComputation(HloComputation** entry_computation);
257   bool ParseInstructionList(HloComputation** computation,
258                             const std::string& computation_name);
259   bool ParseInstruction(HloComputation::Builder* builder,
260                         std::string* root_name);
261   bool ParseInstructionRhs(HloComputation::Builder* builder,
262                            const std::string& name, LocTy name_loc);
263   bool ParseControlPredecessors(HloInstruction* instruction);
264   bool ParseLiteral(Literal* literal);
265   bool ParseLiteral(Literal* literal, const Shape& shape);
266   bool ParseTupleLiteral(Literal* literal, const Shape& shape);
267   bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
268   bool ParseDenseLiteral(Literal* literal, const Shape& shape);
269 
270   // Sets the sub-value of literal at the given linear index to the
271   // given value. If the literal is dense, it must have the default layout.
272   //
273   // `loc` should be the source location of the value.
274   bool SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
275                          Literal* literal);
276   bool SetValueInLiteral(LocTy loc, double value, int64_t index,
277                          Literal* literal);
278   bool SetValueInLiteral(LocTy loc, bool value, int64_t index,
279                          Literal* literal);
280   bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64_t index,
281                          Literal* literal);
282   // `loc` should be the source location of the value.
283   template <typename LiteralNativeT, typename ParsedElemT>
284   bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64_t index,
285                                Literal* literal);
286 
287   // Checks whether the given value is within the range of LiteralNativeT.
288   // `loc` should be the source location of the value.
289   template <typename LiteralNativeT, typename ParsedElemT>
290   bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
291   template <typename LiteralNativeT>
292   bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
293 
294   bool ParseOperands(std::vector<HloInstruction*>* operands);
295   // Fills parsed operands into 'operands' and expects a certain number of
296   // operands.
297   bool ParseOperands(std::vector<HloInstruction*>* operands,
298                      const int expected_size);
299 
300   // Describes the start, limit, and stride on every dimension of the operand
301   // being sliced.
302   struct SliceRanges {
303     std::vector<int64> starts;
304     std::vector<int64> limits;
305     std::vector<int64> strides;
306   };
307 
308   // The data parsed for the kDomain instruction.
309   struct DomainData {
310     std::unique_ptr<DomainMetadata> entry_metadata;
311     std::unique_ptr<DomainMetadata> exit_metadata;
312   };
313 
314   // Types of attributes.
315   enum class AttrTy {
316     kBool,
317     kInt64,
318     kInt32,
319     kFloat,
320     kString,
321     kLiteral,
322     kBracedInt64List,
323     kBracedInt64ListList,
324     kHloComputation,
325     kBracedHloComputationList,
326     kFftType,
327     kPaddingType,
328     kComparisonDirection,
329     kComparisonType,
330     kWindow,
331     kConvolutionDimensionNumbers,
332     kSharding,
333     kFrontendAttributes,
334     kParameterReplication,
335     kInstructionList,
336     kSliceRanges,
337     kPaddingConfig,
338     kMetadata,
339     kFusionKind,
340     kDistribution,
341     kDomain,
342     kPrecisionList,
343     kShape,
344     kShapeList,
345     kEnum,
346     kRandomAlgorithm,
347     kAliasing,
348     kInstructionAliasing,
349     kCustomCallSchedule,
350     kCustomCallApiVersion,
351   };
352 
353   struct AttrConfig {
354     bool required;     // whether it's required or optional
355     AttrTy attr_type;  // what type it is
356     void* result;      // where to store the parsed result.
357   };
358 
359   // attributes ::= (',' attribute)*
360   //
361   // Parses attributes given names and configs of the attributes. Each parsed
362   // result is passed back through the result pointer in corresponding
363   // AttrConfig. Note that the result pointer must point to a optional<T> typed
364   // variable which outlives this function. Returns false on error. You should
365   // not use the any of the results if this function failed.
366   //
367   // Example usage:
368   //
369   //  absl::flat_hash_map<std::string, AttrConfig> attrs;
370   //  optional<int64> foo;
371   //  attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
372   //  optional<Window> bar;
373   //  attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
374   //  if (!ParseAttributes(attrs)) {
375   //    return false; // Do not use 'foo' 'bar' if failed.
376   //  }
377   //  // Do something with 'bar'.
378   //  if (foo) { // If attr foo is seen, do something with 'foo'. }
379   //
380   bool ParseAttributes(
381       const absl::flat_hash_map<std::string, AttrConfig>& attrs);
382 
383   // sub_attributes ::= '{' (','? attribute)* '}'
384   //
385   // Usage is the same as ParseAttributes. See immediately above.
386   bool ParseSubAttributes(
387       const absl::flat_hash_map<std::string, AttrConfig>& attrs);
388 
389   // Parses one attribute. If it has already been seen, return error. Returns
390   // true and adds to seen_attrs on success.
391   //
392   // Do not call this except in ParseAttributes or ParseSubAttributes.
393   bool ParseAttributeHelper(
394       const absl::flat_hash_map<std::string, AttrConfig>& attrs,
395       absl::flat_hash_set<std::string>* seen_attrs);
396 
397   // Copy attributes from `attrs` to `message`, unless the attribute name is in
398   // `non_proto_attrs`.
399   bool CopyAttributeToProtoMessage(
400       absl::flat_hash_set<std::string> non_proto_attrs,
401       const absl::flat_hash_map<std::string, AttrConfig>& attrs,
402       tensorflow::protobuf::Message* message);
403 
404   // Parses an attribute string into a protocol buffer `message`.
405   // Since proto3 has no notion of mandatory fields, `required_attrs` gives the
406   // set of mandatory attributes.
407   // `non_proto_attrs` specifies attributes that are not written to the proto,
408   // but added to the HloInstruction.
409   bool ParseAttributesAsProtoMessage(
410       const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
411       tensorflow::protobuf::Message* message);
412 
413   // Parses a name and finds the corresponding hlo computation.
414   bool ParseComputationName(HloComputation** value);
415   // Parses a list of names and finds the corresponding hlo instructions.
416   bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
417   // Pass expect_outer_curlies == true when parsing a Window in the context of a
418   // larger computation.  Pass false when parsing a stand-alone Window string.
419   bool ParseWindow(Window* window, bool expect_outer_curlies);
420   bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
421   bool ParsePaddingConfig(PaddingConfig* padding);
422   bool ParseMetadata(OpMetadata* metadata);
423   bool ParseSingleOrListMetadata(
424       tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata);
425   bool ParseOpShardingType(OpSharding::Type* type);
426   bool ParseListShardingType(std::vector<OpSharding::Type>* types);
427   bool ParseSharding(OpSharding* sharding);
428   bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
429   bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
430   bool ParseParameterReplication(ParameterReplication* parameter_replication);
431   bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
432 
433   // Parses the metadata behind a kDOmain instruction.
434   bool ParseDomain(DomainData* domain);
435 
436   // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
437   bool ParseDxD(const std::string& name, std::vector<int64>* result);
438   // Parses window's pad sub-attribute, e.g., pad=0_0x3x3.
439   bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
440 
441   bool ParseSliceRanges(SliceRanges* result);
442   bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
443   bool ParseHloComputation(HloComputation** result);
444   bool ParseHloComputationList(std::vector<HloComputation*>* result);
445   bool ParseShapeList(std::vector<Shape>* result);
446   bool ParseInt64List(const TokKind start, const TokKind end,
447                       const TokKind delim, std::vector<int64>* result);
448   bool ParseInt64ListList(const TokKind start, const TokKind end,
449                           const TokKind delim,
450                           std::vector<std::vector<int64>>* result);
451   // 'parse_and_add_item' is an lambda to parse an element in the list and add
452   // the parsed element to the result. It's supposed to capture the result.
453   bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
454                  const std::function<bool()>& parse_and_add_item);
455 
456   bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
457   bool ParseParamList();
458   bool ParseName(std::string* result);
459   bool ParseAttributeName(std::string* result);
460   bool ParseString(std::string* result);
461   bool ParseDimensionSizes(std::vector<int64>* dimension_sizes,
462                            std::vector<bool>* dynamic_dimensions);
463   bool ParseShape(Shape* result);
464   bool ParseLayout(Layout* layout);
465   bool ParseLayoutIntAttribute(int64* attr_value,
466                                absl::string_view attr_description);
467   bool ParseTiles(std::vector<Tile>* tiles);
468   bool ParseOpcode(HloOpcode* result);
469   bool ParseFftType(FftType* result);
470   bool ParsePaddingType(PaddingType* result);
471   bool ParseComparisonDirection(ComparisonDirection* result);
472   bool ParseComparisonType(Comparison::Type* result);
473   bool ParseFusionKind(HloInstruction::FusionKind* result);
474   bool ParseRandomDistribution(RandomDistribution* result);
475   bool ParseRandomAlgorithm(RandomAlgorithm* result);
476   bool ParsePrecision(PrecisionConfig::Precision* result);
477   bool ParseInt64(int64* result);
478   bool ParseDouble(double* result);
479   bool ParseComplex(std::complex<double>* result);
480   bool ParseBool(bool* result);
481   bool ParseToken(TokKind kind, const std::string& msg);
482 
483   using AliasingData =
484       absl::flat_hash_map<ShapeIndex, HloInputOutputAliasConfig::Alias>;
485 
486   // Parses the aliasing information from string `s`, returns `false` if it
487   // fails.
488   bool ParseAliasing(AliasingData* data);
489 
490   // Parses the per-instruction aliasing information from string `s`, returns
491   // `false` if it fails.
492   bool ParseInstructionOutputOperandAliasing(
493       std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
494           aliasing_output_operand_pairs);
495 
496   bool ParseCustomCallSchedule(CustomCallSchedule* result);
497   bool ParseCustomCallApiVersion(CustomCallApiVersion* result);
498   bool ParseShapeIndex(ShapeIndex* out);
499 
500   // Returns true if the current token is the beginning of a shape.
501   bool CanBeShape();
502   // Returns true if the current token is the beginning of a
503   // param_list_to_shape.
504   bool CanBeParamListToShape();
505 
506   // Logs the current parsing line and the given message. Always returns false.
507   bool TokenError(absl::string_view msg);
508   bool Error(LocTy loc, absl::string_view msg);
509 
510   // If the current token is 'kind', eats it (i.e. lexes the next token) and
511   // returns true.
512   bool EatIfPresent(TokKind kind);
513 
514   // Adds the instruction to the pool. Returns false and emits an error if the
515   // instruction already exists.
516   bool AddInstruction(const std::string& name, HloInstruction* instruction,
517                       LocTy name_loc);
518   // Adds the computation to the pool. Returns false and emits an error if the
519   // computation already exists.
520   bool AddComputation(const std::string& name, HloComputation* computation,
521                       LocTy name_loc);
522 
523   HloLexer lexer_;
524 
525   // A stack for the instruction names. The top of the stack stores the
526   // instruction name table for the current scope.
527   //
528   // A instruction's name is unique among its scope (i.e. its parent
529   // computation), but it's not necessarily unique among all computations in the
530   // module. When there are multiple levels of nested computations, the same
531   // name could appear in both an outer computation and an inner computation. So
532   // we need a stack to make sure a name is only visible within its scope,
533   std::vector<InstrNameTable> scoped_name_tables_;
534 
535   // A helper class which pushes and pops to an InstrNameTable stack via RAII.
536   class Scope {
537    public:
Scope(std::vector<InstrNameTable> * scoped_name_tables)538     explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
539         : scoped_name_tables_(scoped_name_tables) {
540       scoped_name_tables_->emplace_back();
541     }
~Scope()542     ~Scope() { scoped_name_tables_->pop_back(); }
543 
544    private:
545     std::vector<InstrNameTable>* scoped_name_tables_;
546   };
547 
548   // Map from the computation name to the computation itself and its location.
549   absl::flat_hash_map<std::string, std::pair<HloComputation*, LocTy>>
550       computation_pool_;
551 
552   std::vector<std::unique_ptr<HloComputation>> computations_;
553   std::vector<std::string> error_;
554 
555   // When an operand name cannot be resolved, this function is called to create
556   // a parameter instruction with the given name and shape. It registers the
557   // name, instruction, and a placeholder location in the name table. It returns
558   // the newly-created instruction and the placeholder location. If `name` is
559   // empty, this should create the parameter with a generated name. This is
560   // supposed to be set and used only in ParseSingleInstruction.
561   std::function<std::pair<HloInstruction*, LocTy>*(const std::string& name,
562                                                    const Shape& shape)>
563       create_missing_instruction_;
564 };
565 
SplitToInt64s(absl::string_view s,char delim,std::vector<int64> * out)566 bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
567   for (const auto& split : absl::StrSplit(s, delim)) {
568     int64_t val;
569     if (!absl::SimpleAtoi(split, &val)) {
570       return false;
571     }
572     out->push_back(val);
573   }
574   return true;
575 }
576 
577 // Creates replica groups from the provided nested array. groups[i] represents
578 // the replica ids for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64>> groups)579 std::vector<ReplicaGroup> CreateReplicaGroups(
580     absl::Span<const std::vector<int64>> groups) {
581   std::vector<ReplicaGroup> replica_groups;
582   absl::c_transform(groups, std::back_inserter(replica_groups),
583                     [](const std::vector<int64>& ids) {
584                       ReplicaGroup group;
585                       *group.mutable_replica_ids() = {ids.begin(), ids.end()};
586                       return group;
587                     });
588   return replica_groups;
589 }
590 
Error(LocTy loc,absl::string_view msg)591 bool HloParserImpl::Error(LocTy loc, absl::string_view msg) {
592   auto line_col = lexer_.GetLineAndColumn(loc);
593   const unsigned line = line_col.first;
594   const unsigned col = line_col.second;
595   std::vector<std::string> error_lines;
596   error_lines.push_back(
597       StrCat("was parsing ", line, ":", col, ": error: ", msg));
598   error_lines.emplace_back(lexer_.GetLine(loc));
599   error_lines.push_back(col == 0 ? "" : StrCat(std::string(col - 1, ' '), "^"));
600 
601   error_.push_back(StrJoin(error_lines, "\n"));
602   VLOG(1) << "Error: " << error_.back();
603   return false;
604 }
605 
TokenError(absl::string_view msg)606 bool HloParserImpl::TokenError(absl::string_view msg) {
607   return Error(lexer_.GetLoc(), msg);
608 }
609 
Run(HloModule * module)610 Status HloParserImpl::Run(HloModule* module) {
611   lexer_.Lex();
612   if (lexer_.GetKind() == TokKind::kw_HloModule) {
613     // This means that the text contains a full HLO module.
614     if (!ParseHloModule(module)) {
615       return InvalidArgument(
616           "Syntax error when trying to parse the text as a HloModule:\n%s",
617           GetError());
618     }
619     return Status::OK();
620   }
621   // This means that the text is a single HLO instruction.
622   if (!ParseSingleInstruction(module)) {
623     return InvalidArgument(
624         "Syntax error when trying to parse the text as a single "
625         "HloInstruction:\n%s",
626         GetError());
627   }
628   return Status::OK();
629 }
630 
631 std::pair<HloInstruction*, HloParserImpl::LocTy>*
FindInstruction(const std::string & name,const optional<Shape> & shape)632 HloParserImpl::FindInstruction(const std::string& name,
633                                const optional<Shape>& shape) {
634   std::pair<HloInstruction*, LocTy>* instr = nullptr;
635   if (!name.empty()) {
636     instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
637   }
638 
639   // Potentially call the missing instruction hook.
640   if (instr == nullptr && create_missing_instruction_ != nullptr &&
641       scoped_name_tables_.size() == 1) {
642     if (!shape.has_value()) {
643       Error(lexer_.GetLoc(),
644             "Operand had no shape in HLO text; cannot create parameter for "
645             "single-instruction module.");
646       return nullptr;
647     }
648     return create_missing_instruction_(name, *shape);
649   }
650 
651   if (instr != nullptr && shape.has_value() &&
652       !ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
653     Error(
654         lexer_.GetLoc(),
655         StrCat("The declared operand shape ",
656                ShapeUtil::HumanStringWithLayout(shape.value()),
657                " is not compatible with the shape of the operand instruction ",
658                ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
659     return nullptr;
660   }
661 
662   return instr;
663 }
664 
ParseShapeIndex(ShapeIndex * out)665 bool HloParserImpl::ParseShapeIndex(ShapeIndex* out) {
666   if (!ParseToken(TokKind::kLbrace, "Expects '{' at the start of ShapeIndex")) {
667     return false;
668   }
669 
670   std::vector<int64> idxs;
671   while (lexer_.GetKind() != TokKind::kRbrace) {
672     int64_t idx;
673     if (!ParseInt64(&idx)) {
674       return false;
675     }
676     idxs.push_back(idx);
677     if (!EatIfPresent(TokKind::kComma)) {
678       break;
679     }
680   }
681   if (!ParseToken(TokKind::kRbrace, "Expects '}' at the end of ShapeIndex")) {
682     return false;
683   }
684   *out = ShapeIndex(idxs.begin(), idxs.end());
685   return true;
686 }
687 
ParseAliasing(AliasingData * data)688 bool HloParserImpl::ParseAliasing(AliasingData* data) {
689   if (!ParseToken(TokKind::kLbrace,
690                   "Expects '{' at the start of aliasing description")) {
691     return false;
692   }
693 
694   while (lexer_.GetKind() != TokKind::kRbrace) {
695     ShapeIndex out;
696     if (!ParseShapeIndex(&out)) {
697       return false;
698     }
699     std::string errmsg =
700         "Expected format: <output_shape_index>: (<input_param>, "
701         "<input_param_shape_index>) OR <output_shape_index>: <input_param>";
702     if (!ParseToken(TokKind::kColon, errmsg)) {
703       return false;
704     }
705 
706     if (!ParseToken(TokKind::kLparen, errmsg)) {
707       return false;
708     }
709     int64_t param_num;
710     ParseInt64(&param_num);
711     if (!ParseToken(TokKind::kComma, errmsg)) {
712       return false;
713     }
714     ShapeIndex param_idx;
715     if (!ParseShapeIndex(&param_idx)) {
716       return false;
717     }
718 
719     HloInputOutputAliasConfig::AliasKind alias_kind =
720         HloInputOutputAliasConfig::kMayAlias;
721     if (EatIfPresent(TokKind::kComma)) {
722       std::string type;
723       ParseName(&type);
724       if (type == "must-alias") {
725         alias_kind = HloInputOutputAliasConfig::kMustAlias;
726       } else if (type == "may-alias") {
727         alias_kind = HloInputOutputAliasConfig::kMayAlias;
728       } else {
729         return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
730       }
731     }
732 
733     data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
734                   std::forward_as_tuple(param_num, param_idx, alias_kind));
735     if (!ParseToken(TokKind::kRparen, errmsg)) {
736       return false;
737     }
738 
739     if (!EatIfPresent(TokKind::kComma)) {
740       break;
741     }
742   }
743   if (!ParseToken(TokKind::kRbrace,
744                   "Expects '}' at the end of aliasing description")) {
745     return false;
746   }
747   return true;
748 }
749 
ParseInstructionOutputOperandAliasing(std::vector<std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> * aliasing_output_operand_pairs)750 bool HloParserImpl::ParseInstructionOutputOperandAliasing(
751     std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
752         aliasing_output_operand_pairs) {
753   if (!ParseToken(
754           TokKind::kLbrace,
755           "Expects '{' at the start of instruction aliasing description")) {
756     return false;
757   }
758 
759   while (lexer_.GetKind() != TokKind::kRbrace) {
760     ShapeIndex out;
761     if (!ParseShapeIndex(&out)) {
762       return false;
763     }
764     std::string errmsg =
765         "Expected format: <output_shape_index>: (<operand_index>, "
766         "<operand_shape_index>)";
767     if (!ParseToken(TokKind::kColon, errmsg)) {
768       return false;
769     }
770 
771     if (!ParseToken(TokKind::kLparen, errmsg)) {
772       return false;
773     }
774     int64_t operand_index;
775     ParseInt64(&operand_index);
776     if (!ParseToken(TokKind::kComma, errmsg)) {
777       return false;
778     }
779     ShapeIndex operand_shape_index;
780     if (!ParseShapeIndex(&operand_shape_index)) {
781       return false;
782     }
783 
784     aliasing_output_operand_pairs->emplace_back(
785         out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
786     if (!ParseToken(TokKind::kRparen, errmsg)) {
787       return false;
788     }
789 
790     if (!EatIfPresent(TokKind::kComma)) {
791       break;
792     }
793   }
794   if (!ParseToken(
795           TokKind::kRbrace,
796           "Expects '}' at the end of instruction aliasing description")) {
797     return false;
798   }
799   return true;
800 }
801 
ParseCustomCallSchedule(CustomCallSchedule * result)802 bool HloParserImpl::ParseCustomCallSchedule(CustomCallSchedule* result) {
803   VLOG(3) << "ParseCustomCallSchedule";
804   if (lexer_.GetKind() != TokKind::kIdent) {
805     return TokenError("expects custom-call schedule");
806   }
807   std::string val = lexer_.GetStrVal();
808   auto status_or_result = StringToCustomCallSchedule(val);
809   if (!status_or_result.ok()) {
810     return TokenError(
811         StrFormat("expects custom-call schedule but sees: %s, error: %s", val,
812                   status_or_result.status().error_message()));
813   }
814   *result = status_or_result.ValueOrDie();
815   lexer_.Lex();
816   return true;
817 }
818 
ParseCustomCallApiVersion(CustomCallApiVersion * result)819 bool HloParserImpl::ParseCustomCallApiVersion(CustomCallApiVersion* result) {
820   VLOG(3) << "ParseCustomCallApiVersion";
821   if (lexer_.GetKind() != TokKind::kIdent) {
822     return TokenError("expects custom-call API version");
823   }
824   std::string val = lexer_.GetStrVal();
825   auto status_or_result = StringToCustomCallApiVersion(val);
826   if (!status_or_result.ok()) {
827     return TokenError(
828         StrFormat("expects custom-call API version but sees: %s, error: %s",
829                   val, status_or_result.status().error_message()));
830   }
831   *result = status_or_result.ValueOrDie();
832   lexer_.Lex();
833   return true;
834 }
835 
836 // ::= 'HloModule' name computations
ParseHloModule(HloModule * module)837 bool HloParserImpl::ParseHloModule(HloModule* module) {
838   if (lexer_.GetKind() != TokKind::kw_HloModule) {
839     return TokenError("expects HloModule");
840   }
841   // Eat 'HloModule'
842   lexer_.Lex();
843 
844   std::string name;
845   if (!ParseName(&name)) {
846     return false;
847   }
848 
849   absl::optional<bool> is_scheduled;
850   absl::optional<AliasingData> aliasing_data;
851   absl::optional<bool> alias_passthrough_params;
852   absl::flat_hash_map<std::string, AttrConfig> attrs;
853 
854   attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
855   attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing,
856                                  &aliasing_data};
857   attrs["alias_passthrough_params"] = {/*required=*/false, AttrTy::kBool,
858                                        &alias_passthrough_params};
859   if (!ParseAttributes(attrs)) {
860     return false;
861   }
862   module->set_name(name);
863   if (!ParseComputations(module)) {
864     return false;
865   }
866 
867   if (is_scheduled.has_value() && *is_scheduled) {
868     TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
869   }
870   if (alias_passthrough_params.has_value() && *alias_passthrough_params) {
871     HloModuleConfig config = module->config();
872     config.set_alias_passthrough_params(true);
873     module->set_config(config);
874   }
875   if (aliasing_data) {
876     HloInputOutputAliasConfig alias_config(module->result_shape());
877     for (auto& p : *aliasing_data) {
878       Status st =
879           alias_config.SetUpAlias(p.first, p.second.parameter_number,
880                                   p.second.parameter_index, p.second.kind);
881       if (!st.ok()) {
882         return TokenError(st.error_message());
883       }
884     }
885     module->input_output_alias_config() = alias_config;
886   }
887 
888   return true;
889 }
890 
891 // computations ::= (computation)+
ParseComputations(HloModule * module)892 bool HloParserImpl::ParseComputations(HloModule* module) {
893   HloComputation* entry_computation = nullptr;
894   do {
895     if (!ParseComputation(&entry_computation)) {
896       return false;
897     }
898   } while (lexer_.GetKind() != TokKind::kEof);
899 
900   for (int i = 0; i < computations_.size(); i++) {
901     // If entry_computation is not nullptr, it means the computation it pointed
902     // to is marked with "ENTRY"; otherwise, no computation is marked with
903     // "ENTRY", and we use the last computation as the entry computation. We
904     // add the non-entry computations as embedded computations to the module.
905     if ((entry_computation != nullptr &&
906          computations_[i].get() != entry_computation) ||
907         (entry_computation == nullptr && i != computations_.size() - 1)) {
908       module->AddEmbeddedComputation(std::move(computations_[i]));
909       continue;
910     }
911     auto computation = module->AddEntryComputation(std::move(computations_[i]));
912     // The parameters and result layouts were set to default layout. Here we
913     // set the layouts to what the hlo text says.
914     for (int p = 0; p < computation->num_parameters(); p++) {
915       const Shape& param_shape = computation->parameter_instruction(p)->shape();
916       TF_CHECK_OK(module->mutable_entry_computation_layout()
917                       ->mutable_parameter_layout(p)
918                       ->CopyLayoutFromShape(param_shape));
919     }
920     const Shape& result_shape = computation->root_instruction()->shape();
921     TF_CHECK_OK(module->mutable_entry_computation_layout()
922                     ->mutable_result_layout()
923                     ->CopyLayoutFromShape(result_shape));
924   }
925   return true;
926 }
927 
928 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
ParseComputation(HloComputation ** entry_computation)929 bool HloParserImpl::ParseComputation(HloComputation** entry_computation) {
930   LocTy maybe_entry_loc = lexer_.GetLoc();
931   const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
932 
933   std::string name;
934   LocTy name_loc = lexer_.GetLoc();
935   if (!ParseName(&name)) {
936     return false;
937   }
938 
939   LocTy shape_loc = nullptr;
940   Shape shape;
941   if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
942     return false;
943   }
944 
945   HloComputation* computation = nullptr;
946   if (!ParseInstructionList(&computation, name)) {
947     return false;
948   }
949 
950   // If param_list_to_shape was present, check compatibility.
951   if (shape_loc != nullptr &&
952       !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
953     return Error(
954         shape_loc,
955         StrCat(
956             "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
957             ", is not compatible with that of its root instruction ",
958             computation->root_instruction()->name(), ", ",
959             ShapeUtil::HumanString(computation->root_instruction()->shape())));
960   }
961 
962   if (is_entry_computation) {
963     if (*entry_computation != nullptr) {
964       return Error(maybe_entry_loc, "expects only one ENTRY");
965     }
966     *entry_computation = computation;
967   }
968 
969   return AddComputation(name, computation, name_loc);
970 }
971 
972 // instruction_list ::= '{' instruction_list1 '}'
973 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation ** computation,const std::string & computation_name)974 bool HloParserImpl::ParseInstructionList(HloComputation** computation,
975                                          const std::string& computation_name) {
976   Scope scope(&scoped_name_tables_);
977   HloComputation::Builder builder(computation_name);
978   if (!ParseToken(TokKind::kLbrace,
979                   "expects '{' at the beginning of instruction list.")) {
980     return false;
981   }
982   std::string root_name;
983   do {
984     if (!ParseInstruction(&builder, &root_name)) {
985       return false;
986     }
987   } while (lexer_.GetKind() != TokKind::kRbrace);
988   if (!ParseToken(TokKind::kRbrace,
989                   "expects '}' at the end of instruction list.")) {
990     return false;
991   }
992   HloInstruction* root = nullptr;
993   if (!root_name.empty()) {
994     std::pair<HloInstruction*, LocTy>* root_node =
995         tensorflow::gtl::FindOrNull(current_name_table(), root_name);
996 
997     // This means some instruction was marked as ROOT but we didn't find it in
998     // the pool, which should not happen.
999     if (root_node == nullptr) {
1000       // LOG(FATAL) crashes the program by calling abort().
1001       LOG(FATAL) << "instruction " << root_name
1002                  << " was marked as ROOT but the parser has not seen it before";
1003     }
1004     root = root_node->first;
1005   }
1006 
1007   // Now root can be either an existing instruction or a nullptr. If it's a
1008   // nullptr, the implementation of Builder will set the last instruction as
1009   // the root instruction.
1010   computations_.emplace_back(builder.Build(root));
1011   *computation = computations_.back().get();
1012   return true;
1013 }
1014 
1015 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,std::string * root_name)1016 bool HloParserImpl::ParseInstruction(HloComputation::Builder* builder,
1017                                      std::string* root_name) {
1018   std::string name;
1019   LocTy maybe_root_loc = lexer_.GetLoc();
1020   bool is_root = EatIfPresent(TokKind::kw_ROOT);
1021 
1022   const LocTy name_loc = lexer_.GetLoc();
1023   if (!ParseName(&name) ||
1024       !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
1025     return false;
1026   }
1027 
1028   if (is_root) {
1029     if (!root_name->empty()) {
1030       return Error(maybe_root_loc, "one computation should have only one ROOT");
1031     }
1032     *root_name = name;
1033   }
1034 
1035   return ParseInstructionRhs(builder, name, name_loc);
1036 }
1037 
ParseInstructionRhs(HloComputation::Builder * builder,const std::string & name,LocTy name_loc)1038 bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
1039                                         const std::string& name,
1040                                         LocTy name_loc) {
1041   Shape shape;
1042   HloOpcode opcode;
1043   std::vector<HloInstruction*> operands;
1044 
1045   const bool parse_shape = CanBeShape();
1046   if ((parse_shape && !ParseShape(&shape)) || !ParseOpcode(&opcode)) {
1047     return false;
1048   }
1049   if (!parse_shape && !CanInferShape(opcode)) {
1050     return TokenError(StrFormat("cannot infer shape for opcode: %s",
1051                                 HloOpcodeString(opcode)));
1052   }
1053 
1054   // Add optional attributes. These are added to any HloInstruction type if
1055   // present.
1056   absl::flat_hash_map<std::string, AttrConfig> attrs;
1057   optional<OpSharding> sharding;
1058   optional<FrontendAttributes> frontend_attributes;
1059   attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
1060   attrs["frontend_attributes"] = {
1061       /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
1062   optional<ParameterReplication> parameter_replication;
1063   attrs["parameter_replication"] = {/*required=*/false,
1064                                     AttrTy::kParameterReplication,
1065                                     &parameter_replication};
1066   optional<std::vector<HloInstruction*>> predecessors;
1067   attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
1068                                    &predecessors};
1069   optional<OpMetadata> metadata;
1070   attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
1071 
1072   optional<std::string> backend_config;
1073   attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
1074                              &backend_config};
1075   optional<std::vector<int64>> outer_dimension_partitions;
1076   attrs["outer_dimension_partitions"] = {/*required=*/false,
1077                                          AttrTy::kBracedInt64List,
1078                                          &outer_dimension_partitions};
1079   const auto maybe_infer_shape =
1080       [&](const std::function<StatusOr<Shape>()>& infer, Shape* shape) {
1081         if (parse_shape) {
1082           return true;
1083         }
1084         auto inferred = infer();
1085         if (!inferred.ok()) {
1086           return TokenError(StrFormat(
1087               "failed to infer shape for opcode: %s, error: %s",
1088               HloOpcodeString(opcode), inferred.status().error_message()));
1089         }
1090         *shape = std::move(inferred).ValueOrDie();
1091         return true;
1092       };
1093   HloInstruction* instruction;
1094   switch (opcode) {
1095     case HloOpcode::kParameter: {
1096       int64_t parameter_number;
1097       if (!ParseToken(TokKind::kLparen,
1098                       "expects '(' before parameter number") ||
1099           !ParseInt64(&parameter_number)) {
1100         return false;
1101       }
1102       if (parameter_number < 0) {
1103         Error(lexer_.GetLoc(), "parameter number must be >= 0");
1104         return false;
1105       }
1106       if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
1107           !ParseAttributes(attrs)) {
1108         return false;
1109       }
1110       instruction = builder->AddInstruction(
1111           HloInstruction::CreateParameter(parameter_number, shape, name));
1112       break;
1113     }
1114     case HloOpcode::kConstant: {
1115       Literal literal;
1116       if (!ParseToken(TokKind::kLparen,
1117                       "expects '(' before constant literal") ||
1118           !ParseLiteral(&literal, shape) ||
1119           !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
1120           !ParseAttributes(attrs)) {
1121         return false;
1122       }
1123       instruction = builder->AddInstruction(
1124           HloInstruction::CreateConstant(std::move(literal)));
1125       break;
1126     }
1127     case HloOpcode::kIota: {
1128       optional<int64> iota_dimension;
1129       attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
1130                                  &iota_dimension};
1131       if (!ParseOperands(&operands, /*expected_size=*/0) ||
1132           !ParseAttributes(attrs)) {
1133         return false;
1134       }
1135       instruction = builder->AddInstruction(
1136           HloInstruction::CreateIota(shape, *iota_dimension));
1137       break;
1138     }
1139     // Unary ops.
1140     case HloOpcode::kAbs:
1141     case HloOpcode::kAllGatherDone:
1142     case HloOpcode::kAllReduceDone:
1143     case HloOpcode::kRoundNearestAfz:
1144     case HloOpcode::kBitcast:
1145     case HloOpcode::kCeil:
1146     case HloOpcode::kClz:
1147     case HloOpcode::kCollectivePermuteDone:
1148     case HloOpcode::kCopy:
1149     case HloOpcode::kCopyDone:
1150     case HloOpcode::kCos:
1151     case HloOpcode::kExp:
1152     case HloOpcode::kExpm1:
1153     case HloOpcode::kImag:
1154     case HloOpcode::kIsFinite:
1155     case HloOpcode::kFloor:
1156     case HloOpcode::kLog:
1157     case HloOpcode::kLog1p:
1158     case HloOpcode::kLogistic:
1159     case HloOpcode::kNot:
1160     case HloOpcode::kNegate:
1161     case HloOpcode::kPopulationCount:
1162     case HloOpcode::kReal:
1163     case HloOpcode::kRsqrt:
1164     case HloOpcode::kSign:
1165     case HloOpcode::kSin:
1166     case HloOpcode::kSqrt:
1167     case HloOpcode::kCbrt:
1168     case HloOpcode::kTanh: {
1169       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1170           !ParseAttributes(attrs)) {
1171         return false;
1172       }
1173       if (!maybe_infer_shape(
1174               [&] {
1175                 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
1176               },
1177               &shape)) {
1178         return false;
1179       }
1180       instruction = builder->AddInstruction(
1181           HloInstruction::CreateUnary(shape, opcode, operands[0]));
1182       break;
1183     }
1184     // Binary ops.
1185     case HloOpcode::kAdd:
1186     case HloOpcode::kDivide:
1187     case HloOpcode::kMultiply:
1188     case HloOpcode::kSubtract:
1189     case HloOpcode::kAtan2:
1190     case HloOpcode::kComplex:
1191     case HloOpcode::kMaximum:
1192     case HloOpcode::kMinimum:
1193     case HloOpcode::kPower:
1194     case HloOpcode::kRemainder:
1195     case HloOpcode::kAnd:
1196     case HloOpcode::kOr:
1197     case HloOpcode::kXor:
1198     case HloOpcode::kShiftLeft:
1199     case HloOpcode::kShiftRightArithmetic:
1200     case HloOpcode::kShiftRightLogical: {
1201       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1202           !ParseAttributes(attrs)) {
1203         return false;
1204       }
1205       if (!maybe_infer_shape(
1206               [&] {
1207                 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1208                                                           operands[1]);
1209               },
1210               &shape)) {
1211         return false;
1212       }
1213       instruction = builder->AddInstruction(HloInstruction::CreateBinary(
1214           shape, opcode, operands[0], operands[1]));
1215       break;
1216     }
1217     // Ternary ops.
1218     case HloOpcode::kClamp:
1219     case HloOpcode::kSelect:
1220     case HloOpcode::kTupleSelect: {
1221       if (!ParseOperands(&operands, /*expected_size=*/3) ||
1222           !ParseAttributes(attrs)) {
1223         return false;
1224       }
1225       if (!maybe_infer_shape(
1226               [&] {
1227                 return ShapeInference::InferTernaryOpShape(
1228                     opcode, operands[0], operands[1], operands[2]);
1229               },
1230               &shape)) {
1231         return false;
1232       }
1233       instruction = builder->AddInstruction(HloInstruction::CreateTernary(
1234           shape, opcode, operands[0], operands[1], operands[2]));
1235       break;
1236     }
1237     // Other supported ops.
1238     case HloOpcode::kConvert: {
1239       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1240           !ParseAttributes(attrs)) {
1241         return false;
1242       }
1243       instruction = builder->AddInstruction(
1244           HloInstruction::CreateConvert(shape, operands[0]));
1245       break;
1246     }
1247     case HloOpcode::kBitcastConvert: {
1248       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1249           !ParseAttributes(attrs)) {
1250         return false;
1251       }
1252       instruction = builder->AddInstruction(
1253           HloInstruction::CreateBitcastConvert(shape, operands[0]));
1254       break;
1255     }
1256     case HloOpcode::kAllGather:
1257     case HloOpcode::kAllGatherStart: {
1258       optional<std::vector<std::vector<int64>>> tmp_groups;
1259       optional<std::vector<int64>> replica_group_ids;
1260       optional<int64> channel_id;
1261       optional<std::vector<int64>> dimensions;
1262       optional<bool> constrain_layout;
1263       optional<bool> use_global_device_ids;
1264       attrs["replica_groups"] = {/*required=*/false,
1265                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1266       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1267       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1268                              &dimensions};
1269       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1270                                    &constrain_layout};
1271       attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1272                                         &use_global_device_ids};
1273       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1274         return false;
1275       }
1276       std::vector<ReplicaGroup> replica_groups;
1277       if (tmp_groups) {
1278         replica_groups = CreateReplicaGroups(*tmp_groups);
1279       }
1280       if (opcode == HloOpcode::kAllGather) {
1281         instruction = builder->AddInstruction(HloInstruction::CreateAllGather(
1282             shape, operands, dimensions->at(0), replica_groups,
1283             constrain_layout ? *constrain_layout : false, channel_id,
1284             use_global_device_ids ? *use_global_device_ids : false));
1285       } else {
1286         instruction =
1287             builder->AddInstruction(HloInstruction::CreateAllGatherStart(
1288                 shape, operands, dimensions->at(0), replica_groups,
1289                 constrain_layout ? *constrain_layout : false, channel_id,
1290                 use_global_device_ids ? *use_global_device_ids : false));
1291       }
1292       break;
1293     }
1294     case HloOpcode::kAllReduce:
1295     case HloOpcode::kAllReduceStart:
1296     case HloOpcode::kReduceScatter: {
1297       optional<std::vector<std::vector<int64>>> tmp_groups;
1298       optional<HloComputation*> to_apply;
1299       optional<std::vector<int64>> replica_group_ids;
1300       optional<int64> channel_id;
1301       optional<bool> constrain_layout;
1302       optional<bool> use_global_device_ids;
1303       optional<std::vector<int64>> dimensions;
1304       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1305                            &to_apply};
1306       attrs["replica_groups"] = {/*required=*/false,
1307                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1308       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1309       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1310                                    &constrain_layout};
1311       attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1312                                         &use_global_device_ids};
1313       if (opcode == HloOpcode::kReduceScatter) {
1314         attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1315                                &dimensions};
1316       }
1317       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1318         return false;
1319       }
1320       std::vector<ReplicaGroup> replica_groups;
1321       if (tmp_groups) {
1322         replica_groups = CreateReplicaGroups(*tmp_groups);
1323       }
1324       if (opcode == HloOpcode::kAllReduce) {
1325         instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
1326             shape, operands, *to_apply, replica_groups,
1327             constrain_layout ? *constrain_layout : false, channel_id,
1328             use_global_device_ids ? *use_global_device_ids : false));
1329       } else if (opcode == HloOpcode::kReduceScatter) {
1330         instruction =
1331             builder->AddInstruction(HloInstruction::CreateReduceScatter(
1332                 shape, operands, *to_apply, replica_groups,
1333                 constrain_layout ? *constrain_layout : false, channel_id,
1334                 use_global_device_ids ? *use_global_device_ids : false,
1335                 dimensions->at(0)));
1336       } else {
1337         instruction =
1338             builder->AddInstruction(HloInstruction::CreateAllReduceStart(
1339                 shape, operands, *to_apply, replica_groups,
1340                 constrain_layout ? *constrain_layout : false, channel_id,
1341                 use_global_device_ids ? *use_global_device_ids : false));
1342       }
1343       break;
1344     }
1345     case HloOpcode::kAllToAll: {
1346       optional<std::vector<std::vector<int64>>> tmp_groups;
1347       attrs["replica_groups"] = {/*required=*/false,
1348                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1349       optional<int64> channel_id;
1350       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1351       optional<std::vector<int64>> dimensions;
1352       attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1353                              &dimensions};
1354       optional<bool> constrain_layout;
1355       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1356                                    &constrain_layout};
1357       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1358           (dimensions && dimensions->size() != 1)) {
1359         return false;
1360       }
1361       std::vector<ReplicaGroup> replica_groups;
1362       if (tmp_groups) {
1363         replica_groups = CreateReplicaGroups(*tmp_groups);
1364       }
1365       optional<int64> split_dimension;
1366       if (dimensions) {
1367         split_dimension = dimensions->at(0);
1368       }
1369       instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
1370           shape, operands, replica_groups,
1371           constrain_layout ? *constrain_layout : false, channel_id,
1372           split_dimension));
1373       break;
1374     }
1375     case HloOpcode::kCollectivePermute:
1376     case HloOpcode::kCollectivePermuteStart: {
1377       optional<std::vector<std::vector<int64>>> source_targets;
1378       attrs["source_target_pairs"] = {
1379           /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
1380       optional<int64> channel_id;
1381       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1382       optional<std::vector<std::vector<int64_t>>> slice_sizes;
1383       attrs["slice_sizes"] = {/*required=*/false, AttrTy::kBracedInt64ListList,
1384                               &slice_sizes};
1385       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1386         return false;
1387       }
1388       std::vector<std::pair<int64, int64>> pairs(source_targets->size());
1389       for (int i = 0; i < pairs.size(); i++) {
1390         if ((*source_targets)[i].size() != 2) {
1391           return TokenError(
1392               "expects 'source_target_pairs=' to be a list of pairs");
1393         }
1394         pairs[i].first = (*source_targets)[i][0];
1395         pairs[i].second = (*source_targets)[i][1];
1396       }
1397       if (!slice_sizes.has_value()) {
1398         if (operands.size() != 1) {
1399           return TokenError(
1400               "CollectivePermute and CollectivePermuteStart must "
1401               "have exactly  one operand (input buffer) unless "
1402               "it performs dynamic-slice and in-place update.");
1403         }
1404         if (opcode == HloOpcode::kCollectivePermute) {
1405           instruction =
1406               builder->AddInstruction(HloInstruction::CreateCollectivePermute(
1407                   shape, operands[0], pairs, channel_id));
1408         } else if (opcode == HloOpcode::kCollectivePermuteStart) {
1409           instruction = builder->AddInstruction(
1410               HloInstruction::CreateCollectivePermuteStart(shape, operands[0],
1411                                                            pairs, channel_id));
1412         } else {
1413           LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1414                         "CollectivePermuteStart, but got "
1415                      << HloOpcodeString(opcode);
1416         }
1417       } else {
1418         if (operands.size() != 4) {
1419           return TokenError(
1420               "CollectivePermute and CollectivePermuteStart must "
1421               "have exactly four operands for dynamic-slice and "
1422               "in-place update.");
1423         }
1424         if (opcode == HloOpcode::kCollectivePermute) {
1425           instruction =
1426               builder->AddInstruction(HloInstruction::CreateCollectivePermute(
1427                   shape, operands[0], operands[1], operands[2], operands[3],
1428                   pairs, *slice_sizes, channel_id));
1429         } else if (opcode == HloOpcode::kCollectivePermuteStart) {
1430           instruction = builder->AddInstruction(
1431               HloInstruction::CreateCollectivePermuteStart(
1432                   shape, operands[0], operands[1], operands[2], operands[3],
1433                   pairs, *slice_sizes, channel_id));
1434         } else {
1435           LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1436                         "CollectivePermuteStart, but got "
1437                      << HloOpcodeString(opcode);
1438         }
1439       }
1440       break;
1441     }
1442     case HloOpcode::kCopyStart: {
1443       // If the is_cross_program_prefetch attribute is not present then default
1444       // to false.
1445       optional<bool> is_cross_program_prefetch = false;
1446       attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
1447                                             &is_cross_program_prefetch};
1448       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1449           !ParseAttributes(attrs)) {
1450         return false;
1451       }
1452       instruction = builder->AddInstruction(HloInstruction::CreateCopyStart(
1453           shape, operands[0], *is_cross_program_prefetch));
1454       break;
1455     }
1456     case HloOpcode::kReplicaId: {
1457       if (!ParseOperands(&operands, /*expected_size=*/0) ||
1458           !ParseAttributes(attrs)) {
1459         return false;
1460       }
1461       instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
1462       break;
1463     }
1464     case HloOpcode::kPartitionId: {
1465       if (!ParseOperands(&operands, /*expected_size=*/0) ||
1466           !ParseAttributes(attrs)) {
1467         return false;
1468       }
1469       instruction =
1470           builder->AddInstruction(HloInstruction::CreatePartitionId());
1471       break;
1472     }
1473     case HloOpcode::kDynamicReshape: {
1474       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1475         return false;
1476       }
1477       instruction =
1478           builder->AddInstruction(HloInstruction::CreateDynamicReshape(
1479               shape, operands[0],
1480               absl::Span<HloInstruction* const>(operands).subspan(1)));
1481       break;
1482     }
1483     case HloOpcode::kReshape: {
1484       optional<int64> inferred_dimension;
1485       attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
1486                                      &inferred_dimension};
1487       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1488           !ParseAttributes(attrs)) {
1489         return false;
1490       }
1491       instruction = builder->AddInstruction(HloInstruction::CreateReshape(
1492           shape, operands[0], inferred_dimension.value_or(-1)));
1493       break;
1494     }
1495     case HloOpcode::kAfterAll: {
1496       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1497         return false;
1498       }
1499       if (operands.empty()) {
1500         instruction = builder->AddInstruction(HloInstruction::CreateToken());
1501       } else {
1502         instruction =
1503             builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
1504       }
1505       break;
1506     }
1507     case HloOpcode::kAddDependency: {
1508       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1509           !ParseAttributes(attrs)) {
1510         return false;
1511       }
1512       instruction = builder->AddInstruction(
1513           HloInstruction::CreateAddDependency(operands[0], operands[1]));
1514       break;
1515     }
1516     case HloOpcode::kSort: {
1517       optional<std::vector<int64>> dimensions;
1518       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1519                              &dimensions};
1520       optional<bool> is_stable = false;
1521       attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
1522       optional<HloComputation*> to_apply;
1523       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1524                            &to_apply};
1525       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1526           dimensions->size() != 1) {
1527         return false;
1528       }
1529       if (!maybe_infer_shape(
1530               [&] {
1531                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1532                 arg_shapes.reserve(operands.size());
1533                 for (auto* operand : operands) {
1534                   arg_shapes.push_back(&operand->shape());
1535                 }
1536                 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1537               },
1538               &shape)) {
1539         return false;
1540       }
1541       instruction = builder->AddInstruction(
1542           HloInstruction::CreateSort(shape, dimensions->at(0), operands,
1543                                      to_apply.value(), is_stable.value()));
1544       break;
1545     }
1546     case HloOpcode::kTuple: {
1547       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1548         return false;
1549       }
1550       if (!maybe_infer_shape(
1551               [&] {
1552                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1553                 arg_shapes.reserve(operands.size());
1554                 for (auto* operand : operands) {
1555                   arg_shapes.push_back(&operand->shape());
1556                 }
1557                 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1558               },
1559               &shape)) {
1560         return false;
1561       }
1562       // HloInstruction::CreateTuple() infers the shape of the tuple from
1563       // operands and should not be used here.
1564       instruction = builder->AddInstruction(
1565           HloInstruction::CreateVariadic(shape, HloOpcode::kTuple, operands));
1566       break;
1567     }
1568     case HloOpcode::kWhile: {
1569       optional<HloComputation*> condition;
1570       optional<HloComputation*> body;
1571       attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
1572                             &condition};
1573       attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
1574       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1575           !ParseAttributes(attrs)) {
1576         return false;
1577       }
1578       if (!maybe_infer_shape(
1579               [&] {
1580                 return ShapeInference::InferWhileShape(
1581                     condition.value()->ComputeProgramShape(),
1582                     body.value()->ComputeProgramShape(), operands[0]->shape());
1583               },
1584               &shape)) {
1585         return false;
1586       }
1587       instruction = builder->AddInstruction(HloInstruction::CreateWhile(
1588           shape, *condition, *body, /*init=*/operands[0]));
1589       break;
1590     }
1591     case HloOpcode::kRecv: {
1592       optional<int64> channel_id;
1593       // If the is_host_transfer attribute is not present then default to false.
1594       optional<bool> is_host_transfer = false;
1595       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1596       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1597                                    &is_host_transfer};
1598       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1599           !ParseAttributes(attrs)) {
1600         return false;
1601       }
1602       // If the is_host_transfer attribute is not present then default to false.
1603       instruction = builder->AddInstruction(HloInstruction::CreateRecv(
1604           shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
1605       break;
1606     }
1607     case HloOpcode::kRecvDone: {
1608       optional<int64> channel_id;
1609       // If the is_host_transfer attribute is not present then default to false.
1610       optional<bool> is_host_transfer = false;
1611       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1612       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1613                                    &is_host_transfer};
1614       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1615           !ParseAttributes(attrs)) {
1616         return false;
1617       }
1618       if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1619         return false;
1620       }
1621       if (channel_id != operands[0]->channel_id()) {
1622         return false;
1623       }
1624       instruction = builder->AddInstruction(
1625           HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
1626       break;
1627     }
1628     case HloOpcode::kSend: {
1629       optional<int64> channel_id;
1630       // If the is_host_transfer attribute is not present then default to false.
1631       optional<bool> is_host_transfer = false;
1632       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1633       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1634                                    &is_host_transfer};
1635       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1636           !ParseAttributes(attrs)) {
1637         return false;
1638       }
1639       instruction = builder->AddInstruction(HloInstruction::CreateSend(
1640           operands[0], operands[1], *channel_id, *is_host_transfer));
1641       break;
1642     }
1643     case HloOpcode::kSendDone: {
1644       optional<int64> channel_id;
1645       // If the is_host_transfer attribute is not present then default to false.
1646       optional<bool> is_host_transfer = false;
1647       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1648       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1649                                    &is_host_transfer};
1650       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1651           !ParseAttributes(attrs)) {
1652         return false;
1653       }
1654       if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1655         return false;
1656       }
1657       if (channel_id != operands[0]->channel_id()) {
1658         return false;
1659       }
1660       instruction = builder->AddInstruction(
1661           HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
1662       break;
1663     }
1664     case HloOpcode::kGetTupleElement: {
1665       optional<int64> index;
1666       attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
1667       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1668           !ParseAttributes(attrs)) {
1669         return false;
1670       }
1671       if (!maybe_infer_shape(
1672               [&] {
1673                 return ShapeUtil::GetTupleElementShape(operands[0]->shape(),
1674                                                        *index);
1675               },
1676               &shape)) {
1677         return false;
1678       }
1679       instruction = builder->AddInstruction(
1680           HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
1681       break;
1682     }
1683     case HloOpcode::kCall: {
1684       optional<HloComputation*> to_apply;
1685       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1686                            &to_apply};
1687       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1688         return false;
1689       }
1690       if (!maybe_infer_shape(
1691               [&] {
1692                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1693                 arg_shapes.reserve(operands.size());
1694                 for (auto* operand : operands) {
1695                   arg_shapes.push_back(&operand->shape());
1696                 }
1697                 return ShapeInference::InferCallShape(
1698                     arg_shapes, to_apply.value()->ComputeProgramShape());
1699               },
1700               &shape)) {
1701         return false;
1702       }
1703       instruction = builder->AddInstruction(
1704           HloInstruction::CreateCall(shape, operands, *to_apply));
1705       break;
1706     }
1707     case HloOpcode::kReduceWindow: {
1708       optional<HloComputation*> reduce_computation;
1709       optional<Window> window;
1710       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1711       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1712                            &reduce_computation};
1713       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1714         return false;
1715       }
1716       if (!window) {
1717         window.emplace();
1718       }
1719       if (operands.size() % 2) {
1720         auto loc = lexer_.GetLoc();
1721         return Error(loc, StrCat("expects an even number of operands, but has ",
1722                                  operands.size(), " operands"));
1723       }
1724       if (!maybe_infer_shape(
1725               [&] {
1726                 return ShapeInference::InferReduceWindowShape(
1727                     operands[0]->shape(), operands[1]->shape(), *window,
1728                     reduce_computation.value()->ComputeProgramShape());
1729               },
1730               &shape)) {
1731         return false;
1732       }
1733       instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
1734           shape, /*operands=*/
1735           absl::Span<HloInstruction* const>(operands).subspan(
1736               0, operands.size() / 2),
1737           /*init_values=*/
1738           absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1739                                                               2),
1740           *window, *reduce_computation));
1741       break;
1742     }
1743     case HloOpcode::kConvolution: {
1744       optional<Window> window;
1745       optional<ConvolutionDimensionNumbers> dnums;
1746       optional<int64> feature_group_count;
1747       optional<int64> batch_group_count;
1748       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1749       attrs["dim_labels"] = {/*required=*/true,
1750                              AttrTy::kConvolutionDimensionNumbers, &dnums};
1751       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1752                                       &feature_group_count};
1753       attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1754                                     &batch_group_count};
1755       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1756       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1757                                     &operand_precision};
1758       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1759           !ParseAttributes(attrs)) {
1760         return false;
1761       }
1762       if (!window) {
1763         window.emplace();
1764       }
1765       if (!feature_group_count) {
1766         feature_group_count = 1;
1767       }
1768       if (!batch_group_count) {
1769         batch_group_count = 1;
1770       }
1771       PrecisionConfig precision_config;
1772       if (operand_precision) {
1773         *precision_config.mutable_operand_precision() = {
1774             operand_precision->begin(), operand_precision->end()};
1775       } else {
1776         precision_config.mutable_operand_precision()->Resize(
1777             operands.size(), PrecisionConfig::DEFAULT);
1778       }
1779       if (!maybe_infer_shape(
1780               [&] {
1781                 return ShapeInference::InferConvolveShape(
1782                     operands[0]->shape(), operands[1]->shape(),
1783                     *feature_group_count, *batch_group_count, *window, *dnums,
1784                     /*preferred_element_type=*/absl::nullopt);
1785               },
1786               &shape)) {
1787         return false;
1788       }
1789       instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
1790           shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
1791           feature_group_count.value(), batch_group_count.value(), *window,
1792           *dnums, precision_config));
1793       break;
1794     }
1795     case HloOpcode::kFft: {
1796       optional<FftType> fft_type;
1797       optional<std::vector<int64>> fft_length;
1798       attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
1799       attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
1800                              &fft_length};
1801       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1802           !ParseAttributes(attrs)) {
1803         return false;
1804       }
1805       if (!maybe_infer_shape(
1806               [&] {
1807                 return ShapeInference::InferFftShape(operands[0]->shape(),
1808                                                      *fft_type, *fft_length);
1809               },
1810               &shape)) {
1811         return false;
1812       }
1813       instruction = builder->AddInstruction(HloInstruction::CreateFft(
1814           shape, operands[0], *fft_type, *fft_length));
1815       break;
1816     }
1817     case HloOpcode::kTriangularSolve: {
1818       TriangularSolveOptions options;
1819       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1820           !ParseAttributesAsProtoMessage(
1821               /*non_proto_attrs=*/attrs, &options)) {
1822         return false;
1823       }
1824       if (!maybe_infer_shape(
1825               [&] {
1826                 return ShapeInference::InferTriangularSolveShape(
1827                     operands[0]->shape(), operands[1]->shape(), options);
1828               },
1829               &shape)) {
1830         return false;
1831       }
1832       instruction =
1833           builder->AddInstruction(HloInstruction::CreateTriangularSolve(
1834               shape, operands[0], operands[1], options));
1835       break;
1836     }
1837     case HloOpcode::kCompare: {
1838       optional<ComparisonDirection> direction;
1839       optional<Comparison::Type> type;
1840       attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
1841                             &direction};
1842       attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
1843       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1844           !ParseAttributes(attrs)) {
1845         return false;
1846       }
1847       if (!maybe_infer_shape(
1848               [&] {
1849                 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1850                                                           operands[1]);
1851               },
1852               &shape)) {
1853         return false;
1854       }
1855       instruction = builder->AddInstruction(HloInstruction::CreateCompare(
1856           shape, operands[0], operands[1], *direction, type));
1857       break;
1858     }
1859     case HloOpcode::kCholesky: {
1860       CholeskyOptions options;
1861       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1862           !ParseAttributesAsProtoMessage(
1863               /*non_proto_attrs=*/attrs, &options)) {
1864         return false;
1865       }
1866       if (!maybe_infer_shape(
1867               [&] {
1868                 return ShapeInference::InferCholeskyShape(operands[0]->shape());
1869               },
1870               &shape)) {
1871         return false;
1872       }
1873       instruction = builder->AddInstruction(
1874           HloInstruction::CreateCholesky(shape, operands[0], options));
1875       break;
1876     }
1877     case HloOpcode::kBroadcast: {
1878       optional<std::vector<int64>> broadcast_dimensions;
1879       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1880                              &broadcast_dimensions};
1881       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1882           !ParseAttributes(attrs)) {
1883         return false;
1884       }
1885       if (!maybe_infer_shape(
1886               [&] {
1887                 return ShapeInference::InferBroadcastShape(
1888                     operands[0]->shape(), *broadcast_dimensions);
1889               },
1890               &shape)) {
1891         return false;
1892       }
1893       instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
1894           shape, operands[0], *broadcast_dimensions));
1895       break;
1896     }
1897     case HloOpcode::kConcatenate: {
1898       optional<std::vector<int64>> dimensions;
1899       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1900                              &dimensions};
1901       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1902           dimensions->size() != 1) {
1903         return false;
1904       }
1905       if (!maybe_infer_shape(
1906               [&] {
1907                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1908                 arg_shapes.reserve(operands.size());
1909                 for (auto* operand : operands) {
1910                   arg_shapes.push_back(&operand->shape());
1911                 }
1912                 return ShapeInference::InferConcatOpShape(arg_shapes,
1913                                                           dimensions->at(0));
1914               },
1915               &shape)) {
1916         return false;
1917       }
1918       instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
1919           shape, operands, dimensions->at(0)));
1920       break;
1921     }
1922     case HloOpcode::kMap: {
1923       optional<HloComputation*> to_apply;
1924       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1925                            &to_apply};
1926       optional<std::vector<int64>> dimensions;
1927       attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1928                              &dimensions};
1929       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1930         return false;
1931       }
1932       if (!maybe_infer_shape(
1933               [&] {
1934                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1935                 arg_shapes.reserve(operands.size());
1936                 for (auto* operand : operands) {
1937                   arg_shapes.push_back(&operand->shape());
1938                 }
1939                 return ShapeInference::InferMapShape(
1940                     arg_shapes, to_apply.value()->ComputeProgramShape(),
1941                     *dimensions);
1942               },
1943               &shape)) {
1944         return false;
1945       }
1946       instruction = builder->AddInstruction(
1947           HloInstruction::CreateMap(shape, operands, *to_apply));
1948       break;
1949     }
1950     case HloOpcode::kReduce: {
1951       auto loc = lexer_.GetLoc();
1952 
1953       optional<HloComputation*> reduce_computation;
1954       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1955                            &reduce_computation};
1956       optional<std::vector<int64>> dimensions_to_reduce;
1957       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1958                              &dimensions_to_reduce};
1959       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1960         return false;
1961       }
1962       if (operands.size() % 2) {
1963         return Error(loc, StrCat("expects an even number of operands, but has ",
1964                                  operands.size(), " operands"));
1965       }
1966       if (!maybe_infer_shape(
1967               [&] {
1968                 absl::InlinedVector<const Shape*, 2> arg_shapes;
1969                 arg_shapes.reserve(operands.size());
1970                 for (auto* operand : operands) {
1971                   arg_shapes.push_back(&operand->shape());
1972                 }
1973                 return ShapeInference::InferReduceShape(
1974                     arg_shapes, *dimensions_to_reduce,
1975                     reduce_computation.value()->ComputeProgramShape());
1976               },
1977               &shape)) {
1978         return false;
1979       }
1980       instruction = builder->AddInstruction(HloInstruction::CreateReduce(
1981           shape, /*operands=*/
1982           absl::Span<HloInstruction* const>(operands).subspan(
1983               0, operands.size() / 2),
1984           /*init_values=*/
1985           absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1986                                                               2),
1987           *dimensions_to_reduce, *reduce_computation));
1988       break;
1989     }
1990     case HloOpcode::kReverse: {
1991       optional<std::vector<int64>> dimensions;
1992       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1993                              &dimensions};
1994       if (!ParseOperands(&operands, /*expected_size=*/1) ||
1995           !ParseAttributes(attrs)) {
1996         return false;
1997       }
1998       if (!maybe_infer_shape(
1999               [&] {
2000                 return ShapeInference::InferReverseShape(operands[0]->shape(),
2001                                                          *dimensions);
2002               },
2003               &shape)) {
2004         return false;
2005       }
2006       instruction = builder->AddInstruction(
2007           HloInstruction::CreateReverse(shape, operands[0], *dimensions));
2008       break;
2009     }
2010     case HloOpcode::kSelectAndScatter: {
2011       optional<HloComputation*> select;
2012       attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
2013       optional<HloComputation*> scatter;
2014       attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
2015       optional<Window> window;
2016       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2017       if (!ParseOperands(&operands, /*expected_size=*/3) ||
2018           !ParseAttributes(attrs)) {
2019         return false;
2020       }
2021       if (!window) {
2022         window.emplace();
2023       }
2024       if (!maybe_infer_shape(
2025               [&] {
2026                 return ShapeInference::InferSelectAndScatterShape(
2027                     operands[0]->shape(), select.value()->ComputeProgramShape(),
2028                     *window, operands[1]->shape(), operands[2]->shape(),
2029                     scatter.value()->ComputeProgramShape());
2030               },
2031               &shape)) {
2032         return false;
2033       }
2034       instruction =
2035           builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
2036               shape, /*operand=*/operands[0], *select, *window,
2037               /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
2038       break;
2039     }
2040     case HloOpcode::kSlice: {
2041       optional<SliceRanges> slice_ranges;
2042       attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
2043       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2044           !ParseAttributes(attrs)) {
2045         return false;
2046       }
2047       instruction = builder->AddInstruction(HloInstruction::CreateSlice(
2048           shape, operands[0], slice_ranges->starts, slice_ranges->limits,
2049           slice_ranges->strides));
2050       break;
2051     }
2052     case HloOpcode::kDynamicSlice: {
2053       optional<std::vector<int64>> dynamic_slice_sizes;
2054       attrs["dynamic_slice_sizes"] = {
2055           /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
2056       LocTy loc = lexer_.GetLoc();
2057       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2058         return false;
2059       }
2060       if (operands.empty()) {
2061         return Error(loc, "Expected at least one operand.");
2062       }
2063       if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
2064           operands.size() != 1 + operands[0]->shape().rank()) {
2065         return Error(loc, "Wrong number of operands.");
2066       }
2067       instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
2068           shape, /*operand=*/operands[0],
2069           /*start_indices=*/absl::MakeSpan(operands).subspan(1),
2070           *dynamic_slice_sizes));
2071       break;
2072     }
2073     case HloOpcode::kDynamicUpdateSlice: {
2074       LocTy loc = lexer_.GetLoc();
2075       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2076         return false;
2077       }
2078       if (operands.size() < 2) {
2079         return Error(loc, "Expected at least two operands.");
2080       }
2081       if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
2082           operands.size() != 2 + operands[0]->shape().rank()) {
2083         return Error(loc, "Wrong number of operands.");
2084       }
2085       instruction =
2086           builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2087               shape, /*operand=*/operands[0], /*update=*/operands[1],
2088               /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
2089       break;
2090     }
2091     case HloOpcode::kTranspose: {
2092       optional<std::vector<int64>> dimensions;
2093       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2094                              &dimensions};
2095       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2096           !ParseAttributes(attrs)) {
2097         return false;
2098       }
2099       if (!maybe_infer_shape(
2100               [&] {
2101                 return ShapeInference::InferTransposeShape(operands[0]->shape(),
2102                                                            *dimensions);
2103               },
2104               &shape)) {
2105         return false;
2106       }
2107       instruction = builder->AddInstruction(
2108           HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
2109       break;
2110     }
2111     case HloOpcode::kBatchNormTraining: {
2112       optional<float> epsilon;
2113       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2114       optional<int64> feature_index;
2115       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2116                                 &feature_index};
2117       if (!ParseOperands(&operands, /*expected_size=*/3) ||
2118           !ParseAttributes(attrs)) {
2119         return false;
2120       }
2121       if (!maybe_infer_shape(
2122               [&] {
2123                 return ShapeInference::InferBatchNormTrainingShape(
2124                     operands[0]->shape(), operands[1]->shape(),
2125                     operands[2]->shape(), *feature_index);
2126               },
2127               &shape)) {
2128         return false;
2129       }
2130       instruction =
2131           builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
2132               shape, /*operand=*/operands[0], /*scale=*/operands[1],
2133               /*offset=*/operands[2], *epsilon, *feature_index));
2134       break;
2135     }
2136     case HloOpcode::kBatchNormInference: {
2137       optional<float> epsilon;
2138       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2139       optional<int64> feature_index;
2140       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2141                                 &feature_index};
2142       if (!ParseOperands(&operands, /*expected_size=*/5) ||
2143           !ParseAttributes(attrs)) {
2144         return false;
2145       }
2146       if (!maybe_infer_shape(
2147               [&] {
2148                 return ShapeInference::InferBatchNormInferenceShape(
2149                     operands[0]->shape(), operands[1]->shape(),
2150                     operands[2]->shape(), operands[3]->shape(),
2151                     operands[4]->shape(), *feature_index);
2152               },
2153               &shape)) {
2154         return false;
2155       }
2156       instruction =
2157           builder->AddInstruction(HloInstruction::CreateBatchNormInference(
2158               shape, /*operand=*/operands[0], /*scale=*/operands[1],
2159               /*offset=*/operands[2], /*mean=*/operands[3],
2160               /*variance=*/operands[4], *epsilon, *feature_index));
2161       break;
2162     }
2163     case HloOpcode::kBatchNormGrad: {
2164       optional<float> epsilon;
2165       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2166       optional<int64> feature_index;
2167       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2168                                 &feature_index};
2169       if (!ParseOperands(&operands, /*expected_size=*/5) ||
2170           !ParseAttributes(attrs)) {
2171         return false;
2172       }
2173       if (!maybe_infer_shape(
2174               [&] {
2175                 return ShapeInference::InferBatchNormGradShape(
2176                     operands[0]->shape(), operands[1]->shape(),
2177                     operands[2]->shape(), operands[3]->shape(),
2178                     operands[4]->shape(), *feature_index);
2179               },
2180               &shape)) {
2181         return false;
2182       }
2183       instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
2184           shape, /*operand=*/operands[0], /*scale=*/operands[1],
2185           /*mean=*/operands[2], /*variance=*/operands[3],
2186           /*grad_output=*/operands[4], *epsilon, *feature_index));
2187       break;
2188     }
2189     case HloOpcode::kPad: {
2190       optional<PaddingConfig> padding;
2191       attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
2192       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2193           !ParseAttributes(attrs)) {
2194         return false;
2195       }
2196       if (!maybe_infer_shape(
2197               [&] {
2198                 return ShapeInference::InferPadShape(
2199                     operands[0]->shape(), operands[1]->shape(), *padding);
2200               },
2201               &shape)) {
2202         return false;
2203       }
2204       instruction = builder->AddInstruction(HloInstruction::CreatePad(
2205           shape, operands[0], /*padding_value=*/operands[1], *padding));
2206       break;
2207     }
2208     case HloOpcode::kFusion: {
2209       optional<HloComputation*> fusion_computation;
2210       attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
2211                         &fusion_computation};
2212       optional<HloInstruction::FusionKind> fusion_kind;
2213       attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
2214       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2215         return false;
2216       }
2217       instruction = builder->AddInstruction(HloInstruction::CreateFusion(
2218           shape, *fusion_kind, operands, *fusion_computation));
2219       break;
2220     }
2221     case HloOpcode::kInfeed: {
2222       optional<std::string> config;
2223       attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2224       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2225           !ParseAttributes(attrs)) {
2226         return false;
2227       }
2228       // We need to know the infeed data shape to construct the infeed
2229       // instruction. This is the zero-th element of the tuple-shaped output of
2230       // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
2231       // if the shape is not a non-empty tuple, so add guard so an error message
2232       // can be emitted instead of a check fail
2233       if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
2234         return Error(lexer_.GetLoc(),
2235                      "infeed must have a non-empty tuple shape");
2236       }
2237       instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
2238           ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
2239           config ? *config : ""));
2240       break;
2241     }
2242     case HloOpcode::kOutfeed: {
2243       optional<std::string> config;
2244       optional<Shape> outfeed_shape;
2245       attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2246       attrs["outfeed_shape"] = {/*required=*/false, AttrTy::kShape,
2247                                 &outfeed_shape};
2248       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2249           !ParseAttributes(attrs)) {
2250         return false;
2251       }
2252       HloInstruction* const outfeed_input = operands[0];
2253       HloInstruction* const outfeed_token = operands[1];
2254       const Shape shape =
2255           outfeed_shape.has_value() ? *outfeed_shape : outfeed_input->shape();
2256       instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
2257           shape, outfeed_input, outfeed_token, config ? *config : ""));
2258       break;
2259     }
2260     case HloOpcode::kRng: {
2261       optional<RandomDistribution> distribution;
2262       attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
2263                                &distribution};
2264       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2265         return false;
2266       }
2267       instruction = builder->AddInstruction(
2268           HloInstruction::CreateRng(shape, *distribution, operands));
2269       break;
2270     }
2271     case HloOpcode::kRngGetAndUpdateState: {
2272       optional<int64> delta;
2273       attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
2274       if (!ParseOperands(&operands, /*expected_size=*/0) ||
2275           !ParseAttributes(attrs)) {
2276         return false;
2277       }
2278       instruction = builder->AddInstruction(
2279           HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
2280       break;
2281     }
2282     case HloOpcode::kRngBitGenerator: {
2283       optional<RandomAlgorithm> algorithm;
2284       attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
2285                             &algorithm};
2286       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2287         return false;
2288       }
2289       instruction =
2290           builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
2291               shape, operands[0], *algorithm));
2292       break;
2293     }
2294     case HloOpcode::kReducePrecision: {
2295       optional<int64> exponent_bits;
2296       optional<int64> mantissa_bits;
2297       attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
2298                                 &exponent_bits};
2299       attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
2300                                 &mantissa_bits};
2301       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2302           !ParseAttributes(attrs)) {
2303         return false;
2304       }
2305       instruction =
2306           builder->AddInstruction(HloInstruction::CreateReducePrecision(
2307               shape, operands[0], static_cast<int>(*exponent_bits),
2308               static_cast<int>(*mantissa_bits)));
2309       break;
2310     }
2311     case HloOpcode::kConditional: {
2312       optional<HloComputation*> true_computation;
2313       optional<HloComputation*> false_computation;
2314       optional<std::vector<HloComputation*>> branch_computations;
2315       if (!ParseOperands(&operands)) {
2316         return false;
2317       }
2318       if (!ShapeUtil::IsScalar(operands[0]->shape())) {
2319         return Error(lexer_.GetLoc(), "The first operand must be a scalar");
2320       }
2321       const bool branch_index_is_bool =
2322           operands[0]->shape().element_type() == PRED;
2323       if (branch_index_is_bool) {
2324         attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
2325                                      &true_computation};
2326         attrs["false_computation"] = {
2327             /*required=*/true, AttrTy::kHloComputation, &false_computation};
2328       } else {
2329         if (operands[0]->shape().element_type() != S32) {
2330           return Error(lexer_.GetLoc(),
2331                        "The first operand must be a scalar of PRED or S32");
2332         }
2333         attrs["branch_computations"] = {/*required=*/true,
2334                                         AttrTy::kBracedHloComputationList,
2335                                         &branch_computations};
2336       }
2337       if (!ParseAttributes(attrs)) {
2338         return false;
2339       }
2340       if (branch_index_is_bool) {
2341         branch_computations.emplace({*true_computation, *false_computation});
2342       }
2343       if (branch_computations->empty() ||
2344           operands.size() != branch_computations->size() + 1) {
2345         return false;
2346       }
2347       if (!maybe_infer_shape(
2348               [&] {
2349                 absl::InlinedVector<ProgramShape, 2> branch_computation_shapes;
2350                 branch_computation_shapes.reserve(branch_computations->size());
2351                 for (auto* computation : *branch_computations) {
2352                   branch_computation_shapes.push_back(
2353                       computation->ComputeProgramShape());
2354                 }
2355                 absl::InlinedVector<Shape, 2> branch_operand_shapes;
2356                 branch_operand_shapes.reserve(operands.size() - 1);
2357                 for (int i = 1; i < operands.size(); ++i) {
2358                   branch_operand_shapes.push_back(operands[i]->shape());
2359                 }
2360                 return ShapeInference::InferConditionalShape(
2361                     operands[0]->shape(), branch_computation_shapes,
2362                     branch_operand_shapes);
2363               },
2364               &shape)) {
2365         return false;
2366       }
2367       instruction = builder->AddInstruction(HloInstruction::CreateConditional(
2368           shape, /*branch_index=*/operands[0],
2369           absl::MakeSpan(*branch_computations),
2370           absl::MakeSpan(operands).subspan(1)));
2371       break;
2372     }
2373     case HloOpcode::kCustomCall: {
2374       optional<std::string> custom_call_target;
2375       optional<Window> window;
2376       optional<ConvolutionDimensionNumbers> dnums;
2377       optional<int64> feature_group_count;
2378       optional<int64> batch_group_count;
2379       optional<std::vector<Shape>> operand_layout_constraints;
2380       optional<bool> custom_call_has_side_effect;
2381       optional<HloComputation*> to_apply;
2382       optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
2383           output_to_operand_aliasing;
2384       optional<PaddingType> padding_type;
2385       optional<std::vector<HloComputation*>> called_computations;
2386       optional<CustomCallSchedule> custom_call_schedule;
2387       optional<CustomCallApiVersion> api_version;
2388       attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
2389                                      &custom_call_target};
2390       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2391       attrs["dim_labels"] = {/*required=*/false,
2392                              AttrTy::kConvolutionDimensionNumbers, &dnums};
2393       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
2394                                       &feature_group_count};
2395       attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
2396                                     &batch_group_count};
2397       attrs["operand_layout_constraints"] = {
2398           /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
2399       attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
2400                                               &custom_call_has_side_effect};
2401       attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
2402                            &to_apply};
2403       attrs["called_computations"] = {/*required=*/false,
2404                                       AttrTy::kBracedHloComputationList,
2405                                       &called_computations};
2406       attrs["output_to_operand_aliasing"] = {/*required=*/false,
2407                                              AttrTy::kInstructionAliasing,
2408                                              &output_to_operand_aliasing};
2409 
2410       attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
2411                                &padding_type};
2412 
2413       optional<Literal> literal;
2414       attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
2415       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2416       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2417                                     &operand_precision};
2418       if (called_computations.has_value() && to_apply.has_value()) {
2419         return Error(lexer_.GetLoc(),
2420                      "A single instruction can't have both to_apply and "
2421                      "calls field");
2422       }
2423       attrs["schedule"] = {/*required=*/false, AttrTy::kCustomCallSchedule,
2424                            &custom_call_schedule};
2425       attrs["api_version"] = {/*required=*/false, AttrTy::kCustomCallApiVersion,
2426                               &api_version};
2427       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2428         return false;
2429       }
2430 
2431       if (api_version.has_value() &&
2432           *api_version == CustomCallApiVersion::API_VERSION_UNSPECIFIED) {
2433         return Error(lexer_.GetLoc(),
2434                      StrCat("Invalid API version: ",
2435                             CustomCallApiVersion_Name(*api_version)));
2436       }
2437       if (operand_layout_constraints.has_value()) {
2438         if (!LayoutUtil::HasLayout(shape)) {
2439           return Error(lexer_.GetLoc(),
2440                        "Layout must be set on layout-constrained custom call");
2441         }
2442         if (operands.size() != operand_layout_constraints->size()) {
2443           return Error(lexer_.GetLoc(),
2444                        StrCat("Expected ", operands.size(),
2445                               " operand layout constraints, ",
2446                               operand_layout_constraints->size(), " given"));
2447         }
2448         for (int64_t i = 0; i < operands.size(); ++i) {
2449           const Shape& operand_shape_with_layout =
2450               (*operand_layout_constraints)[i];
2451           if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
2452             return Error(lexer_.GetLoc(),
2453                          StrCat("Operand layout constraint shape ",
2454                                 ShapeUtil::HumanStringWithLayout(
2455                                     operand_shape_with_layout),
2456                                 " for operand ", i, " does not have a layout"));
2457           }
2458           if (!ShapeUtil::Compatible(operand_shape_with_layout,
2459                                      operands[i]->shape())) {
2460             return Error(
2461                 lexer_.GetLoc(),
2462                 StrCat(
2463                     "Operand layout constraint shape ",
2464                     ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
2465                     " for operand ", i,
2466                     " is not compatible with operand shape ",
2467                     ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
2468           }
2469         }
2470         instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
2471             shape, operands, *custom_call_target, *operand_layout_constraints,
2472             backend_config ? *backend_config : ""));
2473       } else {
2474         if (to_apply.has_value()) {
2475           instruction =
2476               builder->AddInstruction(HloInstruction::CreateCustomCall(
2477                   shape, operands, *to_apply, *custom_call_target,
2478                   backend_config ? *backend_config : ""));
2479         } else if (called_computations.has_value()) {
2480           instruction =
2481               builder->AddInstruction(HloInstruction::CreateCustomCall(
2482                   shape, operands, *called_computations, *custom_call_target,
2483                   backend_config ? *backend_config : ""));
2484         } else {
2485           instruction =
2486               builder->AddInstruction(HloInstruction::CreateCustomCall(
2487                   shape, operands, *custom_call_target,
2488                   backend_config ? *backend_config : ""));
2489         }
2490       }
2491       auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
2492       if (window.has_value()) {
2493         custom_call_instr->set_window(*window);
2494       }
2495       if (dnums.has_value()) {
2496         custom_call_instr->set_convolution_dimension_numbers(*dnums);
2497       }
2498       if (feature_group_count.has_value()) {
2499         custom_call_instr->set_feature_group_count(*feature_group_count);
2500       }
2501       if (batch_group_count.has_value()) {
2502         custom_call_instr->set_batch_group_count(*batch_group_count);
2503       }
2504       if (padding_type.has_value()) {
2505         custom_call_instr->set_padding_type(*padding_type);
2506       }
2507       if (custom_call_has_side_effect.has_value()) {
2508         custom_call_instr->set_custom_call_has_side_effect(
2509             *custom_call_has_side_effect);
2510       }
2511       if (custom_call_schedule.has_value()) {
2512         custom_call_instr->set_custom_call_schedule(*custom_call_schedule);
2513       }
2514       if (api_version.has_value()) {
2515         custom_call_instr->set_api_version(*api_version);
2516       }
2517       if (output_to_operand_aliasing.has_value()) {
2518         custom_call_instr->set_output_to_operand_aliasing(
2519             std::move(*output_to_operand_aliasing));
2520       }
2521       if (literal.has_value()) {
2522         custom_call_instr->set_literal(std::move(*literal));
2523       }
2524       PrecisionConfig precision_config;
2525       if (operand_precision) {
2526         *precision_config.mutable_operand_precision() = {
2527             operand_precision->begin(), operand_precision->end()};
2528       } else {
2529         precision_config.mutable_operand_precision()->Resize(
2530             operands.size(), PrecisionConfig::DEFAULT);
2531       }
2532       *custom_call_instr->mutable_precision_config() = precision_config;
2533       break;
2534     }
2535     case HloOpcode::kDot: {
2536       optional<std::vector<int64>> lhs_contracting_dims;
2537       attrs["lhs_contracting_dims"] = {
2538           /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
2539       optional<std::vector<int64>> rhs_contracting_dims;
2540       attrs["rhs_contracting_dims"] = {
2541           /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
2542       optional<std::vector<int64>> lhs_batch_dims;
2543       attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2544                                  &lhs_batch_dims};
2545       optional<std::vector<int64>> rhs_batch_dims;
2546       attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2547                                  &rhs_batch_dims};
2548       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2549       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2550                                     &operand_precision};
2551 
2552       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2553           !ParseAttributes(attrs)) {
2554         return false;
2555       }
2556 
2557       DotDimensionNumbers dnum;
2558       if (lhs_contracting_dims) {
2559         *dnum.mutable_lhs_contracting_dimensions() = {
2560             lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
2561       }
2562       if (rhs_contracting_dims) {
2563         *dnum.mutable_rhs_contracting_dimensions() = {
2564             rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
2565       }
2566       if (lhs_batch_dims) {
2567         *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
2568                                                 lhs_batch_dims->end()};
2569       }
2570       if (rhs_batch_dims) {
2571         *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
2572                                                 rhs_batch_dims->end()};
2573       }
2574 
2575       PrecisionConfig precision_config;
2576       if (operand_precision) {
2577         *precision_config.mutable_operand_precision() = {
2578             operand_precision->begin(), operand_precision->end()};
2579       } else {
2580         precision_config.mutable_operand_precision()->Resize(
2581             operands.size(), PrecisionConfig::DEFAULT);
2582       }
2583       if (!maybe_infer_shape(
2584               [&] {
2585                 return ShapeInference::InferDotOpShape(
2586                     operands[0]->shape(), operands[1]->shape(), dnum,
2587                     /*preferred_element_type=*/absl::nullopt);
2588               },
2589               &shape)) {
2590         return false;
2591       }
2592       instruction = builder->AddInstruction(HloInstruction::CreateDot(
2593           shape, operands[0], operands[1], dnum, precision_config));
2594       break;
2595     }
2596     case HloOpcode::kGather: {
2597       optional<std::vector<int64>> offset_dims;
2598       attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
2599                               &offset_dims};
2600       optional<std::vector<int64>> collapsed_slice_dims;
2601       attrs["collapsed_slice_dims"] = {
2602           /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
2603       optional<std::vector<int64>> start_index_map;
2604       attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
2605                                   &start_index_map};
2606       optional<int64> index_vector_dim;
2607       attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2608                                    &index_vector_dim};
2609       optional<std::vector<int64>> slice_sizes;
2610       attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
2611                               &slice_sizes};
2612       optional<bool> indices_are_sorted = false;
2613       attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2614                                      &indices_are_sorted};
2615 
2616       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2617           !ParseAttributes(attrs)) {
2618         return false;
2619       }
2620 
2621       GatherDimensionNumbers dim_numbers =
2622           HloGatherInstruction::MakeGatherDimNumbers(
2623               /*offset_dims=*/*offset_dims,
2624               /*collapsed_slice_dims=*/*collapsed_slice_dims,
2625               /*start_index_map=*/*start_index_map,
2626               /*index_vector_dim=*/*index_vector_dim);
2627       if (!maybe_infer_shape(
2628               [&] {
2629                 return ShapeInference::InferGatherShape(
2630                     operands[0]->shape(), operands[1]->shape(), dim_numbers,
2631                     *slice_sizes);
2632               },
2633               &shape)) {
2634         return false;
2635       }
2636       instruction = builder->AddInstruction(HloInstruction::CreateGather(
2637           shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
2638           dim_numbers, *slice_sizes, indices_are_sorted.value()));
2639       break;
2640     }
2641     case HloOpcode::kScatter: {
2642       optional<std::vector<int64>> update_window_dims;
2643       attrs["update_window_dims"] = {
2644           /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
2645       optional<std::vector<int64>> inserted_window_dims;
2646       attrs["inserted_window_dims"] = {
2647           /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
2648       optional<std::vector<int64>> scatter_dims_to_operand_dims;
2649       attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
2650                                                AttrTy::kBracedInt64List,
2651                                                &scatter_dims_to_operand_dims};
2652       optional<int64> index_vector_dim;
2653       attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2654                                    &index_vector_dim};
2655 
2656       optional<HloComputation*> update_computation;
2657       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2658                            &update_computation};
2659       optional<bool> indices_are_sorted = false;
2660       attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2661                                      &indices_are_sorted};
2662       optional<bool> unique_indices = false;
2663       attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
2664                                  &unique_indices};
2665 
2666       if (!ParseOperands(&operands, /*expected_size=*/3) ||
2667           !ParseAttributes(attrs)) {
2668         return false;
2669       }
2670 
2671       ScatterDimensionNumbers dim_numbers =
2672           HloScatterInstruction::MakeScatterDimNumbers(
2673               /*update_window_dims=*/*update_window_dims,
2674               /*inserted_window_dims=*/*inserted_window_dims,
2675               /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
2676               /*index_vector_dim=*/*index_vector_dim);
2677 
2678       if (!maybe_infer_shape(
2679               [&] {
2680                 return ShapeInference::InferScatterShape(
2681                     operands[0]->shape(), operands[1]->shape(),
2682                     operands[2]->shape(),
2683                     update_computation.value()->ComputeProgramShape(),
2684                     dim_numbers);
2685               },
2686               &shape)) {
2687         return false;
2688       }
2689       instruction = builder->AddInstruction(HloInstruction::CreateScatter(
2690           shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
2691           /*updates=*/operands[2], *update_computation, dim_numbers,
2692           indices_are_sorted.value(), unique_indices.value()));
2693       break;
2694     }
2695     case HloOpcode::kDomain: {
2696       DomainData domain;
2697       attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
2698       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2699           !ParseAttributes(attrs)) {
2700         return false;
2701       }
2702       if (!maybe_infer_shape(
2703               [&] {
2704                 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
2705               },
2706               &shape)) {
2707         return false;
2708       }
2709       instruction = builder->AddInstruction(HloInstruction::CreateDomain(
2710           shape, operands[0], std::move(domain.exit_metadata),
2711           std::move(domain.entry_metadata)));
2712       break;
2713     }
2714     case HloOpcode::kTrace:
2715       return TokenError(StrCat("parsing not yet implemented for op: ",
2716                                HloOpcodeString(opcode)));
2717     case HloOpcode::kGetDimensionSize: {
2718       optional<std::vector<int64>> dimensions;
2719       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2720                              &dimensions};
2721       if (!ParseOperands(&operands, /*expected_size=*/1) ||
2722           !ParseAttributes(attrs)) {
2723         return false;
2724       }
2725       if (!maybe_infer_shape(
2726               [&] {
2727                 return ShapeInference::InferGetDimensionSizeShape(
2728                     operands[0]->shape(), dimensions->at(0));
2729               },
2730               &shape)) {
2731         return false;
2732       }
2733       instruction =
2734           builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
2735               shape, operands[0], (*dimensions)[0]));
2736       break;
2737     }
2738     case HloOpcode::kSetDimensionSize: {
2739       optional<std::vector<int64>> dimensions;
2740       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2741                              &dimensions};
2742       if (!ParseOperands(&operands, /*expected_size=*/2) ||
2743           !ParseAttributes(attrs)) {
2744         return false;
2745       }
2746       if (!maybe_infer_shape(
2747               [&] {
2748                 return ShapeInference::InferSetDimensionSizeShape(
2749                     operands[0]->shape(), operands[1]->shape(),
2750                     dimensions->at(0));
2751               },
2752               &shape)) {
2753         return false;
2754       }
2755       instruction =
2756           builder->AddInstruction(HloInstruction::CreateSetDimensionSize(
2757               shape, operands[0], operands[1], (*dimensions)[0]));
2758       break;
2759     }
2760   }
2761 
2762   instruction->SetAndSanitizeName(name);
2763   if (instruction->name() != name) {
2764     return Error(name_loc,
2765                  StrCat("illegal instruction name: ", name,
2766                         "; suggest renaming to: ", instruction->name()));
2767   }
2768 
2769   // Add shared attributes like metadata to the instruction, if they were seen.
2770   if (sharding) {
2771     instruction->set_sharding(
2772         HloSharding::FromProto(sharding.value()).ValueOrDie());
2773   }
2774   if (parameter_replication) {
2775     int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
2776     const auto& replicated =
2777         parameter_replication->replicated_at_leaf_buffers();
2778     if (leaf_count != replicated.size()) {
2779       return Error(lexer_.GetLoc(),
2780                    StrCat("parameter has ", leaf_count,
2781                           " leaf buffers, but parameter_replication has ",
2782                           replicated.size(), " elements."));
2783     }
2784     instruction->set_parameter_replicated_at_leaf_buffers(replicated);
2785   }
2786   if (predecessors) {
2787     for (auto* pre : *predecessors) {
2788       Status status = pre->AddControlDependencyTo(instruction);
2789       if (!status.ok()) {
2790         return Error(name_loc, StrCat("error adding control dependency for: ",
2791                                       name, " status: ", status.ToString()));
2792       }
2793     }
2794   }
2795   if (metadata) {
2796     instruction->set_metadata(*metadata);
2797   }
2798   if (backend_config) {
2799     instruction->set_raw_backend_config_string(std::move(*backend_config));
2800   }
2801   if (outer_dimension_partitions) {
2802     instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
2803   }
2804   if (frontend_attributes) {
2805     instruction->set_frontend_attributes(*frontend_attributes);
2806   }
2807   return AddInstruction(name, instruction, name_loc);
2808 }  // NOLINT(readability/fn_size)
2809 
2810 // ::= '{' (single_sharding | tuple_sharding) '}'
2811 //
2812 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)2813 bool HloParserImpl::ParseSharding(OpSharding* sharding) {
2814   // A single sharding starts with '{' and is not followed by '{'.
2815   // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
2816   // an empty tuple.
2817   if (!ParseToken(TokKind::kLbrace,
2818                   "expected '{' to start sharding attribute")) {
2819     return false;
2820   }
2821 
2822   if (lexer_.GetKind() != TokKind::kLbrace &&
2823       lexer_.GetKind() != TokKind::kRbrace) {
2824     return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
2825   }
2826 
2827   // Tuple sharding.
2828   // Allow empty tuple shardings.
2829   if (lexer_.GetKind() != TokKind::kRbrace) {
2830     do {
2831       if (!ParseSingleSharding(sharding->add_tuple_shardings(),
2832                                /*lbrace_pre_lexed=*/false)) {
2833         return false;
2834       }
2835     } while (EatIfPresent(TokKind::kComma));
2836   }
2837   sharding->set_type(OpSharding::TUPLE);
2838 
2839   return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
2840 }
2841 
2842 // frontend_attributes ::= '{' attributes '}'
2843 // attributes
2844 //   ::= /*empty*/
2845 //   ::= attribute '=' value (',' attribute '=' value)*
ParseFrontendAttributes(FrontendAttributes * frontend_attributes)2846 bool HloParserImpl::ParseFrontendAttributes(
2847     FrontendAttributes* frontend_attributes) {
2848   CHECK(frontend_attributes != nullptr);
2849   if (!ParseToken(TokKind::kLbrace,
2850                   "expected '{' to start frontend attributes")) {
2851     return false;
2852   }
2853   if (lexer_.GetKind() == TokKind::kRbrace) {
2854     // empty
2855   } else {
2856     do {
2857       std::string attribute;
2858       if (!ParseAttributeName(&attribute)) {
2859         return false;
2860       }
2861       if (lexer_.GetKind() != TokKind::kString) {
2862         return false;
2863       }
2864       (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
2865       lexer_.Lex();
2866     } while (EatIfPresent(TokKind::kComma));
2867   }
2868   return ParseToken(TokKind::kRbrace,
2869                     "expects '}' at the end of frontend attributes");
2870 }
2871 
2872 // ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
2873 //         ('devices=' ('[' dims ']')* device_list)?
2874 //         ('metadata=' metadata)* '}'
2875 //
2876 // dims ::= int_list device_list ::= int_list
2877 // metadata ::= single_metadata |
2878 //              ('{' [single_metadata (',' single_metadata)*] '}')
2879 // last_tile_dims ::= sharding_type_list
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)2880 bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
2881                                         bool lbrace_pre_lexed) {
2882   if (!lbrace_pre_lexed &&
2883       !ParseToken(TokKind::kLbrace,
2884                   "expected '{' to start sharding attribute")) {
2885     return false;
2886   }
2887 
2888   LocTy loc = lexer_.GetLoc();
2889   bool maximal = false;
2890   bool replicated = false;
2891   bool manual = false;
2892   bool last_tile_dim_replicate = false;
2893   bool last_tile_dims = false;
2894   std::vector<int64> devices;
2895   std::vector<int64> tile_assignment_dimensions;
2896   std::vector<OpSharding::Type> sharding_types;
2897   while (lexer_.GetKind() != TokKind::kRbrace) {
2898     switch (lexer_.GetKind()) {
2899       case TokKind::kw_maximal:
2900         maximal = true;
2901         lexer_.Lex();
2902         break;
2903       case TokKind::kw_replicated:
2904         replicated = true;
2905         lexer_.Lex();
2906         break;
2907       case TokKind::kw_manual:
2908         manual = true;
2909         lexer_.Lex();
2910         break;
2911       case TokKind::kAttributeName: {
2912         if (lexer_.GetStrVal() == "device") {
2913           if (lexer_.Lex() != TokKind::kInt) {
2914             return TokenError("device= attribute must be an integer");
2915           }
2916           devices = {lexer_.GetInt64Val()};
2917           lexer_.Lex();
2918         } else if (lexer_.GetStrVal() == "devices") {
2919           lexer_.Lex();
2920           if (!ParseToken(TokKind::kLsquare,
2921                           "expected '[' to start sharding devices shape")) {
2922             return false;
2923           }
2924 
2925           do {
2926             int64_t dim;
2927             if (!ParseInt64(&dim)) {
2928               return false;
2929             }
2930             tile_assignment_dimensions.push_back(dim);
2931           } while (EatIfPresent(TokKind::kComma));
2932 
2933           if (!ParseToken(TokKind::kRsquare,
2934                           "expected ']' to start sharding devices shape")) {
2935             return false;
2936           }
2937           do {
2938             int64_t device;
2939             if (!ParseInt64(&device)) {
2940               return false;
2941             }
2942             devices.push_back(device);
2943           } while (EatIfPresent(TokKind::kComma));
2944         } else if (lexer_.GetStrVal() == "metadata") {
2945           lexer_.Lex();
2946           if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) {
2947             return false;
2948           }
2949         } else if (lexer_.GetStrVal() == "last_tile_dims") {
2950           last_tile_dims = true;
2951           lexer_.Lex();
2952           if (!ParseListShardingType(&sharding_types)) {
2953             return false;
2954           }
2955         } else {
2956           return TokenError(
2957               "unknown attribute in sharding: expected device=, devices= "
2958               "metadata= or last_tile_dims= ");
2959         }
2960         break;
2961       }
2962       case TokKind::kw_last_tile_dim_replicate:
2963         last_tile_dim_replicate = true;
2964         lexer_.Lex();
2965         break;
2966       case TokKind::kRbrace:
2967         break;
2968       default:
2969         return TokenError("unexpected token");
2970     }
2971   }
2972 
2973   if (replicated) {
2974     if (!devices.empty()) {
2975       return Error(loc,
2976                    "replicated shardings should not have any devices assigned");
2977     }
2978     sharding->set_type(OpSharding::REPLICATED);
2979   } else if (maximal) {
2980     if (devices.size() != 1) {
2981       return Error(loc,
2982                    "maximal shardings should have exactly one device assigned");
2983     }
2984     sharding->set_type(OpSharding::MAXIMAL);
2985     sharding->add_tile_assignment_devices(devices[0]);
2986   } else if (manual) {
2987     if (!devices.empty()) {
2988       return Error(loc,
2989                    "manual shardings should not have any devices assigned");
2990     }
2991     sharding->set_type(OpSharding::MANUAL);
2992   } else {
2993     if (devices.size() <= 1) {
2994       return Error(
2995           loc, "non-maximal shardings must have more than one device assigned");
2996     }
2997     if (tile_assignment_dimensions.empty()) {
2998       return Error(
2999           loc,
3000           "non-maximal shardings must have a tile assignment list including "
3001           "dimensions");
3002     }
3003     sharding->set_type(OpSharding::OTHER);
3004     for (int64_t dim : tile_assignment_dimensions) {
3005       sharding->add_tile_assignment_dimensions(dim);
3006     }
3007     for (int64_t device : devices) {
3008       sharding->add_tile_assignment_devices(device);
3009     }
3010 
3011     if (last_tile_dims) {
3012       for (OpSharding::Type type : sharding_types) {
3013         sharding->add_last_tile_dims(type);
3014       }
3015     } else {
3016       sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate);
3017     }
3018   }
3019 
3020   lexer_.Lex();
3021   return true;
3022 }
3023 
3024 // parameter_replication ::=
3025 //   '{' ('true' | 'false')* (',' ('true' | 'false'))*  '}'
ParseParameterReplication(ParameterReplication * parameter_replication)3026 bool HloParserImpl::ParseParameterReplication(
3027     ParameterReplication* parameter_replication) {
3028   if (!ParseToken(TokKind::kLbrace,
3029                   "expected '{' to start parameter_replication attribute")) {
3030     return false;
3031   }
3032 
3033   if (lexer_.GetKind() != TokKind::kRbrace) {
3034     do {
3035       if (lexer_.GetKind() == TokKind::kw_true) {
3036         parameter_replication->add_replicated_at_leaf_buffers(true);
3037       } else if (lexer_.GetKind() == TokKind::kw_false) {
3038         parameter_replication->add_replicated_at_leaf_buffers(false);
3039       } else {
3040         return false;
3041       }
3042       lexer_.Lex();
3043     } while (EatIfPresent(TokKind::kComma));
3044   }
3045 
3046   return ParseToken(TokKind::kRbrace,
3047                     "expected '}' to end parameter_replication attribute");
3048 }
3049 
3050 // replica_groups ::='{' int64list_elements '}'
3051 // int64list_elements
3052 //   ::= /*empty*/
3053 //   ::= int64list (',' int64list)*
3054 // int64list ::= '{' int64_elements '}'
3055 // int64_elements
3056 //   ::= /*empty*/
3057 //   ::= int64_val (',' int64_val)*
ParseReplicaGroupsOnly(std::vector<ReplicaGroup> * replica_groups)3058 bool HloParserImpl::ParseReplicaGroupsOnly(
3059     std::vector<ReplicaGroup>* replica_groups) {
3060   std::vector<std::vector<int64>> result;
3061   if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3062                           &result)) {
3063     return false;
3064   }
3065   *replica_groups = CreateReplicaGroups(result);
3066   return true;
3067 }
3068 
3069 // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
3070 //            'exit=' exit_sharding '}'
ParseDomain(DomainData * domain)3071 bool HloParserImpl::ParseDomain(DomainData* domain) {
3072   absl::flat_hash_map<std::string, AttrConfig> attrs;
3073   optional<std::string> kind;
3074   optional<OpSharding> entry_sharding;
3075   optional<OpSharding> exit_sharding;
3076   attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
3077   attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
3078   attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
3079   if (!ParseSubAttributes(attrs)) {
3080     return false;
3081   }
3082   if (*kind == ShardingMetadata::KindName()) {
3083     auto entry_sharding_ptr = absl::make_unique<HloSharding>(
3084         HloSharding::FromProto(*entry_sharding).ValueOrDie());
3085     auto exit_sharding_ptr = absl::make_unique<HloSharding>(
3086         HloSharding::FromProto(*exit_sharding).ValueOrDie());
3087     domain->entry_metadata =
3088         absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
3089     domain->exit_metadata =
3090         absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
3091   } else {
3092     return TokenError(StrCat("unsupported domain kind: ", *kind));
3093   }
3094   return true;
3095 }
3096 
3097 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)3098 bool HloParserImpl::ParseInstructionNames(
3099     std::vector<HloInstruction*>* instructions) {
3100   if (!ParseToken(TokKind::kLbrace,
3101                   "expects '{' at the beginning of instruction name list")) {
3102     return false;
3103   }
3104   LocTy loc = lexer_.GetLoc();
3105   do {
3106     std::string name;
3107     if (!ParseName(&name)) {
3108       return Error(loc, "expects a instruction name");
3109     }
3110     std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
3111     if (!instr) {
3112       return TokenError(StrFormat("instruction '%s' is not defined", name));
3113     }
3114     instructions->push_back(instr->first);
3115   } while (EatIfPresent(TokKind::kComma));
3116 
3117   return ParseToken(TokKind::kRbrace,
3118                     "expects '}' at the end of instruction name list");
3119 }
3120 
SetValueInLiteral(LocTy loc,int64_t value,int64_t index,Literal * literal)3121 bool HloParserImpl::SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
3122                                       Literal* literal) {
3123   const Shape& shape = literal->shape();
3124   switch (shape.element_type()) {
3125     case S8:
3126       return SetValueInLiteralHelper<int8>(loc, value, index, literal);
3127     case S16:
3128       return SetValueInLiteralHelper<int16>(loc, value, index, literal);
3129     case S32:
3130       return SetValueInLiteralHelper<int32>(loc, value, index, literal);
3131     case S64:
3132       return SetValueInLiteralHelper<int64>(loc, value, index, literal);
3133     case U8:
3134       return SetValueInLiteralHelper<tensorflow::uint8>(loc, value, index,
3135                                                         literal);
3136     case U16:
3137       return SetValueInLiteralHelper<tensorflow::uint16>(loc, value, index,
3138                                                          literal);
3139     case U32:
3140       return SetValueInLiteralHelper<tensorflow::uint32>(loc, value, index,
3141                                                          literal);
3142     case U64:
3143       return SetValueInLiteralHelper<tensorflow::uint64>(loc, value, index,
3144                                                          literal);
3145     case PRED:
3146       // Bool type literals with rank >= 1 are printed in 0s and 1s.
3147       return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
3148                                            literal);
3149     default:
3150       LOG(FATAL) << "unknown integral primitive type "
3151                  << PrimitiveType_Name(shape.element_type());
3152   }
3153 }
3154 
SetValueInLiteral(LocTy loc,double value,int64_t index,Literal * literal)3155 bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64_t index,
3156                                       Literal* literal) {
3157   const Shape& shape = literal->shape();
3158   switch (shape.element_type()) {
3159     case F16:
3160       return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
3161     case BF16:
3162       return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
3163                                                            literal);
3164     case F32:
3165       return SetValueInLiteralHelper<float>(loc, value, index, literal);
3166     case F64:
3167       return SetValueInLiteralHelper<double>(loc, value, index, literal);
3168     default:
3169       LOG(FATAL) << "unknown floating point primitive type "
3170                  << PrimitiveType_Name(shape.element_type());
3171   }
3172 }
3173 
SetValueInLiteral(LocTy loc,bool value,int64_t index,Literal * literal)3174 bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64_t index,
3175                                       Literal* literal) {
3176   const Shape& shape = literal->shape();
3177   switch (shape.element_type()) {
3178     case PRED:
3179       return SetValueInLiteralHelper<bool>(loc, value, index, literal);
3180     default:
3181       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3182                  << " is not PRED type";
3183   }
3184 }
3185 
SetValueInLiteral(LocTy loc,std::complex<double> value,int64_t index,Literal * literal)3186 bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
3187                                       int64_t index, Literal* literal) {
3188   const Shape& shape = literal->shape();
3189   switch (shape.element_type()) {
3190     case C64:
3191       return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
3192                                                           literal);
3193     case C128:
3194       return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
3195                                                            literal);
3196     default:
3197       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3198                  << " is not a complex type";
3199   }
3200 }
3201 
3202 template <typename T>
StringifyValue(T val)3203 std::string StringifyValue(T val) {
3204   return StrCat(val);
3205 }
3206 template <>
StringifyValue(std::complex<double> val)3207 std::string StringifyValue(std::complex<double> val) {
3208   return StrFormat("(%f, %f)", std::real(val), std::imag(val));
3209 }
3210 
3211 // Evaluates to V when T == U.
3212 template <typename T, typename U, typename V>
3213 using EnableIfSameWithType = std::enable_if_t<std::is_same<T, U>::value, V>;
3214 
3215 template <class T, EnableIfSameWithType<T, bool, bool> = false>
GetNanPayload(T val)3216 uint64_t GetNanPayload(T val) {
3217   return 0;
3218 }
3219 
3220 template <class T, EnableIfSameWithType<T, int64_t, bool> = false>
GetNanPayload(T val)3221 uint64_t GetNanPayload(T val) {
3222   return 0;
3223 }
3224 
3225 template <class T, EnableIfSameWithType<T, double, bool> = false>
GetNanPayload(T val)3226 uint64_t GetNanPayload(T val) {
3227   auto rep = absl::bit_cast<uint64_t>(val);
3228   if (auto payload = rep & NanPayloadBitMask<double>()) {
3229     return payload;
3230   }
3231   return QuietNanWithoutPayload<double>();
3232 }
3233 
3234 template <typename LiteralNativeT, typename LiteralComponentT>
3235 EnableIfSameWithType<LiteralNativeT, LiteralComponentT, LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real,LiteralComponentT imag)3236 LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
3237   return real;
3238 }
3239 
3240 template <typename LiteralNativeT, typename LiteralComponentT>
3241 EnableIfSameWithType<LiteralNativeT, std::complex<LiteralComponentT>,
3242                      LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real,LiteralComponentT imag)3243 LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
3244   return LiteralNativeT(real, imag);
3245 }
3246 
3247 template <typename T>
3248 struct ComponentType {
3249   using Type = T;
3250 };
3251 
3252 template <typename T>
3253 struct ComponentType<std::complex<T>> {
3254   using Type = T;
3255 };
3256 
3257 template <typename T>
GetReal(T value)3258 T GetReal(T value) {
3259   return value;
3260 }
3261 
3262 template <typename T>
GetReal(std::complex<T> value)3263 T GetReal(std::complex<T> value) {
3264   return value.real();
3265 }
3266 
3267 template <typename T>
GetImag(T value)3268 T GetImag(T value) {
3269   return 0;
3270 }
3271 
3272 template <typename T>
GetImag(std::complex<T> value)3273 T GetImag(std::complex<T> value) {
3274   return value.imag();
3275 }
3276 
3277 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(LocTy loc,ParsedElemT value,int64_t index,Literal * literal)3278 bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
3279                                             int64_t index, Literal* literal) {
3280   if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
3281     return false;
3282   }
3283 
3284   // Check that the index is in range and assign into the literal
3285   if (index >= ShapeUtil::ElementsIn(literal->shape())) {
3286     return Error(loc, StrCat("tries to set value ", StringifyValue(value),
3287                              " to a literal in shape ",
3288                              ShapeUtil::HumanString(literal->shape()),
3289                              " at linear index ", index,
3290                              ", but the index is out of range"));
3291   }
3292   using ParsedElemComponentT = typename ComponentType<ParsedElemT>::Type;
3293   using LiteralNativeComponentT = typename ComponentType<LiteralNativeT>::Type;
3294   const auto handle_nan = [this, literal, index, loc](
3295                               ParsedElemComponentT parsed_value_component,
3296                               LiteralNativeComponentT*
3297                                   literal_value_component) {
3298     if (!std::isnan(static_cast<double>(parsed_value_component))) {
3299       return true;
3300     }
3301     auto nan_payload = GetNanPayload(parsed_value_component);
3302     if (nan_payload == QuietNanWithoutPayload<double>()) {
3303       nan_payload = QuietNanWithoutPayload<LiteralNativeComponentT>();
3304     }
3305     const auto kLargestPayload = NanPayloadBitMask<LiteralNativeComponentT>();
3306     if (nan_payload > kLargestPayload) {
3307       return Error(
3308           loc,
3309           StrCat("tries to set NaN payload 0x", absl::Hex(nan_payload),
3310                  " to a literal in shape ",
3311                  ShapeUtil::HumanString(literal->shape()), " at linear index ",
3312                  index, ", but the NaN payload is out of range (0x",
3313                  absl::Hex(kLargestPayload), ")"));
3314     }
3315     *literal_value_component = NanWithSignAndPayload<LiteralNativeComponentT>(
3316         /*sign=*/std::signbit(static_cast<double>(parsed_value_component)),
3317         /*nan_payload=*/nan_payload);
3318     return true;
3319   };
3320   const ParsedElemComponentT parsed_real_value = GetReal(value);
3321   auto literal_real_value =
3322       static_cast<LiteralNativeComponentT>(parsed_real_value);
3323   if (std::is_floating_point<ParsedElemT>::value ||
3324       std::is_same<ParsedElemT, std::complex<double>>::value) {
3325     if (!handle_nan(parsed_real_value, &literal_real_value)) {
3326       return false;
3327     }
3328   }
3329   const ParsedElemComponentT parsed_imag_value = GetImag(value);
3330   auto literal_imag_value =
3331       static_cast<LiteralNativeComponentT>(parsed_imag_value);
3332   if (std::is_same<ParsedElemT, std::complex<double>>::value) {
3333     if (!handle_nan(parsed_real_value, &literal_imag_value)) {
3334       return false;
3335     }
3336   }
3337   literal->data<LiteralNativeT>().at(index) =
3338       LiteralNativeFromRealImag<LiteralNativeT>(literal_real_value,
3339                                                 literal_imag_value);
3340   return true;
3341 }
3342 
3343 // Similar to ParseLiteral(Literal* literal, const Shape& shape), but parse the
3344 // shape instead of accepting one as argument.
ParseLiteral(Literal * literal)3345 bool HloParserImpl::ParseLiteral(Literal* literal) {
3346   if (lexer_.GetKind() == TokKind::kLparen) {
3347     // Consume Lparen
3348     lexer_.Lex();
3349     std::vector<Literal> elements;
3350     while (lexer_.GetKind() != TokKind::kRparen) {
3351       Literal element;
3352       if (!ParseLiteral(&element)) {
3353         return TokenError("Fails when parsing tuple element");
3354       }
3355       elements.emplace_back(std::move(element));
3356       if (lexer_.GetKind() != TokKind::kRparen) {
3357         ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3358       }
3359     }
3360 
3361     *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3362     // Consume Rparen
3363     return ParseToken(TokKind::kRparen, "expects ')' to close a tuple literal");
3364   }
3365   Shape literal_shape;
3366   if (!ParseShape(&literal_shape)) {
3367     return false;
3368   }
3369   return ParseLiteral(literal, literal_shape);
3370 }
3371 
3372 // literal
3373 //  ::= tuple
3374 //  ::= non_tuple
ParseLiteral(Literal * literal,const Shape & shape)3375 bool HloParserImpl::ParseLiteral(Literal* literal, const Shape& shape) {
3376   return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
3377                          : ParseNonTupleLiteral(literal, shape);
3378 }
3379 
3380 // tuple
3381 //  ::= shape '(' literal_list ')'
3382 // literal_list
3383 //  ::= /*empty*/
3384 //  ::= literal (',' literal)*
ParseTupleLiteral(Literal * literal,const Shape & shape)3385 bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
3386   if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
3387     return false;
3388   }
3389   std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
3390 
3391   if (lexer_.GetKind() == TokKind::kRparen) {
3392     // empty
3393   } else {
3394     // literal, (',' literal)*
3395     for (int i = 0; i < elements.size(); i++) {
3396       if (i > 0) {
3397         ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3398       }
3399       if (!ParseLiteral(&elements[i],
3400                         ShapeUtil::GetTupleElementShape(shape, i))) {
3401         return TokenError(StrCat("expects the ", i, "th element"));
3402       }
3403     }
3404   }
3405   *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3406   return ParseToken(TokKind::kRparen,
3407                     StrCat("expects ')' at the end of the tuple with ",
3408                            ShapeUtil::TupleElementCount(shape), "elements"));
3409 }
3410 
3411 // non_tuple
3412 //   ::= rank01
3413 //   ::= rank2345
3414 // rank2345 ::= shape nested_array
ParseNonTupleLiteral(Literal * literal,const Shape & shape)3415 bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
3416   CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
3417   return ParseDenseLiteral(literal, shape);
3418 }
3419 
ParseDenseLiteral(Literal * literal,const Shape & shape)3420 bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
3421   // Cast `rank` to int because we call shape.dimensions(int rank) below, and if
3422   // `rank` is an int64, that's an implicit narrowing conversion, which is
3423   // implementation-defined behavior.
3424   const int rank = static_cast<int>(shape.rank());
3425 
3426   // Create a literal with the given shape in default layout.
3427   *literal = LiteralUtil::CreateFromDimensions(
3428       shape.element_type(), AsInt64Slice(shape.dimensions()));
3429   int64_t nest_level = 0;
3430   int64_t linear_index = 0;
3431   // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
3432   // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
3433   // when we are parsing the 2nd '{' (right before '1'), we are seeing a
3434   // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
3435   // the first '}' (right after '3'), it means the sub-array ends, and the
3436   // sub-array is supposed to contain exactly 3 elements, so check if
3437   // elems_seen_per_dim[1] is 3.
3438   std::vector<int64> elems_seen_per_dim(rank);
3439   auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string {
3440     std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
3441                                             elems_seen_per_dim.begin() + dim);
3442     return StrCat("[",
3443                   StrJoin(elems_seen_until_dim, ",",
3444                           [](std::string* out, const int64_t num_elems) {
3445                             StrAppend(out, num_elems - 1);
3446                           }),
3447                   "]");
3448   };
3449 
3450   auto add_one_elem_seen = [&] {
3451     if (rank > 0) {
3452       if (nest_level != rank) {
3453         return TokenError(absl::StrFormat(
3454             "expects nested array in rank %d, but sees %d", rank, nest_level));
3455       }
3456       elems_seen_per_dim[rank - 1]++;
3457       if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
3458         return TokenError(absl::StrFormat(
3459             "expects %d elements on the minor-most dimension, but "
3460             "sees more",
3461             shape.dimensions(rank - 1)));
3462       }
3463     }
3464     return true;
3465   };
3466 
3467   do {
3468     switch (lexer_.GetKind()) {
3469       default:
3470         return TokenError("unexpected token type in a literal");
3471       case TokKind::kLbrace: {
3472         nest_level++;
3473         if (nest_level > rank) {
3474           return TokenError(absl::StrFormat(
3475               "expects nested array in rank %d, but sees larger", rank));
3476         }
3477         if (nest_level > 1) {
3478           elems_seen_per_dim[nest_level - 2]++;
3479           if (elems_seen_per_dim[nest_level - 2] >
3480               shape.dimensions(nest_level - 2)) {
3481             return TokenError(absl::StrFormat(
3482                 "expects %d elements in the %sth element, but sees more",
3483                 shape.dimensions(nest_level - 2),
3484                 get_index_str(nest_level - 2)));
3485           }
3486         }
3487         lexer_.Lex();
3488         break;
3489       }
3490       case TokKind::kRbrace: {
3491         nest_level--;
3492         if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
3493           return TokenError(absl::StrFormat(
3494               "expects %d elements in the %sth element, but sees %d",
3495               shape.dimensions(nest_level), get_index_str(nest_level),
3496               elems_seen_per_dim[nest_level]));
3497         }
3498         elems_seen_per_dim[nest_level] = 0;
3499         lexer_.Lex();
3500         break;
3501       }
3502       case TokKind::kLparen: {
3503         if (!primitive_util::IsComplexType(shape.element_type())) {
3504           return TokenError(
3505               absl::StrFormat("unexpected '(' in literal.  Parens are only "
3506                               "valid for complex literals"));
3507         }
3508 
3509         std::complex<double> value;
3510         LocTy loc = lexer_.GetLoc();
3511         if (!add_one_elem_seen() || !ParseComplex(&value) ||
3512             !SetValueInLiteral(loc, value, linear_index++, literal)) {
3513           return false;
3514         }
3515         break;
3516       }
3517       case TokKind::kDots: {
3518         if (nest_level != 1) {
3519           return TokenError(absl::StrFormat(
3520               "expects `...` at nest level 1, but sees it at nest level %d",
3521               nest_level));
3522         }
3523         elems_seen_per_dim[0] = shape.dimensions(0);
3524         lexer_.Lex();
3525         // Fill data with deterministic (garbage) values. Use static to avoid
3526         // creating identical constants which could potentially got CSE'ed
3527         // away. This is a best-effort approach to make sure replaying a HLO
3528         // gives us same optimized HLO graph.
3529         static uint32 data = 0;
3530         uint32* raw_data = static_cast<uint32*>(literal->untyped_data());
3531         for (int64_t i = 0; i < literal->size_bytes() / 4; ++i) {
3532           raw_data[i] = data++;
3533         }
3534         uint8* raw_data_int8 = static_cast<uint8*>(literal->untyped_data());
3535         static uint8 data_int8 = 0;
3536         for (int64_t i = 0; i < literal->size_bytes() % 4; ++i) {
3537           raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++;
3538         }
3539         break;
3540       }
3541       case TokKind::kComma:
3542         // Skip.
3543         lexer_.Lex();
3544         break;
3545       case TokKind::kw_true:
3546       case TokKind::kw_false:
3547       case TokKind::kInt:
3548       case TokKind::kDecimal:
3549       case TokKind::kw_inf:
3550       case TokKind::kNegInf: {
3551         add_one_elem_seen();
3552         if (lexer_.GetKind() == TokKind::kw_true ||
3553             lexer_.GetKind() == TokKind::kw_false) {
3554           if (!SetValueInLiteral(lexer_.GetLoc(),
3555                                  lexer_.GetKind() == TokKind::kw_true,
3556                                  linear_index++, literal)) {
3557             return false;
3558           }
3559           lexer_.Lex();
3560         } else if (primitive_util::IsIntegralType(shape.element_type()) ||
3561                    shape.element_type() == PRED) {
3562           LocTy loc = lexer_.GetLoc();
3563           int64_t value;
3564           if (!ParseInt64(&value)) {
3565             return Error(loc, StrCat("expects integer for primitive type: ",
3566                                      PrimitiveType_Name(shape.element_type())));
3567           }
3568           if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3569             return false;
3570           }
3571         } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
3572           LocTy loc = lexer_.GetLoc();
3573           double value;
3574           if (!ParseDouble(&value)) {
3575             return Error(
3576                 loc, StrCat("expect floating point value for primitive type: ",
3577                             PrimitiveType_Name(shape.element_type())));
3578           }
3579           if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3580             return false;
3581           }
3582         } else {
3583           return TokenError(StrCat("unsupported primitive type ",
3584                                    PrimitiveType_Name(shape.element_type())));
3585         }
3586         break;
3587       }
3588     }  // end of switch
3589   } while (nest_level > 0);
3590 
3591   *literal = literal->Relayout(shape.layout());
3592   return true;
3593 }
3594 
3595 // MaxFiniteValue is a type-traits helper used by
3596 // HloParserImpl::CheckParsedValueIsInRange.
3597 template <typename T>
3598 struct MinMaxFiniteValue {
maxxla::__anone3ea97620111::MinMaxFiniteValue3599   static T max() { return std::numeric_limits<T>::max(); }
minxla::__anone3ea97620111::MinMaxFiniteValue3600   static T min() { return std::numeric_limits<T>::lowest(); }
3601 };
3602 
3603 template <>
3604 struct MinMaxFiniteValue<Eigen::half> {
maxxla::__anone3ea97620111::MinMaxFiniteValue3605   static double max() {
3606     // Sadly this is not constexpr, so this forces `value` to be a method.
3607     return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
3608   }
minxla::__anone3ea97620111::MinMaxFiniteValue3609   static double min() { return -max(); }
3610 };
3611 
3612 template <>
3613 struct MinMaxFiniteValue<bfloat16> {
maxxla::__anone3ea97620111::MinMaxFiniteValue3614   static double max() {
3615     return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
3616   }
minxla::__anone3ea97620111::MinMaxFiniteValue3617   static double min() { return -max(); }
3618 };
3619 
3620 // MSVC's standard C++ library does not define isnan/isfinite for integer types.
3621 // To work around that we will need to provide our own.
3622 template <typename T>
IsFinite(T val)3623 std::enable_if_t<std::is_floating_point<T>::value, bool> IsFinite(T val) {
3624   return std::isfinite(val);
3625 }
3626 template <typename T>
IsNaN(T val)3627 std::enable_if_t<std::is_floating_point<T>::value, bool> IsNaN(T val) {
3628   return std::isnan(val);
3629 }
3630 template <typename T>
IsFinite(T val)3631 std::enable_if_t<std::is_integral<T>::value, bool> IsFinite(T val) {
3632   return std::isfinite(static_cast<double>(val));
3633 }
3634 template <typename T>
IsNaN(T val)3635 std::enable_if_t<std::is_integral<T>::value, bool> IsNaN(T val) {
3636   return std::isnan(static_cast<double>(val));
3637 }
3638 
3639 template <typename LiteralNativeT, typename ParsedElemT>
CheckParsedValueIsInRange(LocTy loc,ParsedElemT value)3640 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
3641   if (std::is_floating_point<ParsedElemT>::value) {
3642     auto value_as_native_t = static_cast<LiteralNativeT>(value);
3643     auto value_double_converted = static_cast<ParsedElemT>(value_as_native_t);
3644     if (!IsFinite(value) || IsFinite(value_double_converted)) {
3645       value = value_double_converted;
3646     }
3647   }
3648   PrimitiveType literal_ty =
3649       primitive_util::NativeToPrimitiveType<LiteralNativeT>();
3650   if (IsNaN(value) ||
3651       (std::numeric_limits<ParsedElemT>::has_infinity &&
3652        (std::numeric_limits<ParsedElemT>::infinity() == value ||
3653         -std::numeric_limits<ParsedElemT>::infinity() == value))) {
3654     // Skip range checking for non-finite value.
3655   } else if (std::is_unsigned<LiteralNativeT>::value) {
3656     CHECK((std::is_same<ParsedElemT, int64>::value ||
3657            std::is_same<ParsedElemT, bool>::value))
3658         << "Unimplemented checking for ParsedElemT";
3659 
3660     const uint64 unsigned_value = value;
3661     const uint64 upper_bound =
3662         static_cast<uint64>(std::numeric_limits<LiteralNativeT>::max());
3663     if (unsigned_value > upper_bound) {
3664       // Value is out of range for LiteralNativeT.
3665       return Error(loc, StrCat("value ", value,
3666                                " is out of range for literal's primitive type ",
3667                                PrimitiveType_Name(literal_ty), " namely [0, ",
3668                                upper_bound, "]."));
3669     }
3670   } else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
3671              value < MinMaxFiniteValue<LiteralNativeT>::min()) {
3672     // Value is out of range for LiteralNativeT.
3673     return Error(loc, StrCat("value ", value,
3674                              " is out of range for literal's primitive type ",
3675                              PrimitiveType_Name(literal_ty), " namely [",
3676                              MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
3677                              MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
3678   }
3679   return true;
3680 }
3681 
3682 template <typename LiteralNativeT>
CheckParsedValueIsInRange(LocTy loc,std::complex<double> value)3683 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc,
3684                                               std::complex<double> value) {
3685   // e.g. `float` for std::complex<float>
3686   using LiteralComplexComponentT =
3687       decltype(std::real(std::declval<LiteralNativeT>()));
3688 
3689   // We could do simply
3690   //
3691   //   return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
3692   //          CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
3693   //
3694   // but this would give bad error messages on failure.
3695 
3696   auto check_component = [&](absl::string_view name, double v) {
3697     if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
3698         v == -std::numeric_limits<double>::infinity()) {
3699       // Skip range-checking for non-finite values.
3700       return true;
3701     }
3702 
3703     double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
3704     double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
3705     if (v < min || v > max) {
3706       // Value is out of range for LitearlComplexComponentT.
3707       return Error(
3708           loc,
3709           StrCat(name, " part ", v,
3710                  " is out of range for literal's primitive type ",
3711                  PrimitiveType_Name(
3712                      primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
3713                  ", namely [", min, ", ", max, "]."));
3714     }
3715     return true;
3716   };
3717   return check_component("real", std::real(value)) &&
3718          check_component("imaginary", std::imag(value));
3719 }
3720 
3721 // operands ::= '(' operands1 ')'
3722 // operands1
3723 //   ::= /*empty*/
3724 //   ::= operand (, operand)*
3725 // operand ::= (shape)? name
ParseOperands(std::vector<HloInstruction * > * operands)3726 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands) {
3727   CHECK(operands != nullptr);
3728   if (!ParseToken(TokKind::kLparen,
3729                   "expects '(' at the beginning of operands")) {
3730     return false;
3731   }
3732   if (lexer_.GetKind() == TokKind::kRparen) {
3733     // empty
3734   } else {
3735     do {
3736       LocTy loc = lexer_.GetLoc();
3737       std::string name;
3738       optional<Shape> shape;
3739       if (CanBeShape()) {
3740         shape.emplace();
3741         if (!ParseShape(&shape.value())) {
3742           return false;
3743         }
3744       }
3745       if (!ParseName(&name)) {
3746         // When parsing a single instruction (as opposed to a whole module), an
3747         // HLO may have one or more operands with a shape but no name:
3748         //
3749         //  foo = add(f32[10], f32[10])
3750         //
3751         // create_missing_instruction_ is always non-null when parsing a single
3752         // instruction, and is responsible for creating kParameter instructions
3753         // for these operands.
3754         if (shape.has_value() && create_missing_instruction_ != nullptr &&
3755             scoped_name_tables_.size() == 1) {
3756           name = "";
3757         } else {
3758           return false;
3759         }
3760       }
3761       std::pair<HloInstruction*, LocTy>* instruction =
3762           FindInstruction(name, shape);
3763       if (instruction == nullptr) {
3764         return Error(loc, StrCat("instruction does not exist: ", name));
3765       }
3766       operands->push_back(instruction->first);
3767     } while (EatIfPresent(TokKind::kComma));
3768   }
3769   return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
3770 }
3771 
ParseOperands(std::vector<HloInstruction * > * operands,const int expected_size)3772 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
3773                                   const int expected_size) {
3774   CHECK(operands != nullptr);
3775   LocTy loc = lexer_.GetLoc();
3776   if (!ParseOperands(operands)) {
3777     return false;
3778   }
3779   if (expected_size != operands->size()) {
3780     return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
3781                              operands->size(), " operands"));
3782   }
3783   return true;
3784 }
3785 
3786 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)3787 bool HloParserImpl::ParseSubAttributes(
3788     const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
3789   LocTy loc = lexer_.GetLoc();
3790   if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
3791     return false;
3792   }
3793   absl::flat_hash_set<std::string> seen_attrs;
3794   if (lexer_.GetKind() == TokKind::kRbrace) {
3795     // empty
3796   } else {
3797     do {
3798       EatIfPresent(TokKind::kComma);
3799       if (!ParseAttributeHelper(attrs, &seen_attrs)) {
3800         return false;
3801       }
3802     } while (lexer_.GetKind() != TokKind::kRbrace);
3803   }
3804   // Check that all required attrs were seen.
3805   for (const auto& attr_it : attrs) {
3806     if (attr_it.second.required &&
3807         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
3808       return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
3809                                   attr_it.first));
3810     }
3811   }
3812   return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
3813 }
3814 
3815 // attributes ::= (',' attribute)*
ParseAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)3816 bool HloParserImpl::ParseAttributes(
3817     const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
3818   LocTy loc = lexer_.GetLoc();
3819   absl::flat_hash_set<std::string> seen_attrs;
3820   while (EatIfPresent(TokKind::kComma)) {
3821     if (!ParseAttributeHelper(attrs, &seen_attrs)) {
3822       return false;
3823     }
3824   }
3825   // Check that all required attrs were seen.
3826   for (const auto& attr_it : attrs) {
3827     if (attr_it.second.required &&
3828         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
3829       return Error(loc, StrFormat("attribute %s is expected but not seen",
3830                                   attr_it.first));
3831     }
3832   }
3833   return true;
3834 }
3835 
ParseAttributeHelper(const absl::flat_hash_map<std::string,AttrConfig> & attrs,absl::flat_hash_set<std::string> * seen_attrs)3836 bool HloParserImpl::ParseAttributeHelper(
3837     const absl::flat_hash_map<std::string, AttrConfig>& attrs,
3838     absl::flat_hash_set<std::string>* seen_attrs) {
3839   LocTy loc = lexer_.GetLoc();
3840   std::string name;
3841   if (!ParseAttributeName(&name)) {
3842     return Error(loc, "error parsing attributes");
3843   }
3844   VLOG(3) << "Parsing attribute " << name;
3845   if (!seen_attrs->insert(name).second) {
3846     return Error(loc, StrFormat("attribute %s already exists", name));
3847   }
3848   auto attr_it = attrs.find(name);
3849   if (attr_it == attrs.end()) {
3850     std::string allowed_attrs;
3851     if (attrs.empty()) {
3852       allowed_attrs = "No attributes are allowed here.";
3853     } else {
3854       allowed_attrs =
3855           StrCat("Allowed attributes: ",
3856                  StrJoin(attrs, ", ",
3857                          [&](std::string* out,
3858                              const std::pair<std::string, AttrConfig>& kv) {
3859                            StrAppend(out, kv.first);
3860                          }));
3861     }
3862     return Error(loc, StrFormat("unexpected attribute \"%s\".  %s", name,
3863                                 allowed_attrs));
3864   }
3865   AttrTy attr_type = attr_it->second.attr_type;
3866   void* attr_out_ptr = attr_it->second.result;
3867   bool success = [&] {
3868     LocTy attr_loc = lexer_.GetLoc();
3869     switch (attr_type) {
3870       case AttrTy::kBool: {
3871         bool result;
3872         if (!ParseBool(&result)) {
3873           return false;
3874         }
3875         static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
3876         return true;
3877       }
3878       case AttrTy::kInt64: {
3879         int64_t result;
3880         if (!ParseInt64(&result)) {
3881           return false;
3882         }
3883         static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
3884         return true;
3885       }
3886       case AttrTy::kInt32: {
3887         int64_t result;
3888         if (!ParseInt64(&result)) {
3889           return false;
3890         }
3891         if (result != static_cast<int32>(result)) {
3892           return Error(attr_loc, "value out of range for int32");
3893         }
3894         static_cast<optional<int32>*>(attr_out_ptr)
3895             ->emplace(static_cast<int32>(result));
3896         return true;
3897       }
3898       case AttrTy::kFloat: {
3899         double result;
3900         if (!ParseDouble(&result)) {
3901           return false;
3902         }
3903         if (result > std::numeric_limits<float>::max() ||
3904             result < std::numeric_limits<float>::lowest()) {
3905           return Error(attr_loc, "value out of range for float");
3906         }
3907         static_cast<optional<float>*>(attr_out_ptr)
3908             ->emplace(static_cast<float>(result));
3909         return true;
3910       }
3911       case AttrTy::kHloComputation: {
3912         HloComputation* result = nullptr;
3913         if (!ParseHloComputation(&result)) {
3914           return false;
3915         }
3916         static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
3917         return true;
3918       }
3919       case AttrTy::kBracedHloComputationList: {
3920         std::vector<HloComputation*> result;
3921         if (!ParseHloComputationList(&result)) {
3922           return false;
3923         }
3924         static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
3925             ->emplace(result);
3926         return true;
3927       }
3928       case AttrTy::kFftType: {
3929         FftType result;
3930         if (!ParseFftType(&result)) {
3931           return false;
3932         }
3933         static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
3934         return true;
3935       }
3936       case AttrTy::kPaddingType: {
3937         PaddingType result;
3938         if (!ParsePaddingType(&result)) {
3939           return false;
3940         }
3941         static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
3942         return true;
3943       }
3944       case AttrTy::kComparisonDirection: {
3945         ComparisonDirection result;
3946         if (!ParseComparisonDirection(&result)) {
3947           return false;
3948         }
3949         static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
3950             ->emplace(result);
3951         return true;
3952       }
3953       case AttrTy::kComparisonType: {
3954         Comparison::Type result;
3955         if (!ParseComparisonType(&result)) {
3956           return false;
3957         }
3958         static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
3959         return true;
3960       }
3961       case AttrTy::kEnum: {
3962         if (lexer_.GetKind() != TokKind::kIdent) {
3963           return TokenError("expects an enumeration value");
3964         }
3965         std::string result = lexer_.GetStrVal();
3966         lexer_.Lex();
3967         static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
3968         return true;
3969       }
3970       case AttrTy::kWindow: {
3971         Window result;
3972         if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
3973           return false;
3974         }
3975         static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
3976         return true;
3977       }
3978       case AttrTy::kConvolutionDimensionNumbers: {
3979         ConvolutionDimensionNumbers result;
3980         if (!ParseConvolutionDimensionNumbers(&result)) {
3981           return false;
3982         }
3983         static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
3984             ->emplace(result);
3985         return true;
3986       }
3987       case AttrTy::kSharding: {
3988         OpSharding sharding;
3989         if (!ParseSharding(&sharding)) {
3990           return false;
3991         }
3992         static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
3993         return true;
3994       }
3995       case AttrTy::kFrontendAttributes: {
3996         FrontendAttributes frontend_attributes;
3997         if (!ParseFrontendAttributes(&frontend_attributes)) {
3998           return false;
3999         }
4000         static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
4001             ->emplace(frontend_attributes);
4002         return true;
4003       }
4004       case AttrTy::kParameterReplication: {
4005         ParameterReplication parameter_replication;
4006         if (!ParseParameterReplication(&parameter_replication)) {
4007           return false;
4008         }
4009         static_cast<optional<ParameterReplication>*>(attr_out_ptr)
4010             ->emplace(parameter_replication);
4011         return true;
4012       }
4013       case AttrTy::kInstructionList: {
4014         std::vector<HloInstruction*> result;
4015         if (!ParseInstructionNames(&result)) {
4016           return false;
4017         }
4018         static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
4019             ->emplace(result);
4020         return true;
4021       }
4022       case AttrTy::kFusionKind: {
4023         HloInstruction::FusionKind result;
4024         if (!ParseFusionKind(&result)) {
4025           return false;
4026         }
4027         static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
4028             ->emplace(result);
4029         return true;
4030       }
4031       case AttrTy::kBracedInt64List: {
4032         std::vector<int64> result;
4033         if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4034                             &result)) {
4035           return false;
4036         }
4037         static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
4038             ->emplace(result);
4039         return true;
4040       }
4041       case AttrTy::kBracedInt64ListList: {
4042         std::vector<std::vector<int64>> result;
4043         if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
4044                                 TokKind::kComma, &result)) {
4045           return false;
4046         }
4047         static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
4048             ->emplace(result);
4049         return true;
4050       }
4051       case AttrTy::kSliceRanges: {
4052         SliceRanges result;
4053         if (!ParseSliceRanges(&result)) {
4054           return false;
4055         }
4056         static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
4057         return true;
4058       }
4059       case AttrTy::kPaddingConfig: {
4060         PaddingConfig result;
4061         if (!ParsePaddingConfig(&result)) {
4062           return false;
4063         }
4064         static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
4065         return true;
4066       }
4067       case AttrTy::kString: {
4068         std::string result;
4069         if (!ParseString(&result)) {
4070           return false;
4071         }
4072         static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
4073         return true;
4074       }
4075       case AttrTy::kMetadata: {
4076         OpMetadata result;
4077         if (!ParseMetadata(&result)) {
4078           return false;
4079         }
4080         static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
4081         return true;
4082       }
4083       case AttrTy::kDistribution: {
4084         RandomDistribution result;
4085         if (!ParseRandomDistribution(&result)) {
4086           return false;
4087         }
4088         static_cast<optional<RandomDistribution>*>(attr_out_ptr)
4089             ->emplace(result);
4090         return true;
4091       }
4092       case AttrTy::kDomain: {
4093         return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
4094       }
4095       case AttrTy::kPrecisionList: {
4096         std::vector<PrecisionConfig::Precision> result;
4097         if (!ParsePrecisionList(&result)) {
4098           return false;
4099         }
4100         static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
4101             attr_out_ptr)
4102             ->emplace(result);
4103         return true;
4104       }
4105       case AttrTy::kShape: {
4106         Shape result;
4107         if (!ParseShape(&result)) {
4108           return false;
4109         }
4110         static_cast<optional<Shape>*>(attr_out_ptr)->emplace(result);
4111         return true;
4112       }
4113       case AttrTy::kShapeList: {
4114         std::vector<Shape> result;
4115         if (!ParseShapeList(&result)) {
4116           return false;
4117         }
4118         static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
4119             ->emplace(result);
4120         return true;
4121       }
4122       case AttrTy::kRandomAlgorithm: {
4123         RandomAlgorithm result;
4124         if (!ParseRandomAlgorithm(&result)) {
4125           return false;
4126         }
4127         static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
4128         return true;
4129       }
4130       case AttrTy::kAliasing: {
4131         AliasingData aliasing_data;
4132         if (!ParseAliasing(&aliasing_data)) {
4133           return false;
4134         }
4135         static_cast<optional<AliasingData>*>(attr_out_ptr)
4136             ->emplace(aliasing_data);
4137         return true;
4138       }
4139       case AttrTy::kInstructionAliasing: {
4140         std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4141             aliasing_output_operand_pairs;
4142         if (!ParseInstructionOutputOperandAliasing(
4143                 &aliasing_output_operand_pairs)) {
4144           return false;
4145         }
4146         static_cast<optional<
4147             std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
4148             attr_out_ptr)
4149             ->emplace(std::move(aliasing_output_operand_pairs));
4150         return true;
4151       }
4152       case AttrTy::kLiteral: {
4153         Literal result;
4154         if (!ParseLiteral(&result)) {
4155           return false;
4156         }
4157         static_cast<optional<Literal>*>(attr_out_ptr)
4158             ->emplace(std::move(result));
4159         return true;
4160       }
4161       case AttrTy::kCustomCallSchedule: {
4162         CustomCallSchedule result;
4163         if (!ParseCustomCallSchedule(&result)) {
4164           return false;
4165         }
4166         static_cast<optional<CustomCallSchedule>*>(attr_out_ptr)
4167             ->emplace(result);
4168         return true;
4169       }
4170       case AttrTy::kCustomCallApiVersion: {
4171         CustomCallApiVersion result;
4172         if (!ParseCustomCallApiVersion(&result)) {
4173           return false;
4174         }
4175         static_cast<optional<CustomCallApiVersion>*>(attr_out_ptr)
4176             ->emplace(result);
4177         return true;
4178       }
4179     }
4180   }();
4181   if (!success) {
4182     return Error(loc, StrFormat("error parsing attribute %s", name));
4183   }
4184   return true;
4185 }
4186 
CopyAttributeToProtoMessage(absl::flat_hash_set<std::string> non_proto_attrs,const absl::flat_hash_map<std::string,AttrConfig> & attrs,tensorflow::protobuf::Message * message)4187 bool HloParserImpl::CopyAttributeToProtoMessage(
4188     absl::flat_hash_set<std::string> non_proto_attrs,
4189     const absl::flat_hash_map<std::string, AttrConfig>& attrs,
4190     tensorflow::protobuf::Message* message) {
4191   const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
4192   const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
4193 
4194   for (const auto& p : attrs) {
4195     const std::string& name = p.first;
4196     if (non_proto_attrs.find(name) != non_proto_attrs.end()) {
4197       continue;
4198     }
4199     const tensorflow::protobuf::FieldDescriptor* fd =
4200         descriptor->FindFieldByName(name);
4201     if (!fd) {
4202       std::string allowed_attrs = "Allowed attributes: ";
4203 
4204       for (int i = 0; i < descriptor->field_count(); ++i) {
4205         if (i == 0) {
4206           absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
4207         } else {
4208           absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
4209         }
4210       }
4211       return TokenError(
4212           StrFormat("unexpected attribute \"%s\".  %s", name, allowed_attrs));
4213     }
4214 
4215     CHECK(!fd->is_repeated());  // Repeated fields not implemented.
4216     bool success = [&] {
4217       switch (fd->type()) {
4218         case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
4219           auto attr_value = static_cast<optional<bool>*>(p.second.result);
4220           if (attr_value->has_value()) {
4221             reflection->SetBool(message, fd, **attr_value);
4222           }
4223           return true;
4224         }
4225         case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
4226           auto attr_value =
4227               static_cast<optional<std::string>*>(p.second.result);
4228           if (attr_value->has_value()) {
4229             const tensorflow::protobuf::EnumValueDescriptor* evd =
4230                 fd->enum_type()->FindValueByName(**attr_value);
4231             reflection->SetEnum(message, fd, evd);
4232           }
4233           return true;
4234         }
4235         default:
4236           return false;
4237       }
4238     }();
4239 
4240     if (!success) {
4241       return TokenError(StrFormat("error parsing attribute %s", name));
4242     }
4243   }
4244 
4245   return true;
4246 }
4247 
4248 // attributes ::= (',' attribute)*
ParseAttributesAsProtoMessage(const absl::flat_hash_map<std::string,AttrConfig> & non_proto_attrs,tensorflow::protobuf::Message * message)4249 bool HloParserImpl::ParseAttributesAsProtoMessage(
4250     const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
4251     tensorflow::protobuf::Message* message) {
4252   const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
4253   absl::flat_hash_map<std::string, AttrConfig> attrs;
4254 
4255   // Storage for attributes.
4256   std::vector<optional<bool>> bool_params;
4257   std::vector<optional<std::string>> string_params;
4258   // Reserve enough capacity to make sure that the vector is not growing, so we
4259   // can rely on the pointers to stay valid.
4260   bool_params.reserve(descriptor->field_count());
4261   string_params.reserve(descriptor->field_count());
4262 
4263   // Populate the storage of expected attributes from the protobuf description.
4264   for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
4265     const tensorflow::protobuf::FieldDescriptor* fd =
4266         descriptor->field(field_idx);
4267     const std::string& field_name = fd->name();
4268     switch (fd->type()) {
4269       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
4270         bool_params.emplace_back(absl::nullopt);
4271         attrs[field_name] = {/*is_required*/ false, AttrTy::kBool,
4272                              &bool_params.back()};
4273         break;
4274       }
4275       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
4276         string_params.emplace_back(absl::nullopt);
4277         attrs[field_name] = {/*is_required*/ false, AttrTy::kEnum,
4278                              &string_params.back()};
4279         break;
4280       }
4281       default:
4282         return TokenError(absl::StrFormat(
4283             "Unexpected protocol buffer type: %s ", fd->DebugString()));
4284     }
4285   }
4286 
4287   absl::flat_hash_set<std::string> non_proto_attrs_names;
4288   non_proto_attrs_names.reserve(non_proto_attrs.size());
4289   for (const auto& p : non_proto_attrs) {
4290     const std::string& attr_name = p.first;
4291     // If an attribute is both specified within 'non_proto_attrs' and an
4292     // attribute of the proto message, we prefer the attribute of the proto
4293     // message.
4294     if (attrs.find(attr_name) == attrs.end()) {
4295       non_proto_attrs_names.insert(attr_name);
4296       attrs[attr_name] = p.second;
4297     }
4298   }
4299 
4300   if (!ParseAttributes(attrs)) {
4301     return false;
4302   }
4303 
4304   return CopyAttributeToProtoMessage(non_proto_attrs_names, attrs, message);
4305 }
4306 
ParseComputationName(HloComputation ** value)4307 bool HloParserImpl::ParseComputationName(HloComputation** value) {
4308   std::string name;
4309   LocTy loc = lexer_.GetLoc();
4310   if (!ParseName(&name)) {
4311     return Error(loc, "expects computation name");
4312   }
4313   std::pair<HloComputation*, LocTy>* computation =
4314       tensorflow::gtl::FindOrNull(computation_pool_, name);
4315   if (computation == nullptr) {
4316     return Error(loc, StrCat("computation does not exist: ", name));
4317   }
4318   *value = computation->first;
4319   return true;
4320 }
4321 
4322 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
4323 // The subattributes can appear in any order. 'size=' is required, others are
4324 // optional.
ParseWindow(Window * window,bool expect_outer_curlies)4325 bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
4326   LocTy loc = lexer_.GetLoc();
4327   if (expect_outer_curlies &&
4328       !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
4329     return false;
4330   }
4331 
4332   std::vector<int64> size;
4333   std::vector<int64> stride;
4334   std::vector<std::vector<int64>> pad;
4335   std::vector<int64> lhs_dilate;
4336   std::vector<int64> rhs_dilate;
4337   std::vector<int64> rhs_reversal;
4338   const auto end_token =
4339       expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
4340   while (lexer_.GetKind() != end_token) {
4341     LocTy attr_loc = lexer_.GetLoc();
4342     std::string field_name;
4343     if (!ParseAttributeName(&field_name)) {
4344       return Error(attr_loc, "expects sub-attributes in window");
4345     }
4346     bool ok = [&] {
4347       if (field_name == "size") {
4348         return ParseDxD("size", &size);
4349       }
4350       if (field_name == "stride") {
4351         return ParseDxD("stride", &stride);
4352       }
4353       if (field_name == "lhs_dilate") {
4354         return ParseDxD("lhs_dilate", &lhs_dilate);
4355       }
4356       if (field_name == "rhs_dilate") {
4357         return ParseDxD("rls_dilate", &rhs_dilate);
4358       }
4359       if (field_name == "pad") {
4360         return ParseWindowPad(&pad);
4361       }
4362       if (field_name == "rhs_reversal") {
4363         return ParseDxD("rhs_reversal", &rhs_reversal);
4364       }
4365       return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
4366     }();
4367     if (!ok) {
4368       return false;
4369     }
4370   }
4371 
4372   if (!stride.empty() && stride.size() != size.size()) {
4373     return Error(loc, "expects 'stride=' has the same size as 'size='");
4374   }
4375   if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
4376     return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
4377   }
4378   if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
4379     return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
4380   }
4381   if (!pad.empty() && pad.size() != size.size()) {
4382     return Error(loc, "expects 'pad=' has the same size as 'size='");
4383   }
4384 
4385   for (int i = 0; i < size.size(); i++) {
4386     window->add_dimensions()->set_size(size[i]);
4387     if (!pad.empty()) {
4388       window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
4389       window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
4390     }
4391     // If some field is not present, it has the default value.
4392     window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
4393     window->mutable_dimensions(i)->set_base_dilation(
4394         lhs_dilate.empty() ? 1 : lhs_dilate[i]);
4395     window->mutable_dimensions(i)->set_window_dilation(
4396         rhs_dilate.empty() ? 1 : rhs_dilate[i]);
4397     window->mutable_dimensions(i)->set_window_reversal(
4398         rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
4399   }
4400   return !expect_outer_curlies ||
4401          ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
4402 }
4403 
4404 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
4405 // The string looks like "dim_labels=0bf_0io->0bf".
4406 //
4407 // '?' dims don't appear in ConvolutionDimensionNumbers.  There can be more than
4408 // one '?' dim.
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)4409 bool HloParserImpl::ParseConvolutionDimensionNumbers(
4410     ConvolutionDimensionNumbers* dnums) {
4411   if (lexer_.GetKind() != TokKind::kDimLabels) {
4412     return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
4413   }
4414   std::string str = lexer_.GetStrVal();
4415 
4416   // The str is expected to have 3 items, lhs, rhs, out, and it must look like
4417   // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
4418   std::vector<std::string> split1 = absl::StrSplit(str, '_');
4419   if (split1.size() != 2) {
4420     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4421                << str;
4422   }
4423   std::vector<std::string> split2 = absl::StrSplit(split1[1], "->");
4424   if (split2.size() != 2) {
4425     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4426                << str;
4427   }
4428   absl::string_view lhs = split1[0];
4429   absl::string_view rhs = split2[0];
4430   absl::string_view out = split2[1];
4431 
4432   auto is_unique = [](absl::string_view str) -> bool {
4433     absl::flat_hash_set<char> chars;
4434     for (char c : str) {
4435       // '?' dims are skipped.
4436       if (c == '?') {
4437         continue;
4438       }
4439       if (!chars.insert(c).second) {
4440         return false;
4441       }
4442     }
4443     return true;
4444   };
4445 
4446   // lhs
4447   {
4448     if (!is_unique(lhs)) {
4449       return TokenError(
4450           StrCat("expects unique lhs dimension numbers, but sees ", lhs));
4451     }
4452     // Count number of spatial dimensions.
4453     for (char c : lhs) {
4454       if (c != 'b' && c != 'f' && c != '?') {
4455         dnums->add_input_spatial_dimensions(-1);
4456       }
4457     }
4458     for (int i = 0; i < lhs.size(); i++) {
4459       char c = lhs[i];
4460       if (c == '?') {
4461         continue;
4462       } else if (c == 'b') {
4463         dnums->set_input_batch_dimension(i);
4464       } else if (c == 'f') {
4465         dnums->set_input_feature_dimension(i);
4466       } else if (c < '0' + lhs.size() && c >= '0') {
4467         dnums->set_input_spatial_dimensions(c - '0', i);
4468       } else {
4469         return TokenError(StrFormat(
4470             "expects [0-%dbf?] in lhs dimension numbers", lhs.size() - 1));
4471       }
4472     }
4473   }
4474   // rhs
4475   {
4476     if (!is_unique(rhs)) {
4477       return TokenError(
4478           StrCat("expects unique rhs dimension numbers, but sees ", rhs));
4479     }
4480     // Count number of spatial dimensions.
4481     for (char c : rhs) {
4482       if (c != 'i' && c != 'o' && c != '?') {
4483         dnums->add_kernel_spatial_dimensions(-1);
4484       }
4485     }
4486     for (int i = 0; i < rhs.size(); i++) {
4487       char c = rhs[i];
4488       if (c == '?') {
4489         continue;
4490       } else if (c == 'i') {
4491         dnums->set_kernel_input_feature_dimension(i);
4492       } else if (c == 'o') {
4493         dnums->set_kernel_output_feature_dimension(i);
4494       } else if (c < '0' + rhs.size() && c >= '0') {
4495         dnums->set_kernel_spatial_dimensions(c - '0', i);
4496       } else {
4497         return TokenError(StrFormat(
4498             "expects [0-%dio?] in rhs dimension numbers", rhs.size() - 1));
4499       }
4500     }
4501   }
4502   // output
4503   {
4504     if (!is_unique(out)) {
4505       return TokenError(
4506           StrCat("expects unique output dimension numbers, but sees ", out));
4507     }
4508     // Count number of spatial dimensions.
4509     for (char c : out) {
4510       if (c != 'b' && c != 'f' && c != '?') {
4511         dnums->add_output_spatial_dimensions(-1);
4512       }
4513     }
4514     for (int i = 0; i < out.size(); i++) {
4515       char c = out[i];
4516       if (c == '?') {
4517         continue;
4518       } else if (c == 'b') {
4519         dnums->set_output_batch_dimension(i);
4520       } else if (c == 'f') {
4521         dnums->set_output_feature_dimension(i);
4522       } else if (c < '0' + out.size() && c >= '0') {
4523         dnums->set_output_spatial_dimensions(c - '0', i);
4524       } else {
4525         return TokenError(StrFormat(
4526             "expects [0-%dbf?] in output dimension numbers", out.size() - 1));
4527       }
4528     }
4529   }
4530 
4531   // lhs, rhs, and output should have the same number of spatial dimensions.
4532   if (dnums->input_spatial_dimensions_size() !=
4533           dnums->output_spatial_dimensions_size() ||
4534       dnums->input_spatial_dimensions_size() !=
4535           dnums->kernel_spatial_dimensions_size()) {
4536     return TokenError(
4537         StrFormat("input, kernel, and output must have same number of spatial "
4538                   "dimensions, but got %d, %d, %d, respectively.",
4539                   dnums->input_spatial_dimensions_size(),
4540                   dnums->kernel_spatial_dimensions_size(),
4541                   dnums->output_spatial_dimensions_size()));
4542   }
4543 
4544   lexer_.Lex();
4545   return true;
4546 }
4547 
4548 // ::= '{' ranges '}'
4549 //   ::= /*empty*/
4550 //   ::= range (',' range)*
4551 // range ::= '[' start ':' limit (':' stride)? ']'
4552 //
4553 // The slice ranges are printed as:
4554 //
4555 //  {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
4556 //
4557 // This function extracts the starts, limits, and strides as 3 vectors to the
4558 // result. If stride is not present, stride is 1. For example, if the slice
4559 // ranges is printed as:
4560 //
4561 //  {[2:3:4], [5:6:7], [8:9]}
4562 //
4563 // The parsed result will be:
4564 //
4565 //  {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
4566 //
ParseSliceRanges(SliceRanges * result)4567 bool HloParserImpl::ParseSliceRanges(SliceRanges* result) {
4568   if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
4569     return false;
4570   }
4571   std::vector<std::vector<int64>> ranges;
4572   if (lexer_.GetKind() == TokKind::kRbrace) {
4573     // empty
4574     return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4575   }
4576   do {
4577     LocTy loc = lexer_.GetLoc();
4578     ranges.emplace_back();
4579     if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
4580                         &ranges.back())) {
4581       return false;
4582     }
4583     const auto& range = ranges.back();
4584     if (range.size() != 2 && range.size() != 3) {
4585       return Error(loc,
4586                    StrFormat("expects [start:limit:step] or [start:limit], "
4587                              "but sees %d elements.",
4588                              range.size()));
4589     }
4590   } while (EatIfPresent(TokKind::kComma));
4591 
4592   for (const auto& range : ranges) {
4593     result->starts.push_back(range[0]);
4594     result->limits.push_back(range[1]);
4595     result->strides.push_back(range.size() == 3 ? range[2] : 1);
4596   }
4597   return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4598 }
4599 
4600 // precisionlist ::= start precision_elements end
4601 // precision_elements
4602 //   ::= /*empty*/
4603 //   ::= precision_val (delim precision_val)*
ParsePrecisionList(std::vector<PrecisionConfig::Precision> * result)4604 bool HloParserImpl::ParsePrecisionList(
4605     std::vector<PrecisionConfig::Precision>* result) {
4606   auto parse_and_add_item = [&]() {
4607     PrecisionConfig::Precision item;
4608     if (!ParsePrecision(&item)) {
4609       return false;
4610     }
4611     result->push_back(item);
4612     return true;
4613   };
4614   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4615                    parse_and_add_item);
4616 }
4617 
ParseHloComputation(HloComputation ** result)4618 bool HloParserImpl::ParseHloComputation(HloComputation** result) {
4619   if (lexer_.GetKind() == TokKind::kLbrace) {
4620     // This means it is a nested computation.
4621     return ParseInstructionList(result, /*computation_name=*/"_");
4622   }
4623   // This means it is a computation name.
4624   return ParseComputationName(result);
4625 }
4626 
ParseHloComputationList(std::vector<HloComputation * > * result)4627 bool HloParserImpl::ParseHloComputationList(
4628     std::vector<HloComputation*>* result) {
4629   auto parse_and_add_item = [&]() {
4630     HloComputation* computation;
4631     if (!ParseHloComputation(&computation)) {
4632       return false;
4633     }
4634     VLOG(3) << "parsed computation " << computation->name();
4635     result->push_back(computation);
4636     return true;
4637   };
4638   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4639                    parse_and_add_item);
4640 }
4641 
4642 // shapelist ::= '{' shapes '}'
4643 // precision_elements
4644 //   ::= /*empty*/
4645 //   ::= shape (',' shape)*
ParseShapeList(std::vector<Shape> * result)4646 bool HloParserImpl::ParseShapeList(std::vector<Shape>* result) {
4647   auto parse_and_add_item = [&]() {
4648     Shape shape;
4649     if (!ParseShape(&shape)) {
4650       return false;
4651     }
4652     result->push_back(std::move(shape));
4653     return true;
4654   };
4655   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4656                    parse_and_add_item);
4657 }
4658 
4659 // int64list ::= start int64_elements end
4660 // int64_elements
4661 //   ::= /*empty*/
4662 //   ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64> * result)4663 bool HloParserImpl::ParseInt64List(const TokKind start, const TokKind end,
4664                                    const TokKind delim,
4665                                    std::vector<int64>* result) {
4666   auto parse_and_add_item = [&]() {
4667     int64_t i;
4668     if (!ParseInt64(&i)) {
4669       return false;
4670     }
4671     result->push_back(i);
4672     return true;
4673   };
4674   return ParseList(start, end, delim, parse_and_add_item);
4675 }
4676 
4677 // int64listlist ::= start int64list_elements end
4678 // int64list_elements
4679 //   ::= /*empty*/
4680 //   ::= int64list (delim int64list)*
4681 // int64list ::= start int64_elements end
4682 // int64_elements
4683 //   ::= /*empty*/
4684 //   ::= int64_val (delim int64_val)*
ParseInt64ListList(const TokKind start,const TokKind end,const TokKind delim,std::vector<std::vector<int64>> * result)4685 bool HloParserImpl::ParseInt64ListList(
4686     const TokKind start, const TokKind end, const TokKind delim,
4687     std::vector<std::vector<int64>>* result) {
4688   auto parse_and_add_item = [&]() {
4689     std::vector<int64> item;
4690     if (!ParseInt64List(start, end, delim, &item)) {
4691       return false;
4692     }
4693     result->push_back(item);
4694     return true;
4695   };
4696   return ParseList(start, end, delim, parse_and_add_item);
4697 }
4698 
ParseList(const TokKind start,const TokKind end,const TokKind delim,const std::function<bool ()> & parse_and_add_item)4699 bool HloParserImpl::ParseList(const TokKind start, const TokKind end,
4700                               const TokKind delim,
4701                               const std::function<bool()>& parse_and_add_item) {
4702   if (!ParseToken(start, StrCat("expects a list starting with ",
4703                                 TokKindToString(start)))) {
4704     return false;
4705   }
4706   if (lexer_.GetKind() == end) {
4707     // empty
4708   } else {
4709     do {
4710       if (!parse_and_add_item()) {
4711         return false;
4712       }
4713     } while (EatIfPresent(delim));
4714   }
4715   return ParseToken(
4716       end, StrCat("expects a list to end with ", TokKindToString(end)));
4717 }
4718 
4719 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)4720 bool HloParserImpl::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
4721   if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
4722     return false;
4723   }
4724   *shape_loc = lexer_.GetLoc();
4725   return ParseShape(shape);
4726 }
4727 
CanBeParamListToShape()4728 bool HloParserImpl::CanBeParamListToShape() {
4729   return lexer_.GetKind() == TokKind::kLparen;
4730 }
4731 
4732 // param_list ::= '(' param_list1 ')'
4733 // param_list1
4734 //   ::= /*empty*/
4735 //   ::= param (',' param)*
4736 // param ::= name shape
ParseParamList()4737 bool HloParserImpl::ParseParamList() {
4738   if (!ParseToken(TokKind::kLparen,
4739                   "expects '(' at the beginning of param list")) {
4740     return false;
4741   }
4742 
4743   if (lexer_.GetKind() == TokKind::kRparen) {
4744     // empty
4745   } else {
4746     do {
4747       Shape shape;
4748       std::string name;
4749       if (!ParseName(&name) || !ParseShape(&shape)) {
4750         return false;
4751       }
4752     } while (EatIfPresent(TokKind::kComma));
4753   }
4754   return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
4755 }
4756 
4757 // dimension_sizes ::= '[' dimension_list ']'
4758 // dimension_list
4759 //   ::= /*empty*/
4760 //   ::= <=? int64 (',' param)*
4761 // param ::= name shape
ParseDimensionSizes(std::vector<int64> * dimension_sizes,std::vector<bool> * dynamic_dimensions)4762 bool HloParserImpl::ParseDimensionSizes(std::vector<int64>* dimension_sizes,
4763                                         std::vector<bool>* dynamic_dimensions) {
4764   auto parse_and_add_item = [&]() {
4765     int64_t i;
4766     bool is_dynamic = false;
4767     if (lexer_.GetKind() == TokKind::kLeq) {
4768       is_dynamic = true;
4769       lexer_.Lex();
4770     }
4771     if (!ParseInt64(&i)) {
4772       return false;
4773     }
4774     dimension_sizes->push_back(i);
4775     dynamic_dimensions->push_back(is_dynamic);
4776     return true;
4777   };
4778   return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
4779                    parse_and_add_item);
4780 }
4781 
4782 // tiles
4783 //   ::= /*empty*/
4784 //   ::= 'T' '(' dim_list ')'
4785 // dim_list
4786 //   ::= /*empty*/
4787 //   ::= (int64 | '*') (',' (int64 | '*'))*
ParseTiles(std::vector<Tile> * tiles)4788 bool HloParserImpl::ParseTiles(std::vector<Tile>* tiles) {
4789   auto parse_and_add_tile_dimension = [&]() {
4790     int64_t i;
4791     if (ParseInt64(&i)) {
4792       tiles->back().add_dimensions(i);
4793       return true;
4794     }
4795     if (lexer_.GetKind() == TokKind::kAsterisk) {
4796       tiles->back().add_dimensions(Tile::kCombineDimension);
4797       lexer_.Lex();
4798       return true;
4799     }
4800     return false;
4801   };
4802 
4803   do {
4804     tiles->push_back(Tile());
4805     if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
4806                    parse_and_add_tile_dimension)) {
4807       return false;
4808     }
4809   } while (lexer_.GetKind() == TokKind::kLparen);
4810   return true;
4811 }
4812 
4813 // int_attribute
4814 //   ::= /*empty*/
4815 //   ::= attr_token '(' attr_value ')'
4816 // attr_token
4817 //   ::= 'E' | 'S'
4818 // attr_value
4819 //   ::= int64
ParseLayoutIntAttribute(int64 * attr_value,absl::string_view attr_description)4820 bool HloParserImpl::ParseLayoutIntAttribute(
4821     int64* attr_value, absl::string_view attr_description) {
4822   if (!ParseToken(TokKind::kLparen,
4823                   StrCat("expects ", attr_description, " to start with ",
4824                          TokKindToString(TokKind::kLparen)))) {
4825     return false;
4826   }
4827   if (!ParseInt64(attr_value)) {
4828     return false;
4829   }
4830   if (!ParseToken(TokKind::kRparen,
4831                   StrCat("expects ", attr_description, " to end with ",
4832                          TokKindToString(TokKind::kRparen)))) {
4833     return false;
4834   }
4835   return true;
4836 }
4837 
4838 // layout ::= '{' int64_list (':' tiles element_size_in_bits memory_space)? '}'
4839 // element_size_in_bits
4840 //   ::= /*empty*/
4841 //   ::= 'E' '(' int64 ')'
4842 // memory_space
4843 //   ::= /*empty*/
4844 //   ::= 'S' '(' int64 ')'
ParseLayout(Layout * layout)4845 bool HloParserImpl::ParseLayout(Layout* layout) {
4846   std::vector<int64> minor_to_major;
4847   std::vector<Tile> tiles;
4848   int64_t element_size_in_bits = 0;
4849   int64_t memory_space = 0;
4850 
4851   auto parse_and_add_item = [&]() {
4852     int64_t i;
4853     if (!ParseInt64(&i)) {
4854       return false;
4855     }
4856     minor_to_major.push_back(i);
4857     return true;
4858   };
4859 
4860   if (!ParseToken(TokKind::kLbrace,
4861                   StrCat("expects layout to start with ",
4862                          TokKindToString(TokKind::kLbrace)))) {
4863     return false;
4864   }
4865   if (lexer_.GetKind() != TokKind::kRbrace) {
4866     if (lexer_.GetKind() == TokKind::kInt) {
4867       // Parse minor to major.
4868       do {
4869         if (!parse_and_add_item()) {
4870           return false;
4871         }
4872       } while (EatIfPresent(TokKind::kComma));
4873     }
4874 
4875     if (lexer_.GetKind() == TokKind::kColon) {
4876       lexer_.Lex();
4877       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
4878         lexer_.Lex();
4879         ParseTiles(&tiles);
4880       }
4881 
4882       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
4883         lexer_.Lex();
4884         ParseLayoutIntAttribute(&element_size_in_bits, "element size in bits");
4885       }
4886 
4887       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "S") {
4888         lexer_.Lex();
4889         ParseLayoutIntAttribute(&memory_space, "memory space");
4890       }
4891     }
4892   }
4893   if (!ParseToken(TokKind::kRbrace,
4894                   StrCat("expects layout to end with ",
4895                          TokKindToString(TokKind::kRbrace)))) {
4896     return false;
4897   }
4898 
4899   std::vector<Tile> vec_tiles(tiles.size());
4900   for (int i = 0; i < tiles.size(); i++) {
4901     vec_tiles[i] = Tile(tiles[i]);
4902   }
4903   *layout = LayoutUtil::MakeLayout(minor_to_major, vec_tiles,
4904                                    element_size_in_bits, memory_space);
4905   return true;
4906 }
4907 
4908 // shape ::= shape_val_
4909 // shape ::= '(' tuple_elements ')'
4910 // tuple_elements
4911 //   ::= /*empty*/
4912 //   ::= shape (',' shape)*
ParseShape(Shape * result)4913 bool HloParserImpl::ParseShape(Shape* result) {
4914   if (EatIfPresent(TokKind::kLparen)) {  // Tuple
4915     std::vector<Shape> shapes;
4916     if (lexer_.GetKind() == TokKind::kRparen) {
4917       /*empty*/
4918     } else {
4919       // shape (',' shape)*
4920       do {
4921         shapes.emplace_back();
4922         if (!ParseShape(&shapes.back())) {
4923           return false;
4924         }
4925       } while (EatIfPresent(TokKind::kComma));
4926     }
4927     *result = ShapeUtil::MakeTupleShape(shapes);
4928     return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
4929   }
4930 
4931   if (lexer_.GetKind() != TokKind::kPrimitiveType) {
4932     return TokenError(absl::StrCat("expected primitive type, saw ",
4933                                    TokKindToString(lexer_.GetKind())));
4934   }
4935   PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
4936   lexer_.Lex();
4937 
4938   // Each element contains a dimension size and a bool indicating whether this
4939   // is a dynamic dimension.
4940   std::vector<int64> dimension_sizes;
4941   std::vector<bool> dynamic_dimensions;
4942   if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
4943     return false;
4944   }
4945   result->set_element_type(primitive_type);
4946   for (int i = 0; i < dimension_sizes.size(); ++i) {
4947     result->add_dimensions(dimension_sizes[i]);
4948     result->set_dynamic_dimension(i, dynamic_dimensions[i]);
4949   }
4950   if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "invalid") {
4951     lexer_.Lex();
4952     if (lexer_.GetKind() != TokKind::kLbrace) {
4953       return false;
4954     }
4955     lexer_.Lex();
4956     if (lexer_.GetKind() != TokKind::kRbrace) {
4957       return false;
4958     }
4959     lexer_.Lex();
4960     result->mutable_layout()->Clear();
4961     return true;
4962   }
4963   LayoutUtil::SetToDefaultLayout(result);
4964   // We need to lookahead to see if a following open brace is the start of a
4965   // layout. The specific problematic case is:
4966   //
4967   // ENTRY %foo (x: f32[42]) -> f32[123] {
4968   //  ...
4969   // }
4970   //
4971   // The open brace could either be the start of a computation or the start of a
4972   // layout for the f32[123] shape. We consider it the start of a layout if the
4973   // next token after the open brace is an integer or a colon.
4974   if (lexer_.GetKind() == TokKind::kLbrace &&
4975       (lexer_.LookAhead() == TokKind::kInt ||
4976        lexer_.LookAhead() == TokKind::kColon)) {
4977     Layout layout;
4978     if (!ParseLayout(&layout)) {
4979       return false;
4980     }
4981     if (layout.minor_to_major_size() != result->rank()) {
4982       return Error(
4983           lexer_.GetLoc(),
4984           StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
4985                     result->rank(), layout.minor_to_major_size()));
4986     }
4987     *result->mutable_layout() = layout;
4988   }
4989   return true;
4990 }
4991 
CanBeShape()4992 bool HloParserImpl::CanBeShape() {
4993   // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
4994   // with '('.
4995   return lexer_.GetKind() == TokKind::kPrimitiveType ||
4996          lexer_.GetKind() == TokKind::kLparen;
4997 }
4998 
ParseName(std::string * result)4999 bool HloParserImpl::ParseName(std::string* result) {
5000   VLOG(3) << "ParseName";
5001   if (lexer_.GetKind() != TokKind::kIdent &&
5002       lexer_.GetKind() != TokKind::kName) {
5003     return TokenError("expects name");
5004   }
5005   *result = lexer_.GetStrVal();
5006   lexer_.Lex();
5007   return true;
5008 }
5009 
ParseAttributeName(std::string * result)5010 bool HloParserImpl::ParseAttributeName(std::string* result) {
5011   if (lexer_.GetKind() != TokKind::kAttributeName) {
5012     return TokenError("expects attribute name");
5013   }
5014   *result = lexer_.GetStrVal();
5015   lexer_.Lex();
5016   return true;
5017 }
5018 
ParseString(std::string * result)5019 bool HloParserImpl::ParseString(std::string* result) {
5020   VLOG(3) << "ParseString";
5021   if (lexer_.GetKind() != TokKind::kString) {
5022     return TokenError("expects string");
5023   }
5024   *result = lexer_.GetStrVal();
5025   lexer_.Lex();
5026   return true;
5027 }
5028 
ParseDxD(const std::string & name,std::vector<int64> * result)5029 bool HloParserImpl::ParseDxD(const std::string& name,
5030                              std::vector<int64>* result) {
5031   LocTy loc = lexer_.GetLoc();
5032   if (!result->empty()) {
5033     return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
5034   }
5035   // 1D
5036   if (lexer_.GetKind() == TokKind::kInt) {
5037     int64_t number;
5038     if (!ParseInt64(&number)) {
5039       return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
5040     }
5041     result->push_back(number);
5042     return true;
5043   }
5044   // 2D or higher.
5045   if (lexer_.GetKind() == TokKind::kDxD) {
5046     std::string str = lexer_.GetStrVal();
5047     if (!SplitToInt64s(str, 'x', result)) {
5048       return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
5049     }
5050     lexer_.Lex();
5051     return true;
5052   }
5053   return TokenError("expects token type kInt or kDxD");
5054 }
5055 
ParseWindowPad(std::vector<std::vector<int64>> * pad)5056 bool HloParserImpl::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
5057   LocTy loc = lexer_.GetLoc();
5058   if (!pad->empty()) {
5059     return Error(loc, "sub-attribute 'pad=' already exists");
5060   }
5061   if (lexer_.GetKind() != TokKind::kPad) {
5062     return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
5063   }
5064   std::string str = lexer_.GetStrVal();
5065   for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
5066     std::vector<int64> low_high;
5067     if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
5068         low_high.size() != 2) {
5069       return Error(loc,
5070                    "expects padding_low and padding_high separated by '_'");
5071     }
5072     pad->push_back(low_high);
5073   }
5074   lexer_.Lex();
5075   return true;
5076 }
5077 
5078 // This is the inverse xla::ToString(PaddingConfig). The padding config string
5079 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
5080 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
5081 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)5082 bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) {
5083   if (lexer_.GetKind() != TokKind::kPad) {
5084     return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
5085   }
5086   LocTy loc = lexer_.GetLoc();
5087   std::string str = lexer_.GetStrVal();
5088   for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
5089     std::vector<int64> padding_dim;
5090     if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
5091         (padding_dim.size() != 2 && padding_dim.size() != 3)) {
5092       return Error(loc,
5093                    "expects padding config pattern like 'low_high_interior' or "
5094                    "'low_high'");
5095     }
5096     auto* dim = padding->add_dimensions();
5097     dim->set_edge_padding_low(padding_dim[0]);
5098     dim->set_edge_padding_high(padding_dim[1]);
5099     dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
5100   }
5101   lexer_.Lex();
5102   return true;
5103 }
5104 
5105 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)5106 bool HloParserImpl::ParseMetadata(OpMetadata* metadata) {
5107   absl::flat_hash_map<std::string, AttrConfig> attrs;
5108   optional<std::string> op_type;
5109   optional<std::string> op_name;
5110   optional<std::string> source_file;
5111   optional<int32> source_line;
5112   optional<std::vector<int64>> profile_type;
5113   attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
5114   attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
5115   attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
5116   attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
5117   attrs["profile_type"] = {/*required=*/false, AttrTy::kBracedInt64List,
5118                            &profile_type};
5119   if (!ParseSubAttributes(attrs)) {
5120     return false;
5121   }
5122   if (op_type) {
5123     metadata->set_op_type(*op_type);
5124   }
5125   if (op_name) {
5126     metadata->set_op_name(*op_name);
5127   }
5128   if (source_file) {
5129     metadata->set_source_file(*source_file);
5130   }
5131   if (source_line) {
5132     metadata->set_source_line(*source_line);
5133   }
5134   if (profile_type) {
5135     for (const auto& type : *profile_type) {
5136       if (!ProfileType_IsValid(type)) {
5137         return false;
5138       }
5139       metadata->add_profile_type(static_cast<ProfileType>(type));
5140     }
5141   }
5142   return true;
5143 }
5144 
5145 // ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}')
ParseSingleOrListMetadata(tensorflow::protobuf::RepeatedPtrField<OpMetadata> * metadata)5146 bool HloParserImpl::ParseSingleOrListMetadata(
5147     tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata) {
5148   if (lexer_.GetKind() == TokKind::kLbrace &&
5149       lexer_.LookAhead() == TokKind::kLbrace) {
5150     if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) {
5151       return false;
5152     }
5153 
5154     if (lexer_.GetKind() != TokKind::kRbrace) {
5155       do {
5156         if (!ParseMetadata(metadata->Add())) {
5157           return false;
5158         }
5159       } while (EatIfPresent(TokKind::kComma));
5160     }
5161 
5162     return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list");
5163   }
5164 
5165   return ParseMetadata(metadata->Add());
5166 }
5167 
ParseOpShardingType(OpSharding::Type * type)5168 bool HloParserImpl::ParseOpShardingType(OpSharding::Type* type) {
5169   switch (lexer_.GetKind()) {
5170     case TokKind::kw_maximal:
5171       *type = OpSharding::MAXIMAL;
5172       lexer_.Lex();
5173       break;
5174     case TokKind::kw_replicated:
5175       *type = OpSharding::REPLICATED;
5176       lexer_.Lex();
5177       break;
5178     case TokKind::kw_manual:
5179       *type = OpSharding::MANUAL;
5180       lexer_.Lex();
5181       break;
5182     default:
5183       return false;
5184   }
5185   return true;
5186 }
5187 
ParseListShardingType(std::vector<OpSharding::Type> * types)5188 bool HloParserImpl::ParseListShardingType(
5189     std::vector<OpSharding::Type>* types) {
5190   if (!ParseToken(TokKind::kLbrace,
5191                   "expected '{' to start sharding type list")) {
5192     return false;
5193   }
5194 
5195   if (lexer_.GetKind() != TokKind::kRbrace) {
5196     do {
5197       OpSharding::Type type;
5198       if (!ParseOpShardingType(&type)) {
5199         return false;
5200       }
5201       types->emplace_back(type);
5202     } while (EatIfPresent(TokKind::kComma));
5203   }
5204 
5205   return ParseToken(TokKind::kRbrace, "expected '}' to end sharding type list");
5206 }
5207 
ParseOpcode(HloOpcode * result)5208 bool HloParserImpl::ParseOpcode(HloOpcode* result) {
5209   VLOG(3) << "ParseOpcode";
5210   if (lexer_.GetKind() != TokKind::kIdent) {
5211     return TokenError("expects opcode");
5212   }
5213   std::string val = lexer_.GetStrVal();
5214   auto status_or_result = StringToHloOpcode(val);
5215   if (!status_or_result.ok()) {
5216     return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
5217                                 status_or_result.status().error_message()));
5218   }
5219   *result = status_or_result.ValueOrDie();
5220   lexer_.Lex();
5221   return true;
5222 }
5223 
ParseFftType(FftType * result)5224 bool HloParserImpl::ParseFftType(FftType* result) {
5225   VLOG(3) << "ParseFftType";
5226   if (lexer_.GetKind() != TokKind::kIdent) {
5227     return TokenError("expects fft type");
5228   }
5229   std::string val = lexer_.GetStrVal();
5230   if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
5231     return TokenError(StrFormat("expects fft type but sees: %s", val));
5232   }
5233   lexer_.Lex();
5234   return true;
5235 }
5236 
ParsePaddingType(PaddingType * result)5237 bool HloParserImpl::ParsePaddingType(PaddingType* result) {
5238   VLOG(3) << "ParsePaddingType";
5239   if (lexer_.GetKind() != TokKind::kIdent) {
5240     return TokenError("expects padding type");
5241   }
5242   std::string val = lexer_.GetStrVal();
5243   if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
5244     return TokenError(StrFormat("expects padding type but sees: %s", val));
5245   }
5246   lexer_.Lex();
5247   return true;
5248 }
5249 
ParseComparisonDirection(ComparisonDirection * result)5250 bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
5251   VLOG(3) << "ParseComparisonDirection";
5252   if (lexer_.GetKind() != TokKind::kIdent) {
5253     return TokenError("expects comparison direction");
5254   }
5255   std::string val = lexer_.GetStrVal();
5256   auto status_or_result = StringToComparisonDirection(val);
5257   if (!status_or_result.ok()) {
5258     return TokenError(
5259         StrFormat("expects comparison direction but sees: %s", val));
5260   }
5261   *result = status_or_result.ValueOrDie();
5262   lexer_.Lex();
5263   return true;
5264 }
5265 
ParseComparisonType(Comparison::Type * result)5266 bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
5267   VLOG(1) << "ParseComparisonType";
5268   if (lexer_.GetKind() != TokKind::kIdent) {
5269     return TokenError("expects comparison type");
5270   }
5271   std::string val = lexer_.GetStrVal();
5272   auto status_or_result = StringToComparisonType(val);
5273   if (!status_or_result.ok()) {
5274     return TokenError(StrFormat("expects comparison type but sees: %s", val));
5275   }
5276   *result = status_or_result.ValueOrDie();
5277   lexer_.Lex();
5278   return true;
5279 }
5280 
ParseFusionKind(HloInstruction::FusionKind * result)5281 bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
5282   VLOG(3) << "ParseFusionKind";
5283   if (lexer_.GetKind() != TokKind::kIdent) {
5284     return TokenError("expects fusion kind");
5285   }
5286   std::string val = lexer_.GetStrVal();
5287   auto status_or_result = StringToFusionKind(val);
5288   if (!status_or_result.ok()) {
5289     return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
5290                                 val,
5291                                 status_or_result.status().error_message()));
5292   }
5293   *result = status_or_result.ValueOrDie();
5294   lexer_.Lex();
5295   return true;
5296 }
5297 
ParseRandomDistribution(RandomDistribution * result)5298 bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
5299   VLOG(3) << "ParseRandomDistribution";
5300   if (lexer_.GetKind() != TokKind::kIdent) {
5301     return TokenError("expects random distribution");
5302   }
5303   std::string val = lexer_.GetStrVal();
5304   auto status_or_result = StringToRandomDistribution(val);
5305   if (!status_or_result.ok()) {
5306     return TokenError(
5307         StrFormat("expects random distribution but sees: %s, error: %s", val,
5308                   status_or_result.status().error_message()));
5309   }
5310   *result = status_or_result.ValueOrDie();
5311   lexer_.Lex();
5312   return true;
5313 }
5314 
ParseRandomAlgorithm(RandomAlgorithm * result)5315 bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
5316   VLOG(3) << "ParseRandomAlgorithm";
5317   if (lexer_.GetKind() != TokKind::kIdent) {
5318     return TokenError("expects random algorithm");
5319   }
5320   std::string val = lexer_.GetStrVal();
5321   auto status_or_result = StringToRandomAlgorithm(val);
5322   if (!status_or_result.ok()) {
5323     return TokenError(
5324         StrFormat("expects random algorithm but sees: %s, error: %s", val,
5325                   status_or_result.status().error_message()));
5326   }
5327   *result = status_or_result.ValueOrDie();
5328   lexer_.Lex();
5329   return true;
5330 }
5331 
ParsePrecision(PrecisionConfig::Precision * result)5332 bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
5333   VLOG(3) << "ParsePrecision";
5334   if (lexer_.GetKind() != TokKind::kIdent) {
5335     return TokenError("expects random distribution");
5336   }
5337   std::string val = lexer_.GetStrVal();
5338   auto status_or_result = StringToPrecision(val);
5339   if (!status_or_result.ok()) {
5340     return TokenError(StrFormat("expects precision but sees: %s, error: %s",
5341                                 val,
5342                                 status_or_result.status().error_message()));
5343   }
5344   *result = status_or_result.ValueOrDie();
5345   lexer_.Lex();
5346   return true;
5347 }
5348 
ParseInt64(int64 * result)5349 bool HloParserImpl::ParseInt64(int64* result) {
5350   VLOG(3) << "ParseInt64";
5351   if (lexer_.GetKind() != TokKind::kInt) {
5352     return TokenError("expects integer");
5353   }
5354   *result = lexer_.GetInt64Val();
5355   lexer_.Lex();
5356   return true;
5357 }
5358 
ParseDouble(double * result)5359 bool HloParserImpl::ParseDouble(double* result) {
5360   switch (lexer_.GetKind()) {
5361     case TokKind::kDecimal: {
5362       double val = lexer_.GetDecimalVal();
5363       // If GetDecimalVal returns +/-inf, that means that we overflowed
5364       // `double`.
5365       if (std::isinf(val)) {
5366         return TokenError(StrCat("Constant is out of range for double (+/-",
5367                                  std::numeric_limits<double>::max(),
5368                                  ") and so is unparsable."));
5369       }
5370       *result = val;
5371       break;
5372     }
5373     case TokKind::kInt:
5374       *result = static_cast<double>(lexer_.GetInt64Val());
5375       break;
5376     case TokKind::kw_inf:
5377       *result = std::numeric_limits<double>::infinity();
5378       break;
5379     case TokKind::kNegInf:
5380       *result = -std::numeric_limits<double>::infinity();
5381       break;
5382     default:
5383       return TokenError("expects decimal or integer");
5384   }
5385   lexer_.Lex();
5386   return true;
5387 }
5388 
ParseComplex(std::complex<double> * result)5389 bool HloParserImpl::ParseComplex(std::complex<double>* result) {
5390   if (lexer_.GetKind() != TokKind::kLparen) {
5391     return TokenError("expects '(' before complex number");
5392   }
5393   lexer_.Lex();
5394 
5395   double real;
5396   LocTy loc = lexer_.GetLoc();
5397   if (!ParseDouble(&real)) {
5398     return Error(loc,
5399                  "expect floating-point value for real part of complex number");
5400   }
5401 
5402   if (lexer_.GetKind() != TokKind::kComma) {
5403     return TokenError(
5404         absl::StrFormat("expect comma after real part of complex literal"));
5405   }
5406   lexer_.Lex();
5407 
5408   double imag;
5409   loc = lexer_.GetLoc();
5410   if (!ParseDouble(&imag)) {
5411     return Error(
5412         loc,
5413         "expect floating-point value for imaginary part of complex number");
5414   }
5415 
5416   if (lexer_.GetKind() != TokKind::kRparen) {
5417     return TokenError(absl::StrFormat("expect ')' after complex number"));
5418   }
5419 
5420   *result = std::complex<double>(real, imag);
5421   lexer_.Lex();
5422   return true;
5423 }
5424 
ParseBool(bool * result)5425 bool HloParserImpl::ParseBool(bool* result) {
5426   if (lexer_.GetKind() != TokKind::kw_true &&
5427       lexer_.GetKind() != TokKind::kw_false) {
5428     return TokenError("expects true or false");
5429   }
5430   *result = lexer_.GetKind() == TokKind::kw_true;
5431   lexer_.Lex();
5432   return true;
5433 }
5434 
ParseToken(TokKind kind,const std::string & msg)5435 bool HloParserImpl::ParseToken(TokKind kind, const std::string& msg) {
5436   VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
5437   if (lexer_.GetKind() != kind) {
5438     return TokenError(msg);
5439   }
5440   lexer_.Lex();
5441   return true;
5442 }
5443 
EatIfPresent(TokKind kind)5444 bool HloParserImpl::EatIfPresent(TokKind kind) {
5445   if (lexer_.GetKind() != kind) {
5446     return false;
5447   }
5448   lexer_.Lex();
5449   return true;
5450 }
5451 
AddInstruction(const std::string & name,HloInstruction * instruction,LocTy name_loc)5452 bool HloParserImpl::AddInstruction(const std::string& name,
5453                                    HloInstruction* instruction,
5454                                    LocTy name_loc) {
5455   auto result = current_name_table().insert({name, {instruction, name_loc}});
5456   if (!result.second) {
5457     Error(name_loc, StrCat("instruction already exists: ", name));
5458     return Error(/*loc=*/result.first->second.second,
5459                  "instruction previously defined here");
5460   }
5461   return true;
5462 }
5463 
AddComputation(const std::string & name,HloComputation * computation,LocTy name_loc)5464 bool HloParserImpl::AddComputation(const std::string& name,
5465                                    HloComputation* computation,
5466                                    LocTy name_loc) {
5467   auto result = computation_pool_.insert({name, {computation, name_loc}});
5468   if (!result.second) {
5469     Error(name_loc, StrCat("computation already exists: ", name));
5470     return Error(/*loc=*/result.first->second.second,
5471                  "computation previously defined here");
5472   }
5473   return true;
5474 }
5475 
ParseShapeOnly()5476 StatusOr<Shape> HloParserImpl::ParseShapeOnly() {
5477   lexer_.Lex();
5478   Shape shape;
5479   if (!ParseShape(&shape)) {
5480     return InvalidArgument("Syntax error:\n%s", GetError());
5481   }
5482   if (lexer_.GetKind() != TokKind::kEof) {
5483     return InvalidArgument("Syntax error:\nExtra content after shape");
5484   }
5485   return shape;
5486 }
5487 
ParseShardingOnly()5488 StatusOr<HloSharding> HloParserImpl::ParseShardingOnly() {
5489   lexer_.Lex();
5490   OpSharding op_sharding;
5491   if (!ParseSharding(&op_sharding)) {
5492     return InvalidArgument("Syntax error:\n%s", GetError());
5493   }
5494   if (lexer_.GetKind() != TokKind::kEof) {
5495     return InvalidArgument("Syntax error:\nExtra content after sharding");
5496   }
5497   return HloSharding::FromProto(op_sharding);
5498 }
5499 
ParseFrontendAttributesOnly()5500 StatusOr<FrontendAttributes> HloParserImpl::ParseFrontendAttributesOnly() {
5501   lexer_.Lex();
5502   FrontendAttributes attributes;
5503   if (!ParseFrontendAttributes(&attributes)) {
5504     return InvalidArgument("Syntax error:\n%s", GetError());
5505   }
5506   if (lexer_.GetKind() != TokKind::kEof) {
5507     return InvalidArgument(
5508         "Syntax error:\nExtra content after frontend attributes");
5509   }
5510   return attributes;
5511 }
5512 
ParseParameterReplicationOnly()5513 StatusOr<std::vector<bool>> HloParserImpl::ParseParameterReplicationOnly() {
5514   lexer_.Lex();
5515   ParameterReplication parameter_replication;
5516   if (!ParseParameterReplication(&parameter_replication)) {
5517     return InvalidArgument("Syntax error:\n%s", GetError());
5518   }
5519   if (lexer_.GetKind() != TokKind::kEof) {
5520     return InvalidArgument(
5521         "Syntax error:\nExtra content after parameter replication");
5522   }
5523   return std::vector<bool>(
5524       parameter_replication.replicated_at_leaf_buffers().begin(),
5525       parameter_replication.replicated_at_leaf_buffers().end());
5526 }
5527 
ParseReplicaGroupsOnly()5528 StatusOr<std::vector<ReplicaGroup>> HloParserImpl::ParseReplicaGroupsOnly() {
5529   lexer_.Lex();
5530   std::vector<ReplicaGroup> replica_groups;
5531   if (!ParseReplicaGroupsOnly(&replica_groups)) {
5532     return InvalidArgument("Syntax error:\n%s", GetError());
5533   }
5534   if (lexer_.GetKind() != TokKind::kEof) {
5535     return InvalidArgument("Syntax error:\nExtra content after replica groups");
5536   }
5537   return replica_groups;
5538 }
5539 
ParseWindowOnly()5540 StatusOr<Window> HloParserImpl::ParseWindowOnly() {
5541   lexer_.Lex();
5542   Window window;
5543   if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
5544     return InvalidArgument("Syntax error:\n%s", GetError());
5545   }
5546   if (lexer_.GetKind() != TokKind::kEof) {
5547     return InvalidArgument("Syntax error:\nExtra content after window");
5548   }
5549   return window;
5550 }
5551 
5552 StatusOr<ConvolutionDimensionNumbers>
ParseConvolutionDimensionNumbersOnly()5553 HloParserImpl::ParseConvolutionDimensionNumbersOnly() {
5554   lexer_.Lex();
5555   ConvolutionDimensionNumbers dnums;
5556   if (!ParseConvolutionDimensionNumbers(&dnums)) {
5557     return InvalidArgument("Syntax error:\n%s", GetError());
5558   }
5559   if (lexer_.GetKind() != TokKind::kEof) {
5560     return InvalidArgument(
5561         "Syntax error:\nExtra content after convolution dnums");
5562   }
5563   return dnums;
5564 }
5565 
ParsePaddingConfigOnly()5566 StatusOr<PaddingConfig> HloParserImpl::ParsePaddingConfigOnly() {
5567   lexer_.Lex();
5568   PaddingConfig padding_config;
5569   if (!ParsePaddingConfig(&padding_config)) {
5570     return InvalidArgument("Syntax error:\n%s", GetError());
5571   }
5572   if (lexer_.GetKind() != TokKind::kEof) {
5573     return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
5574   }
5575   return padding_config;
5576 }
5577 
ParseSingleInstruction(HloModule * module)5578 bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
5579   if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
5580     LOG(FATAL) << "Parser state is not clean. Please do not call any other "
5581                   "methods before calling ParseSingleInstruction.";
5582   }
5583   HloComputation::Builder builder(module->name());
5584 
5585   // The missing instruction hook we register creates the shaped instruction on
5586   // the fly as a parameter and returns it.
5587   int64_t parameter_count = 0;
5588   create_missing_instruction_ =
5589       [this, &builder, &parameter_count](
5590           const std::string& name,
5591           const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
5592     std::string new_name = name.empty() ? StrCat("_", parameter_count) : name;
5593     HloInstruction* parameter = builder.AddInstruction(
5594         HloInstruction::CreateParameter(parameter_count++, shape, new_name));
5595     current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
5596     return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
5597   };
5598 
5599   // Parse the instruction with the registered hook.
5600   Scope scope(&scoped_name_tables_);
5601   if (CanBeShape()) {
5602     // This means that the instruction's left-hand side is probably omitted,
5603     // e.g.
5604     //
5605     //  f32[10] fusion(...), calls={...}
5606     if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
5607       return false;
5608     }
5609   } else {
5610     // This means that the instruction's left-hand side might exist, e.g.
5611     //
5612     //  foo = f32[10] fusion(...), calls={...}
5613     std::string root_name;
5614     if (!ParseInstruction(&builder, &root_name)) {
5615       return false;
5616     }
5617   }
5618 
5619   if (lexer_.GetKind() != TokKind::kEof) {
5620     Error(
5621         lexer_.GetLoc(),
5622         "Syntax error:\nExpected eof after parsing single instruction.  Did "
5623         "you mean to write an HLO module and forget the \"HloModule\" header?");
5624     return false;
5625   }
5626 
5627   module->AddEntryComputation(builder.Build());
5628   for (auto& comp : computations_) {
5629     module->AddEmbeddedComputation(std::move(comp));
5630   }
5631   TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
5632   return true;
5633 }
5634 
5635 }  // namespace
5636 
ParseAndReturnUnverifiedModule(absl::string_view str,const HloModuleConfig & config)5637 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5638     absl::string_view str, const HloModuleConfig& config) {
5639   auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
5640   HloParserImpl parser(str);
5641   TF_RETURN_IF_ERROR(parser.Run(module.get()));
5642   return std::move(module);
5643 }
5644 
ParseAndReturnUnverifiedModule(absl::string_view str)5645 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5646     absl::string_view str) {
5647   return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
5648 }
5649 
ParseSharding(absl::string_view str)5650 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
5651   HloParserImpl parser(str);
5652   return parser.ParseShardingOnly();
5653 }
5654 
ParseFrontendAttributes(absl::string_view str)5655 StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
5656   HloParserImpl parser(str);
5657   return parser.ParseFrontendAttributesOnly();
5658 }
5659 
ParseParameterReplication(absl::string_view str)5660 StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
5661   HloParserImpl parser(str);
5662   return parser.ParseParameterReplicationOnly();
5663 }
5664 
ParseReplicaGroupsOnly(absl::string_view str)5665 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
5666     absl::string_view str) {
5667   HloParserImpl parser(str);
5668   return parser.ParseReplicaGroupsOnly();
5669 }
5670 
ParseWindow(absl::string_view str)5671 StatusOr<Window> ParseWindow(absl::string_view str) {
5672   HloParserImpl parser(str);
5673   return parser.ParseWindowOnly();
5674 }
5675 
ParseConvolutionDimensionNumbers(absl::string_view str)5676 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
5677     absl::string_view str) {
5678   HloParserImpl parser(str);
5679   return parser.ParseConvolutionDimensionNumbersOnly();
5680 }
5681 
ParsePaddingConfig(absl::string_view str)5682 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
5683   HloParserImpl parser(str);
5684   return parser.ParsePaddingConfigOnly();
5685 }
5686 
ParseShape(absl::string_view str)5687 StatusOr<Shape> ParseShape(absl::string_view str) {
5688   HloParserImpl parser(str);
5689   return parser.ParseShapeOnly();
5690 }
5691 
CreateHloParserForTests(absl::string_view str)5692 std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
5693     absl::string_view str) {
5694   return absl::make_unique<HloParserImpl>(str);
5695 }
5696 
5697 }  // namespace xla
5698