1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17
18 #include <memory>
19 #include <string>
20 #include <type_traits>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/numbers.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/span.h"
35 #include "absl/types/variant.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/literal_util.h"
38 #include "tensorflow/compiler/xla/primitive_util.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
41 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
48 #include "tensorflow/compiler/xla/service/shape_inference.h"
49 #include "tensorflow/compiler/xla/shape_util.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/platform/macros.h"
54 #include "tensorflow/core/platform/protobuf.h"
55
56 namespace xla {
57
58 namespace {
59
60 using absl::nullopt;
61 using absl::optional;
62 using absl::StrAppend;
63 using absl::StrCat;
64 using absl::StrFormat;
65 using absl::StrJoin;
66
67 // Creates and returns a schedule created using the order of the instructions in
68 // the HloComputation::instructions() vectors in the module.
ScheduleFromInstructionOrder(HloModule * module)69 HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
70 HloSchedule schedule(module);
71 for (HloComputation* computation : module->computations()) {
72 if (!computation->IsFusionComputation()) {
73 for (HloInstruction* instruction : computation->instructions()) {
74 schedule.GetOrCreateSequence(computation).push_back(instruction);
75 }
76 }
77 }
78 return schedule;
79 }
80
CanInferShape(HloOpcode code)81 bool CanInferShape(HloOpcode code) {
82 switch (code) {
83 case HloOpcode::kAbs:
84 case HloOpcode::kAdd:
85 case HloOpcode::kAddDependency:
86 case HloOpcode::kAfterAll:
87 case HloOpcode::kAtan2:
88 case HloOpcode::kBatchNormGrad:
89 case HloOpcode::kBatchNormInference:
90 case HloOpcode::kBatchNormTraining:
91 case HloOpcode::kBroadcast:
92 case HloOpcode::kCall:
93 case HloOpcode::kCeil:
94 case HloOpcode::kCholesky:
95 case HloOpcode::kClamp:
96 case HloOpcode::kClz:
97 case HloOpcode::kCompare:
98 case HloOpcode::kComplex:
99 case HloOpcode::kConcatenate:
100 case HloOpcode::kConditional:
101 case HloOpcode::kConvolution:
102 case HloOpcode::kCopy:
103 case HloOpcode::kCos:
104 case HloOpcode::kDivide:
105 case HloOpcode::kDomain:
106 case HloOpcode::kDot:
107 case HloOpcode::kExp:
108 case HloOpcode::kExpm1:
109 case HloOpcode::kFft:
110 case HloOpcode::kFloor:
111 case HloOpcode::kGather:
112 case HloOpcode::kGetDimensionSize:
113 case HloOpcode::kSetDimensionSize:
114 case HloOpcode::kGetTupleElement:
115 case HloOpcode::kImag:
116 case HloOpcode::kIsFinite:
117 case HloOpcode::kLog:
118 case HloOpcode::kLog1p:
119 case HloOpcode::kLogistic:
120 case HloOpcode::kAnd:
121 case HloOpcode::kNot:
122 case HloOpcode::kOr:
123 case HloOpcode::kXor:
124 case HloOpcode::kMap:
125 case HloOpcode::kMaximum:
126 case HloOpcode::kMinimum:
127 case HloOpcode::kMultiply:
128 case HloOpcode::kNegate:
129 case HloOpcode::kPad:
130 case HloOpcode::kPartitionId:
131 case HloOpcode::kPopulationCount:
132 case HloOpcode::kPower:
133 case HloOpcode::kReal:
134 case HloOpcode::kReduce:
135 case HloOpcode::kRemainder:
136 case HloOpcode::kReplicaId:
137 case HloOpcode::kReverse:
138 case HloOpcode::kRoundNearestAfz:
139 case HloOpcode::kRsqrt:
140 case HloOpcode::kScatter:
141 case HloOpcode::kSelect:
142 case HloOpcode::kShiftLeft:
143 case HloOpcode::kShiftRightArithmetic:
144 case HloOpcode::kShiftRightLogical:
145 case HloOpcode::kSign:
146 case HloOpcode::kSin:
147 case HloOpcode::kSqrt:
148 case HloOpcode::kCbrt:
149 case HloOpcode::kReduceWindow:
150 case HloOpcode::kSelectAndScatter:
151 case HloOpcode::kSort:
152 case HloOpcode::kSubtract:
153 case HloOpcode::kTanh:
154 case HloOpcode::kTrace:
155 case HloOpcode::kTranspose:
156 case HloOpcode::kTriangularSolve:
157 case HloOpcode::kTuple:
158 case HloOpcode::kTupleSelect:
159 case HloOpcode::kWhile:
160 return true;
161 // Technically the following ops do not require an explicit result shape,
162 // but we made it so that we always write the shapes explicitly.
163 case HloOpcode::kAllGather:
164 case HloOpcode::kAllReduce:
165 case HloOpcode::kAllToAll:
166 case HloOpcode::kCollectivePermute:
167 case HloOpcode::kCollectivePermuteStart:
168 case HloOpcode::kCollectivePermuteDone:
169 case HloOpcode::kCopyDone:
170 case HloOpcode::kCopyStart:
171 case HloOpcode::kDynamicReshape:
172 case HloOpcode::kDynamicSlice:
173 case HloOpcode::kDynamicUpdateSlice:
174 case HloOpcode::kRecv:
175 case HloOpcode::kRecvDone:
176 case HloOpcode::kSend:
177 case HloOpcode::kSendDone:
178 case HloOpcode::kSlice:
179 // The following ops require an explicit result shape.
180 case HloOpcode::kBitcast:
181 case HloOpcode::kBitcastConvert:
182 case HloOpcode::kConstant:
183 case HloOpcode::kConvert:
184 case HloOpcode::kCustomCall:
185 case HloOpcode::kFusion:
186 case HloOpcode::kInfeed:
187 case HloOpcode::kIota:
188 case HloOpcode::kOutfeed:
189 case HloOpcode::kParameter:
190 case HloOpcode::kReducePrecision:
191 case HloOpcode::kReshape:
192 case HloOpcode::kRng:
193 case HloOpcode::kRngBitGenerator:
194 case HloOpcode::kRngGetAndUpdateState:
195 return false;
196 }
197 }
198
199 // Parser for the HloModule::ToString() format text.
200 class HloParserImpl : public HloParser {
201 public:
202 using LocTy = HloLexer::LocTy;
203
HloParserImpl(absl::string_view str)204 explicit HloParserImpl(absl::string_view str) : lexer_(str) {}
205
206 // Runs the parser and constructs the resulting HLO in the given (empty)
207 // HloModule. Returns the error status in case an error occurred.
208 Status Run(HloModule* module) override;
209
210 // Returns the error information.
GetError() const211 std::string GetError() const { return StrJoin(error_, "\n"); }
212
213 // Stand alone parsing utils for various aggregate data types.
214 StatusOr<Shape> ParseShapeOnly();
215 StatusOr<HloSharding> ParseShardingOnly();
216 StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
217 StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
218 StatusOr<Window> ParseWindowOnly();
219 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
220 StatusOr<PaddingConfig> ParsePaddingConfigOnly();
221 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
222
223 private:
224 using InstrNameTable =
225 absl::flat_hash_map<std::string, std::pair<HloInstruction*, LocTy>>;
226
227 // Returns the map from the instruction name to the instruction itself and its
228 // location in the current scope.
current_name_table()229 InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
230
231 // Locates an instruction with the given name in the current_name_table() or
232 // returns nullptr.
233 //
234 // When the name is not found or name is empty, if create_missing_instruction_
235 // hook is registered and a "shape" is provided, the hook will be called to
236 // create an instruction. This is useful when we reify parameters as they're
237 // resolved; i.e. for ParseSingleInstruction.
238 std::pair<HloInstruction*, LocTy>* FindInstruction(
239 const std::string& name, const optional<Shape>& shape = nullopt);
240
241 // Parse a single instruction worth of text.
242 bool ParseSingleInstruction(HloModule* module);
243
244 // Parses a module, returning false if an error occurred.
245 bool ParseHloModule(HloModule* module);
246
247 bool ParseComputations(HloModule* module);
248 bool ParseComputation(HloComputation** entry_computation);
249 bool ParseInstructionList(HloComputation** computation,
250 const std::string& computation_name);
251 bool ParseInstruction(HloComputation::Builder* builder,
252 std::string* root_name);
253 bool ParseInstructionRhs(HloComputation::Builder* builder,
254 const std::string& name, LocTy name_loc);
255 bool ParseControlPredecessors(HloInstruction* instruction);
256 bool ParseLiteral(Literal* literal);
257 bool ParseLiteral(Literal* literal, const Shape& shape);
258 bool ParseTupleLiteral(Literal* literal, const Shape& shape);
259 bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
260 bool ParseDenseLiteral(Literal* literal, const Shape& shape);
261
262 // Sets the sub-value of literal at the given linear index to the
263 // given value. If the literal is dense, it must have the default layout.
264 //
265 // `loc` should be the source location of the value.
266 bool SetValueInLiteral(LocTy loc, int64 value, int64 index, Literal* literal);
267 bool SetValueInLiteral(LocTy loc, double value, int64 index,
268 Literal* literal);
269 bool SetValueInLiteral(LocTy loc, bool value, int64 index, Literal* literal);
270 bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64 index,
271 Literal* literal);
272 // `loc` should be the source location of the value.
273 template <typename LiteralNativeT, typename ParsedElemT>
274 bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64 index,
275 Literal* literal);
276
277 // Checks whether the given value is within the range of LiteralNativeT.
278 // `loc` should be the source location of the value.
279 template <typename LiteralNativeT, typename ParsedElemT>
280 bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
281 template <typename LiteralNativeT>
282 bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
283
284 bool ParseOperands(std::vector<HloInstruction*>* operands);
285 // Fills parsed operands into 'operands' and expects a certain number of
286 // operands.
287 bool ParseOperands(std::vector<HloInstruction*>* operands,
288 const int expected_size);
289
290 // Describes the start, limit, and stride on every dimension of the operand
291 // being sliced.
292 struct SliceRanges {
293 std::vector<int64> starts;
294 std::vector<int64> limits;
295 std::vector<int64> strides;
296 };
297
298 // The data parsed for the kDomain instruction.
299 struct DomainData {
300 std::unique_ptr<DomainMetadata> entry_metadata;
301 std::unique_ptr<DomainMetadata> exit_metadata;
302 };
303
304 // Types of attributes.
305 enum class AttrTy {
306 kBool,
307 kInt64,
308 kInt32,
309 kFloat,
310 kString,
311 kLiteral,
312 kBracedInt64List,
313 kBracedInt64ListList,
314 kHloComputation,
315 kBracedHloComputationList,
316 kFftType,
317 kPaddingType,
318 kComparisonDirection,
319 kComparisonType,
320 kWindow,
321 kConvolutionDimensionNumbers,
322 kSharding,
323 kFrontendAttributes,
324 kParameterReplication,
325 kInstructionList,
326 kSliceRanges,
327 kPaddingConfig,
328 kMetadata,
329 kFusionKind,
330 kDistribution,
331 kDomain,
332 kPrecisionList,
333 kShape,
334 kShapeList,
335 kEnum,
336 kRandomAlgorithm,
337 kAliasing,
338 kInstructionAliasing,
339 };
340
341 struct AttrConfig {
342 bool required; // whether it's required or optional
343 AttrTy attr_type; // what type it is
344 void* result; // where to store the parsed result.
345 };
346
347 // attributes ::= (',' attribute)*
348 //
349 // Parses attributes given names and configs of the attributes. Each parsed
350 // result is passed back through the result pointer in corresponding
351 // AttrConfig. Note that the result pointer must point to a optional<T> typed
352 // variable which outlives this function. Returns false on error. You should
353 // not use the any of the results if this function failed.
354 //
355 // Example usage:
356 //
357 // absl::flat_hash_map<std::string, AttrConfig> attrs;
358 // optional<int64> foo;
359 // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
360 // optional<Window> bar;
361 // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
362 // if (!ParseAttributes(attrs)) {
363 // return false; // Do not use 'foo' 'bar' if failed.
364 // }
365 // // Do something with 'bar'.
366 // if (foo) { // If attr foo is seen, do something with 'foo'. }
367 //
368 bool ParseAttributes(
369 const absl::flat_hash_map<std::string, AttrConfig>& attrs);
370
371 // sub_attributes ::= '{' (','? attribute)* '}'
372 //
373 // Usage is the same as ParseAttributes. See immediately above.
374 bool ParseSubAttributes(
375 const absl::flat_hash_map<std::string, AttrConfig>& attrs);
376
377 // Parses one attribute. If it has already been seen, return error. Returns
378 // true and adds to seen_attrs on success.
379 //
380 // Do not call this except in ParseAttributes or ParseSubAttributes.
381 bool ParseAttributeHelper(
382 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
383 absl::flat_hash_set<std::string>* seen_attrs);
384
385 // Copy attributes from `attrs` to `message`, unless the attribute name is in
386 // `non_proto_attrs`.
387 bool CopyAttributeToProtoMessage(
388 absl::flat_hash_set<std::string> non_proto_attrs,
389 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
390 tensorflow::protobuf::Message* message);
391
392 // Parses an attribute string into a protocol buffer `message`.
393 // Since proto3 has no notion of mandatory fields, `required_attrs` gives the
394 // set of mandatory attributes.
395 // `non_proto_attrs` specifies attributes that are not written to the proto,
396 // but added to the HloInstruction.
397 bool ParseAttributesAsProtoMessage(
398 const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
399 tensorflow::protobuf::Message* message);
400
401 // Parses a name and finds the corresponding hlo computation.
402 bool ParseComputationName(HloComputation** value);
403 // Parses a list of names and finds the corresponding hlo instructions.
404 bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
405 // Pass expect_outer_curlies == true when parsing a Window in the context of a
406 // larger computation. Pass false when parsing a stand-alone Window string.
407 bool ParseWindow(Window* window, bool expect_outer_curlies);
408 bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
409 bool ParsePaddingConfig(PaddingConfig* padding);
410 bool ParseMetadata(OpMetadata* metadata);
411 bool ParseSingleOrListMetadata(
412 tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata);
413 bool ParseSharding(OpSharding* sharding);
414 bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
415 bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
416 bool ParseParameterReplication(ParameterReplication* parameter_replication);
417 bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
418
419 // Parses the metadata behind a kDOmain instruction.
420 bool ParseDomain(DomainData* domain);
421
422 // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
423 bool ParseDxD(const std::string& name, std::vector<int64>* result);
424 // Parses window's pad sub-attribute, e.g., pad=0_0x3x3.
425 bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
426
427 bool ParseSliceRanges(SliceRanges* result);
428 bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
429 bool ParseHloComputation(HloComputation** result);
430 bool ParseHloComputationList(std::vector<HloComputation*>* result);
431 bool ParseShapeList(std::vector<Shape>* result);
432 bool ParseInt64List(const TokKind start, const TokKind end,
433 const TokKind delim, std::vector<int64>* result);
434 bool ParseInt64ListList(const TokKind start, const TokKind end,
435 const TokKind delim,
436 std::vector<std::vector<int64>>* result);
437 // 'parse_and_add_item' is an lambda to parse an element in the list and add
438 // the parsed element to the result. It's supposed to capture the result.
439 bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
440 const std::function<bool()>& parse_and_add_item);
441
442 bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
443 bool ParseParamList();
444 bool ParseName(std::string* result);
445 bool ParseAttributeName(std::string* result);
446 bool ParseString(std::string* result);
447 bool ParseDimensionSizes(std::vector<int64>* dimension_sizes,
448 std::vector<bool>* dynamic_dimensions);
449 bool ParseShape(Shape* result);
450 bool ParseLayout(Layout* layout);
451 bool ParseLayoutIntAttribute(int64* attr_value,
452 absl::string_view attr_description);
453 bool ParseTiles(std::vector<Tile>* tiles);
454 bool ParseOpcode(HloOpcode* result);
455 bool ParseFftType(FftType* result);
456 bool ParsePaddingType(PaddingType* result);
457 bool ParseComparisonDirection(ComparisonDirection* result);
458 bool ParseComparisonType(Comparison::Type* result);
459 bool ParseFusionKind(HloInstruction::FusionKind* result);
460 bool ParseRandomDistribution(RandomDistribution* result);
461 bool ParseRandomAlgorithm(RandomAlgorithm* result);
462 bool ParsePrecision(PrecisionConfig::Precision* result);
463 bool ParseInt64(int64* result);
464 bool ParseDouble(double* result);
465 bool ParseComplex(std::complex<double>* result);
466 bool ParseBool(bool* result);
467 bool ParseToken(TokKind kind, const std::string& msg);
468
469 using AliasingData =
470 absl::flat_hash_map<ShapeIndex, HloInputOutputAliasConfig::Alias>;
471
472 // Parses the aliasing information from string `s`, returns `false` if it
473 // fails.
474 bool ParseAliasing(AliasingData* data);
475
476 // Parses the per-instruction aliasing information from string `s`, returns
477 // `false` if it fails.
478 bool ParseInstructionOutputOperandAliasing(
479 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
480 aliasing_output_operand_pairs);
481
482 bool ParseShapeIndex(ShapeIndex* out);
483
484 // Returns true if the current token is the beginning of a shape.
485 bool CanBeShape();
486 // Returns true if the current token is the beginning of a
487 // param_list_to_shape.
488 bool CanBeParamListToShape();
489
490 // Logs the current parsing line and the given message. Always returns false.
491 bool TokenError(absl::string_view msg);
492 bool Error(LocTy loc, absl::string_view msg);
493
494 // If the current token is 'kind', eats it (i.e. lexes the next token) and
495 // returns true.
496 bool EatIfPresent(TokKind kind);
497
498 // Adds the instruction to the pool. Returns false and emits an error if the
499 // instruction already exists.
500 bool AddInstruction(const std::string& name, HloInstruction* instruction,
501 LocTy name_loc);
502 // Adds the computation to the pool. Returns false and emits an error if the
503 // computation already exists.
504 bool AddComputation(const std::string& name, HloComputation* computation,
505 LocTy name_loc);
506
507 HloLexer lexer_;
508
509 // A stack for the instruction names. The top of the stack stores the
510 // instruction name table for the current scope.
511 //
512 // A instruction's name is unique among its scope (i.e. its parent
513 // computation), but it's not necessarily unique among all computations in the
514 // module. When there are multiple levels of nested computations, the same
515 // name could appear in both an outer computation and an inner computation. So
516 // we need a stack to make sure a name is only visible within its scope,
517 std::vector<InstrNameTable> scoped_name_tables_;
518
519 // A helper class which pushes and pops to an InstrNameTable stack via RAII.
520 class Scope {
521 public:
Scope(std::vector<InstrNameTable> * scoped_name_tables)522 explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
523 : scoped_name_tables_(scoped_name_tables) {
524 scoped_name_tables_->emplace_back();
525 }
~Scope()526 ~Scope() { scoped_name_tables_->pop_back(); }
527
528 private:
529 std::vector<InstrNameTable>* scoped_name_tables_;
530 };
531
532 // Map from the computation name to the computation itself and its location.
533 absl::flat_hash_map<std::string, std::pair<HloComputation*, LocTy>>
534 computation_pool_;
535
536 std::vector<std::unique_ptr<HloComputation>> computations_;
537 std::vector<std::string> error_;
538
539 // When an operand name cannot be resolved, this function is called to create
540 // a parameter instruction with the given name and shape. It registers the
541 // name, instruction, and a placeholder location in the name table. It returns
542 // the newly-created instruction and the placeholder location. If `name` is
543 // empty, this should create the parameter with a generated name. This is
544 // supposed to be set and used only in ParseSingleInstruction.
545 std::function<std::pair<HloInstruction*, LocTy>*(const std::string& name,
546 const Shape& shape)>
547 create_missing_instruction_;
548 };
549
SplitToInt64s(absl::string_view s,char delim,std::vector<int64> * out)550 bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
551 for (const auto& split : absl::StrSplit(s, delim)) {
552 int64 val;
553 if (!absl::SimpleAtoi(split, &val)) {
554 return false;
555 }
556 out->push_back(val);
557 }
558 return true;
559 }
560
561 // Creates replica groups from the provided nested array. groups[i] represents
562 // the replica ids for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64>> groups)563 std::vector<ReplicaGroup> CreateReplicaGroups(
564 absl::Span<const std::vector<int64>> groups) {
565 std::vector<ReplicaGroup> replica_groups;
566 absl::c_transform(groups, std::back_inserter(replica_groups),
567 [](const std::vector<int64>& ids) {
568 ReplicaGroup group;
569 *group.mutable_replica_ids() = {ids.begin(), ids.end()};
570 return group;
571 });
572 return replica_groups;
573 }
574
Error(LocTy loc,absl::string_view msg)575 bool HloParserImpl::Error(LocTy loc, absl::string_view msg) {
576 auto line_col = lexer_.GetLineAndColumn(loc);
577 const unsigned line = line_col.first;
578 const unsigned col = line_col.second;
579 std::vector<std::string> error_lines;
580 error_lines.push_back(
581 StrCat("was parsing ", line, ":", col, ": error: ", msg));
582 error_lines.emplace_back(lexer_.GetLine(loc));
583 error_lines.push_back(col == 0 ? "" : StrCat(std::string(col - 1, ' '), "^"));
584
585 error_.push_back(StrJoin(error_lines, "\n"));
586 VLOG(1) << "Error: " << error_.back();
587 return false;
588 }
589
TokenError(absl::string_view msg)590 bool HloParserImpl::TokenError(absl::string_view msg) {
591 return Error(lexer_.GetLoc(), msg);
592 }
593
Run(HloModule * module)594 Status HloParserImpl::Run(HloModule* module) {
595 lexer_.Lex();
596 if (lexer_.GetKind() == TokKind::kw_HloModule) {
597 // This means that the text contains a full HLO module.
598 if (!ParseHloModule(module)) {
599 return InvalidArgument(
600 "Syntax error when trying to parse the text as a HloModule:\n%s",
601 GetError());
602 }
603 return Status::OK();
604 }
605 // This means that the text is a single HLO instruction.
606 if (!ParseSingleInstruction(module)) {
607 return InvalidArgument(
608 "Syntax error when trying to parse the text as a single "
609 "HloInstruction:\n%s",
610 GetError());
611 }
612 return Status::OK();
613 }
614
615 std::pair<HloInstruction*, HloParserImpl::LocTy>*
FindInstruction(const std::string & name,const optional<Shape> & shape)616 HloParserImpl::FindInstruction(const std::string& name,
617 const optional<Shape>& shape) {
618 std::pair<HloInstruction*, LocTy>* instr = nullptr;
619 if (!name.empty()) {
620 instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
621 }
622
623 // Potentially call the missing instruction hook.
624 if (instr == nullptr && create_missing_instruction_ != nullptr &&
625 scoped_name_tables_.size() == 1) {
626 if (!shape.has_value()) {
627 Error(lexer_.GetLoc(),
628 "Operand had no shape in HLO text; cannot create parameter for "
629 "single-instruction module.");
630 return nullptr;
631 }
632 return create_missing_instruction_(name, *shape);
633 }
634
635 if (instr != nullptr && shape.has_value() &&
636 !ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
637 Error(
638 lexer_.GetLoc(),
639 StrCat("The declared operand shape ",
640 ShapeUtil::HumanStringWithLayout(shape.value()),
641 " is not compatible with the shape of the operand instruction ",
642 ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
643 return nullptr;
644 }
645
646 return instr;
647 }
648
ParseShapeIndex(ShapeIndex * out)649 bool HloParserImpl::ParseShapeIndex(ShapeIndex* out) {
650 if (!ParseToken(TokKind::kLbrace, "Expects '{' at the start of ShapeIndex")) {
651 return false;
652 }
653
654 std::vector<int64> idxs;
655 while (lexer_.GetKind() != TokKind::kRbrace) {
656 int64 idx;
657 if (!ParseInt64(&idx)) {
658 return false;
659 }
660 idxs.push_back(idx);
661 if (!EatIfPresent(TokKind::kComma)) {
662 break;
663 }
664 }
665 if (!ParseToken(TokKind::kRbrace, "Expects '}' at the end of ShapeIndex")) {
666 return false;
667 }
668 *out = ShapeIndex(idxs.begin(), idxs.end());
669 return true;
670 }
671
ParseAliasing(AliasingData * data)672 bool HloParserImpl::ParseAliasing(AliasingData* data) {
673 if (!ParseToken(TokKind::kLbrace,
674 "Expects '{' at the start of aliasing description")) {
675 return false;
676 }
677
678 while (lexer_.GetKind() != TokKind::kRbrace) {
679 ShapeIndex out;
680 if (!ParseShapeIndex(&out)) {
681 return false;
682 }
683 std::string errmsg =
684 "Expected format: <output_shape_index>: (<input_param>, "
685 "<input_param_shape_index>) OR <output_shape_index>: <input_param>";
686 if (!ParseToken(TokKind::kColon, errmsg)) {
687 return false;
688 }
689
690 if (!ParseToken(TokKind::kLparen, errmsg)) {
691 return false;
692 }
693 int64 param_num;
694 ParseInt64(¶m_num);
695 if (!ParseToken(TokKind::kComma, errmsg)) {
696 return false;
697 }
698 ShapeIndex param_idx;
699 if (!ParseShapeIndex(¶m_idx)) {
700 return false;
701 }
702
703 HloInputOutputAliasConfig::AliasKind alias_kind =
704 HloInputOutputAliasConfig::kMayAlias;
705 if (EatIfPresent(TokKind::kComma)) {
706 std::string type;
707 ParseName(&type);
708 if (type == "must-alias") {
709 alias_kind = HloInputOutputAliasConfig::kMustAlias;
710 } else if (type == "may-alias") {
711 alias_kind = HloInputOutputAliasConfig::kMayAlias;
712 } else {
713 return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
714 }
715 }
716
717 data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
718 std::forward_as_tuple(param_num, param_idx, alias_kind));
719 if (!ParseToken(TokKind::kRparen, errmsg)) {
720 return false;
721 }
722
723 if (!EatIfPresent(TokKind::kComma)) {
724 break;
725 }
726 }
727 if (!ParseToken(TokKind::kRbrace,
728 "Expects '}' at the end of aliasing description")) {
729 return false;
730 }
731 return true;
732 }
733
ParseInstructionOutputOperandAliasing(std::vector<std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> * aliasing_output_operand_pairs)734 bool HloParserImpl::ParseInstructionOutputOperandAliasing(
735 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
736 aliasing_output_operand_pairs) {
737 if (!ParseToken(
738 TokKind::kLbrace,
739 "Expects '{' at the start of instruction aliasing description")) {
740 return false;
741 }
742
743 while (lexer_.GetKind() != TokKind::kRbrace) {
744 ShapeIndex out;
745 if (!ParseShapeIndex(&out)) {
746 return false;
747 }
748 std::string errmsg =
749 "Expected format: <output_shape_index>: (<operand_index>, "
750 "<operand_shape_index>)";
751 if (!ParseToken(TokKind::kColon, errmsg)) {
752 return false;
753 }
754
755 if (!ParseToken(TokKind::kLparen, errmsg)) {
756 return false;
757 }
758 int64 operand_index;
759 ParseInt64(&operand_index);
760 if (!ParseToken(TokKind::kComma, errmsg)) {
761 return false;
762 }
763 ShapeIndex operand_shape_index;
764 if (!ParseShapeIndex(&operand_shape_index)) {
765 return false;
766 }
767
768 aliasing_output_operand_pairs->emplace_back(
769 out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
770 if (!ParseToken(TokKind::kRparen, errmsg)) {
771 return false;
772 }
773
774 if (!EatIfPresent(TokKind::kComma)) {
775 break;
776 }
777 }
778 if (!ParseToken(
779 TokKind::kRbrace,
780 "Expects '}' at the end of instruction aliasing description")) {
781 return false;
782 }
783 return true;
784 }
785
786 // ::= 'HloModule' name computations
ParseHloModule(HloModule * module)787 bool HloParserImpl::ParseHloModule(HloModule* module) {
788 if (lexer_.GetKind() != TokKind::kw_HloModule) {
789 return TokenError("expects HloModule");
790 }
791 // Eat 'HloModule'
792 lexer_.Lex();
793
794 std::string name;
795 if (!ParseName(&name)) {
796 return false;
797 }
798
799 absl::optional<bool> is_scheduled;
800 absl::optional<AliasingData> aliasing_data;
801 absl::flat_hash_map<std::string, AttrConfig> attrs;
802
803 attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
804 attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing,
805 &aliasing_data};
806 if (!ParseAttributes(attrs)) {
807 return false;
808 }
809 module->set_name(name);
810 if (!ParseComputations(module)) {
811 return false;
812 }
813
814 if (is_scheduled.has_value() && *is_scheduled) {
815 TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
816 }
817 if (aliasing_data) {
818 HloInputOutputAliasConfig alias_config(module->result_shape());
819 for (auto& p : *aliasing_data) {
820 Status st =
821 alias_config.SetUpAlias(p.first, p.second.parameter_number,
822 p.second.parameter_index, p.second.kind);
823 if (!st.ok()) {
824 return TokenError(st.error_message());
825 }
826 }
827 module->input_output_alias_config() = alias_config;
828 }
829
830 return true;
831 }
832
833 // computations ::= (computation)+
ParseComputations(HloModule * module)834 bool HloParserImpl::ParseComputations(HloModule* module) {
835 HloComputation* entry_computation = nullptr;
836 do {
837 if (!ParseComputation(&entry_computation)) {
838 return false;
839 }
840 } while (lexer_.GetKind() != TokKind::kEof);
841
842 for (int i = 0; i < computations_.size(); i++) {
843 // If entry_computation is not nullptr, it means the computation it pointed
844 // to is marked with "ENTRY"; otherwise, no computation is marked with
845 // "ENTRY", and we use the last computation as the entry computation. We
846 // add the non-entry computations as embedded computations to the module.
847 if ((entry_computation != nullptr &&
848 computations_[i].get() != entry_computation) ||
849 (entry_computation == nullptr && i != computations_.size() - 1)) {
850 module->AddEmbeddedComputation(std::move(computations_[i]));
851 continue;
852 }
853 auto computation = module->AddEntryComputation(std::move(computations_[i]));
854 // The parameters and result layouts were set to default layout. Here we
855 // set the layouts to what the hlo text says.
856 for (int p = 0; p < computation->num_parameters(); p++) {
857 const Shape& param_shape = computation->parameter_instruction(p)->shape();
858 TF_CHECK_OK(module->mutable_entry_computation_layout()
859 ->mutable_parameter_layout(p)
860 ->CopyLayoutFromShape(param_shape));
861 }
862 const Shape& result_shape = computation->root_instruction()->shape();
863 TF_CHECK_OK(module->mutable_entry_computation_layout()
864 ->mutable_result_layout()
865 ->CopyLayoutFromShape(result_shape));
866 }
867 return true;
868 }
869
870 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
ParseComputation(HloComputation ** entry_computation)871 bool HloParserImpl::ParseComputation(HloComputation** entry_computation) {
872 LocTy maybe_entry_loc = lexer_.GetLoc();
873 const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
874
875 std::string name;
876 LocTy name_loc = lexer_.GetLoc();
877 if (!ParseName(&name)) {
878 return false;
879 }
880
881 LocTy shape_loc = nullptr;
882 Shape shape;
883 if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
884 return false;
885 }
886
887 HloComputation* computation = nullptr;
888 if (!ParseInstructionList(&computation, name)) {
889 return false;
890 }
891
892 // If param_list_to_shape was present, check compatibility.
893 if (shape_loc != nullptr &&
894 !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
895 return Error(
896 shape_loc,
897 StrCat(
898 "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
899 ", is not compatible with that of its root instruction ",
900 computation->root_instruction()->name(), ", ",
901 ShapeUtil::HumanString(computation->root_instruction()->shape())));
902 }
903
904 if (is_entry_computation) {
905 if (*entry_computation != nullptr) {
906 return Error(maybe_entry_loc, "expects only one ENTRY");
907 }
908 *entry_computation = computation;
909 }
910
911 return AddComputation(name, computation, name_loc);
912 }
913
914 // instruction_list ::= '{' instruction_list1 '}'
915 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation ** computation,const std::string & computation_name)916 bool HloParserImpl::ParseInstructionList(HloComputation** computation,
917 const std::string& computation_name) {
918 Scope scope(&scoped_name_tables_);
919 HloComputation::Builder builder(computation_name);
920 if (!ParseToken(TokKind::kLbrace,
921 "expects '{' at the beginning of instruction list.")) {
922 return false;
923 }
924 std::string root_name;
925 do {
926 if (!ParseInstruction(&builder, &root_name)) {
927 return false;
928 }
929 } while (lexer_.GetKind() != TokKind::kRbrace);
930 if (!ParseToken(TokKind::kRbrace,
931 "expects '}' at the end of instruction list.")) {
932 return false;
933 }
934 HloInstruction* root = nullptr;
935 if (!root_name.empty()) {
936 std::pair<HloInstruction*, LocTy>* root_node =
937 tensorflow::gtl::FindOrNull(current_name_table(), root_name);
938
939 // This means some instruction was marked as ROOT but we didn't find it in
940 // the pool, which should not happen.
941 if (root_node == nullptr) {
942 // LOG(FATAL) crashes the program by calling abort().
943 LOG(FATAL) << "instruction " << root_name
944 << " was marked as ROOT but the parser has not seen it before";
945 }
946 root = root_node->first;
947 }
948
949 // Now root can be either an existing instruction or a nullptr. If it's a
950 // nullptr, the implementation of Builder will set the last instruction as
951 // the root instruction.
952 computations_.emplace_back(builder.Build(root));
953 *computation = computations_.back().get();
954 return true;
955 }
956
957 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,std::string * root_name)958 bool HloParserImpl::ParseInstruction(HloComputation::Builder* builder,
959 std::string* root_name) {
960 std::string name;
961 LocTy maybe_root_loc = lexer_.GetLoc();
962 bool is_root = EatIfPresent(TokKind::kw_ROOT);
963
964 const LocTy name_loc = lexer_.GetLoc();
965 if (!ParseName(&name) ||
966 !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
967 return false;
968 }
969
970 if (is_root) {
971 if (!root_name->empty()) {
972 return Error(maybe_root_loc, "one computation should have only one ROOT");
973 }
974 *root_name = name;
975 }
976
977 return ParseInstructionRhs(builder, name, name_loc);
978 }
979
ParseInstructionRhs(HloComputation::Builder * builder,const std::string & name,LocTy name_loc)980 bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
981 const std::string& name,
982 LocTy name_loc) {
983 Shape shape;
984 HloOpcode opcode;
985 std::vector<HloInstruction*> operands;
986
987 const bool parse_shape = CanBeShape();
988 if ((parse_shape && !ParseShape(&shape)) || !ParseOpcode(&opcode)) {
989 return false;
990 }
991 if (!parse_shape && !CanInferShape(opcode)) {
992 return TokenError(StrFormat("cannot infer shape for opcode: %s",
993 HloOpcodeString(opcode)));
994 }
995
996 // Add optional attributes. These are added to any HloInstruction type if
997 // present.
998 absl::flat_hash_map<std::string, AttrConfig> attrs;
999 optional<OpSharding> sharding;
1000 optional<FrontendAttributes> frontend_attributes;
1001 attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
1002 attrs["frontend_attributes"] = {
1003 /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
1004 optional<ParameterReplication> parameter_replication;
1005 attrs["parameter_replication"] = {/*required=*/false,
1006 AttrTy::kParameterReplication,
1007 ¶meter_replication};
1008 optional<std::vector<HloInstruction*>> predecessors;
1009 attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
1010 &predecessors};
1011 optional<OpMetadata> metadata;
1012 attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
1013
1014 optional<std::string> backend_config;
1015 attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
1016 &backend_config};
1017 optional<std::vector<int64>> outer_dimension_partitions;
1018 attrs["outer_dimension_partitions"] = {/*required=*/false,
1019 AttrTy::kBracedInt64List,
1020 &outer_dimension_partitions};
1021 const auto maybe_infer_shape =
1022 [&](const std::function<StatusOr<Shape>()>& infer, Shape* shape) {
1023 if (parse_shape) {
1024 return true;
1025 }
1026 auto inferred = infer();
1027 if (!inferred.ok()) {
1028 return TokenError(StrFormat(
1029 "failed to infer shape for opcode: %s, error: %s",
1030 HloOpcodeString(opcode), inferred.status().error_message()));
1031 }
1032 *shape = std::move(inferred).ValueOrDie();
1033 return true;
1034 };
1035 HloInstruction* instruction;
1036 switch (opcode) {
1037 case HloOpcode::kParameter: {
1038 int64 parameter_number;
1039 if (!ParseToken(TokKind::kLparen,
1040 "expects '(' before parameter number") ||
1041 !ParseInt64(¶meter_number)) {
1042 return false;
1043 }
1044 if (parameter_number < 0) {
1045 Error(lexer_.GetLoc(), "parameter number must be >= 0");
1046 return false;
1047 }
1048 if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
1049 !ParseAttributes(attrs)) {
1050 return false;
1051 }
1052 instruction = builder->AddInstruction(
1053 HloInstruction::CreateParameter(parameter_number, shape, name));
1054 break;
1055 }
1056 case HloOpcode::kConstant: {
1057 Literal literal;
1058 if (!ParseToken(TokKind::kLparen,
1059 "expects '(' before constant literal") ||
1060 !ParseLiteral(&literal, shape) ||
1061 !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
1062 !ParseAttributes(attrs)) {
1063 return false;
1064 }
1065 instruction = builder->AddInstruction(
1066 HloInstruction::CreateConstant(std::move(literal)));
1067 break;
1068 }
1069 case HloOpcode::kIota: {
1070 optional<int64> iota_dimension;
1071 attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
1072 &iota_dimension};
1073 if (!ParseOperands(&operands, /*expected_size=*/0) ||
1074 !ParseAttributes(attrs)) {
1075 return false;
1076 }
1077 instruction = builder->AddInstruction(
1078 HloInstruction::CreateIota(shape, *iota_dimension));
1079 break;
1080 }
1081 // Unary ops.
1082 case HloOpcode::kAbs:
1083 case HloOpcode::kRoundNearestAfz:
1084 case HloOpcode::kBitcast:
1085 case HloOpcode::kCeil:
1086 case HloOpcode::kClz:
1087 case HloOpcode::kCollectivePermuteDone:
1088 case HloOpcode::kCopy:
1089 case HloOpcode::kCopyDone:
1090 case HloOpcode::kCos:
1091 case HloOpcode::kExp:
1092 case HloOpcode::kExpm1:
1093 case HloOpcode::kImag:
1094 case HloOpcode::kIsFinite:
1095 case HloOpcode::kFloor:
1096 case HloOpcode::kLog:
1097 case HloOpcode::kLog1p:
1098 case HloOpcode::kLogistic:
1099 case HloOpcode::kNot:
1100 case HloOpcode::kNegate:
1101 case HloOpcode::kPopulationCount:
1102 case HloOpcode::kReal:
1103 case HloOpcode::kRsqrt:
1104 case HloOpcode::kSign:
1105 case HloOpcode::kSin:
1106 case HloOpcode::kSqrt:
1107 case HloOpcode::kCbrt:
1108 case HloOpcode::kTanh: {
1109 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1110 !ParseAttributes(attrs)) {
1111 return false;
1112 }
1113 if (!maybe_infer_shape(
1114 [&] {
1115 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
1116 },
1117 &shape)) {
1118 return false;
1119 }
1120 instruction = builder->AddInstruction(
1121 HloInstruction::CreateUnary(shape, opcode, operands[0]));
1122 break;
1123 }
1124 // Binary ops.
1125 case HloOpcode::kAdd:
1126 case HloOpcode::kDivide:
1127 case HloOpcode::kMultiply:
1128 case HloOpcode::kSubtract:
1129 case HloOpcode::kAtan2:
1130 case HloOpcode::kComplex:
1131 case HloOpcode::kMaximum:
1132 case HloOpcode::kMinimum:
1133 case HloOpcode::kPower:
1134 case HloOpcode::kRemainder:
1135 case HloOpcode::kAnd:
1136 case HloOpcode::kOr:
1137 case HloOpcode::kXor:
1138 case HloOpcode::kShiftLeft:
1139 case HloOpcode::kShiftRightArithmetic:
1140 case HloOpcode::kShiftRightLogical: {
1141 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1142 !ParseAttributes(attrs)) {
1143 return false;
1144 }
1145 if (!maybe_infer_shape(
1146 [&] {
1147 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1148 operands[1]);
1149 },
1150 &shape)) {
1151 return false;
1152 }
1153 instruction = builder->AddInstruction(HloInstruction::CreateBinary(
1154 shape, opcode, operands[0], operands[1]));
1155 break;
1156 }
1157 // Ternary ops.
1158 case HloOpcode::kClamp:
1159 case HloOpcode::kSelect:
1160 case HloOpcode::kTupleSelect: {
1161 if (!ParseOperands(&operands, /*expected_size=*/3) ||
1162 !ParseAttributes(attrs)) {
1163 return false;
1164 }
1165 if (!maybe_infer_shape(
1166 [&] {
1167 return ShapeInference::InferTernaryOpShape(
1168 opcode, operands[0], operands[1], operands[2]);
1169 },
1170 &shape)) {
1171 return false;
1172 }
1173 instruction = builder->AddInstruction(HloInstruction::CreateTernary(
1174 shape, opcode, operands[0], operands[1], operands[2]));
1175 break;
1176 }
1177 // Other supported ops.
1178 case HloOpcode::kConvert: {
1179 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1180 !ParseAttributes(attrs)) {
1181 return false;
1182 }
1183 instruction = builder->AddInstruction(
1184 HloInstruction::CreateConvert(shape, operands[0]));
1185 break;
1186 }
1187 case HloOpcode::kBitcastConvert: {
1188 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1189 !ParseAttributes(attrs)) {
1190 return false;
1191 }
1192 instruction = builder->AddInstruction(
1193 HloInstruction::CreateBitcastConvert(shape, operands[0]));
1194 break;
1195 }
1196 case HloOpcode::kAllGather: {
1197 optional<std::vector<std::vector<int64>>> tmp_groups;
1198 optional<std::vector<int64>> replica_group_ids;
1199 optional<int64> channel_id;
1200 optional<std::vector<int64>> dimensions;
1201 optional<bool> constrain_layout;
1202 optional<bool> use_global_device_ids;
1203 attrs["replica_groups"] = {/*required=*/false,
1204 AttrTy::kBracedInt64ListList, &tmp_groups};
1205 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1206 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1207 &dimensions};
1208 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1209 &constrain_layout};
1210 attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1211 &use_global_device_ids};
1212 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1213 return false;
1214 }
1215 std::vector<ReplicaGroup> replica_groups;
1216 if (tmp_groups) {
1217 replica_groups = CreateReplicaGroups(*tmp_groups);
1218 }
1219 instruction = builder->AddInstruction(HloInstruction::CreateAllGather(
1220 shape, operands[0], dimensions->at(0), replica_groups,
1221 constrain_layout ? *constrain_layout : false, channel_id,
1222 use_global_device_ids ? *use_global_device_ids : false));
1223 break;
1224 }
1225 case HloOpcode::kAllReduce: {
1226 optional<std::vector<std::vector<int64>>> tmp_groups;
1227 optional<HloComputation*> to_apply;
1228 optional<std::vector<int64>> replica_group_ids;
1229 optional<int64> channel_id;
1230 optional<bool> constrain_layout;
1231 optional<bool> use_global_device_ids;
1232 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1233 &to_apply};
1234 attrs["replica_groups"] = {/*required=*/false,
1235 AttrTy::kBracedInt64ListList, &tmp_groups};
1236 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1237 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1238 &constrain_layout};
1239 attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1240 &use_global_device_ids};
1241 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1242 return false;
1243 }
1244 std::vector<ReplicaGroup> replica_groups;
1245 if (tmp_groups) {
1246 replica_groups = CreateReplicaGroups(*tmp_groups);
1247 }
1248 instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
1249 shape, operands, *to_apply, replica_groups,
1250 constrain_layout ? *constrain_layout : false, channel_id,
1251 use_global_device_ids ? *use_global_device_ids : false));
1252 break;
1253 }
1254 case HloOpcode::kAllToAll: {
1255 optional<std::vector<std::vector<int64>>> tmp_groups;
1256 attrs["replica_groups"] = {/*required=*/false,
1257 AttrTy::kBracedInt64ListList, &tmp_groups};
1258 optional<int64> channel_id;
1259 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1260 optional<std::vector<int64>> dimensions;
1261 attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1262 &dimensions};
1263 optional<bool> constrain_layout;
1264 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1265 &constrain_layout};
1266 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1267 (dimensions && dimensions->size() != 1)) {
1268 return false;
1269 }
1270 std::vector<ReplicaGroup> replica_groups;
1271 if (tmp_groups) {
1272 replica_groups = CreateReplicaGroups(*tmp_groups);
1273 }
1274 optional<int64> split_dimension;
1275 if (dimensions) {
1276 split_dimension = dimensions->at(0);
1277 }
1278 instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
1279 shape, operands, replica_groups,
1280 constrain_layout ? *constrain_layout : false, channel_id,
1281 split_dimension));
1282 break;
1283 }
1284 case HloOpcode::kCollectivePermute:
1285 case HloOpcode::kCollectivePermuteStart: {
1286 optional<std::vector<std::vector<int64>>> source_targets;
1287 attrs["source_target_pairs"] = {
1288 /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
1289 optional<int64> channel_id;
1290 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1291 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1292 !ParseAttributes(attrs)) {
1293 return false;
1294 }
1295 std::vector<std::pair<int64, int64>> pairs(source_targets->size());
1296 for (int i = 0; i < pairs.size(); i++) {
1297 if ((*source_targets)[i].size() != 2) {
1298 return TokenError(
1299 "expects 'source_target_pairs=' to be a list of pairs");
1300 }
1301 pairs[i].first = (*source_targets)[i][0];
1302 pairs[i].second = (*source_targets)[i][1];
1303 }
1304 if (opcode == HloOpcode::kCollectivePermute) {
1305 instruction =
1306 builder->AddInstruction(HloInstruction::CreateCollectivePermute(
1307 shape, operands[0], pairs, channel_id));
1308 } else if (opcode == HloOpcode::kCollectivePermuteStart) {
1309 instruction = builder->AddInstruction(
1310 HloInstruction::CreateCollectivePermuteStart(shape, operands[0],
1311 pairs, channel_id));
1312 } else {
1313 LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1314 "CollectivePermuteStart, but got "
1315 << HloOpcodeString(opcode);
1316 }
1317 break;
1318 }
1319 case HloOpcode::kCopyStart: {
1320 // If the is_cross_program_prefetch attribute is not present then default
1321 // to false.
1322 optional<bool> is_cross_program_prefetch = false;
1323 attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
1324 &is_cross_program_prefetch};
1325 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1326 !ParseAttributes(attrs)) {
1327 return false;
1328 }
1329 instruction = builder->AddInstruction(HloInstruction::CreateCopyStart(
1330 shape, operands[0], *is_cross_program_prefetch));
1331 break;
1332 }
1333 case HloOpcode::kReplicaId: {
1334 if (!ParseOperands(&operands, /*expected_size=*/0) ||
1335 !ParseAttributes(attrs)) {
1336 return false;
1337 }
1338 instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
1339 break;
1340 }
1341 case HloOpcode::kPartitionId: {
1342 if (!ParseOperands(&operands, /*expected_size=*/0) ||
1343 !ParseAttributes(attrs)) {
1344 return false;
1345 }
1346 instruction =
1347 builder->AddInstruction(HloInstruction::CreatePartitionId());
1348 break;
1349 }
1350 case HloOpcode::kDynamicReshape: {
1351 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1352 return false;
1353 }
1354 instruction =
1355 builder->AddInstruction(HloInstruction::CreateDynamicReshape(
1356 shape, operands[0],
1357 absl::Span<HloInstruction* const>(operands).subspan(1)));
1358 break;
1359 }
1360 case HloOpcode::kReshape: {
1361 optional<int64> inferred_dimension;
1362 attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
1363 &inferred_dimension};
1364 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1365 !ParseAttributes(attrs)) {
1366 return false;
1367 }
1368 instruction = builder->AddInstruction(HloInstruction::CreateReshape(
1369 shape, operands[0], inferred_dimension.value_or(-1)));
1370 break;
1371 }
1372 case HloOpcode::kAfterAll: {
1373 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1374 return false;
1375 }
1376 if (operands.empty()) {
1377 instruction = builder->AddInstruction(HloInstruction::CreateToken());
1378 } else {
1379 instruction =
1380 builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
1381 }
1382 break;
1383 }
1384 case HloOpcode::kAddDependency: {
1385 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1386 !ParseAttributes(attrs)) {
1387 return false;
1388 }
1389 instruction = builder->AddInstruction(
1390 HloInstruction::CreateAddDependency(operands[0], operands[1]));
1391 break;
1392 }
1393 case HloOpcode::kSort: {
1394 optional<std::vector<int64>> dimensions;
1395 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1396 &dimensions};
1397 optional<bool> is_stable = false;
1398 attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
1399 optional<HloComputation*> to_apply;
1400 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1401 &to_apply};
1402 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1403 dimensions->size() != 1) {
1404 return false;
1405 }
1406 if (!maybe_infer_shape(
1407 [&] {
1408 absl::InlinedVector<const Shape*, 2> arg_shapes;
1409 arg_shapes.reserve(operands.size());
1410 for (auto* operand : operands) {
1411 arg_shapes.push_back(&operand->shape());
1412 }
1413 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1414 },
1415 &shape)) {
1416 return false;
1417 }
1418 instruction = builder->AddInstruction(
1419 HloInstruction::CreateSort(shape, dimensions->at(0), operands,
1420 to_apply.value(), is_stable.value()));
1421 break;
1422 }
1423 case HloOpcode::kTuple: {
1424 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1425 return false;
1426 }
1427 instruction =
1428 builder->AddInstruction(HloInstruction::CreateTuple(operands));
1429 break;
1430 }
1431 case HloOpcode::kWhile: {
1432 optional<HloComputation*> condition;
1433 optional<HloComputation*> body;
1434 attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
1435 &condition};
1436 attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
1437 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1438 !ParseAttributes(attrs)) {
1439 return false;
1440 }
1441 if (!maybe_infer_shape(
1442 [&] {
1443 return ShapeInference::InferWhileShape(
1444 condition.value()->ComputeProgramShape(),
1445 body.value()->ComputeProgramShape(), operands[0]->shape());
1446 },
1447 &shape)) {
1448 return false;
1449 }
1450 instruction = builder->AddInstruction(HloInstruction::CreateWhile(
1451 shape, *condition, *body, /*init=*/operands[0]));
1452 break;
1453 }
1454 case HloOpcode::kRecv: {
1455 optional<int64> channel_id;
1456 // If the is_host_transfer attribute is not present then default to false.
1457 optional<bool> is_host_transfer = false;
1458 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1459 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1460 &is_host_transfer};
1461 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1462 !ParseAttributes(attrs)) {
1463 return false;
1464 }
1465 // If the is_host_transfer attribute is not present then default to false.
1466 instruction = builder->AddInstruction(HloInstruction::CreateRecv(
1467 shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
1468 break;
1469 }
1470 case HloOpcode::kRecvDone: {
1471 optional<int64> channel_id;
1472 // If the is_host_transfer attribute is not present then default to false.
1473 optional<bool> is_host_transfer = false;
1474 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1475 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1476 &is_host_transfer};
1477 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1478 !ParseAttributes(attrs)) {
1479 return false;
1480 }
1481 if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1482 return false;
1483 }
1484 if (channel_id != operands[0]->channel_id()) {
1485 return false;
1486 }
1487 instruction = builder->AddInstruction(
1488 HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
1489 break;
1490 }
1491 case HloOpcode::kSend: {
1492 optional<int64> channel_id;
1493 // If the is_host_transfer attribute is not present then default to false.
1494 optional<bool> is_host_transfer = false;
1495 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1496 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1497 &is_host_transfer};
1498 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1499 !ParseAttributes(attrs)) {
1500 return false;
1501 }
1502 instruction = builder->AddInstruction(HloInstruction::CreateSend(
1503 operands[0], operands[1], *channel_id, *is_host_transfer));
1504 break;
1505 }
1506 case HloOpcode::kSendDone: {
1507 optional<int64> channel_id;
1508 // If the is_host_transfer attribute is not present then default to false.
1509 optional<bool> is_host_transfer = false;
1510 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1511 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1512 &is_host_transfer};
1513 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1514 !ParseAttributes(attrs)) {
1515 return false;
1516 }
1517 if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1518 return false;
1519 }
1520 if (channel_id != operands[0]->channel_id()) {
1521 return false;
1522 }
1523 instruction = builder->AddInstruction(
1524 HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
1525 break;
1526 }
1527 case HloOpcode::kGetTupleElement: {
1528 optional<int64> index;
1529 attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
1530 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1531 !ParseAttributes(attrs)) {
1532 return false;
1533 }
1534 if (!maybe_infer_shape(
1535 [&] {
1536 return ShapeUtil::GetTupleElementShape(operands[0]->shape(),
1537 *index);
1538 },
1539 &shape)) {
1540 return false;
1541 }
1542 instruction = builder->AddInstruction(
1543 HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
1544 break;
1545 }
1546 case HloOpcode::kCall: {
1547 optional<HloComputation*> to_apply;
1548 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1549 &to_apply};
1550 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1551 return false;
1552 }
1553 if (!maybe_infer_shape(
1554 [&] {
1555 absl::InlinedVector<const Shape*, 2> arg_shapes;
1556 arg_shapes.reserve(operands.size());
1557 for (auto* operand : operands) {
1558 arg_shapes.push_back(&operand->shape());
1559 }
1560 return ShapeInference::InferCallShape(
1561 arg_shapes, to_apply.value()->ComputeProgramShape());
1562 },
1563 &shape)) {
1564 return false;
1565 }
1566 instruction = builder->AddInstruction(
1567 HloInstruction::CreateCall(shape, operands, *to_apply));
1568 break;
1569 }
1570 case HloOpcode::kReduceWindow: {
1571 optional<HloComputation*> reduce_computation;
1572 optional<Window> window;
1573 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1574 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1575 &reduce_computation};
1576 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1577 return false;
1578 }
1579 if (!window) {
1580 window.emplace();
1581 }
1582 if (operands.size() % 2) {
1583 auto loc = lexer_.GetLoc();
1584 return Error(loc, StrCat("expects an even number of operands, but has ",
1585 operands.size(), " operands"));
1586 }
1587 if (!maybe_infer_shape(
1588 [&] {
1589 return ShapeInference::InferReduceWindowShape(
1590 operands[0]->shape(), operands[1]->shape(), *window,
1591 reduce_computation.value()->ComputeProgramShape());
1592 },
1593 &shape)) {
1594 return false;
1595 }
1596 instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
1597 shape, /*operands=*/
1598 absl::Span<HloInstruction* const>(operands).subspan(
1599 0, operands.size() / 2),
1600 /*init_values=*/
1601 absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1602 2),
1603 *window, *reduce_computation));
1604 break;
1605 }
1606 case HloOpcode::kConvolution: {
1607 optional<Window> window;
1608 optional<ConvolutionDimensionNumbers> dnums;
1609 optional<int64> feature_group_count;
1610 optional<int64> batch_group_count;
1611 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1612 attrs["dim_labels"] = {/*required=*/true,
1613 AttrTy::kConvolutionDimensionNumbers, &dnums};
1614 attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1615 &feature_group_count};
1616 attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1617 &batch_group_count};
1618 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1619 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1620 &operand_precision};
1621 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1622 !ParseAttributes(attrs)) {
1623 return false;
1624 }
1625 if (!window) {
1626 window.emplace();
1627 }
1628 if (!feature_group_count) {
1629 feature_group_count = 1;
1630 }
1631 if (!batch_group_count) {
1632 batch_group_count = 1;
1633 }
1634 PrecisionConfig precision_config;
1635 if (operand_precision) {
1636 *precision_config.mutable_operand_precision() = {
1637 operand_precision->begin(), operand_precision->end()};
1638 } else {
1639 precision_config.mutable_operand_precision()->Resize(
1640 operands.size(), PrecisionConfig::DEFAULT);
1641 }
1642 if (!maybe_infer_shape(
1643 [&] {
1644 return ShapeInference::InferConvolveShape(
1645 operands[0]->shape(), operands[1]->shape(),
1646 *feature_group_count, *batch_group_count, *window, *dnums,
1647 /*preferred_element_type=*/absl::nullopt);
1648 },
1649 &shape)) {
1650 return false;
1651 }
1652 instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
1653 shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
1654 feature_group_count.value(), batch_group_count.value(), *window,
1655 *dnums, precision_config));
1656 break;
1657 }
1658 case HloOpcode::kFft: {
1659 optional<FftType> fft_type;
1660 optional<std::vector<int64>> fft_length;
1661 attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
1662 attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
1663 &fft_length};
1664 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1665 !ParseAttributes(attrs)) {
1666 return false;
1667 }
1668 if (!maybe_infer_shape(
1669 [&] {
1670 return ShapeInference::InferFftShape(operands[0]->shape(),
1671 *fft_type, *fft_length);
1672 },
1673 &shape)) {
1674 return false;
1675 }
1676 instruction = builder->AddInstruction(HloInstruction::CreateFft(
1677 shape, operands[0], *fft_type, *fft_length));
1678 break;
1679 }
1680 case HloOpcode::kTriangularSolve: {
1681 TriangularSolveOptions options;
1682 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1683 !ParseAttributesAsProtoMessage(
1684 /*non_proto_attrs=*/attrs, &options)) {
1685 return false;
1686 }
1687 if (!maybe_infer_shape(
1688 [&] {
1689 return ShapeInference::InferTriangularSolveShape(
1690 operands[0]->shape(), operands[1]->shape(), options);
1691 },
1692 &shape)) {
1693 return false;
1694 }
1695 instruction =
1696 builder->AddInstruction(HloInstruction::CreateTriangularSolve(
1697 shape, operands[0], operands[1], options));
1698 break;
1699 }
1700 case HloOpcode::kCompare: {
1701 optional<ComparisonDirection> direction;
1702 optional<Comparison::Type> type;
1703 attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
1704 &direction};
1705 attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
1706 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1707 !ParseAttributes(attrs)) {
1708 return false;
1709 }
1710 if (!maybe_infer_shape(
1711 [&] {
1712 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1713 operands[1]);
1714 },
1715 &shape)) {
1716 return false;
1717 }
1718 instruction = builder->AddInstruction(HloInstruction::CreateCompare(
1719 shape, operands[0], operands[1], *direction, type));
1720 break;
1721 }
1722 case HloOpcode::kCholesky: {
1723 CholeskyOptions options;
1724 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1725 !ParseAttributesAsProtoMessage(
1726 /*non_proto_attrs=*/attrs, &options)) {
1727 return false;
1728 }
1729 if (!maybe_infer_shape(
1730 [&] {
1731 return ShapeInference::InferCholeskyShape(operands[0]->shape());
1732 },
1733 &shape)) {
1734 return false;
1735 }
1736 instruction = builder->AddInstruction(
1737 HloInstruction::CreateCholesky(shape, operands[0], options));
1738 break;
1739 }
1740 case HloOpcode::kBroadcast: {
1741 optional<std::vector<int64>> broadcast_dimensions;
1742 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1743 &broadcast_dimensions};
1744 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1745 !ParseAttributes(attrs)) {
1746 return false;
1747 }
1748 if (!maybe_infer_shape(
1749 [&] {
1750 return ShapeInference::InferBroadcastShape(
1751 operands[0]->shape(), *broadcast_dimensions);
1752 },
1753 &shape)) {
1754 return false;
1755 }
1756 instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
1757 shape, operands[0], *broadcast_dimensions));
1758 break;
1759 }
1760 case HloOpcode::kConcatenate: {
1761 optional<std::vector<int64>> dimensions;
1762 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1763 &dimensions};
1764 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1765 dimensions->size() != 1) {
1766 return false;
1767 }
1768 if (!maybe_infer_shape(
1769 [&] {
1770 absl::InlinedVector<const Shape*, 2> arg_shapes;
1771 arg_shapes.reserve(operands.size());
1772 for (auto* operand : operands) {
1773 arg_shapes.push_back(&operand->shape());
1774 }
1775 return ShapeInference::InferConcatOpShape(arg_shapes,
1776 dimensions->at(0));
1777 },
1778 &shape)) {
1779 return false;
1780 }
1781 instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
1782 shape, operands, dimensions->at(0)));
1783 break;
1784 }
1785 case HloOpcode::kMap: {
1786 optional<HloComputation*> to_apply;
1787 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1788 &to_apply};
1789 optional<std::vector<int64>> dimensions;
1790 attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1791 &dimensions};
1792 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1793 return false;
1794 }
1795 if (!maybe_infer_shape(
1796 [&] {
1797 absl::InlinedVector<const Shape*, 2> arg_shapes;
1798 arg_shapes.reserve(operands.size());
1799 for (auto* operand : operands) {
1800 arg_shapes.push_back(&operand->shape());
1801 }
1802 return ShapeInference::InferMapShape(
1803 arg_shapes, to_apply.value()->ComputeProgramShape(),
1804 *dimensions);
1805 },
1806 &shape)) {
1807 return false;
1808 }
1809 instruction = builder->AddInstruction(
1810 HloInstruction::CreateMap(shape, operands, *to_apply));
1811 break;
1812 }
1813 case HloOpcode::kReduce: {
1814 auto loc = lexer_.GetLoc();
1815
1816 optional<HloComputation*> reduce_computation;
1817 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1818 &reduce_computation};
1819 optional<std::vector<int64>> dimensions_to_reduce;
1820 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1821 &dimensions_to_reduce};
1822 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1823 return false;
1824 }
1825 if (operands.size() % 2) {
1826 return Error(loc, StrCat("expects an even number of operands, but has ",
1827 operands.size(), " operands"));
1828 }
1829 if (!maybe_infer_shape(
1830 [&] {
1831 absl::InlinedVector<const Shape*, 2> arg_shapes;
1832 arg_shapes.reserve(operands.size());
1833 for (auto* operand : operands) {
1834 arg_shapes.push_back(&operand->shape());
1835 }
1836 return ShapeInference::InferReduceShape(
1837 arg_shapes, *dimensions_to_reduce,
1838 reduce_computation.value()->ComputeProgramShape());
1839 },
1840 &shape)) {
1841 return false;
1842 }
1843 instruction = builder->AddInstruction(HloInstruction::CreateReduce(
1844 shape, /*operands=*/
1845 absl::Span<HloInstruction* const>(operands).subspan(
1846 0, operands.size() / 2),
1847 /*init_values=*/
1848 absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1849 2),
1850 *dimensions_to_reduce, *reduce_computation));
1851 break;
1852 }
1853 case HloOpcode::kReverse: {
1854 optional<std::vector<int64>> dimensions;
1855 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1856 &dimensions};
1857 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1858 !ParseAttributes(attrs)) {
1859 return false;
1860 }
1861 if (!maybe_infer_shape(
1862 [&] {
1863 return ShapeInference::InferReverseShape(operands[0]->shape(),
1864 *dimensions);
1865 },
1866 &shape)) {
1867 return false;
1868 }
1869 instruction = builder->AddInstruction(
1870 HloInstruction::CreateReverse(shape, operands[0], *dimensions));
1871 break;
1872 }
1873 case HloOpcode::kSelectAndScatter: {
1874 optional<HloComputation*> select;
1875 attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
1876 optional<HloComputation*> scatter;
1877 attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
1878 optional<Window> window;
1879 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1880 if (!ParseOperands(&operands, /*expected_size=*/3) ||
1881 !ParseAttributes(attrs)) {
1882 return false;
1883 }
1884 if (!window) {
1885 window.emplace();
1886 }
1887 if (!maybe_infer_shape(
1888 [&] {
1889 return ShapeInference::InferSelectAndScatterShape(
1890 operands[0]->shape(), select.value()->ComputeProgramShape(),
1891 *window, operands[1]->shape(), operands[2]->shape(),
1892 scatter.value()->ComputeProgramShape());
1893 },
1894 &shape)) {
1895 return false;
1896 }
1897 instruction =
1898 builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
1899 shape, /*operand=*/operands[0], *select, *window,
1900 /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
1901 break;
1902 }
1903 case HloOpcode::kSlice: {
1904 optional<SliceRanges> slice_ranges;
1905 attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
1906 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1907 !ParseAttributes(attrs)) {
1908 return false;
1909 }
1910 instruction = builder->AddInstruction(HloInstruction::CreateSlice(
1911 shape, operands[0], slice_ranges->starts, slice_ranges->limits,
1912 slice_ranges->strides));
1913 break;
1914 }
1915 case HloOpcode::kDynamicSlice: {
1916 optional<std::vector<int64>> dynamic_slice_sizes;
1917 attrs["dynamic_slice_sizes"] = {
1918 /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
1919 LocTy loc = lexer_.GetLoc();
1920 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1921 return false;
1922 }
1923 if (operands.empty()) {
1924 return Error(loc, "Expected at least one operand.");
1925 }
1926 if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
1927 operands.size() != 1 + operands[0]->shape().rank()) {
1928 return Error(loc, "Wrong number of operands.");
1929 }
1930 instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
1931 shape, /*operand=*/operands[0],
1932 /*start_indices=*/absl::MakeSpan(operands).subspan(1),
1933 *dynamic_slice_sizes));
1934 break;
1935 }
1936 case HloOpcode::kDynamicUpdateSlice: {
1937 LocTy loc = lexer_.GetLoc();
1938 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1939 return false;
1940 }
1941 if (operands.size() < 2) {
1942 return Error(loc, "Expected at least two operands.");
1943 }
1944 if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
1945 operands.size() != 2 + operands[0]->shape().rank()) {
1946 return Error(loc, "Wrong number of operands.");
1947 }
1948 instruction =
1949 builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1950 shape, /*operand=*/operands[0], /*update=*/operands[1],
1951 /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
1952 break;
1953 }
1954 case HloOpcode::kTranspose: {
1955 optional<std::vector<int64>> dimensions;
1956 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1957 &dimensions};
1958 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1959 !ParseAttributes(attrs)) {
1960 return false;
1961 }
1962 if (!maybe_infer_shape(
1963 [&] {
1964 return ShapeInference::InferTransposeShape(operands[0]->shape(),
1965 *dimensions);
1966 },
1967 &shape)) {
1968 return false;
1969 }
1970 instruction = builder->AddInstruction(
1971 HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
1972 break;
1973 }
1974 case HloOpcode::kBatchNormTraining: {
1975 optional<float> epsilon;
1976 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
1977 optional<int64> feature_index;
1978 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
1979 &feature_index};
1980 if (!ParseOperands(&operands, /*expected_size=*/3) ||
1981 !ParseAttributes(attrs)) {
1982 return false;
1983 }
1984 if (!maybe_infer_shape(
1985 [&] {
1986 return ShapeInference::InferBatchNormTrainingShape(
1987 operands[0]->shape(), operands[1]->shape(),
1988 operands[2]->shape(), *feature_index);
1989 },
1990 &shape)) {
1991 return false;
1992 }
1993 instruction =
1994 builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
1995 shape, /*operand=*/operands[0], /*scale=*/operands[1],
1996 /*offset=*/operands[2], *epsilon, *feature_index));
1997 break;
1998 }
1999 case HloOpcode::kBatchNormInference: {
2000 optional<float> epsilon;
2001 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2002 optional<int64> feature_index;
2003 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2004 &feature_index};
2005 if (!ParseOperands(&operands, /*expected_size=*/5) ||
2006 !ParseAttributes(attrs)) {
2007 return false;
2008 }
2009 if (!maybe_infer_shape(
2010 [&] {
2011 return ShapeInference::InferBatchNormInferenceShape(
2012 operands[0]->shape(), operands[1]->shape(),
2013 operands[2]->shape(), operands[3]->shape(),
2014 operands[4]->shape(), *feature_index);
2015 },
2016 &shape)) {
2017 return false;
2018 }
2019 instruction =
2020 builder->AddInstruction(HloInstruction::CreateBatchNormInference(
2021 shape, /*operand=*/operands[0], /*scale=*/operands[1],
2022 /*offset=*/operands[2], /*mean=*/operands[3],
2023 /*variance=*/operands[4], *epsilon, *feature_index));
2024 break;
2025 }
2026 case HloOpcode::kBatchNormGrad: {
2027 optional<float> epsilon;
2028 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2029 optional<int64> feature_index;
2030 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2031 &feature_index};
2032 if (!ParseOperands(&operands, /*expected_size=*/5) ||
2033 !ParseAttributes(attrs)) {
2034 return false;
2035 }
2036 if (!maybe_infer_shape(
2037 [&] {
2038 return ShapeInference::InferBatchNormGradShape(
2039 operands[0]->shape(), operands[1]->shape(),
2040 operands[2]->shape(), operands[3]->shape(),
2041 operands[4]->shape(), *feature_index);
2042 },
2043 &shape)) {
2044 return false;
2045 }
2046 instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
2047 shape, /*operand=*/operands[0], /*scale=*/operands[1],
2048 /*mean=*/operands[2], /*variance=*/operands[3],
2049 /*grad_output=*/operands[4], *epsilon, *feature_index));
2050 break;
2051 }
2052 case HloOpcode::kPad: {
2053 optional<PaddingConfig> padding;
2054 attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
2055 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2056 !ParseAttributes(attrs)) {
2057 return false;
2058 }
2059 if (!maybe_infer_shape(
2060 [&] {
2061 return ShapeInference::InferPadShape(
2062 operands[0]->shape(), operands[1]->shape(), *padding);
2063 },
2064 &shape)) {
2065 return false;
2066 }
2067 instruction = builder->AddInstruction(HloInstruction::CreatePad(
2068 shape, operands[0], /*padding_value=*/operands[1], *padding));
2069 break;
2070 }
2071 case HloOpcode::kFusion: {
2072 optional<HloComputation*> fusion_computation;
2073 attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
2074 &fusion_computation};
2075 optional<HloInstruction::FusionKind> fusion_kind;
2076 attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
2077 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2078 return false;
2079 }
2080 instruction = builder->AddInstruction(HloInstruction::CreateFusion(
2081 shape, *fusion_kind, operands, *fusion_computation));
2082 break;
2083 }
2084 case HloOpcode::kInfeed: {
2085 optional<std::string> config;
2086 attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2087 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2088 !ParseAttributes(attrs)) {
2089 return false;
2090 }
2091 // We need to know the infeed data shape to construct the infeed
2092 // instruction. This is the zero-th element of the tuple-shaped output of
2093 // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
2094 // if the shape is not a non-empty tuple, so add guard so an error message
2095 // can be emitted instead of a check fail
2096 if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
2097 return Error(lexer_.GetLoc(),
2098 "infeed must have a non-empty tuple shape");
2099 }
2100 instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
2101 ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
2102 config ? *config : ""));
2103 break;
2104 }
2105 case HloOpcode::kOutfeed: {
2106 optional<std::string> config;
2107 optional<Shape> outfeed_shape;
2108 attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2109 attrs["outfeed_shape"] = {/*required=*/false, AttrTy::kShape,
2110 &outfeed_shape};
2111 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2112 !ParseAttributes(attrs)) {
2113 return false;
2114 }
2115 HloInstruction* const outfeed_input = operands[0];
2116 HloInstruction* const outfeed_token = operands[1];
2117 const Shape shape =
2118 outfeed_shape.has_value() ? *outfeed_shape : outfeed_input->shape();
2119 instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
2120 shape, outfeed_input, outfeed_token, config ? *config : ""));
2121 break;
2122 }
2123 case HloOpcode::kRng: {
2124 optional<RandomDistribution> distribution;
2125 attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
2126 &distribution};
2127 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2128 return false;
2129 }
2130 instruction = builder->AddInstruction(
2131 HloInstruction::CreateRng(shape, *distribution, operands));
2132 break;
2133 }
2134 case HloOpcode::kRngGetAndUpdateState: {
2135 optional<int64> delta;
2136 attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
2137 if (!ParseOperands(&operands, /*expected_size=*/0) ||
2138 !ParseAttributes(attrs)) {
2139 return false;
2140 }
2141 instruction = builder->AddInstruction(
2142 HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
2143 break;
2144 }
2145 case HloOpcode::kRngBitGenerator: {
2146 optional<RandomAlgorithm> algorithm;
2147 attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
2148 &algorithm};
2149 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2150 return false;
2151 }
2152 instruction =
2153 builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
2154 shape, operands[0], *algorithm));
2155 break;
2156 }
2157 case HloOpcode::kReducePrecision: {
2158 optional<int64> exponent_bits;
2159 optional<int64> mantissa_bits;
2160 attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
2161 &exponent_bits};
2162 attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
2163 &mantissa_bits};
2164 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2165 !ParseAttributes(attrs)) {
2166 return false;
2167 }
2168 instruction =
2169 builder->AddInstruction(HloInstruction::CreateReducePrecision(
2170 shape, operands[0], static_cast<int>(*exponent_bits),
2171 static_cast<int>(*mantissa_bits)));
2172 break;
2173 }
2174 case HloOpcode::kConditional: {
2175 optional<HloComputation*> true_computation;
2176 optional<HloComputation*> false_computation;
2177 optional<std::vector<HloComputation*>> branch_computations;
2178 if (!ParseOperands(&operands)) {
2179 return false;
2180 }
2181 if (!ShapeUtil::IsScalar(operands[0]->shape())) {
2182 return Error(lexer_.GetLoc(), "The first operand must be a scalar");
2183 }
2184 const bool branch_index_is_bool =
2185 operands[0]->shape().element_type() == PRED;
2186 if (branch_index_is_bool) {
2187 attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
2188 &true_computation};
2189 attrs["false_computation"] = {
2190 /*required=*/true, AttrTy::kHloComputation, &false_computation};
2191 } else {
2192 if (operands[0]->shape().element_type() != S32) {
2193 return Error(lexer_.GetLoc(),
2194 "The first operand must be a scalar of PRED or S32");
2195 }
2196 attrs["branch_computations"] = {/*required=*/true,
2197 AttrTy::kBracedHloComputationList,
2198 &branch_computations};
2199 }
2200 if (!ParseAttributes(attrs)) {
2201 return false;
2202 }
2203 if (branch_index_is_bool) {
2204 branch_computations.emplace({*true_computation, *false_computation});
2205 }
2206 if (branch_computations->empty() ||
2207 operands.size() != branch_computations->size() + 1) {
2208 return false;
2209 }
2210 if (!maybe_infer_shape(
2211 [&] {
2212 absl::InlinedVector<ProgramShape, 2> branch_computation_shapes;
2213 branch_computation_shapes.reserve(branch_computations->size());
2214 for (auto* computation : *branch_computations) {
2215 branch_computation_shapes.push_back(
2216 computation->ComputeProgramShape());
2217 }
2218 absl::InlinedVector<Shape, 2> branch_operand_shapes;
2219 branch_operand_shapes.reserve(operands.size() - 1);
2220 for (int i = 1; i < operands.size(); ++i) {
2221 branch_operand_shapes.push_back(operands[i]->shape());
2222 }
2223 return ShapeInference::InferConditionalShape(
2224 operands[0]->shape(), branch_computation_shapes,
2225 branch_operand_shapes);
2226 },
2227 &shape)) {
2228 return false;
2229 }
2230 instruction = builder->AddInstruction(HloInstruction::CreateConditional(
2231 shape, /*branch_index=*/operands[0],
2232 absl::MakeSpan(*branch_computations),
2233 absl::MakeSpan(operands).subspan(1)));
2234 break;
2235 }
2236 case HloOpcode::kCustomCall: {
2237 optional<std::string> custom_call_target;
2238 optional<Window> window;
2239 optional<ConvolutionDimensionNumbers> dnums;
2240 optional<int64> feature_group_count;
2241 optional<int64> batch_group_count;
2242 optional<std::vector<Shape>> operand_layout_constraints;
2243 optional<bool> custom_call_has_side_effect;
2244 optional<HloComputation*> to_apply;
2245 optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
2246 output_to_operand_aliasing;
2247 optional<PaddingType> padding_type;
2248 optional<std::vector<HloComputation*>> called_computations;
2249 attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
2250 &custom_call_target};
2251 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2252 attrs["dim_labels"] = {/*required=*/false,
2253 AttrTy::kConvolutionDimensionNumbers, &dnums};
2254 attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
2255 &feature_group_count};
2256 attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
2257 &batch_group_count};
2258 attrs["operand_layout_constraints"] = {
2259 /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
2260 attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
2261 &custom_call_has_side_effect};
2262 attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
2263 &to_apply};
2264 attrs["called_computations"] = {/*required=*/false,
2265 AttrTy::kBracedHloComputationList,
2266 &called_computations};
2267 attrs["output_to_operand_aliasing"] = {/*required=*/false,
2268 AttrTy::kInstructionAliasing,
2269 &output_to_operand_aliasing};
2270
2271 attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
2272 &padding_type};
2273
2274 optional<Literal> literal;
2275 attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
2276 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2277 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2278 &operand_precision};
2279 if (called_computations.has_value() && to_apply.has_value()) {
2280 return Error(lexer_.GetLoc(),
2281 "A single instruction can't have both to_apply and "
2282 "calls field");
2283 }
2284 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2285 return false;
2286 }
2287 if (operand_layout_constraints.has_value()) {
2288 if (!LayoutUtil::HasLayout(shape)) {
2289 return Error(lexer_.GetLoc(),
2290 "Layout must be set on layout-constrained custom call");
2291 }
2292 if (operands.size() != operand_layout_constraints->size()) {
2293 return Error(lexer_.GetLoc(),
2294 StrCat("Expected ", operands.size(),
2295 " operand layout constraints, ",
2296 operand_layout_constraints->size(), " given"));
2297 }
2298 for (int64 i = 0; i < operands.size(); ++i) {
2299 const Shape& operand_shape_with_layout =
2300 (*operand_layout_constraints)[i];
2301 if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
2302 return Error(lexer_.GetLoc(),
2303 StrCat("Operand layout constraint shape ",
2304 ShapeUtil::HumanStringWithLayout(
2305 operand_shape_with_layout),
2306 " for operand ", i, " does not have a layout"));
2307 }
2308 if (!ShapeUtil::Compatible(operand_shape_with_layout,
2309 operands[i]->shape())) {
2310 return Error(
2311 lexer_.GetLoc(),
2312 StrCat(
2313 "Operand layout constraint shape ",
2314 ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
2315 " for operand ", i,
2316 " is not compatible with operand shape ",
2317 ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
2318 }
2319 }
2320 instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
2321 shape, operands, *custom_call_target, *operand_layout_constraints,
2322 backend_config ? *backend_config : ""));
2323 } else {
2324 if (to_apply.has_value()) {
2325 instruction =
2326 builder->AddInstruction(HloInstruction::CreateCustomCall(
2327 shape, operands, *to_apply, *custom_call_target,
2328 backend_config ? *backend_config : ""));
2329 } else if (called_computations.has_value()) {
2330 instruction =
2331 builder->AddInstruction(HloInstruction::CreateCustomCall(
2332 shape, operands, *called_computations, *custom_call_target,
2333 backend_config ? *backend_config : ""));
2334 } else {
2335 instruction =
2336 builder->AddInstruction(HloInstruction::CreateCustomCall(
2337 shape, operands, *custom_call_target,
2338 backend_config ? *backend_config : ""));
2339 }
2340 }
2341 auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
2342 if (window.has_value()) {
2343 custom_call_instr->set_window(*window);
2344 }
2345 if (dnums.has_value()) {
2346 custom_call_instr->set_convolution_dimension_numbers(*dnums);
2347 }
2348 if (feature_group_count.has_value()) {
2349 custom_call_instr->set_feature_group_count(*feature_group_count);
2350 }
2351 if (batch_group_count.has_value()) {
2352 custom_call_instr->set_batch_group_count(*batch_group_count);
2353 }
2354 if (padding_type.has_value()) {
2355 custom_call_instr->set_padding_type(*padding_type);
2356 }
2357 if (custom_call_has_side_effect.has_value()) {
2358 custom_call_instr->set_custom_call_has_side_effect(
2359 *custom_call_has_side_effect);
2360 }
2361 if (output_to_operand_aliasing.has_value()) {
2362 custom_call_instr->set_output_to_operand_aliasing(
2363 std::move(*output_to_operand_aliasing));
2364 }
2365 if (literal.has_value()) {
2366 custom_call_instr->set_literal(std::move(*literal));
2367 }
2368 PrecisionConfig precision_config;
2369 if (operand_precision) {
2370 *precision_config.mutable_operand_precision() = {
2371 operand_precision->begin(), operand_precision->end()};
2372 } else {
2373 precision_config.mutable_operand_precision()->Resize(
2374 operands.size(), PrecisionConfig::DEFAULT);
2375 }
2376 *custom_call_instr->mutable_precision_config() = precision_config;
2377 break;
2378 }
2379 case HloOpcode::kDot: {
2380 optional<std::vector<int64>> lhs_contracting_dims;
2381 attrs["lhs_contracting_dims"] = {
2382 /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
2383 optional<std::vector<int64>> rhs_contracting_dims;
2384 attrs["rhs_contracting_dims"] = {
2385 /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
2386 optional<std::vector<int64>> lhs_batch_dims;
2387 attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2388 &lhs_batch_dims};
2389 optional<std::vector<int64>> rhs_batch_dims;
2390 attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2391 &rhs_batch_dims};
2392 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2393 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2394 &operand_precision};
2395
2396 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2397 !ParseAttributes(attrs)) {
2398 return false;
2399 }
2400
2401 DotDimensionNumbers dnum;
2402 if (lhs_contracting_dims) {
2403 *dnum.mutable_lhs_contracting_dimensions() = {
2404 lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
2405 }
2406 if (rhs_contracting_dims) {
2407 *dnum.mutable_rhs_contracting_dimensions() = {
2408 rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
2409 }
2410 if (lhs_batch_dims) {
2411 *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
2412 lhs_batch_dims->end()};
2413 }
2414 if (rhs_batch_dims) {
2415 *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
2416 rhs_batch_dims->end()};
2417 }
2418
2419 PrecisionConfig precision_config;
2420 if (operand_precision) {
2421 *precision_config.mutable_operand_precision() = {
2422 operand_precision->begin(), operand_precision->end()};
2423 } else {
2424 precision_config.mutable_operand_precision()->Resize(
2425 operands.size(), PrecisionConfig::DEFAULT);
2426 }
2427 if (!maybe_infer_shape(
2428 [&] {
2429 return ShapeInference::InferDotOpShape(
2430 operands[0]->shape(), operands[1]->shape(), dnum,
2431 /*preferred_element_type=*/absl::nullopt);
2432 },
2433 &shape)) {
2434 return false;
2435 }
2436 instruction = builder->AddInstruction(HloInstruction::CreateDot(
2437 shape, operands[0], operands[1], dnum, precision_config));
2438 break;
2439 }
2440 case HloOpcode::kGather: {
2441 optional<std::vector<int64>> offset_dims;
2442 attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
2443 &offset_dims};
2444 optional<std::vector<int64>> collapsed_slice_dims;
2445 attrs["collapsed_slice_dims"] = {
2446 /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
2447 optional<std::vector<int64>> start_index_map;
2448 attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
2449 &start_index_map};
2450 optional<int64> index_vector_dim;
2451 attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2452 &index_vector_dim};
2453 optional<std::vector<int64>> slice_sizes;
2454 attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
2455 &slice_sizes};
2456 optional<bool> indices_are_sorted = false;
2457 attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2458 &indices_are_sorted};
2459
2460 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2461 !ParseAttributes(attrs)) {
2462 return false;
2463 }
2464
2465 GatherDimensionNumbers dim_numbers =
2466 HloGatherInstruction::MakeGatherDimNumbers(
2467 /*offset_dims=*/*offset_dims,
2468 /*collapsed_slice_dims=*/*collapsed_slice_dims,
2469 /*start_index_map=*/*start_index_map,
2470 /*index_vector_dim=*/*index_vector_dim);
2471 if (!maybe_infer_shape(
2472 [&] {
2473 return ShapeInference::InferGatherShape(
2474 operands[0]->shape(), operands[1]->shape(), dim_numbers,
2475 *slice_sizes);
2476 },
2477 &shape)) {
2478 return false;
2479 }
2480 instruction = builder->AddInstruction(HloInstruction::CreateGather(
2481 shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
2482 dim_numbers, *slice_sizes, indices_are_sorted.value()));
2483 break;
2484 }
2485 case HloOpcode::kScatter: {
2486 optional<std::vector<int64>> update_window_dims;
2487 attrs["update_window_dims"] = {
2488 /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
2489 optional<std::vector<int64>> inserted_window_dims;
2490 attrs["inserted_window_dims"] = {
2491 /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
2492 optional<std::vector<int64>> scatter_dims_to_operand_dims;
2493 attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
2494 AttrTy::kBracedInt64List,
2495 &scatter_dims_to_operand_dims};
2496 optional<int64> index_vector_dim;
2497 attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2498 &index_vector_dim};
2499
2500 optional<HloComputation*> update_computation;
2501 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2502 &update_computation};
2503 optional<bool> indices_are_sorted = false;
2504 attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2505 &indices_are_sorted};
2506 optional<bool> unique_indices = false;
2507 attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
2508 &unique_indices};
2509
2510 if (!ParseOperands(&operands, /*expected_size=*/3) ||
2511 !ParseAttributes(attrs)) {
2512 return false;
2513 }
2514
2515 ScatterDimensionNumbers dim_numbers =
2516 HloScatterInstruction::MakeScatterDimNumbers(
2517 /*update_window_dims=*/*update_window_dims,
2518 /*inserted_window_dims=*/*inserted_window_dims,
2519 /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
2520 /*index_vector_dim=*/*index_vector_dim);
2521
2522 if (!maybe_infer_shape(
2523 [&] {
2524 return ShapeInference::InferScatterShape(
2525 operands[0]->shape(), operands[1]->shape(),
2526 operands[2]->shape(),
2527 update_computation.value()->ComputeProgramShape(),
2528 dim_numbers);
2529 },
2530 &shape)) {
2531 return false;
2532 }
2533 instruction = builder->AddInstruction(HloInstruction::CreateScatter(
2534 shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
2535 /*updates=*/operands[2], *update_computation, dim_numbers,
2536 indices_are_sorted.value(), unique_indices.value()));
2537 break;
2538 }
2539 case HloOpcode::kDomain: {
2540 DomainData domain;
2541 attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
2542 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2543 !ParseAttributes(attrs)) {
2544 return false;
2545 }
2546 if (!maybe_infer_shape(
2547 [&] {
2548 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
2549 },
2550 &shape)) {
2551 return false;
2552 }
2553 instruction = builder->AddInstruction(HloInstruction::CreateDomain(
2554 shape, operands[0], std::move(domain.exit_metadata),
2555 std::move(domain.entry_metadata)));
2556 break;
2557 }
2558 case HloOpcode::kTrace:
2559 return TokenError(StrCat("parsing not yet implemented for op: ",
2560 HloOpcodeString(opcode)));
2561 case HloOpcode::kGetDimensionSize: {
2562 optional<std::vector<int64>> dimensions;
2563 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2564 &dimensions};
2565 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2566 !ParseAttributes(attrs)) {
2567 return false;
2568 }
2569 if (!maybe_infer_shape(
2570 [&] {
2571 return ShapeInference::InferGetDimensionSizeShape(
2572 operands[0]->shape(), dimensions->at(0));
2573 },
2574 &shape)) {
2575 return false;
2576 }
2577 instruction =
2578 builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
2579 shape, operands[0], (*dimensions)[0]));
2580 break;
2581 }
2582 case HloOpcode::kSetDimensionSize: {
2583 optional<std::vector<int64>> dimensions;
2584 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2585 &dimensions};
2586 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2587 !ParseAttributes(attrs)) {
2588 return false;
2589 }
2590 if (!maybe_infer_shape(
2591 [&] {
2592 return ShapeInference::InferSetDimensionSizeShape(
2593 operands[0]->shape(), operands[1]->shape(),
2594 dimensions->at(0));
2595 },
2596 &shape)) {
2597 return false;
2598 }
2599 instruction =
2600 builder->AddInstruction(HloInstruction::CreateSetDimensionSize(
2601 shape, operands[0], operands[1], (*dimensions)[0]));
2602 break;
2603 }
2604 }
2605
2606 instruction->SetAndSanitizeName(name);
2607 if (instruction->name() != name) {
2608 return Error(name_loc,
2609 StrCat("illegal instruction name: ", name,
2610 "; suggest renaming to: ", instruction->name()));
2611 }
2612
2613 // Add shared attributes like metadata to the instruction, if they were seen.
2614 if (sharding) {
2615 instruction->set_sharding(
2616 HloSharding::FromProto(sharding.value()).ValueOrDie());
2617 }
2618 if (parameter_replication) {
2619 int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
2620 const auto& replicated =
2621 parameter_replication->replicated_at_leaf_buffers();
2622 if (leaf_count != replicated.size()) {
2623 return Error(lexer_.GetLoc(),
2624 StrCat("parameter has ", leaf_count,
2625 " leaf buffers, but parameter_replication has ",
2626 replicated.size(), " elements."));
2627 }
2628 instruction->set_parameter_replicated_at_leaf_buffers(replicated);
2629 }
2630 if (predecessors) {
2631 for (auto* pre : *predecessors) {
2632 Status status = pre->AddControlDependencyTo(instruction);
2633 if (!status.ok()) {
2634 return Error(name_loc, StrCat("error adding control dependency for: ",
2635 name, " status: ", status.ToString()));
2636 }
2637 }
2638 }
2639 if (metadata) {
2640 instruction->set_metadata(*metadata);
2641 }
2642 if (backend_config) {
2643 instruction->set_raw_backend_config_string(std::move(*backend_config));
2644 }
2645 if (outer_dimension_partitions) {
2646 instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
2647 }
2648 if (frontend_attributes) {
2649 instruction->set_frontend_attributes(*frontend_attributes);
2650 }
2651 return AddInstruction(name, instruction, name_loc);
2652 } // NOLINT(readability/fn_size)
2653
2654 // ::= '{' (single_sharding | tuple_sharding) '}'
2655 //
2656 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)2657 bool HloParserImpl::ParseSharding(OpSharding* sharding) {
2658 // A single sharding starts with '{' and is not followed by '{'.
2659 // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
2660 // an empty tuple.
2661 if (!ParseToken(TokKind::kLbrace,
2662 "expected '{' to start sharding attribute")) {
2663 return false;
2664 }
2665
2666 if (lexer_.GetKind() != TokKind::kLbrace &&
2667 lexer_.GetKind() != TokKind::kRbrace) {
2668 return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
2669 }
2670
2671 // Tuple sharding.
2672 // Allow empty tuple shardings.
2673 if (lexer_.GetKind() != TokKind::kRbrace) {
2674 do {
2675 if (!ParseSingleSharding(sharding->add_tuple_shardings(),
2676 /*lbrace_pre_lexed=*/false)) {
2677 return false;
2678 }
2679 } while (EatIfPresent(TokKind::kComma));
2680 }
2681 sharding->set_type(OpSharding::TUPLE);
2682
2683 return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
2684 }
2685
2686 // frontend_attributes ::= '{' attributes '}'
2687 // attributes
2688 // ::= /*empty*/
2689 // ::= attribute '=' value (',' attribute '=' value)*
ParseFrontendAttributes(FrontendAttributes * frontend_attributes)2690 bool HloParserImpl::ParseFrontendAttributes(
2691 FrontendAttributes* frontend_attributes) {
2692 CHECK(frontend_attributes != nullptr);
2693 if (!ParseToken(TokKind::kLbrace,
2694 "expected '{' to start frontend attributes")) {
2695 return false;
2696 }
2697 if (lexer_.GetKind() == TokKind::kRbrace) {
2698 // empty
2699 } else {
2700 do {
2701 std::string attribute;
2702 if (!ParseAttributeName(&attribute)) {
2703 return false;
2704 }
2705 if (lexer_.GetKind() != TokKind::kString) {
2706 return false;
2707 }
2708 (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
2709 lexer_.Lex();
2710 } while (EatIfPresent(TokKind::kComma));
2711 }
2712 return ParseToken(TokKind::kRbrace,
2713 "expects '}' at the end of frontend attributes");
2714 }
2715
2716 // ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
2717 // ('devices=' ('[' dims ']')* device_list)?
2718 // ('metadata=' metadata)* '}'
2719 //
2720 // dims ::= int_list device_list ::= int_list
2721 // metadata ::= single_metadata |
2722 // ('{' [single_metadata (',' single_metadata)*] '}')
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)2723 bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
2724 bool lbrace_pre_lexed) {
2725 if (!lbrace_pre_lexed &&
2726 !ParseToken(TokKind::kLbrace,
2727 "expected '{' to start sharding attribute")) {
2728 return false;
2729 }
2730
2731 LocTy loc = lexer_.GetLoc();
2732 bool maximal = false;
2733 bool replicated = false;
2734 bool manual = false;
2735 bool last_tile_dim_replicate = false;
2736 std::vector<int64> devices;
2737 std::vector<int64> tile_assignment_dimensions;
2738 while (lexer_.GetKind() != TokKind::kRbrace) {
2739 switch (lexer_.GetKind()) {
2740 case TokKind::kw_maximal:
2741 maximal = true;
2742 lexer_.Lex();
2743 break;
2744 case TokKind::kw_replicated:
2745 replicated = true;
2746 lexer_.Lex();
2747 break;
2748 case TokKind::kw_manual:
2749 manual = true;
2750 lexer_.Lex();
2751 break;
2752 case TokKind::kAttributeName: {
2753 if (lexer_.GetStrVal() == "device") {
2754 if (lexer_.Lex() != TokKind::kInt) {
2755 return TokenError("device= attribute must be an integer");
2756 }
2757 devices = {lexer_.GetInt64Val()};
2758 lexer_.Lex();
2759 } else if (lexer_.GetStrVal() == "devices") {
2760 lexer_.Lex();
2761 if (!ParseToken(TokKind::kLsquare,
2762 "expected '[' to start sharding devices shape")) {
2763 return false;
2764 }
2765
2766 do {
2767 int64 dim;
2768 if (!ParseInt64(&dim)) {
2769 return false;
2770 }
2771 tile_assignment_dimensions.push_back(dim);
2772 } while (EatIfPresent(TokKind::kComma));
2773
2774 if (!ParseToken(TokKind::kRsquare,
2775 "expected ']' to start sharding devices shape")) {
2776 return false;
2777 }
2778 do {
2779 int64 device;
2780 if (!ParseInt64(&device)) {
2781 return false;
2782 }
2783 devices.push_back(device);
2784 } while (EatIfPresent(TokKind::kComma));
2785 } else if (lexer_.GetStrVal() == "metadata") {
2786 lexer_.Lex();
2787 if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) {
2788 return false;
2789 }
2790 } else {
2791 return TokenError(
2792 "unknown attribute in sharding: expected device=, devices= or "
2793 "metadata=");
2794 }
2795 break;
2796 }
2797 case TokKind::kw_last_tile_dim_replicate:
2798 last_tile_dim_replicate = true;
2799 lexer_.Lex();
2800 break;
2801 case TokKind::kRbrace:
2802 break;
2803 default:
2804 return TokenError("unexpected token");
2805 }
2806 }
2807
2808 if (replicated) {
2809 if (!devices.empty()) {
2810 return Error(loc,
2811 "replicated shardings should not have any devices assigned");
2812 }
2813 sharding->set_type(OpSharding::REPLICATED);
2814 } else if (maximal) {
2815 if (devices.size() != 1) {
2816 return Error(loc,
2817 "maximal shardings should have exactly one device assigned");
2818 }
2819 sharding->set_type(OpSharding::MAXIMAL);
2820 sharding->add_tile_assignment_devices(devices[0]);
2821 } else if (manual) {
2822 if (!devices.empty()) {
2823 return Error(loc,
2824 "manual shardings should not have any devices assigned");
2825 }
2826 sharding->set_type(OpSharding::MANUAL);
2827 } else {
2828 if (devices.size() <= 1) {
2829 return Error(
2830 loc, "non-maximal shardings must have more than one device assigned");
2831 }
2832 if (tile_assignment_dimensions.empty()) {
2833 return Error(
2834 loc,
2835 "non-maximal shardings must have a tile assignment list including "
2836 "dimensions");
2837 }
2838 sharding->set_type(OpSharding::OTHER);
2839 for (int64 dim : tile_assignment_dimensions) {
2840 sharding->add_tile_assignment_dimensions(dim);
2841 }
2842 for (int64 device : devices) {
2843 sharding->add_tile_assignment_devices(device);
2844 }
2845 sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate);
2846 }
2847
2848 lexer_.Lex();
2849 return true;
2850 }
2851
2852 // parameter_replication ::=
2853 // '{' ('true' | 'false')* (',' ('true' | 'false'))* '}'
ParseParameterReplication(ParameterReplication * parameter_replication)2854 bool HloParserImpl::ParseParameterReplication(
2855 ParameterReplication* parameter_replication) {
2856 if (!ParseToken(TokKind::kLbrace,
2857 "expected '{' to start parameter_replication attribute")) {
2858 return false;
2859 }
2860
2861 if (lexer_.GetKind() != TokKind::kRbrace) {
2862 do {
2863 if (lexer_.GetKind() == TokKind::kw_true) {
2864 parameter_replication->add_replicated_at_leaf_buffers(true);
2865 } else if (lexer_.GetKind() == TokKind::kw_false) {
2866 parameter_replication->add_replicated_at_leaf_buffers(false);
2867 } else {
2868 return false;
2869 }
2870 lexer_.Lex();
2871 } while (EatIfPresent(TokKind::kComma));
2872 }
2873
2874 return ParseToken(TokKind::kRbrace,
2875 "expected '}' to end parameter_replication attribute");
2876 }
2877
2878 // replica_groups ::='{' int64list_elements '}'
2879 // int64list_elements
2880 // ::= /*empty*/
2881 // ::= int64list (',' int64list)*
2882 // int64list ::= '{' int64_elements '}'
2883 // int64_elements
2884 // ::= /*empty*/
2885 // ::= int64_val (',' int64_val)*
ParseReplicaGroupsOnly(std::vector<ReplicaGroup> * replica_groups)2886 bool HloParserImpl::ParseReplicaGroupsOnly(
2887 std::vector<ReplicaGroup>* replica_groups) {
2888 std::vector<std::vector<int64>> result;
2889 if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
2890 &result)) {
2891 return false;
2892 }
2893 *replica_groups = CreateReplicaGroups(result);
2894 return true;
2895 }
2896
2897 // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
2898 // 'exit=' exit_sharding '}'
ParseDomain(DomainData * domain)2899 bool HloParserImpl::ParseDomain(DomainData* domain) {
2900 absl::flat_hash_map<std::string, AttrConfig> attrs;
2901 optional<std::string> kind;
2902 optional<OpSharding> entry_sharding;
2903 optional<OpSharding> exit_sharding;
2904 attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
2905 attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
2906 attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
2907 if (!ParseSubAttributes(attrs)) {
2908 return false;
2909 }
2910 if (*kind == ShardingMetadata::KindName()) {
2911 auto entry_sharding_ptr = absl::make_unique<HloSharding>(
2912 HloSharding::FromProto(*entry_sharding).ValueOrDie());
2913 auto exit_sharding_ptr = absl::make_unique<HloSharding>(
2914 HloSharding::FromProto(*exit_sharding).ValueOrDie());
2915 domain->entry_metadata =
2916 absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
2917 domain->exit_metadata =
2918 absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
2919 } else {
2920 return TokenError(StrCat("unsupported domain kind: ", *kind));
2921 }
2922 return true;
2923 }
2924
2925 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)2926 bool HloParserImpl::ParseInstructionNames(
2927 std::vector<HloInstruction*>* instructions) {
2928 if (!ParseToken(TokKind::kLbrace,
2929 "expects '{' at the beginning of instruction name list")) {
2930 return false;
2931 }
2932 LocTy loc = lexer_.GetLoc();
2933 do {
2934 std::string name;
2935 if (!ParseName(&name)) {
2936 return Error(loc, "expects a instruction name");
2937 }
2938 std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
2939 if (!instr) {
2940 return TokenError(StrFormat("instruction '%s' is not defined", name));
2941 }
2942 instructions->push_back(instr->first);
2943 } while (EatIfPresent(TokKind::kComma));
2944
2945 return ParseToken(TokKind::kRbrace,
2946 "expects '}' at the end of instruction name list");
2947 }
2948
SetValueInLiteral(LocTy loc,int64 value,int64 index,Literal * literal)2949 bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value, int64 index,
2950 Literal* literal) {
2951 const Shape& shape = literal->shape();
2952 switch (shape.element_type()) {
2953 case S8:
2954 return SetValueInLiteralHelper<int8>(loc, value, index, literal);
2955 case S16:
2956 return SetValueInLiteralHelper<int16>(loc, value, index, literal);
2957 case S32:
2958 return SetValueInLiteralHelper<int32>(loc, value, index, literal);
2959 case S64:
2960 return SetValueInLiteralHelper<int64>(loc, value, index, literal);
2961 case U8:
2962 return SetValueInLiteralHelper<tensorflow::uint8>(loc, value, index,
2963 literal);
2964 case U16:
2965 return SetValueInLiteralHelper<tensorflow::uint16>(loc, value, index,
2966 literal);
2967 case U32:
2968 return SetValueInLiteralHelper<tensorflow::uint32>(loc, value, index,
2969 literal);
2970 case U64:
2971 return SetValueInLiteralHelper<tensorflow::uint64>(loc, value, index,
2972 literal);
2973 case PRED:
2974 // Bool type literals with rank >= 1 are printed in 0s and 1s.
2975 return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
2976 literal);
2977 default:
2978 LOG(FATAL) << "unknown integral primitive type "
2979 << PrimitiveType_Name(shape.element_type());
2980 }
2981 }
2982
SetValueInLiteral(LocTy loc,double value,int64 index,Literal * literal)2983 bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64 index,
2984 Literal* literal) {
2985 const Shape& shape = literal->shape();
2986 switch (shape.element_type()) {
2987 case F16:
2988 return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
2989 case BF16:
2990 return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
2991 literal);
2992 case F32:
2993 return SetValueInLiteralHelper<float>(loc, value, index, literal);
2994 case F64:
2995 return SetValueInLiteralHelper<double>(loc, value, index, literal);
2996 default:
2997 LOG(FATAL) << "unknown floating point primitive type "
2998 << PrimitiveType_Name(shape.element_type());
2999 }
3000 }
3001
SetValueInLiteral(LocTy loc,bool value,int64 index,Literal * literal)3002 bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64 index,
3003 Literal* literal) {
3004 const Shape& shape = literal->shape();
3005 switch (shape.element_type()) {
3006 case PRED:
3007 return SetValueInLiteralHelper<bool>(loc, value, index, literal);
3008 default:
3009 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3010 << " is not PRED type";
3011 }
3012 }
3013
SetValueInLiteral(LocTy loc,std::complex<double> value,int64 index,Literal * literal)3014 bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
3015 int64 index, Literal* literal) {
3016 const Shape& shape = literal->shape();
3017 switch (shape.element_type()) {
3018 case C64:
3019 return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
3020 literal);
3021 case C128:
3022 return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
3023 literal);
3024 default:
3025 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3026 << " is not a complex type";
3027 }
3028 }
3029
3030 template <typename T>
StringifyValue(T val)3031 std::string StringifyValue(T val) {
3032 return StrCat(val);
3033 }
3034 template <>
StringifyValue(std::complex<double> val)3035 std::string StringifyValue(std::complex<double> val) {
3036 return StrFormat("(%f, %f)", std::real(val), std::imag(val));
3037 }
3038
3039 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(LocTy loc,ParsedElemT value,int64 index,Literal * literal)3040 bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
3041 int64 index, Literal* literal) {
3042 if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
3043 return false;
3044 }
3045
3046 // Check that the index is in range and assign into the literal
3047 if (index >= ShapeUtil::ElementsIn(literal->shape())) {
3048 return Error(loc, StrCat("tries to set value ", StringifyValue(value),
3049 " to a literal in shape ",
3050 ShapeUtil::HumanString(literal->shape()),
3051 " at linear index ", index,
3052 ", but the index is out of range"));
3053 }
3054 literal->data<LiteralNativeT>().at(index) =
3055 static_cast<LiteralNativeT>(value);
3056 return true;
3057 }
3058
ParseLiteral(Literal * literal)3059 bool HloParserImpl::ParseLiteral(Literal* literal) {
3060 Shape literal_shape;
3061 if (!ParseShape(&literal_shape)) {
3062 return false;
3063 }
3064 return ParseLiteral(literal, literal_shape);
3065 }
3066
3067 // literal
3068 // ::= tuple
3069 // ::= non_tuple
ParseLiteral(Literal * literal,const Shape & shape)3070 bool HloParserImpl::ParseLiteral(Literal* literal, const Shape& shape) {
3071 return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
3072 : ParseNonTupleLiteral(literal, shape);
3073 }
3074
3075 // tuple
3076 // ::= shape '(' literal_list ')'
3077 // literal_list
3078 // ::= /*empty*/
3079 // ::= literal (',' literal)*
ParseTupleLiteral(Literal * literal,const Shape & shape)3080 bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
3081 if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
3082 return false;
3083 }
3084 std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
3085
3086 if (lexer_.GetKind() == TokKind::kRparen) {
3087 // empty
3088 } else {
3089 // literal, (',' literal)*
3090 for (int i = 0; i < elements.size(); i++) {
3091 if (i > 0) {
3092 ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3093 }
3094 if (!ParseLiteral(&elements[i],
3095 ShapeUtil::GetTupleElementShape(shape, i))) {
3096 return TokenError(StrCat("expects the ", i, "th element"));
3097 }
3098 }
3099 }
3100 *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3101 return ParseToken(TokKind::kRparen,
3102 StrCat("expects ')' at the end of the tuple with ",
3103 ShapeUtil::TupleElementCount(shape), "elements"));
3104 }
3105
3106 // non_tuple
3107 // ::= rank01
3108 // ::= rank2345
3109 // rank2345 ::= shape nested_array
ParseNonTupleLiteral(Literal * literal,const Shape & shape)3110 bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
3111 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
3112 return ParseDenseLiteral(literal, shape);
3113 }
3114
ParseDenseLiteral(Literal * literal,const Shape & shape)3115 bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
3116 // Cast `rank` to int because we call shape.dimensions(int rank) below, and if
3117 // `rank` is an int64, that's an implicit narrowing conversion, which is
3118 // implementation-defined behavior.
3119 const int rank = static_cast<int>(shape.rank());
3120
3121 // Create a literal with the given shape in default layout.
3122 *literal = LiteralUtil::CreateFromDimensions(
3123 shape.element_type(), AsInt64Slice(shape.dimensions()));
3124 int64 nest_level = 0;
3125 int64 linear_index = 0;
3126 // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
3127 // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
3128 // when we are parsing the 2nd '{' (right before '1'), we are seeing a
3129 // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
3130 // the first '}' (right after '3'), it means the sub-array ends, and the
3131 // sub-array is supposed to contain exactly 3 elements, so check if
3132 // elems_seen_per_dim[1] is 3.
3133 std::vector<int64> elems_seen_per_dim(rank);
3134 auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string {
3135 std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
3136 elems_seen_per_dim.begin() + dim);
3137 return StrCat("[",
3138 StrJoin(elems_seen_until_dim, ",",
3139 [](std::string* out, const int64 num_elems) {
3140 StrAppend(out, num_elems - 1);
3141 }),
3142 "]");
3143 };
3144
3145 auto add_one_elem_seen = [&] {
3146 if (rank > 0) {
3147 if (nest_level != rank) {
3148 return TokenError(absl::StrFormat(
3149 "expects nested array in rank %d, but sees %d", rank, nest_level));
3150 }
3151 elems_seen_per_dim[rank - 1]++;
3152 if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
3153 return TokenError(absl::StrFormat(
3154 "expects %d elements on the minor-most dimension, but "
3155 "sees more",
3156 shape.dimensions(rank - 1)));
3157 }
3158 }
3159 return true;
3160 };
3161
3162 do {
3163 switch (lexer_.GetKind()) {
3164 default:
3165 return TokenError("unexpected token type in a literal");
3166 case TokKind::kLbrace: {
3167 nest_level++;
3168 if (nest_level > rank) {
3169 return TokenError(absl::StrFormat(
3170 "expects nested array in rank %d, but sees larger", rank));
3171 }
3172 if (nest_level > 1) {
3173 elems_seen_per_dim[nest_level - 2]++;
3174 if (elems_seen_per_dim[nest_level - 2] >
3175 shape.dimensions(nest_level - 2)) {
3176 return TokenError(absl::StrFormat(
3177 "expects %d elements in the %sth element, but sees more",
3178 shape.dimensions(nest_level - 2),
3179 get_index_str(nest_level - 2)));
3180 }
3181 }
3182 lexer_.Lex();
3183 break;
3184 }
3185 case TokKind::kRbrace: {
3186 nest_level--;
3187 if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
3188 return TokenError(absl::StrFormat(
3189 "expects %d elements in the %sth element, but sees %d",
3190 shape.dimensions(nest_level), get_index_str(nest_level),
3191 elems_seen_per_dim[nest_level]));
3192 }
3193 elems_seen_per_dim[nest_level] = 0;
3194 lexer_.Lex();
3195 break;
3196 }
3197 case TokKind::kLparen: {
3198 if (!primitive_util::IsComplexType(shape.element_type())) {
3199 return TokenError(
3200 absl::StrFormat("unexpected '(' in literal. Parens are only "
3201 "valid for complex literals"));
3202 }
3203
3204 std::complex<double> value;
3205 LocTy loc = lexer_.GetLoc();
3206 if (!add_one_elem_seen() || !ParseComplex(&value) ||
3207 !SetValueInLiteral(loc, value, linear_index++, literal)) {
3208 return false;
3209 }
3210 break;
3211 }
3212 case TokKind::kDots: {
3213 if (nest_level != 1) {
3214 return TokenError(absl::StrFormat(
3215 "expects `...` at nest level 1, but sees it at nest level %d",
3216 nest_level));
3217 }
3218 elems_seen_per_dim[0] = shape.dimensions(0);
3219 lexer_.Lex();
3220 // Fill data with deterministic (garbage) values. Use static to avoid
3221 // creating identical constants which could potentially got CSE'ed
3222 // away. This is a best-effort approach to make sure replaying a HLO
3223 // gives us same optimized HLO graph.
3224 static uint32 data = 0;
3225 uint32* raw_data = static_cast<uint32*>(literal->untyped_data());
3226 for (int64 i = 0; i < literal->size_bytes() / 4; ++i) {
3227 raw_data[i] = data++;
3228 }
3229 uint8* raw_data_int8 = static_cast<uint8*>(literal->untyped_data());
3230 static uint8 data_int8 = 0;
3231 for (int64 i = 0; i < literal->size_bytes() % 4; ++i) {
3232 raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++;
3233 }
3234 break;
3235 }
3236 case TokKind::kComma:
3237 // Skip.
3238 lexer_.Lex();
3239 break;
3240 case TokKind::kw_true:
3241 case TokKind::kw_false:
3242 case TokKind::kInt:
3243 case TokKind::kDecimal:
3244 case TokKind::kw_nan:
3245 case TokKind::kNegNan:
3246 case TokKind::kw_inf:
3247 case TokKind::kNegInf: {
3248 add_one_elem_seen();
3249 if (lexer_.GetKind() == TokKind::kw_true ||
3250 lexer_.GetKind() == TokKind::kw_false) {
3251 if (!SetValueInLiteral(lexer_.GetLoc(),
3252 lexer_.GetKind() == TokKind::kw_true,
3253 linear_index++, literal)) {
3254 return false;
3255 }
3256 lexer_.Lex();
3257 } else if (primitive_util::IsIntegralType(shape.element_type()) ||
3258 shape.element_type() == PRED) {
3259 LocTy loc = lexer_.GetLoc();
3260 int64 value;
3261 if (!ParseInt64(&value)) {
3262 return Error(loc, StrCat("expects integer for primitive type: ",
3263 PrimitiveType_Name(shape.element_type())));
3264 }
3265 if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3266 return false;
3267 }
3268 } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
3269 LocTy loc = lexer_.GetLoc();
3270 double value;
3271 if (!ParseDouble(&value)) {
3272 return Error(
3273 loc, StrCat("expect floating point value for primitive type: ",
3274 PrimitiveType_Name(shape.element_type())));
3275 }
3276 if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3277 return false;
3278 }
3279 } else {
3280 return TokenError(StrCat("unsupported primitive type ",
3281 PrimitiveType_Name(shape.element_type())));
3282 }
3283 break;
3284 }
3285 } // end of switch
3286 } while (nest_level > 0);
3287
3288 *literal = literal->Relayout(shape.layout());
3289 return true;
3290 }
3291
3292 // MaxFiniteValue is a type-traits helper used by
3293 // HloParserImpl::CheckParsedValueIsInRange.
3294 template <typename T>
3295 struct MinMaxFiniteValue {
maxxla::__anonf806807d0111::MinMaxFiniteValue3296 static T max() { return std::numeric_limits<T>::max(); }
minxla::__anonf806807d0111::MinMaxFiniteValue3297 static T min() { return std::numeric_limits<T>::lowest(); }
3298 };
3299
3300 template <>
3301 struct MinMaxFiniteValue<Eigen::half> {
maxxla::__anonf806807d0111::MinMaxFiniteValue3302 static double max() {
3303 // Sadly this is not constexpr, so this forces `value` to be a method.
3304 return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
3305 }
minxla::__anonf806807d0111::MinMaxFiniteValue3306 static double min() { return -max(); }
3307 };
3308
3309 template <>
3310 struct MinMaxFiniteValue<bfloat16> {
maxxla::__anonf806807d0111::MinMaxFiniteValue3311 static double max() {
3312 return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
3313 }
minxla::__anonf806807d0111::MinMaxFiniteValue3314 static double min() { return -max(); }
3315 };
3316
3317 // MSVC's standard C++ library does not define isnan/isfinite for integer types.
3318 // To work around that we will need to provide our own.
3319 template <typename T>
IsFinite(T val)3320 std::enable_if_t<std::is_floating_point<T>::value, bool> IsFinite(T val) {
3321 return std::isfinite(val);
3322 }
3323 template <typename T>
IsNaN(T val)3324 std::enable_if_t<std::is_floating_point<T>::value, bool> IsNaN(T val) {
3325 return std::isnan(val);
3326 }
3327 template <typename T>
IsFinite(T val)3328 std::enable_if_t<std::is_integral<T>::value, bool> IsFinite(T val) {
3329 return std::isfinite(static_cast<double>(val));
3330 }
3331 template <typename T>
IsNaN(T val)3332 std::enable_if_t<std::is_integral<T>::value, bool> IsNaN(T val) {
3333 return std::isnan(static_cast<double>(val));
3334 }
3335
3336 template <typename LiteralNativeT, typename ParsedElemT>
CheckParsedValueIsInRange(LocTy loc,ParsedElemT value)3337 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
3338 if (std::is_floating_point<ParsedElemT>::value) {
3339 auto value_as_native_t = static_cast<LiteralNativeT>(value);
3340 auto value_double_converted = static_cast<ParsedElemT>(value_as_native_t);
3341 if (!IsFinite(value) || IsFinite(value_double_converted)) {
3342 value = value_double_converted;
3343 }
3344 }
3345 PrimitiveType literal_ty =
3346 primitive_util::NativeToPrimitiveType<LiteralNativeT>();
3347 if (IsNaN(value) ||
3348 (std::numeric_limits<ParsedElemT>::has_infinity &&
3349 (std::numeric_limits<ParsedElemT>::infinity() == value ||
3350 -std::numeric_limits<ParsedElemT>::infinity() == value))) {
3351 // Skip range checking for non-finite value.
3352 } else if (std::is_unsigned<LiteralNativeT>::value) {
3353 CHECK((std::is_same<ParsedElemT, int64>::value ||
3354 std::is_same<ParsedElemT, bool>::value))
3355 << "Unimplemented checking for ParsedElemT";
3356
3357 const uint64 unsigned_value = value;
3358 const uint64 upper_bound =
3359 static_cast<uint64>(std::numeric_limits<LiteralNativeT>::max());
3360 if (unsigned_value > upper_bound) {
3361 // Value is out of range for LiteralNativeT.
3362 return Error(loc, StrCat("value ", value,
3363 " is out of range for literal's primitive type ",
3364 PrimitiveType_Name(literal_ty), " namely [0, ",
3365 upper_bound, "]."));
3366 }
3367 } else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
3368 value < MinMaxFiniteValue<LiteralNativeT>::min()) {
3369 // Value is out of range for LiteralNativeT.
3370 return Error(loc, StrCat("value ", value,
3371 " is out of range for literal's primitive type ",
3372 PrimitiveType_Name(literal_ty), " namely [",
3373 MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
3374 MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
3375 }
3376 return true;
3377 }
3378
3379 template <typename LiteralNativeT>
CheckParsedValueIsInRange(LocTy loc,std::complex<double> value)3380 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc,
3381 std::complex<double> value) {
3382 // e.g. `float` for std::complex<float>
3383 using LiteralComplexComponentT =
3384 decltype(std::real(std::declval<LiteralNativeT>()));
3385
3386 // We could do simply
3387 //
3388 // return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
3389 // CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
3390 //
3391 // but this would give bad error messages on failure.
3392
3393 auto check_component = [&](absl::string_view name, double v) {
3394 if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
3395 v == -std::numeric_limits<double>::infinity()) {
3396 // Skip range-checking for non-finite values.
3397 return true;
3398 }
3399
3400 double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
3401 double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
3402 if (v < min || v > max) {
3403 // Value is out of range for LitearlComplexComponentT.
3404 return Error(
3405 loc,
3406 StrCat(name, " part ", v,
3407 " is out of range for literal's primitive type ",
3408 PrimitiveType_Name(
3409 primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
3410 ", namely [", min, ", ", max, "]."));
3411 }
3412 return true;
3413 };
3414 return check_component("real", std::real(value)) &&
3415 check_component("imaginary", std::imag(value));
3416 }
3417
3418 // operands ::= '(' operands1 ')'
3419 // operands1
3420 // ::= /*empty*/
3421 // ::= operand (, operand)*
3422 // operand ::= (shape)? name
ParseOperands(std::vector<HloInstruction * > * operands)3423 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands) {
3424 CHECK(operands != nullptr);
3425 if (!ParseToken(TokKind::kLparen,
3426 "expects '(' at the beginning of operands")) {
3427 return false;
3428 }
3429 if (lexer_.GetKind() == TokKind::kRparen) {
3430 // empty
3431 } else {
3432 do {
3433 LocTy loc = lexer_.GetLoc();
3434 std::string name;
3435 optional<Shape> shape;
3436 if (CanBeShape()) {
3437 shape.emplace();
3438 if (!ParseShape(&shape.value())) {
3439 return false;
3440 }
3441 }
3442 if (!ParseName(&name)) {
3443 // When parsing a single instruction (as opposed to a whole module), an
3444 // HLO may have one or more operands with a shape but no name:
3445 //
3446 // foo = add(f32[10], f32[10])
3447 //
3448 // create_missing_instruction_ is always non-null when parsing a single
3449 // instruction, and is responsible for creating kParameter instructions
3450 // for these operands.
3451 if (shape.has_value() && create_missing_instruction_ != nullptr &&
3452 scoped_name_tables_.size() == 1) {
3453 name = "";
3454 } else {
3455 return false;
3456 }
3457 }
3458 std::pair<HloInstruction*, LocTy>* instruction =
3459 FindInstruction(name, shape);
3460 if (instruction == nullptr) {
3461 return Error(loc, StrCat("instruction does not exist: ", name));
3462 }
3463 operands->push_back(instruction->first);
3464 } while (EatIfPresent(TokKind::kComma));
3465 }
3466 return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
3467 }
3468
ParseOperands(std::vector<HloInstruction * > * operands,const int expected_size)3469 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
3470 const int expected_size) {
3471 CHECK(operands != nullptr);
3472 LocTy loc = lexer_.GetLoc();
3473 if (!ParseOperands(operands)) {
3474 return false;
3475 }
3476 if (expected_size != operands->size()) {
3477 return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
3478 operands->size(), " operands"));
3479 }
3480 return true;
3481 }
3482
3483 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)3484 bool HloParserImpl::ParseSubAttributes(
3485 const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
3486 LocTy loc = lexer_.GetLoc();
3487 if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
3488 return false;
3489 }
3490 absl::flat_hash_set<std::string> seen_attrs;
3491 if (lexer_.GetKind() == TokKind::kRbrace) {
3492 // empty
3493 } else {
3494 do {
3495 EatIfPresent(TokKind::kComma);
3496 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
3497 return false;
3498 }
3499 } while (lexer_.GetKind() != TokKind::kRbrace);
3500 }
3501 // Check that all required attrs were seen.
3502 for (const auto& attr_it : attrs) {
3503 if (attr_it.second.required &&
3504 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
3505 return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
3506 attr_it.first));
3507 }
3508 }
3509 return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
3510 }
3511
3512 // attributes ::= (',' attribute)*
ParseAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)3513 bool HloParserImpl::ParseAttributes(
3514 const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
3515 LocTy loc = lexer_.GetLoc();
3516 absl::flat_hash_set<std::string> seen_attrs;
3517 while (EatIfPresent(TokKind::kComma)) {
3518 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
3519 return false;
3520 }
3521 }
3522 // Check that all required attrs were seen.
3523 for (const auto& attr_it : attrs) {
3524 if (attr_it.second.required &&
3525 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
3526 return Error(loc, StrFormat("attribute %s is expected but not seen",
3527 attr_it.first));
3528 }
3529 }
3530 return true;
3531 }
3532
ParseAttributeHelper(const absl::flat_hash_map<std::string,AttrConfig> & attrs,absl::flat_hash_set<std::string> * seen_attrs)3533 bool HloParserImpl::ParseAttributeHelper(
3534 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
3535 absl::flat_hash_set<std::string>* seen_attrs) {
3536 LocTy loc = lexer_.GetLoc();
3537 std::string name;
3538 if (!ParseAttributeName(&name)) {
3539 return Error(loc, "error parsing attributes");
3540 }
3541 VLOG(3) << "Parsing attribute " << name;
3542 if (!seen_attrs->insert(name).second) {
3543 return Error(loc, StrFormat("attribute %s already exists", name));
3544 }
3545 auto attr_it = attrs.find(name);
3546 if (attr_it == attrs.end()) {
3547 std::string allowed_attrs;
3548 if (attrs.empty()) {
3549 allowed_attrs = "No attributes are allowed here.";
3550 } else {
3551 allowed_attrs =
3552 StrCat("Allowed attributes: ",
3553 StrJoin(attrs, ", ",
3554 [&](std::string* out,
3555 const std::pair<std::string, AttrConfig>& kv) {
3556 StrAppend(out, kv.first);
3557 }));
3558 }
3559 return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
3560 allowed_attrs));
3561 }
3562 AttrTy attr_type = attr_it->second.attr_type;
3563 void* attr_out_ptr = attr_it->second.result;
3564 bool success = [&] {
3565 LocTy attr_loc = lexer_.GetLoc();
3566 switch (attr_type) {
3567 case AttrTy::kBool: {
3568 bool result;
3569 if (!ParseBool(&result)) {
3570 return false;
3571 }
3572 static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
3573 return true;
3574 }
3575 case AttrTy::kInt64: {
3576 int64 result;
3577 if (!ParseInt64(&result)) {
3578 return false;
3579 }
3580 static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
3581 return true;
3582 }
3583 case AttrTy::kInt32: {
3584 int64 result;
3585 if (!ParseInt64(&result)) {
3586 return false;
3587 }
3588 if (result != static_cast<int32>(result)) {
3589 return Error(attr_loc, "value out of range for int32");
3590 }
3591 static_cast<optional<int32>*>(attr_out_ptr)
3592 ->emplace(static_cast<int32>(result));
3593 return true;
3594 }
3595 case AttrTy::kFloat: {
3596 double result;
3597 if (!ParseDouble(&result)) {
3598 return false;
3599 }
3600 if (result > std::numeric_limits<float>::max() ||
3601 result < std::numeric_limits<float>::lowest()) {
3602 return Error(attr_loc, "value out of range for float");
3603 }
3604 static_cast<optional<float>*>(attr_out_ptr)
3605 ->emplace(static_cast<float>(result));
3606 return true;
3607 }
3608 case AttrTy::kHloComputation: {
3609 HloComputation* result = nullptr;
3610 if (!ParseHloComputation(&result)) {
3611 return false;
3612 }
3613 static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
3614 return true;
3615 }
3616 case AttrTy::kBracedHloComputationList: {
3617 std::vector<HloComputation*> result;
3618 if (!ParseHloComputationList(&result)) {
3619 return false;
3620 }
3621 static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
3622 ->emplace(result);
3623 return true;
3624 }
3625 case AttrTy::kFftType: {
3626 FftType result;
3627 if (!ParseFftType(&result)) {
3628 return false;
3629 }
3630 static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
3631 return true;
3632 }
3633 case AttrTy::kPaddingType: {
3634 PaddingType result;
3635 if (!ParsePaddingType(&result)) {
3636 return false;
3637 }
3638 static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
3639 return true;
3640 }
3641 case AttrTy::kComparisonDirection: {
3642 ComparisonDirection result;
3643 if (!ParseComparisonDirection(&result)) {
3644 return false;
3645 }
3646 static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
3647 ->emplace(result);
3648 return true;
3649 }
3650 case AttrTy::kComparisonType: {
3651 Comparison::Type result;
3652 if (!ParseComparisonType(&result)) {
3653 return false;
3654 }
3655 static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
3656 return true;
3657 }
3658 case AttrTy::kEnum: {
3659 if (lexer_.GetKind() != TokKind::kIdent) {
3660 return TokenError("expects an enumeration value");
3661 }
3662 std::string result = lexer_.GetStrVal();
3663 lexer_.Lex();
3664 static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
3665 return true;
3666 }
3667 case AttrTy::kWindow: {
3668 Window result;
3669 if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
3670 return false;
3671 }
3672 static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
3673 return true;
3674 }
3675 case AttrTy::kConvolutionDimensionNumbers: {
3676 ConvolutionDimensionNumbers result;
3677 if (!ParseConvolutionDimensionNumbers(&result)) {
3678 return false;
3679 }
3680 static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
3681 ->emplace(result);
3682 return true;
3683 }
3684 case AttrTy::kSharding: {
3685 OpSharding sharding;
3686 if (!ParseSharding(&sharding)) {
3687 return false;
3688 }
3689 static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
3690 return true;
3691 }
3692 case AttrTy::kFrontendAttributes: {
3693 FrontendAttributes frontend_attributes;
3694 if (!ParseFrontendAttributes(&frontend_attributes)) {
3695 return false;
3696 }
3697 static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
3698 ->emplace(frontend_attributes);
3699 return true;
3700 }
3701 case AttrTy::kParameterReplication: {
3702 ParameterReplication parameter_replication;
3703 if (!ParseParameterReplication(¶meter_replication)) {
3704 return false;
3705 }
3706 static_cast<optional<ParameterReplication>*>(attr_out_ptr)
3707 ->emplace(parameter_replication);
3708 return true;
3709 }
3710 case AttrTy::kInstructionList: {
3711 std::vector<HloInstruction*> result;
3712 if (!ParseInstructionNames(&result)) {
3713 return false;
3714 }
3715 static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
3716 ->emplace(result);
3717 return true;
3718 }
3719 case AttrTy::kFusionKind: {
3720 HloInstruction::FusionKind result;
3721 if (!ParseFusionKind(&result)) {
3722 return false;
3723 }
3724 static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
3725 ->emplace(result);
3726 return true;
3727 }
3728 case AttrTy::kBracedInt64List: {
3729 std::vector<int64> result;
3730 if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3731 &result)) {
3732 return false;
3733 }
3734 static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
3735 ->emplace(result);
3736 return true;
3737 }
3738 case AttrTy::kBracedInt64ListList: {
3739 std::vector<std::vector<int64>> result;
3740 if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
3741 TokKind::kComma, &result)) {
3742 return false;
3743 }
3744 static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
3745 ->emplace(result);
3746 return true;
3747 }
3748 case AttrTy::kSliceRanges: {
3749 SliceRanges result;
3750 if (!ParseSliceRanges(&result)) {
3751 return false;
3752 }
3753 static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
3754 return true;
3755 }
3756 case AttrTy::kPaddingConfig: {
3757 PaddingConfig result;
3758 if (!ParsePaddingConfig(&result)) {
3759 return false;
3760 }
3761 static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
3762 return true;
3763 }
3764 case AttrTy::kString: {
3765 std::string result;
3766 if (!ParseString(&result)) {
3767 return false;
3768 }
3769 static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
3770 return true;
3771 }
3772 case AttrTy::kMetadata: {
3773 OpMetadata result;
3774 if (!ParseMetadata(&result)) {
3775 return false;
3776 }
3777 static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
3778 return true;
3779 }
3780 case AttrTy::kDistribution: {
3781 RandomDistribution result;
3782 if (!ParseRandomDistribution(&result)) {
3783 return false;
3784 }
3785 static_cast<optional<RandomDistribution>*>(attr_out_ptr)
3786 ->emplace(result);
3787 return true;
3788 }
3789 case AttrTy::kDomain: {
3790 return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
3791 }
3792 case AttrTy::kPrecisionList: {
3793 std::vector<PrecisionConfig::Precision> result;
3794 if (!ParsePrecisionList(&result)) {
3795 return false;
3796 }
3797 static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
3798 attr_out_ptr)
3799 ->emplace(result);
3800 return true;
3801 }
3802 case AttrTy::kShape: {
3803 Shape result;
3804 if (!ParseShape(&result)) {
3805 return false;
3806 }
3807 static_cast<optional<Shape>*>(attr_out_ptr)->emplace(result);
3808 return true;
3809 }
3810 case AttrTy::kShapeList: {
3811 std::vector<Shape> result;
3812 if (!ParseShapeList(&result)) {
3813 return false;
3814 }
3815 static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
3816 ->emplace(result);
3817 return true;
3818 }
3819 case AttrTy::kRandomAlgorithm: {
3820 RandomAlgorithm result;
3821 if (!ParseRandomAlgorithm(&result)) {
3822 return false;
3823 }
3824 static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
3825 return true;
3826 }
3827 case AttrTy::kAliasing: {
3828 AliasingData aliasing_data;
3829 if (!ParseAliasing(&aliasing_data)) {
3830 return false;
3831 }
3832 static_cast<optional<AliasingData>*>(attr_out_ptr)
3833 ->emplace(aliasing_data);
3834 return true;
3835 }
3836 case AttrTy::kInstructionAliasing: {
3837 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
3838 aliasing_output_operand_pairs;
3839 if (!ParseInstructionOutputOperandAliasing(
3840 &aliasing_output_operand_pairs)) {
3841 return false;
3842 }
3843 static_cast<optional<
3844 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
3845 attr_out_ptr)
3846 ->emplace(std::move(aliasing_output_operand_pairs));
3847 return true;
3848 }
3849 case AttrTy::kLiteral: {
3850 if (!ParseToken(TokKind::kLparen, "expects '(' before literal")) {
3851 return false;
3852 }
3853 Literal result;
3854 if (!ParseLiteral(&result)) {
3855 return false;
3856 }
3857 if (!ParseToken(TokKind::kRparen, "expects ')' after literal")) {
3858 return false;
3859 }
3860 static_cast<optional<Literal>*>(attr_out_ptr)
3861 ->emplace(std::move(result));
3862 return true;
3863 }
3864 }
3865 }();
3866 if (!success) {
3867 return Error(loc, StrFormat("error parsing attribute %s", name));
3868 }
3869 return true;
3870 }
3871
CopyAttributeToProtoMessage(absl::flat_hash_set<std::string> non_proto_attrs,const absl::flat_hash_map<std::string,AttrConfig> & attrs,tensorflow::protobuf::Message * message)3872 bool HloParserImpl::CopyAttributeToProtoMessage(
3873 absl::flat_hash_set<std::string> non_proto_attrs,
3874 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
3875 tensorflow::protobuf::Message* message) {
3876 const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
3877 const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
3878
3879 for (const auto& p : attrs) {
3880 const std::string& name = p.first;
3881 if (non_proto_attrs.find(name) != non_proto_attrs.end()) {
3882 continue;
3883 }
3884 const tensorflow::protobuf::FieldDescriptor* fd =
3885 descriptor->FindFieldByName(name);
3886 if (!fd) {
3887 std::string allowed_attrs = "Allowed attributes: ";
3888
3889 for (int i = 0; i < descriptor->field_count(); ++i) {
3890 if (i == 0) {
3891 absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
3892 } else {
3893 absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
3894 }
3895 }
3896 return TokenError(
3897 StrFormat("unexpected attribute \"%s\". %s", name, allowed_attrs));
3898 }
3899
3900 CHECK(!fd->is_repeated()); // Repeated fields not implemented.
3901 bool success = [&] {
3902 switch (fd->type()) {
3903 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
3904 auto attr_value = static_cast<optional<bool>*>(p.second.result);
3905 if (attr_value->has_value()) {
3906 reflection->SetBool(message, fd, **attr_value);
3907 }
3908 return true;
3909 }
3910 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
3911 auto attr_value =
3912 static_cast<optional<std::string>*>(p.second.result);
3913 if (attr_value->has_value()) {
3914 const tensorflow::protobuf::EnumValueDescriptor* evd =
3915 fd->enum_type()->FindValueByName(**attr_value);
3916 reflection->SetEnum(message, fd, evd);
3917 }
3918 return true;
3919 }
3920 default:
3921 return false;
3922 }
3923 }();
3924
3925 if (!success) {
3926 return TokenError(StrFormat("error parsing attribute %s", name));
3927 }
3928 }
3929
3930 return true;
3931 }
3932
3933 // attributes ::= (',' attribute)*
ParseAttributesAsProtoMessage(const absl::flat_hash_map<std::string,AttrConfig> & non_proto_attrs,tensorflow::protobuf::Message * message)3934 bool HloParserImpl::ParseAttributesAsProtoMessage(
3935 const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
3936 tensorflow::protobuf::Message* message) {
3937 const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
3938 absl::flat_hash_map<std::string, AttrConfig> attrs;
3939
3940 // Storage for attributes.
3941 std::vector<optional<bool>> bool_params;
3942 std::vector<optional<std::string>> string_params;
3943 // Reserve enough capacity to make sure that the vector is not growing, so we
3944 // can rely on the pointers to stay valid.
3945 bool_params.reserve(descriptor->field_count());
3946 string_params.reserve(descriptor->field_count());
3947
3948 // Populate the storage of expected attributes from the protobuf description.
3949 for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
3950 const tensorflow::protobuf::FieldDescriptor* fd =
3951 descriptor->field(field_idx);
3952 const std::string& field_name = fd->name();
3953 switch (fd->type()) {
3954 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
3955 bool_params.emplace_back(absl::nullopt);
3956 attrs[field_name] = {/*is_required*/ false, AttrTy::kBool,
3957 &bool_params.back()};
3958 break;
3959 }
3960 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
3961 string_params.emplace_back(absl::nullopt);
3962 attrs[field_name] = {/*is_required*/ false, AttrTy::kEnum,
3963 &string_params.back()};
3964 break;
3965 }
3966 default:
3967 return TokenError(absl::StrFormat(
3968 "Unexpected protocol buffer type: %s ", fd->DebugString()));
3969 }
3970 }
3971
3972 absl::flat_hash_set<std::string> non_proto_attrs_names;
3973 non_proto_attrs_names.reserve(non_proto_attrs.size());
3974 for (const auto& p : non_proto_attrs) {
3975 const std::string& attr_name = p.first;
3976 // If an attribute is both specified within 'non_proto_attrs' and an
3977 // attribute of the proto message, we prefer the attribute of the proto
3978 // message.
3979 if (attrs.find(attr_name) == attrs.end()) {
3980 non_proto_attrs_names.insert(attr_name);
3981 attrs[attr_name] = p.second;
3982 }
3983 }
3984
3985 if (!ParseAttributes(attrs)) {
3986 return false;
3987 }
3988
3989 return CopyAttributeToProtoMessage(non_proto_attrs_names, attrs, message);
3990 }
3991
ParseComputationName(HloComputation ** value)3992 bool HloParserImpl::ParseComputationName(HloComputation** value) {
3993 std::string name;
3994 LocTy loc = lexer_.GetLoc();
3995 if (!ParseName(&name)) {
3996 return Error(loc, "expects computation name");
3997 }
3998 std::pair<HloComputation*, LocTy>* computation =
3999 tensorflow::gtl::FindOrNull(computation_pool_, name);
4000 if (computation == nullptr) {
4001 return Error(loc, StrCat("computation does not exist: ", name));
4002 }
4003 *value = computation->first;
4004 return true;
4005 }
4006
4007 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
4008 // The subattributes can appear in any order. 'size=' is required, others are
4009 // optional.
ParseWindow(Window * window,bool expect_outer_curlies)4010 bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
4011 LocTy loc = lexer_.GetLoc();
4012 if (expect_outer_curlies &&
4013 !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
4014 return false;
4015 }
4016
4017 std::vector<int64> size;
4018 std::vector<int64> stride;
4019 std::vector<std::vector<int64>> pad;
4020 std::vector<int64> lhs_dilate;
4021 std::vector<int64> rhs_dilate;
4022 std::vector<int64> rhs_reversal;
4023 const auto end_token =
4024 expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
4025 while (lexer_.GetKind() != end_token) {
4026 LocTy attr_loc = lexer_.GetLoc();
4027 std::string field_name;
4028 if (!ParseAttributeName(&field_name)) {
4029 return Error(attr_loc, "expects sub-attributes in window");
4030 }
4031 bool ok = [&] {
4032 if (field_name == "size") {
4033 return ParseDxD("size", &size);
4034 }
4035 if (field_name == "stride") {
4036 return ParseDxD("stride", &stride);
4037 }
4038 if (field_name == "lhs_dilate") {
4039 return ParseDxD("lhs_dilate", &lhs_dilate);
4040 }
4041 if (field_name == "rhs_dilate") {
4042 return ParseDxD("rls_dilate", &rhs_dilate);
4043 }
4044 if (field_name == "pad") {
4045 return ParseWindowPad(&pad);
4046 }
4047 if (field_name == "rhs_reversal") {
4048 return ParseDxD("rhs_reversal", &rhs_reversal);
4049 }
4050 return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
4051 }();
4052 if (!ok) {
4053 return false;
4054 }
4055 }
4056
4057 if (!stride.empty() && stride.size() != size.size()) {
4058 return Error(loc, "expects 'stride=' has the same size as 'size='");
4059 }
4060 if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
4061 return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
4062 }
4063 if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
4064 return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
4065 }
4066 if (!pad.empty() && pad.size() != size.size()) {
4067 return Error(loc, "expects 'pad=' has the same size as 'size='");
4068 }
4069
4070 for (int i = 0; i < size.size(); i++) {
4071 window->add_dimensions()->set_size(size[i]);
4072 if (!pad.empty()) {
4073 window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
4074 window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
4075 }
4076 // If some field is not present, it has the default value.
4077 window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
4078 window->mutable_dimensions(i)->set_base_dilation(
4079 lhs_dilate.empty() ? 1 : lhs_dilate[i]);
4080 window->mutable_dimensions(i)->set_window_dilation(
4081 rhs_dilate.empty() ? 1 : rhs_dilate[i]);
4082 window->mutable_dimensions(i)->set_window_reversal(
4083 rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
4084 }
4085 return !expect_outer_curlies ||
4086 ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
4087 }
4088
4089 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
4090 // Thestring looks like "dim_labels=0bf_0io->0bf".
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)4091 bool HloParserImpl::ParseConvolutionDimensionNumbers(
4092 ConvolutionDimensionNumbers* dnums) {
4093 if (lexer_.GetKind() != TokKind::kDimLabels) {
4094 return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
4095 }
4096 std::string str = lexer_.GetStrVal();
4097
4098 // The str is expected to have 3 items, lhs, rhs, out, and it must look like
4099 // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
4100 std::vector<std::string> split1 = absl::StrSplit(str, '_');
4101 if (split1.size() != 2) {
4102 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4103 << str;
4104 }
4105 std::vector<std::string> split2 = absl::StrSplit(split1[1], "->");
4106 if (split2.size() != 2) {
4107 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4108 << str;
4109 }
4110 absl::string_view lhs = split1[0];
4111 absl::string_view rhs = split2[0];
4112 absl::string_view out = split2[1];
4113
4114 const int64 rank = lhs.length();
4115 if (rank != rhs.length() || rank != out.length()) {
4116 return TokenError(
4117 "convolution lhs, rhs, and output must have the same rank");
4118 }
4119 if (rank < 2) {
4120 return TokenError("convolution rank must >=2");
4121 }
4122
4123 auto is_unique = [](std::string str) -> bool {
4124 absl::c_sort(str);
4125 return std::unique(str.begin(), str.end()) == str.end();
4126 };
4127
4128 // lhs
4129 {
4130 if (!is_unique(std::string(lhs))) {
4131 return TokenError(
4132 StrCat("expects unique lhs dimension numbers, but sees ", lhs));
4133 }
4134 for (int i = 0; i < rank - 2; i++) {
4135 dnums->add_input_spatial_dimensions(-1);
4136 }
4137 for (int i = 0; i < rank; i++) {
4138 char c = lhs[i];
4139 if (c == 'b') {
4140 dnums->set_input_batch_dimension(i);
4141 } else if (c == 'f') {
4142 dnums->set_input_feature_dimension(i);
4143 } else if (c < '0' + rank && c >= '0') {
4144 dnums->set_input_spatial_dimensions(c - '0', i);
4145 } else {
4146 return TokenError(
4147 StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1));
4148 }
4149 }
4150 }
4151 // rhs
4152 {
4153 if (!is_unique(std::string(rhs))) {
4154 return TokenError(
4155 StrCat("expects unique rhs dimension numbers, but sees ", rhs));
4156 }
4157 for (int i = 0; i < rank - 2; i++) {
4158 dnums->add_kernel_spatial_dimensions(-1);
4159 }
4160 for (int i = 0; i < rank; i++) {
4161 char c = rhs[i];
4162 if (c == 'i') {
4163 dnums->set_kernel_input_feature_dimension(i);
4164 } else if (c == 'o') {
4165 dnums->set_kernel_output_feature_dimension(i);
4166 } else if (c < '0' + rank && c >= '0') {
4167 dnums->set_kernel_spatial_dimensions(c - '0', i);
4168 } else {
4169 return TokenError(
4170 StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1));
4171 }
4172 }
4173 }
4174 // output
4175 {
4176 if (!is_unique(std::string(out))) {
4177 return TokenError(
4178 StrCat("expects unique output dimension numbers, but sees ", out));
4179 }
4180 for (int i = 0; i < rank - 2; i++) {
4181 dnums->add_output_spatial_dimensions(-1);
4182 }
4183 for (int i = 0; i < rank; i++) {
4184 char c = out[i];
4185 if (c == 'b') {
4186 dnums->set_output_batch_dimension(i);
4187 } else if (c == 'f') {
4188 dnums->set_output_feature_dimension(i);
4189 } else if (c < '0' + rank && c >= '0') {
4190 dnums->set_output_spatial_dimensions(c - '0', i);
4191 } else {
4192 return TokenError(StrFormat(
4193 "expects [0-%dbf] in output dimension numbers", rank - 1));
4194 }
4195 }
4196 }
4197
4198 lexer_.Lex();
4199 return true;
4200 }
4201
4202 // ::= '{' ranges '}'
4203 // ::= /*empty*/
4204 // ::= range (',' range)*
4205 // range ::= '[' start ':' limit (':' stride)? ']'
4206 //
4207 // The slice ranges are printed as:
4208 //
4209 // {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
4210 //
4211 // This function extracts the starts, limits, and strides as 3 vectors to the
4212 // result. If stride is not present, stride is 1. For example, if the slice
4213 // ranges is printed as:
4214 //
4215 // {[2:3:4], [5:6:7], [8:9]}
4216 //
4217 // The parsed result will be:
4218 //
4219 // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
4220 //
ParseSliceRanges(SliceRanges * result)4221 bool HloParserImpl::ParseSliceRanges(SliceRanges* result) {
4222 if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
4223 return false;
4224 }
4225 std::vector<std::vector<int64>> ranges;
4226 if (lexer_.GetKind() == TokKind::kRbrace) {
4227 // empty
4228 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4229 }
4230 do {
4231 LocTy loc = lexer_.GetLoc();
4232 ranges.emplace_back();
4233 if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
4234 &ranges.back())) {
4235 return false;
4236 }
4237 const auto& range = ranges.back();
4238 if (range.size() != 2 && range.size() != 3) {
4239 return Error(loc,
4240 StrFormat("expects [start:limit:step] or [start:limit], "
4241 "but sees %d elements.",
4242 range.size()));
4243 }
4244 } while (EatIfPresent(TokKind::kComma));
4245
4246 for (const auto& range : ranges) {
4247 result->starts.push_back(range[0]);
4248 result->limits.push_back(range[1]);
4249 result->strides.push_back(range.size() == 3 ? range[2] : 1);
4250 }
4251 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4252 }
4253
4254 // precisionlist ::= start precision_elements end
4255 // precision_elements
4256 // ::= /*empty*/
4257 // ::= precision_val (delim precision_val)*
ParsePrecisionList(std::vector<PrecisionConfig::Precision> * result)4258 bool HloParserImpl::ParsePrecisionList(
4259 std::vector<PrecisionConfig::Precision>* result) {
4260 auto parse_and_add_item = [&]() {
4261 PrecisionConfig::Precision item;
4262 if (!ParsePrecision(&item)) {
4263 return false;
4264 }
4265 result->push_back(item);
4266 return true;
4267 };
4268 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4269 parse_and_add_item);
4270 }
4271
ParseHloComputation(HloComputation ** result)4272 bool HloParserImpl::ParseHloComputation(HloComputation** result) {
4273 if (lexer_.GetKind() == TokKind::kLbrace) {
4274 // This means it is a nested computation.
4275 return ParseInstructionList(result, /*computation_name=*/"_");
4276 }
4277 // This means it is a computation name.
4278 return ParseComputationName(result);
4279 }
4280
ParseHloComputationList(std::vector<HloComputation * > * result)4281 bool HloParserImpl::ParseHloComputationList(
4282 std::vector<HloComputation*>* result) {
4283 auto parse_and_add_item = [&]() {
4284 HloComputation* computation;
4285 if (!ParseHloComputation(&computation)) {
4286 return false;
4287 }
4288 VLOG(3) << "parsed computation " << computation->name();
4289 result->push_back(computation);
4290 return true;
4291 };
4292 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4293 parse_and_add_item);
4294 }
4295
4296 // shapelist ::= '{' shapes '}'
4297 // precision_elements
4298 // ::= /*empty*/
4299 // ::= shape (',' shape)*
ParseShapeList(std::vector<Shape> * result)4300 bool HloParserImpl::ParseShapeList(std::vector<Shape>* result) {
4301 auto parse_and_add_item = [&]() {
4302 Shape shape;
4303 if (!ParseShape(&shape)) {
4304 return false;
4305 }
4306 result->push_back(std::move(shape));
4307 return true;
4308 };
4309 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4310 parse_and_add_item);
4311 }
4312
4313 // int64list ::= start int64_elements end
4314 // int64_elements
4315 // ::= /*empty*/
4316 // ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64> * result)4317 bool HloParserImpl::ParseInt64List(const TokKind start, const TokKind end,
4318 const TokKind delim,
4319 std::vector<int64>* result) {
4320 auto parse_and_add_item = [&]() {
4321 int64 i;
4322 if (!ParseInt64(&i)) {
4323 return false;
4324 }
4325 result->push_back(i);
4326 return true;
4327 };
4328 return ParseList(start, end, delim, parse_and_add_item);
4329 }
4330
4331 // int64listlist ::= start int64list_elements end
4332 // int64list_elements
4333 // ::= /*empty*/
4334 // ::= int64list (delim int64list)*
4335 // int64list ::= start int64_elements end
4336 // int64_elements
4337 // ::= /*empty*/
4338 // ::= int64_val (delim int64_val)*
ParseInt64ListList(const TokKind start,const TokKind end,const TokKind delim,std::vector<std::vector<int64>> * result)4339 bool HloParserImpl::ParseInt64ListList(
4340 const TokKind start, const TokKind end, const TokKind delim,
4341 std::vector<std::vector<int64>>* result) {
4342 auto parse_and_add_item = [&]() {
4343 std::vector<int64> item;
4344 if (!ParseInt64List(start, end, delim, &item)) {
4345 return false;
4346 }
4347 result->push_back(item);
4348 return true;
4349 };
4350 return ParseList(start, end, delim, parse_and_add_item);
4351 }
4352
ParseList(const TokKind start,const TokKind end,const TokKind delim,const std::function<bool ()> & parse_and_add_item)4353 bool HloParserImpl::ParseList(const TokKind start, const TokKind end,
4354 const TokKind delim,
4355 const std::function<bool()>& parse_and_add_item) {
4356 if (!ParseToken(start, StrCat("expects a list starting with ",
4357 TokKindToString(start)))) {
4358 return false;
4359 }
4360 if (lexer_.GetKind() == end) {
4361 // empty
4362 } else {
4363 do {
4364 if (!parse_and_add_item()) {
4365 return false;
4366 }
4367 } while (EatIfPresent(delim));
4368 }
4369 return ParseToken(
4370 end, StrCat("expects a list to end with ", TokKindToString(end)));
4371 }
4372
4373 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)4374 bool HloParserImpl::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
4375 if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
4376 return false;
4377 }
4378 *shape_loc = lexer_.GetLoc();
4379 return ParseShape(shape);
4380 }
4381
CanBeParamListToShape()4382 bool HloParserImpl::CanBeParamListToShape() {
4383 return lexer_.GetKind() == TokKind::kLparen;
4384 }
4385
4386 // param_list ::= '(' param_list1 ')'
4387 // param_list1
4388 // ::= /*empty*/
4389 // ::= param (',' param)*
4390 // param ::= name shape
ParseParamList()4391 bool HloParserImpl::ParseParamList() {
4392 if (!ParseToken(TokKind::kLparen,
4393 "expects '(' at the beginning of param list")) {
4394 return false;
4395 }
4396
4397 if (lexer_.GetKind() == TokKind::kRparen) {
4398 // empty
4399 } else {
4400 do {
4401 Shape shape;
4402 std::string name;
4403 if (!ParseName(&name) || !ParseShape(&shape)) {
4404 return false;
4405 }
4406 } while (EatIfPresent(TokKind::kComma));
4407 }
4408 return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
4409 }
4410
4411 // dimension_sizes ::= '[' dimension_list ']'
4412 // dimension_list
4413 // ::= /*empty*/
4414 // ::= <=? int64 (',' param)*
4415 // param ::= name shape
ParseDimensionSizes(std::vector<int64> * dimension_sizes,std::vector<bool> * dynamic_dimensions)4416 bool HloParserImpl::ParseDimensionSizes(std::vector<int64>* dimension_sizes,
4417 std::vector<bool>* dynamic_dimensions) {
4418 auto parse_and_add_item = [&]() {
4419 int64 i;
4420 bool is_dynamic = false;
4421 if (lexer_.GetKind() == TokKind::kLeq) {
4422 is_dynamic = true;
4423 lexer_.Lex();
4424 }
4425 if (!ParseInt64(&i)) {
4426 return false;
4427 }
4428 dimension_sizes->push_back(i);
4429 dynamic_dimensions->push_back(is_dynamic);
4430 return true;
4431 };
4432 return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
4433 parse_and_add_item);
4434 }
4435
4436 // tiles
4437 // ::= /*empty*/
4438 // ::= 'T' '(' dim_list ')'
4439 // dim_list
4440 // ::= /*empty*/
4441 // ::= (int64 | '*') (',' (int64 | '*'))*
ParseTiles(std::vector<Tile> * tiles)4442 bool HloParserImpl::ParseTiles(std::vector<Tile>* tiles) {
4443 auto parse_and_add_tile_dimension = [&]() {
4444 tensorflow::int64 i;
4445 if (ParseInt64(&i)) {
4446 tiles->back().add_dimensions(i);
4447 return true;
4448 }
4449 if (lexer_.GetKind() == TokKind::kAsterisk) {
4450 tiles->back().add_dimensions(Tile::kCombineDimension);
4451 lexer_.Lex();
4452 return true;
4453 }
4454 return false;
4455 };
4456
4457 do {
4458 tiles->push_back(Tile());
4459 if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
4460 parse_and_add_tile_dimension)) {
4461 return false;
4462 }
4463 } while (lexer_.GetKind() == TokKind::kLparen);
4464 return true;
4465 }
4466
4467 // int_attribute
4468 // ::= /*empty*/
4469 // ::= attr_token '(' attr_value ')'
4470 // attr_token
4471 // ::= 'E' | 'S'
4472 // attr_value
4473 // ::= int64
ParseLayoutIntAttribute(int64 * attr_value,absl::string_view attr_description)4474 bool HloParserImpl::ParseLayoutIntAttribute(
4475 int64* attr_value, absl::string_view attr_description) {
4476 if (!ParseToken(TokKind::kLparen,
4477 StrCat("expects ", attr_description, " to start with ",
4478 TokKindToString(TokKind::kLparen)))) {
4479 return false;
4480 }
4481 if (!ParseInt64(attr_value)) {
4482 return false;
4483 }
4484 if (!ParseToken(TokKind::kRparen,
4485 StrCat("expects ", attr_description, " to end with ",
4486 TokKindToString(TokKind::kRparen)))) {
4487 return false;
4488 }
4489 return true;
4490 }
4491
4492 // layout ::= '{' int64_list (':' tiles element_size_in_bits memory_space)? '}'
4493 // element_size_in_bits
4494 // ::= /*empty*/
4495 // ::= 'E' '(' int64 ')'
4496 // memory_space
4497 // ::= /*empty*/
4498 // ::= 'S' '(' int64 ')'
ParseLayout(Layout * layout)4499 bool HloParserImpl::ParseLayout(Layout* layout) {
4500 std::vector<int64> minor_to_major;
4501 std::vector<Tile> tiles;
4502 tensorflow::int64 element_size_in_bits = 0;
4503 tensorflow::int64 memory_space = 0;
4504
4505 auto parse_and_add_item = [&]() {
4506 int64 i;
4507 if (!ParseInt64(&i)) {
4508 return false;
4509 }
4510 minor_to_major.push_back(i);
4511 return true;
4512 };
4513
4514 if (!ParseToken(TokKind::kLbrace,
4515 StrCat("expects layout to start with ",
4516 TokKindToString(TokKind::kLbrace)))) {
4517 return false;
4518 }
4519 if (lexer_.GetKind() != TokKind::kRbrace) {
4520 if (lexer_.GetKind() == TokKind::kInt) {
4521 // Parse minor to major.
4522 do {
4523 if (!parse_and_add_item()) {
4524 return false;
4525 }
4526 } while (EatIfPresent(TokKind::kComma));
4527 }
4528
4529 if (lexer_.GetKind() == TokKind::kColon) {
4530 lexer_.Lex();
4531 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
4532 lexer_.Lex();
4533 ParseTiles(&tiles);
4534 }
4535
4536 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
4537 lexer_.Lex();
4538 ParseLayoutIntAttribute(&element_size_in_bits, "element size in bits");
4539 }
4540
4541 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "S") {
4542 lexer_.Lex();
4543 ParseLayoutIntAttribute(&memory_space, "memory space");
4544 }
4545 }
4546 }
4547 if (!ParseToken(TokKind::kRbrace,
4548 StrCat("expects layout to end with ",
4549 TokKindToString(TokKind::kRbrace)))) {
4550 return false;
4551 }
4552
4553 std::vector<Tile> vec_tiles(tiles.size());
4554 for (int i = 0; i < tiles.size(); i++) {
4555 vec_tiles[i] = Tile(tiles[i]);
4556 }
4557 *layout = LayoutUtil::MakeLayout(minor_to_major, vec_tiles,
4558 element_size_in_bits, memory_space);
4559 return true;
4560 }
4561
4562 // shape ::= shape_val_
4563 // shape ::= '(' tuple_elements ')'
4564 // tuple_elements
4565 // ::= /*empty*/
4566 // ::= shape (',' shape)*
ParseShape(Shape * result)4567 bool HloParserImpl::ParseShape(Shape* result) {
4568 if (EatIfPresent(TokKind::kLparen)) { // Tuple
4569 std::vector<Shape> shapes;
4570 if (lexer_.GetKind() == TokKind::kRparen) {
4571 /*empty*/
4572 } else {
4573 // shape (',' shape)*
4574 do {
4575 shapes.emplace_back();
4576 if (!ParseShape(&shapes.back())) {
4577 return false;
4578 }
4579 } while (EatIfPresent(TokKind::kComma));
4580 }
4581 *result = ShapeUtil::MakeTupleShape(shapes);
4582 return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
4583 }
4584
4585 if (lexer_.GetKind() != TokKind::kPrimitiveType) {
4586 return TokenError(absl::StrCat("expected primitive type, saw ",
4587 TokKindToString(lexer_.GetKind())));
4588 }
4589 PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
4590 lexer_.Lex();
4591
4592 // Each element contains a dimension size and a bool indicating whether this
4593 // is a dynamic dimension.
4594 std::vector<int64> dimension_sizes;
4595 std::vector<bool> dynamic_dimensions;
4596 if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
4597 return false;
4598 }
4599 result->set_element_type(primitive_type);
4600 for (int i = 0; i < dimension_sizes.size(); ++i) {
4601 result->add_dimensions(dimension_sizes[i]);
4602 result->set_dynamic_dimension(i, dynamic_dimensions[i]);
4603 }
4604 LayoutUtil::SetToDefaultLayout(result);
4605
4606 // We need to lookahead to see if a following open brace is the start of a
4607 // layout. The specific problematic case is:
4608 //
4609 // ENTRY %foo (x: f32[42]) -> f32[123] {
4610 // ...
4611 // }
4612 //
4613 // The open brace could either be the start of a computation or the start of a
4614 // layout for the f32[123] shape. We consider it the start of a layout if the
4615 // next token after the open brace is an integer or a colon.
4616 if (lexer_.GetKind() == TokKind::kLbrace &&
4617 (lexer_.LookAhead() == TokKind::kInt ||
4618 lexer_.LookAhead() == TokKind::kColon)) {
4619 Layout layout;
4620 if (!ParseLayout(&layout)) {
4621 return false;
4622 }
4623 if (layout.minor_to_major_size() != result->rank()) {
4624 return Error(
4625 lexer_.GetLoc(),
4626 StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
4627 result->rank(), layout.minor_to_major_size()));
4628 }
4629 *result->mutable_layout() = layout;
4630 }
4631 return true;
4632 }
4633
CanBeShape()4634 bool HloParserImpl::CanBeShape() {
4635 // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
4636 // with '('.
4637 return lexer_.GetKind() == TokKind::kPrimitiveType ||
4638 lexer_.GetKind() == TokKind::kLparen;
4639 }
4640
ParseName(std::string * result)4641 bool HloParserImpl::ParseName(std::string* result) {
4642 VLOG(3) << "ParseName";
4643 if (lexer_.GetKind() != TokKind::kIdent &&
4644 lexer_.GetKind() != TokKind::kName) {
4645 return TokenError("expects name");
4646 }
4647 *result = lexer_.GetStrVal();
4648 lexer_.Lex();
4649 return true;
4650 }
4651
ParseAttributeName(std::string * result)4652 bool HloParserImpl::ParseAttributeName(std::string* result) {
4653 if (lexer_.GetKind() != TokKind::kAttributeName) {
4654 return TokenError("expects attribute name");
4655 }
4656 *result = lexer_.GetStrVal();
4657 lexer_.Lex();
4658 return true;
4659 }
4660
ParseString(std::string * result)4661 bool HloParserImpl::ParseString(std::string* result) {
4662 VLOG(3) << "ParseString";
4663 if (lexer_.GetKind() != TokKind::kString) {
4664 return TokenError("expects string");
4665 }
4666 *result = lexer_.GetStrVal();
4667 lexer_.Lex();
4668 return true;
4669 }
4670
ParseDxD(const std::string & name,std::vector<int64> * result)4671 bool HloParserImpl::ParseDxD(const std::string& name,
4672 std::vector<int64>* result) {
4673 LocTy loc = lexer_.GetLoc();
4674 if (!result->empty()) {
4675 return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
4676 }
4677 // 1D
4678 if (lexer_.GetKind() == TokKind::kInt) {
4679 int64 number;
4680 if (!ParseInt64(&number)) {
4681 return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
4682 }
4683 result->push_back(number);
4684 return true;
4685 }
4686 // 2D or higher.
4687 if (lexer_.GetKind() == TokKind::kDxD) {
4688 std::string str = lexer_.GetStrVal();
4689 if (!SplitToInt64s(str, 'x', result)) {
4690 return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
4691 }
4692 lexer_.Lex();
4693 return true;
4694 }
4695 return TokenError("expects token type kInt or kDxD");
4696 }
4697
ParseWindowPad(std::vector<std::vector<int64>> * pad)4698 bool HloParserImpl::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
4699 LocTy loc = lexer_.GetLoc();
4700 if (!pad->empty()) {
4701 return Error(loc, "sub-attribute 'pad=' already exists");
4702 }
4703 if (lexer_.GetKind() != TokKind::kPad) {
4704 return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
4705 }
4706 std::string str = lexer_.GetStrVal();
4707 for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
4708 std::vector<int64> low_high;
4709 if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
4710 low_high.size() != 2) {
4711 return Error(loc,
4712 "expects padding_low and padding_high separated by '_'");
4713 }
4714 pad->push_back(low_high);
4715 }
4716 lexer_.Lex();
4717 return true;
4718 }
4719
4720 // This is the inverse xla::ToString(PaddingConfig). The padding config string
4721 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
4722 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
4723 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)4724 bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) {
4725 if (lexer_.GetKind() != TokKind::kPad) {
4726 return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
4727 }
4728 LocTy loc = lexer_.GetLoc();
4729 std::string str = lexer_.GetStrVal();
4730 for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
4731 std::vector<int64> padding_dim;
4732 if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
4733 (padding_dim.size() != 2 && padding_dim.size() != 3)) {
4734 return Error(loc,
4735 "expects padding config pattern like 'low_high_interior' or "
4736 "'low_high'");
4737 }
4738 auto* dim = padding->add_dimensions();
4739 dim->set_edge_padding_low(padding_dim[0]);
4740 dim->set_edge_padding_high(padding_dim[1]);
4741 dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
4742 }
4743 lexer_.Lex();
4744 return true;
4745 }
4746
4747 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)4748 bool HloParserImpl::ParseMetadata(OpMetadata* metadata) {
4749 absl::flat_hash_map<std::string, AttrConfig> attrs;
4750 optional<std::string> op_type;
4751 optional<std::string> op_name;
4752 optional<std::string> source_file;
4753 optional<int32> source_line;
4754 optional<std::vector<int64>> profile_type;
4755 attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
4756 attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
4757 attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
4758 attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
4759 attrs["profile_type"] = {/*required=*/false, AttrTy::kBracedInt64List,
4760 &profile_type};
4761 if (!ParseSubAttributes(attrs)) {
4762 return false;
4763 }
4764 if (op_type) {
4765 metadata->set_op_type(*op_type);
4766 }
4767 if (op_name) {
4768 metadata->set_op_name(*op_name);
4769 }
4770 if (source_file) {
4771 metadata->set_source_file(*source_file);
4772 }
4773 if (source_line) {
4774 metadata->set_source_line(*source_line);
4775 }
4776 if (profile_type) {
4777 for (const auto& type : *profile_type) {
4778 if (!ProfileType_IsValid(type)) {
4779 return false;
4780 }
4781 metadata->add_profile_type(static_cast<ProfileType>(type));
4782 }
4783 }
4784 return true;
4785 }
4786
4787 // ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}')
ParseSingleOrListMetadata(tensorflow::protobuf::RepeatedPtrField<OpMetadata> * metadata)4788 bool HloParserImpl::ParseSingleOrListMetadata(
4789 tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata) {
4790 if (lexer_.GetKind() == TokKind::kLbrace &&
4791 lexer_.LookAhead() == TokKind::kLbrace) {
4792 if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) {
4793 return false;
4794 }
4795
4796 if (lexer_.GetKind() != TokKind::kRbrace) {
4797 do {
4798 if (!ParseMetadata(metadata->Add())) {
4799 return false;
4800 }
4801 } while (EatIfPresent(TokKind::kComma));
4802 }
4803
4804 return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list");
4805 }
4806
4807 return ParseMetadata(metadata->Add());
4808 }
4809
ParseOpcode(HloOpcode * result)4810 bool HloParserImpl::ParseOpcode(HloOpcode* result) {
4811 VLOG(3) << "ParseOpcode";
4812 if (lexer_.GetKind() != TokKind::kIdent) {
4813 return TokenError("expects opcode");
4814 }
4815 std::string val = lexer_.GetStrVal();
4816 auto status_or_result = StringToHloOpcode(val);
4817 if (!status_or_result.ok()) {
4818 return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
4819 status_or_result.status().error_message()));
4820 }
4821 *result = status_or_result.ValueOrDie();
4822 lexer_.Lex();
4823 return true;
4824 }
4825
ParseFftType(FftType * result)4826 bool HloParserImpl::ParseFftType(FftType* result) {
4827 VLOG(3) << "ParseFftType";
4828 if (lexer_.GetKind() != TokKind::kIdent) {
4829 return TokenError("expects fft type");
4830 }
4831 std::string val = lexer_.GetStrVal();
4832 if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
4833 return TokenError(StrFormat("expects fft type but sees: %s", val));
4834 }
4835 lexer_.Lex();
4836 return true;
4837 }
4838
ParsePaddingType(PaddingType * result)4839 bool HloParserImpl::ParsePaddingType(PaddingType* result) {
4840 VLOG(3) << "ParsePaddingType";
4841 if (lexer_.GetKind() != TokKind::kIdent) {
4842 return TokenError("expects padding type");
4843 }
4844 std::string val = lexer_.GetStrVal();
4845 if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
4846 return TokenError(StrFormat("expects padding type but sees: %s", val));
4847 }
4848 lexer_.Lex();
4849 return true;
4850 }
4851
ParseComparisonDirection(ComparisonDirection * result)4852 bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
4853 VLOG(3) << "ParseComparisonDirection";
4854 if (lexer_.GetKind() != TokKind::kIdent) {
4855 return TokenError("expects comparison direction");
4856 }
4857 std::string val = lexer_.GetStrVal();
4858 auto status_or_result = StringToComparisonDirection(val);
4859 if (!status_or_result.ok()) {
4860 return TokenError(
4861 StrFormat("expects comparison direction but sees: %s", val));
4862 }
4863 *result = status_or_result.ValueOrDie();
4864 lexer_.Lex();
4865 return true;
4866 }
4867
ParseComparisonType(Comparison::Type * result)4868 bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
4869 VLOG(1) << "ParseComparisonType";
4870 if (lexer_.GetKind() != TokKind::kIdent) {
4871 return TokenError("expects comparison type");
4872 }
4873 std::string val = lexer_.GetStrVal();
4874 auto status_or_result = StringToComparisonType(val);
4875 if (!status_or_result.ok()) {
4876 return TokenError(StrFormat("expects comparison type but sees: %s", val));
4877 }
4878 *result = status_or_result.ValueOrDie();
4879 lexer_.Lex();
4880 return true;
4881 }
4882
ParseFusionKind(HloInstruction::FusionKind * result)4883 bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
4884 VLOG(3) << "ParseFusionKind";
4885 if (lexer_.GetKind() != TokKind::kIdent) {
4886 return TokenError("expects fusion kind");
4887 }
4888 std::string val = lexer_.GetStrVal();
4889 auto status_or_result = StringToFusionKind(val);
4890 if (!status_or_result.ok()) {
4891 return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
4892 val,
4893 status_or_result.status().error_message()));
4894 }
4895 *result = status_or_result.ValueOrDie();
4896 lexer_.Lex();
4897 return true;
4898 }
4899
ParseRandomDistribution(RandomDistribution * result)4900 bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
4901 VLOG(3) << "ParseRandomDistribution";
4902 if (lexer_.GetKind() != TokKind::kIdent) {
4903 return TokenError("expects random distribution");
4904 }
4905 std::string val = lexer_.GetStrVal();
4906 auto status_or_result = StringToRandomDistribution(val);
4907 if (!status_or_result.ok()) {
4908 return TokenError(
4909 StrFormat("expects random distribution but sees: %s, error: %s", val,
4910 status_or_result.status().error_message()));
4911 }
4912 *result = status_or_result.ValueOrDie();
4913 lexer_.Lex();
4914 return true;
4915 }
4916
ParseRandomAlgorithm(RandomAlgorithm * result)4917 bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
4918 VLOG(3) << "ParseRandomAlgorithm";
4919 if (lexer_.GetKind() != TokKind::kIdent) {
4920 return TokenError("expects random algorithm");
4921 }
4922 std::string val = lexer_.GetStrVal();
4923 auto status_or_result = StringToRandomAlgorithm(val);
4924 if (!status_or_result.ok()) {
4925 return TokenError(
4926 StrFormat("expects random algorithm but sees: %s, error: %s", val,
4927 status_or_result.status().error_message()));
4928 }
4929 *result = status_or_result.ValueOrDie();
4930 lexer_.Lex();
4931 return true;
4932 }
4933
ParsePrecision(PrecisionConfig::Precision * result)4934 bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
4935 VLOG(3) << "ParsePrecision";
4936 if (lexer_.GetKind() != TokKind::kIdent) {
4937 return TokenError("expects random distribution");
4938 }
4939 std::string val = lexer_.GetStrVal();
4940 auto status_or_result = StringToPrecision(val);
4941 if (!status_or_result.ok()) {
4942 return TokenError(StrFormat("expects precision but sees: %s, error: %s",
4943 val,
4944 status_or_result.status().error_message()));
4945 }
4946 *result = status_or_result.ValueOrDie();
4947 lexer_.Lex();
4948 return true;
4949 }
4950
ParseInt64(int64 * result)4951 bool HloParserImpl::ParseInt64(int64* result) {
4952 VLOG(3) << "ParseInt64";
4953 if (lexer_.GetKind() != TokKind::kInt) {
4954 return TokenError("expects integer");
4955 }
4956 *result = lexer_.GetInt64Val();
4957 lexer_.Lex();
4958 return true;
4959 }
4960
ParseDouble(double * result)4961 bool HloParserImpl::ParseDouble(double* result) {
4962 switch (lexer_.GetKind()) {
4963 case TokKind::kDecimal: {
4964 double val = lexer_.GetDecimalVal();
4965 // If GetDecimalVal returns +/-inf, that means that we overflowed
4966 // `double`.
4967 if (std::isinf(val)) {
4968 return TokenError(StrCat("Constant is out of range for double (+/-",
4969 std::numeric_limits<double>::max(),
4970 ") and so is unparsable."));
4971 }
4972 *result = val;
4973 break;
4974 }
4975 case TokKind::kInt:
4976 *result = static_cast<double>(lexer_.GetInt64Val());
4977 break;
4978 case TokKind::kw_nan:
4979 *result = std::numeric_limits<double>::quiet_NaN();
4980 break;
4981 case TokKind::kNegNan:
4982 *result = -std::numeric_limits<double>::quiet_NaN();
4983 break;
4984 case TokKind::kw_inf:
4985 *result = std::numeric_limits<double>::infinity();
4986 break;
4987 case TokKind::kNegInf:
4988 *result = -std::numeric_limits<double>::infinity();
4989 break;
4990 default:
4991 return TokenError("expects decimal or integer");
4992 }
4993 lexer_.Lex();
4994 return true;
4995 }
4996
ParseComplex(std::complex<double> * result)4997 bool HloParserImpl::ParseComplex(std::complex<double>* result) {
4998 if (lexer_.GetKind() != TokKind::kLparen) {
4999 return TokenError("expects '(' before complex number");
5000 }
5001 lexer_.Lex();
5002
5003 double real;
5004 LocTy loc = lexer_.GetLoc();
5005 if (!ParseDouble(&real)) {
5006 return Error(loc,
5007 "expect floating-point value for real part of complex number");
5008 }
5009
5010 if (lexer_.GetKind() != TokKind::kComma) {
5011 return TokenError(
5012 absl::StrFormat("expect comma after real part of complex literal"));
5013 }
5014 lexer_.Lex();
5015
5016 double imag;
5017 loc = lexer_.GetLoc();
5018 if (!ParseDouble(&imag)) {
5019 return Error(
5020 loc,
5021 "expect floating-point value for imaginary part of complex number");
5022 }
5023
5024 if (lexer_.GetKind() != TokKind::kRparen) {
5025 return TokenError(absl::StrFormat("expect ')' after complex number"));
5026 }
5027
5028 *result = std::complex<double>(real, imag);
5029 lexer_.Lex();
5030 return true;
5031 }
5032
ParseBool(bool * result)5033 bool HloParserImpl::ParseBool(bool* result) {
5034 if (lexer_.GetKind() != TokKind::kw_true &&
5035 lexer_.GetKind() != TokKind::kw_false) {
5036 return TokenError("expects true or false");
5037 }
5038 *result = lexer_.GetKind() == TokKind::kw_true;
5039 lexer_.Lex();
5040 return true;
5041 }
5042
ParseToken(TokKind kind,const std::string & msg)5043 bool HloParserImpl::ParseToken(TokKind kind, const std::string& msg) {
5044 VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
5045 if (lexer_.GetKind() != kind) {
5046 return TokenError(msg);
5047 }
5048 lexer_.Lex();
5049 return true;
5050 }
5051
EatIfPresent(TokKind kind)5052 bool HloParserImpl::EatIfPresent(TokKind kind) {
5053 if (lexer_.GetKind() != kind) {
5054 return false;
5055 }
5056 lexer_.Lex();
5057 return true;
5058 }
5059
AddInstruction(const std::string & name,HloInstruction * instruction,LocTy name_loc)5060 bool HloParserImpl::AddInstruction(const std::string& name,
5061 HloInstruction* instruction,
5062 LocTy name_loc) {
5063 auto result = current_name_table().insert({name, {instruction, name_loc}});
5064 if (!result.second) {
5065 Error(name_loc, StrCat("instruction already exists: ", name));
5066 return Error(/*loc=*/result.first->second.second,
5067 "instruction previously defined here");
5068 }
5069 return true;
5070 }
5071
AddComputation(const std::string & name,HloComputation * computation,LocTy name_loc)5072 bool HloParserImpl::AddComputation(const std::string& name,
5073 HloComputation* computation,
5074 LocTy name_loc) {
5075 auto result = computation_pool_.insert({name, {computation, name_loc}});
5076 if (!result.second) {
5077 Error(name_loc, StrCat("computation already exists: ", name));
5078 return Error(/*loc=*/result.first->second.second,
5079 "computation previously defined here");
5080 }
5081 return true;
5082 }
5083
ParseShapeOnly()5084 StatusOr<Shape> HloParserImpl::ParseShapeOnly() {
5085 lexer_.Lex();
5086 Shape shape;
5087 if (!ParseShape(&shape)) {
5088 return InvalidArgument("Syntax error:\n%s", GetError());
5089 }
5090 if (lexer_.GetKind() != TokKind::kEof) {
5091 return InvalidArgument("Syntax error:\nExtra content after shape");
5092 }
5093 return shape;
5094 }
5095
ParseShardingOnly()5096 StatusOr<HloSharding> HloParserImpl::ParseShardingOnly() {
5097 lexer_.Lex();
5098 OpSharding op_sharding;
5099 if (!ParseSharding(&op_sharding)) {
5100 return InvalidArgument("Syntax error:\n%s", GetError());
5101 }
5102 if (lexer_.GetKind() != TokKind::kEof) {
5103 return InvalidArgument("Syntax error:\nExtra content after sharding");
5104 }
5105 return HloSharding::FromProto(op_sharding);
5106 }
5107
ParseFrontendAttributesOnly()5108 StatusOr<FrontendAttributes> HloParserImpl::ParseFrontendAttributesOnly() {
5109 lexer_.Lex();
5110 FrontendAttributes attributes;
5111 if (!ParseFrontendAttributes(&attributes)) {
5112 return InvalidArgument("Syntax error:\n%s", GetError());
5113 }
5114 if (lexer_.GetKind() != TokKind::kEof) {
5115 return InvalidArgument(
5116 "Syntax error:\nExtra content after frontend attributes");
5117 }
5118 return attributes;
5119 }
5120
ParseParameterReplicationOnly()5121 StatusOr<std::vector<bool>> HloParserImpl::ParseParameterReplicationOnly() {
5122 lexer_.Lex();
5123 ParameterReplication parameter_replication;
5124 if (!ParseParameterReplication(¶meter_replication)) {
5125 return InvalidArgument("Syntax error:\n%s", GetError());
5126 }
5127 if (lexer_.GetKind() != TokKind::kEof) {
5128 return InvalidArgument(
5129 "Syntax error:\nExtra content after parameter replication");
5130 }
5131 return std::vector<bool>(
5132 parameter_replication.replicated_at_leaf_buffers().begin(),
5133 parameter_replication.replicated_at_leaf_buffers().end());
5134 }
5135
ParseReplicaGroupsOnly()5136 StatusOr<std::vector<ReplicaGroup>> HloParserImpl::ParseReplicaGroupsOnly() {
5137 lexer_.Lex();
5138 std::vector<ReplicaGroup> replica_groups;
5139 if (!ParseReplicaGroupsOnly(&replica_groups)) {
5140 return InvalidArgument("Syntax error:\n%s", GetError());
5141 }
5142 if (lexer_.GetKind() != TokKind::kEof) {
5143 return InvalidArgument("Syntax error:\nExtra content after replica groups");
5144 }
5145 return replica_groups;
5146 }
5147
ParseWindowOnly()5148 StatusOr<Window> HloParserImpl::ParseWindowOnly() {
5149 lexer_.Lex();
5150 Window window;
5151 if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
5152 return InvalidArgument("Syntax error:\n%s", GetError());
5153 }
5154 if (lexer_.GetKind() != TokKind::kEof) {
5155 return InvalidArgument("Syntax error:\nExtra content after window");
5156 }
5157 return window;
5158 }
5159
5160 StatusOr<ConvolutionDimensionNumbers>
ParseConvolutionDimensionNumbersOnly()5161 HloParserImpl::ParseConvolutionDimensionNumbersOnly() {
5162 lexer_.Lex();
5163 ConvolutionDimensionNumbers dnums;
5164 if (!ParseConvolutionDimensionNumbers(&dnums)) {
5165 return InvalidArgument("Syntax error:\n%s", GetError());
5166 }
5167 if (lexer_.GetKind() != TokKind::kEof) {
5168 return InvalidArgument(
5169 "Syntax error:\nExtra content after convolution dnums");
5170 }
5171 return dnums;
5172 }
5173
ParsePaddingConfigOnly()5174 StatusOr<PaddingConfig> HloParserImpl::ParsePaddingConfigOnly() {
5175 lexer_.Lex();
5176 PaddingConfig padding_config;
5177 if (!ParsePaddingConfig(&padding_config)) {
5178 return InvalidArgument("Syntax error:\n%s", GetError());
5179 }
5180 if (lexer_.GetKind() != TokKind::kEof) {
5181 return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
5182 }
5183 return padding_config;
5184 }
5185
ParseSingleInstruction(HloModule * module)5186 bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
5187 if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
5188 LOG(FATAL) << "Parser state is not clean. Please do not call any other "
5189 "methods before calling ParseSingleInstruction.";
5190 }
5191 HloComputation::Builder builder(module->name());
5192
5193 // The missing instruction hook we register creates the shaped instruction on
5194 // the fly as a parameter and returns it.
5195 int64 parameter_count = 0;
5196 create_missing_instruction_ =
5197 [this, &builder, ¶meter_count](
5198 const std::string& name,
5199 const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
5200 std::string new_name = name.empty() ? StrCat("_", parameter_count) : name;
5201 HloInstruction* parameter = builder.AddInstruction(
5202 HloInstruction::CreateParameter(parameter_count++, shape, new_name));
5203 current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
5204 return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
5205 };
5206
5207 // Parse the instruction with the registered hook.
5208 Scope scope(&scoped_name_tables_);
5209 if (CanBeShape()) {
5210 // This means that the instruction's left-hand side is probably omitted,
5211 // e.g.
5212 //
5213 // f32[10] fusion(...), calls={...}
5214 if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
5215 return false;
5216 }
5217 } else {
5218 // This means that the instruction's left-hand side might exist, e.g.
5219 //
5220 // foo = f32[10] fusion(...), calls={...}
5221 std::string root_name;
5222 if (!ParseInstruction(&builder, &root_name)) {
5223 return false;
5224 }
5225 }
5226
5227 if (lexer_.GetKind() != TokKind::kEof) {
5228 Error(
5229 lexer_.GetLoc(),
5230 "Syntax error:\nExpected eof after parsing single instruction. Did "
5231 "you mean to write an HLO module and forget the \"HloModule\" header?");
5232 return false;
5233 }
5234
5235 module->AddEntryComputation(builder.Build());
5236 for (auto& comp : computations_) {
5237 module->AddEmbeddedComputation(std::move(comp));
5238 }
5239 TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
5240 return true;
5241 }
5242
5243 } // namespace
5244
ParseAndReturnUnverifiedModule(absl::string_view str,const HloModuleConfig & config)5245 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5246 absl::string_view str, const HloModuleConfig& config) {
5247 auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
5248 HloParserImpl parser(str);
5249 TF_RETURN_IF_ERROR(parser.Run(module.get()));
5250 return std::move(module);
5251 }
5252
ParseAndReturnUnverifiedModule(absl::string_view str)5253 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5254 absl::string_view str) {
5255 return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
5256 }
5257
ParseSharding(absl::string_view str)5258 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
5259 HloParserImpl parser(str);
5260 return parser.ParseShardingOnly();
5261 }
5262
ParseFrontendAttributes(absl::string_view str)5263 StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
5264 HloParserImpl parser(str);
5265 return parser.ParseFrontendAttributesOnly();
5266 }
5267
ParseParameterReplication(absl::string_view str)5268 StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
5269 HloParserImpl parser(str);
5270 return parser.ParseParameterReplicationOnly();
5271 }
5272
ParseReplicaGroupsOnly(absl::string_view str)5273 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
5274 absl::string_view str) {
5275 HloParserImpl parser(str);
5276 return parser.ParseReplicaGroupsOnly();
5277 }
5278
ParseWindow(absl::string_view str)5279 StatusOr<Window> ParseWindow(absl::string_view str) {
5280 HloParserImpl parser(str);
5281 return parser.ParseWindowOnly();
5282 }
5283
ParseConvolutionDimensionNumbers(absl::string_view str)5284 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
5285 absl::string_view str) {
5286 HloParserImpl parser(str);
5287 return parser.ParseConvolutionDimensionNumbersOnly();
5288 }
5289
ParsePaddingConfig(absl::string_view str)5290 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
5291 HloParserImpl parser(str);
5292 return parser.ParsePaddingConfigOnly();
5293 }
5294
ParseShape(absl::string_view str)5295 StatusOr<Shape> ParseShape(absl::string_view str) {
5296 HloParserImpl parser(str);
5297 return parser.ParseShapeOnly();
5298 }
5299
CreateHloParserForTests(absl::string_view str)5300 std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
5301 absl::string_view str) {
5302 return absl::make_unique<HloParserImpl>(str);
5303 }
5304
5305 } // namespace xla
5306