1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/base/casts.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "absl/types/variant.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/literal_util.h"
41 #include "tensorflow/compiler/xla/primitive_util.h"
42 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
43 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
44 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
45 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
46 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
47 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
50 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
51 #include "tensorflow/compiler/xla/service/shape_inference.h"
52 #include "tensorflow/compiler/xla/shape_util.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/lib/gtl/map_util.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/protobuf.h"
58
59 namespace xla {
60
61 namespace {
62
63 using absl::nullopt;
64 using absl::optional;
65 using absl::StrAppend;
66 using absl::StrCat;
67 using absl::StrFormat;
68 using absl::StrJoin;
69
70 // Creates and returns a schedule created using the order of the instructions in
71 // the HloComputation::instructions() vectors in the module.
ScheduleFromInstructionOrder(HloModule * module)72 HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
73 HloSchedule schedule(module);
74 for (HloComputation* computation : module->computations()) {
75 if (!computation->IsFusionComputation()) {
76 for (HloInstruction* instruction : computation->instructions()) {
77 schedule.GetOrCreateSequence(computation).push_back(instruction);
78 }
79 }
80 }
81 return schedule;
82 }
83
CanInferShape(HloOpcode code)84 bool CanInferShape(HloOpcode code) {
85 switch (code) {
86 case HloOpcode::kAbs:
87 case HloOpcode::kAdd:
88 case HloOpcode::kAddDependency:
89 case HloOpcode::kAfterAll:
90 case HloOpcode::kAtan2:
91 case HloOpcode::kBatchNormGrad:
92 case HloOpcode::kBatchNormInference:
93 case HloOpcode::kBatchNormTraining:
94 case HloOpcode::kBroadcast:
95 case HloOpcode::kCall:
96 case HloOpcode::kCeil:
97 case HloOpcode::kCholesky:
98 case HloOpcode::kClamp:
99 case HloOpcode::kClz:
100 case HloOpcode::kCompare:
101 case HloOpcode::kComplex:
102 case HloOpcode::kConcatenate:
103 case HloOpcode::kConditional:
104 case HloOpcode::kConvolution:
105 case HloOpcode::kCopy:
106 case HloOpcode::kCos:
107 case HloOpcode::kDivide:
108 case HloOpcode::kDomain:
109 case HloOpcode::kDot:
110 case HloOpcode::kExp:
111 case HloOpcode::kExpm1:
112 case HloOpcode::kFft:
113 case HloOpcode::kFloor:
114 case HloOpcode::kGather:
115 case HloOpcode::kGetDimensionSize:
116 case HloOpcode::kSetDimensionSize:
117 case HloOpcode::kGetTupleElement:
118 case HloOpcode::kImag:
119 case HloOpcode::kIsFinite:
120 case HloOpcode::kLog:
121 case HloOpcode::kLog1p:
122 case HloOpcode::kLogistic:
123 case HloOpcode::kAnd:
124 case HloOpcode::kNot:
125 case HloOpcode::kOr:
126 case HloOpcode::kXor:
127 case HloOpcode::kMap:
128 case HloOpcode::kMaximum:
129 case HloOpcode::kMinimum:
130 case HloOpcode::kMultiply:
131 case HloOpcode::kNegate:
132 case HloOpcode::kPad:
133 case HloOpcode::kPartitionId:
134 case HloOpcode::kPopulationCount:
135 case HloOpcode::kPower:
136 case HloOpcode::kReal:
137 case HloOpcode::kReduce:
138 case HloOpcode::kRemainder:
139 case HloOpcode::kReplicaId:
140 case HloOpcode::kReverse:
141 case HloOpcode::kRoundNearestAfz:
142 case HloOpcode::kRsqrt:
143 case HloOpcode::kScatter:
144 case HloOpcode::kSelect:
145 case HloOpcode::kShiftLeft:
146 case HloOpcode::kShiftRightArithmetic:
147 case HloOpcode::kShiftRightLogical:
148 case HloOpcode::kSign:
149 case HloOpcode::kSin:
150 case HloOpcode::kSqrt:
151 case HloOpcode::kCbrt:
152 case HloOpcode::kReduceWindow:
153 case HloOpcode::kSelectAndScatter:
154 case HloOpcode::kSort:
155 case HloOpcode::kSubtract:
156 case HloOpcode::kTanh:
157 case HloOpcode::kTrace:
158 case HloOpcode::kTranspose:
159 case HloOpcode::kTriangularSolve:
160 case HloOpcode::kTuple:
161 case HloOpcode::kTupleSelect:
162 case HloOpcode::kWhile:
163 return true;
164 // Technically the following ops do not require an explicit result shape,
165 // but we made it so that we always write the shapes explicitly.
166 case HloOpcode::kAllGather:
167 case HloOpcode::kAllGatherStart:
168 case HloOpcode::kAllGatherDone:
169 case HloOpcode::kAllReduce:
170 case HloOpcode::kAllReduceStart:
171 case HloOpcode::kAllReduceDone:
172 case HloOpcode::kAllToAll:
173 case HloOpcode::kCollectivePermute:
174 case HloOpcode::kCollectivePermuteStart:
175 case HloOpcode::kCollectivePermuteDone:
176 case HloOpcode::kCopyDone:
177 case HloOpcode::kCopyStart:
178 case HloOpcode::kDynamicReshape:
179 case HloOpcode::kDynamicSlice:
180 case HloOpcode::kDynamicUpdateSlice:
181 case HloOpcode::kRecv:
182 case HloOpcode::kRecvDone:
183 case HloOpcode::kReduceScatter:
184 case HloOpcode::kSend:
185 case HloOpcode::kSendDone:
186 case HloOpcode::kSlice:
187 // The following ops require an explicit result shape.
188 case HloOpcode::kBitcast:
189 case HloOpcode::kBitcastConvert:
190 case HloOpcode::kConstant:
191 case HloOpcode::kConvert:
192 case HloOpcode::kCustomCall:
193 case HloOpcode::kFusion:
194 case HloOpcode::kInfeed:
195 case HloOpcode::kIota:
196 case HloOpcode::kOutfeed:
197 case HloOpcode::kParameter:
198 case HloOpcode::kReducePrecision:
199 case HloOpcode::kReshape:
200 case HloOpcode::kRng:
201 case HloOpcode::kRngBitGenerator:
202 case HloOpcode::kRngGetAndUpdateState:
203 return false;
204 }
205 }
206
207 // Parser for the HloModule::ToString() format text.
208 class HloParserImpl : public HloParser {
209 public:
210 using LocTy = HloLexer::LocTy;
211
HloParserImpl(absl::string_view str)212 explicit HloParserImpl(absl::string_view str) : lexer_(str) {}
213
214 // Runs the parser and constructs the resulting HLO in the given (empty)
215 // HloModule. Returns the error status in case an error occurred.
216 Status Run(HloModule* module) override;
217
218 // Returns the error information.
GetError() const219 std::string GetError() const { return StrJoin(error_, "\n"); }
220
221 // Stand alone parsing utils for various aggregate data types.
222 StatusOr<Shape> ParseShapeOnly();
223 StatusOr<HloSharding> ParseShardingOnly();
224 StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
225 StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
226 StatusOr<Window> ParseWindowOnly();
227 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
228 StatusOr<PaddingConfig> ParsePaddingConfigOnly();
229 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
230
231 private:
232 using InstrNameTable =
233 absl::flat_hash_map<std::string, std::pair<HloInstruction*, LocTy>>;
234
235 // Returns the map from the instruction name to the instruction itself and its
236 // location in the current scope.
current_name_table()237 InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
238
239 // Locates an instruction with the given name in the current_name_table() or
240 // returns nullptr.
241 //
242 // When the name is not found or name is empty, if create_missing_instruction_
243 // hook is registered and a "shape" is provided, the hook will be called to
244 // create an instruction. This is useful when we reify parameters as they're
245 // resolved; i.e. for ParseSingleInstruction.
246 std::pair<HloInstruction*, LocTy>* FindInstruction(
247 const std::string& name, const optional<Shape>& shape = nullopt);
248
249 // Parse a single instruction worth of text.
250 bool ParseSingleInstruction(HloModule* module);
251
252 // Parses a module, returning false if an error occurred.
253 bool ParseHloModule(HloModule* module);
254
255 bool ParseComputations(HloModule* module);
256 bool ParseComputation(HloComputation** entry_computation);
257 bool ParseInstructionList(HloComputation** computation,
258 const std::string& computation_name);
259 bool ParseInstruction(HloComputation::Builder* builder,
260 std::string* root_name);
261 bool ParseInstructionRhs(HloComputation::Builder* builder,
262 const std::string& name, LocTy name_loc);
263 bool ParseControlPredecessors(HloInstruction* instruction);
264 bool ParseLiteral(Literal* literal);
265 bool ParseLiteral(Literal* literal, const Shape& shape);
266 bool ParseTupleLiteral(Literal* literal, const Shape& shape);
267 bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
268 bool ParseDenseLiteral(Literal* literal, const Shape& shape);
269
270 // Sets the sub-value of literal at the given linear index to the
271 // given value. If the literal is dense, it must have the default layout.
272 //
273 // `loc` should be the source location of the value.
274 bool SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
275 Literal* literal);
276 bool SetValueInLiteral(LocTy loc, double value, int64_t index,
277 Literal* literal);
278 bool SetValueInLiteral(LocTy loc, bool value, int64_t index,
279 Literal* literal);
280 bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64_t index,
281 Literal* literal);
282 // `loc` should be the source location of the value.
283 template <typename LiteralNativeT, typename ParsedElemT>
284 bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64_t index,
285 Literal* literal);
286
287 // Checks whether the given value is within the range of LiteralNativeT.
288 // `loc` should be the source location of the value.
289 template <typename LiteralNativeT, typename ParsedElemT>
290 bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
291 template <typename LiteralNativeT>
292 bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
293
294 bool ParseOperands(std::vector<HloInstruction*>* operands);
295 // Fills parsed operands into 'operands' and expects a certain number of
296 // operands.
297 bool ParseOperands(std::vector<HloInstruction*>* operands,
298 const int expected_size);
299
300 // Describes the start, limit, and stride on every dimension of the operand
301 // being sliced.
302 struct SliceRanges {
303 std::vector<int64> starts;
304 std::vector<int64> limits;
305 std::vector<int64> strides;
306 };
307
308 // The data parsed for the kDomain instruction.
309 struct DomainData {
310 std::unique_ptr<DomainMetadata> entry_metadata;
311 std::unique_ptr<DomainMetadata> exit_metadata;
312 };
313
314 // Types of attributes.
315 enum class AttrTy {
316 kBool,
317 kInt64,
318 kInt32,
319 kFloat,
320 kString,
321 kLiteral,
322 kBracedInt64List,
323 kBracedInt64ListList,
324 kHloComputation,
325 kBracedHloComputationList,
326 kFftType,
327 kPaddingType,
328 kComparisonDirection,
329 kComparisonType,
330 kWindow,
331 kConvolutionDimensionNumbers,
332 kSharding,
333 kFrontendAttributes,
334 kParameterReplication,
335 kInstructionList,
336 kSliceRanges,
337 kPaddingConfig,
338 kMetadata,
339 kFusionKind,
340 kDistribution,
341 kDomain,
342 kPrecisionList,
343 kShape,
344 kShapeList,
345 kEnum,
346 kRandomAlgorithm,
347 kAliasing,
348 kInstructionAliasing,
349 kCustomCallSchedule,
350 kCustomCallApiVersion,
351 };
352
353 struct AttrConfig {
354 bool required; // whether it's required or optional
355 AttrTy attr_type; // what type it is
356 void* result; // where to store the parsed result.
357 };
358
359 // attributes ::= (',' attribute)*
360 //
361 // Parses attributes given names and configs of the attributes. Each parsed
362 // result is passed back through the result pointer in corresponding
363 // AttrConfig. Note that the result pointer must point to a optional<T> typed
364 // variable which outlives this function. Returns false on error. You should
365 // not use the any of the results if this function failed.
366 //
367 // Example usage:
368 //
369 // absl::flat_hash_map<std::string, AttrConfig> attrs;
370 // optional<int64> foo;
371 // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
372 // optional<Window> bar;
373 // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
374 // if (!ParseAttributes(attrs)) {
375 // return false; // Do not use 'foo' 'bar' if failed.
376 // }
377 // // Do something with 'bar'.
378 // if (foo) { // If attr foo is seen, do something with 'foo'. }
379 //
380 bool ParseAttributes(
381 const absl::flat_hash_map<std::string, AttrConfig>& attrs);
382
383 // sub_attributes ::= '{' (','? attribute)* '}'
384 //
385 // Usage is the same as ParseAttributes. See immediately above.
386 bool ParseSubAttributes(
387 const absl::flat_hash_map<std::string, AttrConfig>& attrs);
388
389 // Parses one attribute. If it has already been seen, return error. Returns
390 // true and adds to seen_attrs on success.
391 //
392 // Do not call this except in ParseAttributes or ParseSubAttributes.
393 bool ParseAttributeHelper(
394 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
395 absl::flat_hash_set<std::string>* seen_attrs);
396
397 // Copy attributes from `attrs` to `message`, unless the attribute name is in
398 // `non_proto_attrs`.
399 bool CopyAttributeToProtoMessage(
400 absl::flat_hash_set<std::string> non_proto_attrs,
401 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
402 tensorflow::protobuf::Message* message);
403
404 // Parses an attribute string into a protocol buffer `message`.
405 // Since proto3 has no notion of mandatory fields, `required_attrs` gives the
406 // set of mandatory attributes.
407 // `non_proto_attrs` specifies attributes that are not written to the proto,
408 // but added to the HloInstruction.
409 bool ParseAttributesAsProtoMessage(
410 const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
411 tensorflow::protobuf::Message* message);
412
413 // Parses a name and finds the corresponding hlo computation.
414 bool ParseComputationName(HloComputation** value);
415 // Parses a list of names and finds the corresponding hlo instructions.
416 bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
417 // Pass expect_outer_curlies == true when parsing a Window in the context of a
418 // larger computation. Pass false when parsing a stand-alone Window string.
419 bool ParseWindow(Window* window, bool expect_outer_curlies);
420 bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
421 bool ParsePaddingConfig(PaddingConfig* padding);
422 bool ParseMetadata(OpMetadata* metadata);
423 bool ParseSingleOrListMetadata(
424 tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata);
425 bool ParseOpShardingType(OpSharding::Type* type);
426 bool ParseListShardingType(std::vector<OpSharding::Type>* types);
427 bool ParseSharding(OpSharding* sharding);
428 bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
429 bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
430 bool ParseParameterReplication(ParameterReplication* parameter_replication);
431 bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
432
433 // Parses the metadata behind a kDOmain instruction.
434 bool ParseDomain(DomainData* domain);
435
436 // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
437 bool ParseDxD(const std::string& name, std::vector<int64>* result);
438 // Parses window's pad sub-attribute, e.g., pad=0_0x3x3.
439 bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
440
441 bool ParseSliceRanges(SliceRanges* result);
442 bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
443 bool ParseHloComputation(HloComputation** result);
444 bool ParseHloComputationList(std::vector<HloComputation*>* result);
445 bool ParseShapeList(std::vector<Shape>* result);
446 bool ParseInt64List(const TokKind start, const TokKind end,
447 const TokKind delim, std::vector<int64>* result);
448 bool ParseInt64ListList(const TokKind start, const TokKind end,
449 const TokKind delim,
450 std::vector<std::vector<int64>>* result);
451 // 'parse_and_add_item' is an lambda to parse an element in the list and add
452 // the parsed element to the result. It's supposed to capture the result.
453 bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
454 const std::function<bool()>& parse_and_add_item);
455
456 bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
457 bool ParseParamList();
458 bool ParseName(std::string* result);
459 bool ParseAttributeName(std::string* result);
460 bool ParseString(std::string* result);
461 bool ParseDimensionSizes(std::vector<int64>* dimension_sizes,
462 std::vector<bool>* dynamic_dimensions);
463 bool ParseShape(Shape* result);
464 bool ParseLayout(Layout* layout);
465 bool ParseLayoutIntAttribute(int64* attr_value,
466 absl::string_view attr_description);
467 bool ParseTiles(std::vector<Tile>* tiles);
468 bool ParseOpcode(HloOpcode* result);
469 bool ParseFftType(FftType* result);
470 bool ParsePaddingType(PaddingType* result);
471 bool ParseComparisonDirection(ComparisonDirection* result);
472 bool ParseComparisonType(Comparison::Type* result);
473 bool ParseFusionKind(HloInstruction::FusionKind* result);
474 bool ParseRandomDistribution(RandomDistribution* result);
475 bool ParseRandomAlgorithm(RandomAlgorithm* result);
476 bool ParsePrecision(PrecisionConfig::Precision* result);
477 bool ParseInt64(int64* result);
478 bool ParseDouble(double* result);
479 bool ParseComplex(std::complex<double>* result);
480 bool ParseBool(bool* result);
481 bool ParseToken(TokKind kind, const std::string& msg);
482
483 using AliasingData =
484 absl::flat_hash_map<ShapeIndex, HloInputOutputAliasConfig::Alias>;
485
486 // Parses the aliasing information from string `s`, returns `false` if it
487 // fails.
488 bool ParseAliasing(AliasingData* data);
489
490 // Parses the per-instruction aliasing information from string `s`, returns
491 // `false` if it fails.
492 bool ParseInstructionOutputOperandAliasing(
493 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
494 aliasing_output_operand_pairs);
495
496 bool ParseCustomCallSchedule(CustomCallSchedule* result);
497 bool ParseCustomCallApiVersion(CustomCallApiVersion* result);
498 bool ParseShapeIndex(ShapeIndex* out);
499
500 // Returns true if the current token is the beginning of a shape.
501 bool CanBeShape();
502 // Returns true if the current token is the beginning of a
503 // param_list_to_shape.
504 bool CanBeParamListToShape();
505
506 // Logs the current parsing line and the given message. Always returns false.
507 bool TokenError(absl::string_view msg);
508 bool Error(LocTy loc, absl::string_view msg);
509
510 // If the current token is 'kind', eats it (i.e. lexes the next token) and
511 // returns true.
512 bool EatIfPresent(TokKind kind);
513
514 // Adds the instruction to the pool. Returns false and emits an error if the
515 // instruction already exists.
516 bool AddInstruction(const std::string& name, HloInstruction* instruction,
517 LocTy name_loc);
518 // Adds the computation to the pool. Returns false and emits an error if the
519 // computation already exists.
520 bool AddComputation(const std::string& name, HloComputation* computation,
521 LocTy name_loc);
522
523 HloLexer lexer_;
524
525 // A stack for the instruction names. The top of the stack stores the
526 // instruction name table for the current scope.
527 //
528 // A instruction's name is unique among its scope (i.e. its parent
529 // computation), but it's not necessarily unique among all computations in the
530 // module. When there are multiple levels of nested computations, the same
531 // name could appear in both an outer computation and an inner computation. So
532 // we need a stack to make sure a name is only visible within its scope,
533 std::vector<InstrNameTable> scoped_name_tables_;
534
535 // A helper class which pushes and pops to an InstrNameTable stack via RAII.
536 class Scope {
537 public:
Scope(std::vector<InstrNameTable> * scoped_name_tables)538 explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
539 : scoped_name_tables_(scoped_name_tables) {
540 scoped_name_tables_->emplace_back();
541 }
~Scope()542 ~Scope() { scoped_name_tables_->pop_back(); }
543
544 private:
545 std::vector<InstrNameTable>* scoped_name_tables_;
546 };
547
548 // Map from the computation name to the computation itself and its location.
549 absl::flat_hash_map<std::string, std::pair<HloComputation*, LocTy>>
550 computation_pool_;
551
552 std::vector<std::unique_ptr<HloComputation>> computations_;
553 std::vector<std::string> error_;
554
555 // When an operand name cannot be resolved, this function is called to create
556 // a parameter instruction with the given name and shape. It registers the
557 // name, instruction, and a placeholder location in the name table. It returns
558 // the newly-created instruction and the placeholder location. If `name` is
559 // empty, this should create the parameter with a generated name. This is
560 // supposed to be set and used only in ParseSingleInstruction.
561 std::function<std::pair<HloInstruction*, LocTy>*(const std::string& name,
562 const Shape& shape)>
563 create_missing_instruction_;
564 };
565
SplitToInt64s(absl::string_view s,char delim,std::vector<int64> * out)566 bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
567 for (const auto& split : absl::StrSplit(s, delim)) {
568 int64_t val;
569 if (!absl::SimpleAtoi(split, &val)) {
570 return false;
571 }
572 out->push_back(val);
573 }
574 return true;
575 }
576
577 // Creates replica groups from the provided nested array. groups[i] represents
578 // the replica ids for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64>> groups)579 std::vector<ReplicaGroup> CreateReplicaGroups(
580 absl::Span<const std::vector<int64>> groups) {
581 std::vector<ReplicaGroup> replica_groups;
582 absl::c_transform(groups, std::back_inserter(replica_groups),
583 [](const std::vector<int64>& ids) {
584 ReplicaGroup group;
585 *group.mutable_replica_ids() = {ids.begin(), ids.end()};
586 return group;
587 });
588 return replica_groups;
589 }
590
Error(LocTy loc,absl::string_view msg)591 bool HloParserImpl::Error(LocTy loc, absl::string_view msg) {
592 auto line_col = lexer_.GetLineAndColumn(loc);
593 const unsigned line = line_col.first;
594 const unsigned col = line_col.second;
595 std::vector<std::string> error_lines;
596 error_lines.push_back(
597 StrCat("was parsing ", line, ":", col, ": error: ", msg));
598 error_lines.emplace_back(lexer_.GetLine(loc));
599 error_lines.push_back(col == 0 ? "" : StrCat(std::string(col - 1, ' '), "^"));
600
601 error_.push_back(StrJoin(error_lines, "\n"));
602 VLOG(1) << "Error: " << error_.back();
603 return false;
604 }
605
TokenError(absl::string_view msg)606 bool HloParserImpl::TokenError(absl::string_view msg) {
607 return Error(lexer_.GetLoc(), msg);
608 }
609
Run(HloModule * module)610 Status HloParserImpl::Run(HloModule* module) {
611 lexer_.Lex();
612 if (lexer_.GetKind() == TokKind::kw_HloModule) {
613 // This means that the text contains a full HLO module.
614 if (!ParseHloModule(module)) {
615 return InvalidArgument(
616 "Syntax error when trying to parse the text as a HloModule:\n%s",
617 GetError());
618 }
619 return Status::OK();
620 }
621 // This means that the text is a single HLO instruction.
622 if (!ParseSingleInstruction(module)) {
623 return InvalidArgument(
624 "Syntax error when trying to parse the text as a single "
625 "HloInstruction:\n%s",
626 GetError());
627 }
628 return Status::OK();
629 }
630
631 std::pair<HloInstruction*, HloParserImpl::LocTy>*
FindInstruction(const std::string & name,const optional<Shape> & shape)632 HloParserImpl::FindInstruction(const std::string& name,
633 const optional<Shape>& shape) {
634 std::pair<HloInstruction*, LocTy>* instr = nullptr;
635 if (!name.empty()) {
636 instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
637 }
638
639 // Potentially call the missing instruction hook.
640 if (instr == nullptr && create_missing_instruction_ != nullptr &&
641 scoped_name_tables_.size() == 1) {
642 if (!shape.has_value()) {
643 Error(lexer_.GetLoc(),
644 "Operand had no shape in HLO text; cannot create parameter for "
645 "single-instruction module.");
646 return nullptr;
647 }
648 return create_missing_instruction_(name, *shape);
649 }
650
651 if (instr != nullptr && shape.has_value() &&
652 !ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
653 Error(
654 lexer_.GetLoc(),
655 StrCat("The declared operand shape ",
656 ShapeUtil::HumanStringWithLayout(shape.value()),
657 " is not compatible with the shape of the operand instruction ",
658 ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
659 return nullptr;
660 }
661
662 return instr;
663 }
664
ParseShapeIndex(ShapeIndex * out)665 bool HloParserImpl::ParseShapeIndex(ShapeIndex* out) {
666 if (!ParseToken(TokKind::kLbrace, "Expects '{' at the start of ShapeIndex")) {
667 return false;
668 }
669
670 std::vector<int64> idxs;
671 while (lexer_.GetKind() != TokKind::kRbrace) {
672 int64_t idx;
673 if (!ParseInt64(&idx)) {
674 return false;
675 }
676 idxs.push_back(idx);
677 if (!EatIfPresent(TokKind::kComma)) {
678 break;
679 }
680 }
681 if (!ParseToken(TokKind::kRbrace, "Expects '}' at the end of ShapeIndex")) {
682 return false;
683 }
684 *out = ShapeIndex(idxs.begin(), idxs.end());
685 return true;
686 }
687
ParseAliasing(AliasingData * data)688 bool HloParserImpl::ParseAliasing(AliasingData* data) {
689 if (!ParseToken(TokKind::kLbrace,
690 "Expects '{' at the start of aliasing description")) {
691 return false;
692 }
693
694 while (lexer_.GetKind() != TokKind::kRbrace) {
695 ShapeIndex out;
696 if (!ParseShapeIndex(&out)) {
697 return false;
698 }
699 std::string errmsg =
700 "Expected format: <output_shape_index>: (<input_param>, "
701 "<input_param_shape_index>) OR <output_shape_index>: <input_param>";
702 if (!ParseToken(TokKind::kColon, errmsg)) {
703 return false;
704 }
705
706 if (!ParseToken(TokKind::kLparen, errmsg)) {
707 return false;
708 }
709 int64_t param_num;
710 ParseInt64(¶m_num);
711 if (!ParseToken(TokKind::kComma, errmsg)) {
712 return false;
713 }
714 ShapeIndex param_idx;
715 if (!ParseShapeIndex(¶m_idx)) {
716 return false;
717 }
718
719 HloInputOutputAliasConfig::AliasKind alias_kind =
720 HloInputOutputAliasConfig::kMayAlias;
721 if (EatIfPresent(TokKind::kComma)) {
722 std::string type;
723 ParseName(&type);
724 if (type == "must-alias") {
725 alias_kind = HloInputOutputAliasConfig::kMustAlias;
726 } else if (type == "may-alias") {
727 alias_kind = HloInputOutputAliasConfig::kMayAlias;
728 } else {
729 return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
730 }
731 }
732
733 data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
734 std::forward_as_tuple(param_num, param_idx, alias_kind));
735 if (!ParseToken(TokKind::kRparen, errmsg)) {
736 return false;
737 }
738
739 if (!EatIfPresent(TokKind::kComma)) {
740 break;
741 }
742 }
743 if (!ParseToken(TokKind::kRbrace,
744 "Expects '}' at the end of aliasing description")) {
745 return false;
746 }
747 return true;
748 }
749
ParseInstructionOutputOperandAliasing(std::vector<std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> * aliasing_output_operand_pairs)750 bool HloParserImpl::ParseInstructionOutputOperandAliasing(
751 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
752 aliasing_output_operand_pairs) {
753 if (!ParseToken(
754 TokKind::kLbrace,
755 "Expects '{' at the start of instruction aliasing description")) {
756 return false;
757 }
758
759 while (lexer_.GetKind() != TokKind::kRbrace) {
760 ShapeIndex out;
761 if (!ParseShapeIndex(&out)) {
762 return false;
763 }
764 std::string errmsg =
765 "Expected format: <output_shape_index>: (<operand_index>, "
766 "<operand_shape_index>)";
767 if (!ParseToken(TokKind::kColon, errmsg)) {
768 return false;
769 }
770
771 if (!ParseToken(TokKind::kLparen, errmsg)) {
772 return false;
773 }
774 int64_t operand_index;
775 ParseInt64(&operand_index);
776 if (!ParseToken(TokKind::kComma, errmsg)) {
777 return false;
778 }
779 ShapeIndex operand_shape_index;
780 if (!ParseShapeIndex(&operand_shape_index)) {
781 return false;
782 }
783
784 aliasing_output_operand_pairs->emplace_back(
785 out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
786 if (!ParseToken(TokKind::kRparen, errmsg)) {
787 return false;
788 }
789
790 if (!EatIfPresent(TokKind::kComma)) {
791 break;
792 }
793 }
794 if (!ParseToken(
795 TokKind::kRbrace,
796 "Expects '}' at the end of instruction aliasing description")) {
797 return false;
798 }
799 return true;
800 }
801
ParseCustomCallSchedule(CustomCallSchedule * result)802 bool HloParserImpl::ParseCustomCallSchedule(CustomCallSchedule* result) {
803 VLOG(3) << "ParseCustomCallSchedule";
804 if (lexer_.GetKind() != TokKind::kIdent) {
805 return TokenError("expects custom-call schedule");
806 }
807 std::string val = lexer_.GetStrVal();
808 auto status_or_result = StringToCustomCallSchedule(val);
809 if (!status_or_result.ok()) {
810 return TokenError(
811 StrFormat("expects custom-call schedule but sees: %s, error: %s", val,
812 status_or_result.status().error_message()));
813 }
814 *result = status_or_result.ValueOrDie();
815 lexer_.Lex();
816 return true;
817 }
818
ParseCustomCallApiVersion(CustomCallApiVersion * result)819 bool HloParserImpl::ParseCustomCallApiVersion(CustomCallApiVersion* result) {
820 VLOG(3) << "ParseCustomCallApiVersion";
821 if (lexer_.GetKind() != TokKind::kIdent) {
822 return TokenError("expects custom-call API version");
823 }
824 std::string val = lexer_.GetStrVal();
825 auto status_or_result = StringToCustomCallApiVersion(val);
826 if (!status_or_result.ok()) {
827 return TokenError(
828 StrFormat("expects custom-call API version but sees: %s, error: %s",
829 val, status_or_result.status().error_message()));
830 }
831 *result = status_or_result.ValueOrDie();
832 lexer_.Lex();
833 return true;
834 }
835
836 // ::= 'HloModule' name computations
ParseHloModule(HloModule * module)837 bool HloParserImpl::ParseHloModule(HloModule* module) {
838 if (lexer_.GetKind() != TokKind::kw_HloModule) {
839 return TokenError("expects HloModule");
840 }
841 // Eat 'HloModule'
842 lexer_.Lex();
843
844 std::string name;
845 if (!ParseName(&name)) {
846 return false;
847 }
848
849 absl::optional<bool> is_scheduled;
850 absl::optional<AliasingData> aliasing_data;
851 absl::optional<bool> alias_passthrough_params;
852 absl::flat_hash_map<std::string, AttrConfig> attrs;
853
854 attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
855 attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing,
856 &aliasing_data};
857 attrs["alias_passthrough_params"] = {/*required=*/false, AttrTy::kBool,
858 &alias_passthrough_params};
859 if (!ParseAttributes(attrs)) {
860 return false;
861 }
862 module->set_name(name);
863 if (!ParseComputations(module)) {
864 return false;
865 }
866
867 if (is_scheduled.has_value() && *is_scheduled) {
868 TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
869 }
870 if (alias_passthrough_params.has_value() && *alias_passthrough_params) {
871 HloModuleConfig config = module->config();
872 config.set_alias_passthrough_params(true);
873 module->set_config(config);
874 }
875 if (aliasing_data) {
876 HloInputOutputAliasConfig alias_config(module->result_shape());
877 for (auto& p : *aliasing_data) {
878 Status st =
879 alias_config.SetUpAlias(p.first, p.second.parameter_number,
880 p.second.parameter_index, p.second.kind);
881 if (!st.ok()) {
882 return TokenError(st.error_message());
883 }
884 }
885 module->input_output_alias_config() = alias_config;
886 }
887
888 return true;
889 }
890
891 // computations ::= (computation)+
ParseComputations(HloModule * module)892 bool HloParserImpl::ParseComputations(HloModule* module) {
893 HloComputation* entry_computation = nullptr;
894 do {
895 if (!ParseComputation(&entry_computation)) {
896 return false;
897 }
898 } while (lexer_.GetKind() != TokKind::kEof);
899
900 for (int i = 0; i < computations_.size(); i++) {
901 // If entry_computation is not nullptr, it means the computation it pointed
902 // to is marked with "ENTRY"; otherwise, no computation is marked with
903 // "ENTRY", and we use the last computation as the entry computation. We
904 // add the non-entry computations as embedded computations to the module.
905 if ((entry_computation != nullptr &&
906 computations_[i].get() != entry_computation) ||
907 (entry_computation == nullptr && i != computations_.size() - 1)) {
908 module->AddEmbeddedComputation(std::move(computations_[i]));
909 continue;
910 }
911 auto computation = module->AddEntryComputation(std::move(computations_[i]));
912 // The parameters and result layouts were set to default layout. Here we
913 // set the layouts to what the hlo text says.
914 for (int p = 0; p < computation->num_parameters(); p++) {
915 const Shape& param_shape = computation->parameter_instruction(p)->shape();
916 TF_CHECK_OK(module->mutable_entry_computation_layout()
917 ->mutable_parameter_layout(p)
918 ->CopyLayoutFromShape(param_shape));
919 }
920 const Shape& result_shape = computation->root_instruction()->shape();
921 TF_CHECK_OK(module->mutable_entry_computation_layout()
922 ->mutable_result_layout()
923 ->CopyLayoutFromShape(result_shape));
924 }
925 return true;
926 }
927
928 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
ParseComputation(HloComputation ** entry_computation)929 bool HloParserImpl::ParseComputation(HloComputation** entry_computation) {
930 LocTy maybe_entry_loc = lexer_.GetLoc();
931 const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
932
933 std::string name;
934 LocTy name_loc = lexer_.GetLoc();
935 if (!ParseName(&name)) {
936 return false;
937 }
938
939 LocTy shape_loc = nullptr;
940 Shape shape;
941 if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
942 return false;
943 }
944
945 HloComputation* computation = nullptr;
946 if (!ParseInstructionList(&computation, name)) {
947 return false;
948 }
949
950 // If param_list_to_shape was present, check compatibility.
951 if (shape_loc != nullptr &&
952 !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
953 return Error(
954 shape_loc,
955 StrCat(
956 "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
957 ", is not compatible with that of its root instruction ",
958 computation->root_instruction()->name(), ", ",
959 ShapeUtil::HumanString(computation->root_instruction()->shape())));
960 }
961
962 if (is_entry_computation) {
963 if (*entry_computation != nullptr) {
964 return Error(maybe_entry_loc, "expects only one ENTRY");
965 }
966 *entry_computation = computation;
967 }
968
969 return AddComputation(name, computation, name_loc);
970 }
971
972 // instruction_list ::= '{' instruction_list1 '}'
973 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation ** computation,const std::string & computation_name)974 bool HloParserImpl::ParseInstructionList(HloComputation** computation,
975 const std::string& computation_name) {
976 Scope scope(&scoped_name_tables_);
977 HloComputation::Builder builder(computation_name);
978 if (!ParseToken(TokKind::kLbrace,
979 "expects '{' at the beginning of instruction list.")) {
980 return false;
981 }
982 std::string root_name;
983 do {
984 if (!ParseInstruction(&builder, &root_name)) {
985 return false;
986 }
987 } while (lexer_.GetKind() != TokKind::kRbrace);
988 if (!ParseToken(TokKind::kRbrace,
989 "expects '}' at the end of instruction list.")) {
990 return false;
991 }
992 HloInstruction* root = nullptr;
993 if (!root_name.empty()) {
994 std::pair<HloInstruction*, LocTy>* root_node =
995 tensorflow::gtl::FindOrNull(current_name_table(), root_name);
996
997 // This means some instruction was marked as ROOT but we didn't find it in
998 // the pool, which should not happen.
999 if (root_node == nullptr) {
1000 // LOG(FATAL) crashes the program by calling abort().
1001 LOG(FATAL) << "instruction " << root_name
1002 << " was marked as ROOT but the parser has not seen it before";
1003 }
1004 root = root_node->first;
1005 }
1006
1007 // Now root can be either an existing instruction or a nullptr. If it's a
1008 // nullptr, the implementation of Builder will set the last instruction as
1009 // the root instruction.
1010 computations_.emplace_back(builder.Build(root));
1011 *computation = computations_.back().get();
1012 return true;
1013 }
1014
1015 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,std::string * root_name)1016 bool HloParserImpl::ParseInstruction(HloComputation::Builder* builder,
1017 std::string* root_name) {
1018 std::string name;
1019 LocTy maybe_root_loc = lexer_.GetLoc();
1020 bool is_root = EatIfPresent(TokKind::kw_ROOT);
1021
1022 const LocTy name_loc = lexer_.GetLoc();
1023 if (!ParseName(&name) ||
1024 !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
1025 return false;
1026 }
1027
1028 if (is_root) {
1029 if (!root_name->empty()) {
1030 return Error(maybe_root_loc, "one computation should have only one ROOT");
1031 }
1032 *root_name = name;
1033 }
1034
1035 return ParseInstructionRhs(builder, name, name_loc);
1036 }
1037
ParseInstructionRhs(HloComputation::Builder * builder,const std::string & name,LocTy name_loc)1038 bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
1039 const std::string& name,
1040 LocTy name_loc) {
1041 Shape shape;
1042 HloOpcode opcode;
1043 std::vector<HloInstruction*> operands;
1044
1045 const bool parse_shape = CanBeShape();
1046 if ((parse_shape && !ParseShape(&shape)) || !ParseOpcode(&opcode)) {
1047 return false;
1048 }
1049 if (!parse_shape && !CanInferShape(opcode)) {
1050 return TokenError(StrFormat("cannot infer shape for opcode: %s",
1051 HloOpcodeString(opcode)));
1052 }
1053
1054 // Add optional attributes. These are added to any HloInstruction type if
1055 // present.
1056 absl::flat_hash_map<std::string, AttrConfig> attrs;
1057 optional<OpSharding> sharding;
1058 optional<FrontendAttributes> frontend_attributes;
1059 attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
1060 attrs["frontend_attributes"] = {
1061 /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
1062 optional<ParameterReplication> parameter_replication;
1063 attrs["parameter_replication"] = {/*required=*/false,
1064 AttrTy::kParameterReplication,
1065 ¶meter_replication};
1066 optional<std::vector<HloInstruction*>> predecessors;
1067 attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
1068 &predecessors};
1069 optional<OpMetadata> metadata;
1070 attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
1071
1072 optional<std::string> backend_config;
1073 attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
1074 &backend_config};
1075 optional<std::vector<int64>> outer_dimension_partitions;
1076 attrs["outer_dimension_partitions"] = {/*required=*/false,
1077 AttrTy::kBracedInt64List,
1078 &outer_dimension_partitions};
1079 const auto maybe_infer_shape =
1080 [&](const std::function<StatusOr<Shape>()>& infer, Shape* shape) {
1081 if (parse_shape) {
1082 return true;
1083 }
1084 auto inferred = infer();
1085 if (!inferred.ok()) {
1086 return TokenError(StrFormat(
1087 "failed to infer shape for opcode: %s, error: %s",
1088 HloOpcodeString(opcode), inferred.status().error_message()));
1089 }
1090 *shape = std::move(inferred).ValueOrDie();
1091 return true;
1092 };
1093 HloInstruction* instruction;
1094 switch (opcode) {
1095 case HloOpcode::kParameter: {
1096 int64_t parameter_number;
1097 if (!ParseToken(TokKind::kLparen,
1098 "expects '(' before parameter number") ||
1099 !ParseInt64(¶meter_number)) {
1100 return false;
1101 }
1102 if (parameter_number < 0) {
1103 Error(lexer_.GetLoc(), "parameter number must be >= 0");
1104 return false;
1105 }
1106 if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
1107 !ParseAttributes(attrs)) {
1108 return false;
1109 }
1110 instruction = builder->AddInstruction(
1111 HloInstruction::CreateParameter(parameter_number, shape, name));
1112 break;
1113 }
1114 case HloOpcode::kConstant: {
1115 Literal literal;
1116 if (!ParseToken(TokKind::kLparen,
1117 "expects '(' before constant literal") ||
1118 !ParseLiteral(&literal, shape) ||
1119 !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
1120 !ParseAttributes(attrs)) {
1121 return false;
1122 }
1123 instruction = builder->AddInstruction(
1124 HloInstruction::CreateConstant(std::move(literal)));
1125 break;
1126 }
1127 case HloOpcode::kIota: {
1128 optional<int64> iota_dimension;
1129 attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
1130 &iota_dimension};
1131 if (!ParseOperands(&operands, /*expected_size=*/0) ||
1132 !ParseAttributes(attrs)) {
1133 return false;
1134 }
1135 instruction = builder->AddInstruction(
1136 HloInstruction::CreateIota(shape, *iota_dimension));
1137 break;
1138 }
1139 // Unary ops.
1140 case HloOpcode::kAbs:
1141 case HloOpcode::kAllGatherDone:
1142 case HloOpcode::kAllReduceDone:
1143 case HloOpcode::kRoundNearestAfz:
1144 case HloOpcode::kBitcast:
1145 case HloOpcode::kCeil:
1146 case HloOpcode::kClz:
1147 case HloOpcode::kCollectivePermuteDone:
1148 case HloOpcode::kCopy:
1149 case HloOpcode::kCopyDone:
1150 case HloOpcode::kCos:
1151 case HloOpcode::kExp:
1152 case HloOpcode::kExpm1:
1153 case HloOpcode::kImag:
1154 case HloOpcode::kIsFinite:
1155 case HloOpcode::kFloor:
1156 case HloOpcode::kLog:
1157 case HloOpcode::kLog1p:
1158 case HloOpcode::kLogistic:
1159 case HloOpcode::kNot:
1160 case HloOpcode::kNegate:
1161 case HloOpcode::kPopulationCount:
1162 case HloOpcode::kReal:
1163 case HloOpcode::kRsqrt:
1164 case HloOpcode::kSign:
1165 case HloOpcode::kSin:
1166 case HloOpcode::kSqrt:
1167 case HloOpcode::kCbrt:
1168 case HloOpcode::kTanh: {
1169 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1170 !ParseAttributes(attrs)) {
1171 return false;
1172 }
1173 if (!maybe_infer_shape(
1174 [&] {
1175 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
1176 },
1177 &shape)) {
1178 return false;
1179 }
1180 instruction = builder->AddInstruction(
1181 HloInstruction::CreateUnary(shape, opcode, operands[0]));
1182 break;
1183 }
1184 // Binary ops.
1185 case HloOpcode::kAdd:
1186 case HloOpcode::kDivide:
1187 case HloOpcode::kMultiply:
1188 case HloOpcode::kSubtract:
1189 case HloOpcode::kAtan2:
1190 case HloOpcode::kComplex:
1191 case HloOpcode::kMaximum:
1192 case HloOpcode::kMinimum:
1193 case HloOpcode::kPower:
1194 case HloOpcode::kRemainder:
1195 case HloOpcode::kAnd:
1196 case HloOpcode::kOr:
1197 case HloOpcode::kXor:
1198 case HloOpcode::kShiftLeft:
1199 case HloOpcode::kShiftRightArithmetic:
1200 case HloOpcode::kShiftRightLogical: {
1201 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1202 !ParseAttributes(attrs)) {
1203 return false;
1204 }
1205 if (!maybe_infer_shape(
1206 [&] {
1207 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1208 operands[1]);
1209 },
1210 &shape)) {
1211 return false;
1212 }
1213 instruction = builder->AddInstruction(HloInstruction::CreateBinary(
1214 shape, opcode, operands[0], operands[1]));
1215 break;
1216 }
1217 // Ternary ops.
1218 case HloOpcode::kClamp:
1219 case HloOpcode::kSelect:
1220 case HloOpcode::kTupleSelect: {
1221 if (!ParseOperands(&operands, /*expected_size=*/3) ||
1222 !ParseAttributes(attrs)) {
1223 return false;
1224 }
1225 if (!maybe_infer_shape(
1226 [&] {
1227 return ShapeInference::InferTernaryOpShape(
1228 opcode, operands[0], operands[1], operands[2]);
1229 },
1230 &shape)) {
1231 return false;
1232 }
1233 instruction = builder->AddInstruction(HloInstruction::CreateTernary(
1234 shape, opcode, operands[0], operands[1], operands[2]));
1235 break;
1236 }
1237 // Other supported ops.
1238 case HloOpcode::kConvert: {
1239 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1240 !ParseAttributes(attrs)) {
1241 return false;
1242 }
1243 instruction = builder->AddInstruction(
1244 HloInstruction::CreateConvert(shape, operands[0]));
1245 break;
1246 }
1247 case HloOpcode::kBitcastConvert: {
1248 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1249 !ParseAttributes(attrs)) {
1250 return false;
1251 }
1252 instruction = builder->AddInstruction(
1253 HloInstruction::CreateBitcastConvert(shape, operands[0]));
1254 break;
1255 }
1256 case HloOpcode::kAllGather:
1257 case HloOpcode::kAllGatherStart: {
1258 optional<std::vector<std::vector<int64>>> tmp_groups;
1259 optional<std::vector<int64>> replica_group_ids;
1260 optional<int64> channel_id;
1261 optional<std::vector<int64>> dimensions;
1262 optional<bool> constrain_layout;
1263 optional<bool> use_global_device_ids;
1264 attrs["replica_groups"] = {/*required=*/false,
1265 AttrTy::kBracedInt64ListList, &tmp_groups};
1266 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1267 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1268 &dimensions};
1269 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1270 &constrain_layout};
1271 attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1272 &use_global_device_ids};
1273 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1274 return false;
1275 }
1276 std::vector<ReplicaGroup> replica_groups;
1277 if (tmp_groups) {
1278 replica_groups = CreateReplicaGroups(*tmp_groups);
1279 }
1280 if (opcode == HloOpcode::kAllGather) {
1281 instruction = builder->AddInstruction(HloInstruction::CreateAllGather(
1282 shape, operands, dimensions->at(0), replica_groups,
1283 constrain_layout ? *constrain_layout : false, channel_id,
1284 use_global_device_ids ? *use_global_device_ids : false));
1285 } else {
1286 instruction =
1287 builder->AddInstruction(HloInstruction::CreateAllGatherStart(
1288 shape, operands, dimensions->at(0), replica_groups,
1289 constrain_layout ? *constrain_layout : false, channel_id,
1290 use_global_device_ids ? *use_global_device_ids : false));
1291 }
1292 break;
1293 }
1294 case HloOpcode::kAllReduce:
1295 case HloOpcode::kAllReduceStart:
1296 case HloOpcode::kReduceScatter: {
1297 optional<std::vector<std::vector<int64>>> tmp_groups;
1298 optional<HloComputation*> to_apply;
1299 optional<std::vector<int64>> replica_group_ids;
1300 optional<int64> channel_id;
1301 optional<bool> constrain_layout;
1302 optional<bool> use_global_device_ids;
1303 optional<std::vector<int64>> dimensions;
1304 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1305 &to_apply};
1306 attrs["replica_groups"] = {/*required=*/false,
1307 AttrTy::kBracedInt64ListList, &tmp_groups};
1308 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1309 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1310 &constrain_layout};
1311 attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1312 &use_global_device_ids};
1313 if (opcode == HloOpcode::kReduceScatter) {
1314 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1315 &dimensions};
1316 }
1317 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1318 return false;
1319 }
1320 std::vector<ReplicaGroup> replica_groups;
1321 if (tmp_groups) {
1322 replica_groups = CreateReplicaGroups(*tmp_groups);
1323 }
1324 if (opcode == HloOpcode::kAllReduce) {
1325 instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
1326 shape, operands, *to_apply, replica_groups,
1327 constrain_layout ? *constrain_layout : false, channel_id,
1328 use_global_device_ids ? *use_global_device_ids : false));
1329 } else if (opcode == HloOpcode::kReduceScatter) {
1330 instruction =
1331 builder->AddInstruction(HloInstruction::CreateReduceScatter(
1332 shape, operands, *to_apply, replica_groups,
1333 constrain_layout ? *constrain_layout : false, channel_id,
1334 use_global_device_ids ? *use_global_device_ids : false,
1335 dimensions->at(0)));
1336 } else {
1337 instruction =
1338 builder->AddInstruction(HloInstruction::CreateAllReduceStart(
1339 shape, operands, *to_apply, replica_groups,
1340 constrain_layout ? *constrain_layout : false, channel_id,
1341 use_global_device_ids ? *use_global_device_ids : false));
1342 }
1343 break;
1344 }
1345 case HloOpcode::kAllToAll: {
1346 optional<std::vector<std::vector<int64>>> tmp_groups;
1347 attrs["replica_groups"] = {/*required=*/false,
1348 AttrTy::kBracedInt64ListList, &tmp_groups};
1349 optional<int64> channel_id;
1350 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1351 optional<std::vector<int64>> dimensions;
1352 attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1353 &dimensions};
1354 optional<bool> constrain_layout;
1355 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1356 &constrain_layout};
1357 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1358 (dimensions && dimensions->size() != 1)) {
1359 return false;
1360 }
1361 std::vector<ReplicaGroup> replica_groups;
1362 if (tmp_groups) {
1363 replica_groups = CreateReplicaGroups(*tmp_groups);
1364 }
1365 optional<int64> split_dimension;
1366 if (dimensions) {
1367 split_dimension = dimensions->at(0);
1368 }
1369 instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
1370 shape, operands, replica_groups,
1371 constrain_layout ? *constrain_layout : false, channel_id,
1372 split_dimension));
1373 break;
1374 }
1375 case HloOpcode::kCollectivePermute:
1376 case HloOpcode::kCollectivePermuteStart: {
1377 optional<std::vector<std::vector<int64>>> source_targets;
1378 attrs["source_target_pairs"] = {
1379 /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
1380 optional<int64> channel_id;
1381 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1382 optional<std::vector<std::vector<int64_t>>> slice_sizes;
1383 attrs["slice_sizes"] = {/*required=*/false, AttrTy::kBracedInt64ListList,
1384 &slice_sizes};
1385 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1386 return false;
1387 }
1388 std::vector<std::pair<int64, int64>> pairs(source_targets->size());
1389 for (int i = 0; i < pairs.size(); i++) {
1390 if ((*source_targets)[i].size() != 2) {
1391 return TokenError(
1392 "expects 'source_target_pairs=' to be a list of pairs");
1393 }
1394 pairs[i].first = (*source_targets)[i][0];
1395 pairs[i].second = (*source_targets)[i][1];
1396 }
1397 if (!slice_sizes.has_value()) {
1398 if (operands.size() != 1) {
1399 return TokenError(
1400 "CollectivePermute and CollectivePermuteStart must "
1401 "have exactly one operand (input buffer) unless "
1402 "it performs dynamic-slice and in-place update.");
1403 }
1404 if (opcode == HloOpcode::kCollectivePermute) {
1405 instruction =
1406 builder->AddInstruction(HloInstruction::CreateCollectivePermute(
1407 shape, operands[0], pairs, channel_id));
1408 } else if (opcode == HloOpcode::kCollectivePermuteStart) {
1409 instruction = builder->AddInstruction(
1410 HloInstruction::CreateCollectivePermuteStart(shape, operands[0],
1411 pairs, channel_id));
1412 } else {
1413 LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1414 "CollectivePermuteStart, but got "
1415 << HloOpcodeString(opcode);
1416 }
1417 } else {
1418 if (operands.size() != 4) {
1419 return TokenError(
1420 "CollectivePermute and CollectivePermuteStart must "
1421 "have exactly four operands for dynamic-slice and "
1422 "in-place update.");
1423 }
1424 if (opcode == HloOpcode::kCollectivePermute) {
1425 instruction =
1426 builder->AddInstruction(HloInstruction::CreateCollectivePermute(
1427 shape, operands[0], operands[1], operands[2], operands[3],
1428 pairs, *slice_sizes, channel_id));
1429 } else if (opcode == HloOpcode::kCollectivePermuteStart) {
1430 instruction = builder->AddInstruction(
1431 HloInstruction::CreateCollectivePermuteStart(
1432 shape, operands[0], operands[1], operands[2], operands[3],
1433 pairs, *slice_sizes, channel_id));
1434 } else {
1435 LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1436 "CollectivePermuteStart, but got "
1437 << HloOpcodeString(opcode);
1438 }
1439 }
1440 break;
1441 }
1442 case HloOpcode::kCopyStart: {
1443 // If the is_cross_program_prefetch attribute is not present then default
1444 // to false.
1445 optional<bool> is_cross_program_prefetch = false;
1446 attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
1447 &is_cross_program_prefetch};
1448 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1449 !ParseAttributes(attrs)) {
1450 return false;
1451 }
1452 instruction = builder->AddInstruction(HloInstruction::CreateCopyStart(
1453 shape, operands[0], *is_cross_program_prefetch));
1454 break;
1455 }
1456 case HloOpcode::kReplicaId: {
1457 if (!ParseOperands(&operands, /*expected_size=*/0) ||
1458 !ParseAttributes(attrs)) {
1459 return false;
1460 }
1461 instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
1462 break;
1463 }
1464 case HloOpcode::kPartitionId: {
1465 if (!ParseOperands(&operands, /*expected_size=*/0) ||
1466 !ParseAttributes(attrs)) {
1467 return false;
1468 }
1469 instruction =
1470 builder->AddInstruction(HloInstruction::CreatePartitionId());
1471 break;
1472 }
1473 case HloOpcode::kDynamicReshape: {
1474 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1475 return false;
1476 }
1477 instruction =
1478 builder->AddInstruction(HloInstruction::CreateDynamicReshape(
1479 shape, operands[0],
1480 absl::Span<HloInstruction* const>(operands).subspan(1)));
1481 break;
1482 }
1483 case HloOpcode::kReshape: {
1484 optional<int64> inferred_dimension;
1485 attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
1486 &inferred_dimension};
1487 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1488 !ParseAttributes(attrs)) {
1489 return false;
1490 }
1491 instruction = builder->AddInstruction(HloInstruction::CreateReshape(
1492 shape, operands[0], inferred_dimension.value_or(-1)));
1493 break;
1494 }
1495 case HloOpcode::kAfterAll: {
1496 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1497 return false;
1498 }
1499 if (operands.empty()) {
1500 instruction = builder->AddInstruction(HloInstruction::CreateToken());
1501 } else {
1502 instruction =
1503 builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
1504 }
1505 break;
1506 }
1507 case HloOpcode::kAddDependency: {
1508 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1509 !ParseAttributes(attrs)) {
1510 return false;
1511 }
1512 instruction = builder->AddInstruction(
1513 HloInstruction::CreateAddDependency(operands[0], operands[1]));
1514 break;
1515 }
1516 case HloOpcode::kSort: {
1517 optional<std::vector<int64>> dimensions;
1518 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1519 &dimensions};
1520 optional<bool> is_stable = false;
1521 attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
1522 optional<HloComputation*> to_apply;
1523 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1524 &to_apply};
1525 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1526 dimensions->size() != 1) {
1527 return false;
1528 }
1529 if (!maybe_infer_shape(
1530 [&] {
1531 absl::InlinedVector<const Shape*, 2> arg_shapes;
1532 arg_shapes.reserve(operands.size());
1533 for (auto* operand : operands) {
1534 arg_shapes.push_back(&operand->shape());
1535 }
1536 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1537 },
1538 &shape)) {
1539 return false;
1540 }
1541 instruction = builder->AddInstruction(
1542 HloInstruction::CreateSort(shape, dimensions->at(0), operands,
1543 to_apply.value(), is_stable.value()));
1544 break;
1545 }
1546 case HloOpcode::kTuple: {
1547 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1548 return false;
1549 }
1550 if (!maybe_infer_shape(
1551 [&] {
1552 absl::InlinedVector<const Shape*, 2> arg_shapes;
1553 arg_shapes.reserve(operands.size());
1554 for (auto* operand : operands) {
1555 arg_shapes.push_back(&operand->shape());
1556 }
1557 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1558 },
1559 &shape)) {
1560 return false;
1561 }
1562 // HloInstruction::CreateTuple() infers the shape of the tuple from
1563 // operands and should not be used here.
1564 instruction = builder->AddInstruction(
1565 HloInstruction::CreateVariadic(shape, HloOpcode::kTuple, operands));
1566 break;
1567 }
1568 case HloOpcode::kWhile: {
1569 optional<HloComputation*> condition;
1570 optional<HloComputation*> body;
1571 attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
1572 &condition};
1573 attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
1574 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1575 !ParseAttributes(attrs)) {
1576 return false;
1577 }
1578 if (!maybe_infer_shape(
1579 [&] {
1580 return ShapeInference::InferWhileShape(
1581 condition.value()->ComputeProgramShape(),
1582 body.value()->ComputeProgramShape(), operands[0]->shape());
1583 },
1584 &shape)) {
1585 return false;
1586 }
1587 instruction = builder->AddInstruction(HloInstruction::CreateWhile(
1588 shape, *condition, *body, /*init=*/operands[0]));
1589 break;
1590 }
1591 case HloOpcode::kRecv: {
1592 optional<int64> channel_id;
1593 // If the is_host_transfer attribute is not present then default to false.
1594 optional<bool> is_host_transfer = false;
1595 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1596 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1597 &is_host_transfer};
1598 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1599 !ParseAttributes(attrs)) {
1600 return false;
1601 }
1602 // If the is_host_transfer attribute is not present then default to false.
1603 instruction = builder->AddInstruction(HloInstruction::CreateRecv(
1604 shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
1605 break;
1606 }
1607 case HloOpcode::kRecvDone: {
1608 optional<int64> channel_id;
1609 // If the is_host_transfer attribute is not present then default to false.
1610 optional<bool> is_host_transfer = false;
1611 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1612 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1613 &is_host_transfer};
1614 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1615 !ParseAttributes(attrs)) {
1616 return false;
1617 }
1618 if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1619 return false;
1620 }
1621 if (channel_id != operands[0]->channel_id()) {
1622 return false;
1623 }
1624 instruction = builder->AddInstruction(
1625 HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
1626 break;
1627 }
1628 case HloOpcode::kSend: {
1629 optional<int64> channel_id;
1630 // If the is_host_transfer attribute is not present then default to false.
1631 optional<bool> is_host_transfer = false;
1632 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1633 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1634 &is_host_transfer};
1635 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1636 !ParseAttributes(attrs)) {
1637 return false;
1638 }
1639 instruction = builder->AddInstruction(HloInstruction::CreateSend(
1640 operands[0], operands[1], *channel_id, *is_host_transfer));
1641 break;
1642 }
1643 case HloOpcode::kSendDone: {
1644 optional<int64> channel_id;
1645 // If the is_host_transfer attribute is not present then default to false.
1646 optional<bool> is_host_transfer = false;
1647 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1648 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1649 &is_host_transfer};
1650 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1651 !ParseAttributes(attrs)) {
1652 return false;
1653 }
1654 if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1655 return false;
1656 }
1657 if (channel_id != operands[0]->channel_id()) {
1658 return false;
1659 }
1660 instruction = builder->AddInstruction(
1661 HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
1662 break;
1663 }
1664 case HloOpcode::kGetTupleElement: {
1665 optional<int64> index;
1666 attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
1667 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1668 !ParseAttributes(attrs)) {
1669 return false;
1670 }
1671 if (!maybe_infer_shape(
1672 [&] {
1673 return ShapeUtil::GetTupleElementShape(operands[0]->shape(),
1674 *index);
1675 },
1676 &shape)) {
1677 return false;
1678 }
1679 instruction = builder->AddInstruction(
1680 HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
1681 break;
1682 }
1683 case HloOpcode::kCall: {
1684 optional<HloComputation*> to_apply;
1685 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1686 &to_apply};
1687 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1688 return false;
1689 }
1690 if (!maybe_infer_shape(
1691 [&] {
1692 absl::InlinedVector<const Shape*, 2> arg_shapes;
1693 arg_shapes.reserve(operands.size());
1694 for (auto* operand : operands) {
1695 arg_shapes.push_back(&operand->shape());
1696 }
1697 return ShapeInference::InferCallShape(
1698 arg_shapes, to_apply.value()->ComputeProgramShape());
1699 },
1700 &shape)) {
1701 return false;
1702 }
1703 instruction = builder->AddInstruction(
1704 HloInstruction::CreateCall(shape, operands, *to_apply));
1705 break;
1706 }
1707 case HloOpcode::kReduceWindow: {
1708 optional<HloComputation*> reduce_computation;
1709 optional<Window> window;
1710 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1711 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1712 &reduce_computation};
1713 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1714 return false;
1715 }
1716 if (!window) {
1717 window.emplace();
1718 }
1719 if (operands.size() % 2) {
1720 auto loc = lexer_.GetLoc();
1721 return Error(loc, StrCat("expects an even number of operands, but has ",
1722 operands.size(), " operands"));
1723 }
1724 if (!maybe_infer_shape(
1725 [&] {
1726 return ShapeInference::InferReduceWindowShape(
1727 operands[0]->shape(), operands[1]->shape(), *window,
1728 reduce_computation.value()->ComputeProgramShape());
1729 },
1730 &shape)) {
1731 return false;
1732 }
1733 instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
1734 shape, /*operands=*/
1735 absl::Span<HloInstruction* const>(operands).subspan(
1736 0, operands.size() / 2),
1737 /*init_values=*/
1738 absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1739 2),
1740 *window, *reduce_computation));
1741 break;
1742 }
1743 case HloOpcode::kConvolution: {
1744 optional<Window> window;
1745 optional<ConvolutionDimensionNumbers> dnums;
1746 optional<int64> feature_group_count;
1747 optional<int64> batch_group_count;
1748 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1749 attrs["dim_labels"] = {/*required=*/true,
1750 AttrTy::kConvolutionDimensionNumbers, &dnums};
1751 attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1752 &feature_group_count};
1753 attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1754 &batch_group_count};
1755 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1756 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1757 &operand_precision};
1758 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1759 !ParseAttributes(attrs)) {
1760 return false;
1761 }
1762 if (!window) {
1763 window.emplace();
1764 }
1765 if (!feature_group_count) {
1766 feature_group_count = 1;
1767 }
1768 if (!batch_group_count) {
1769 batch_group_count = 1;
1770 }
1771 PrecisionConfig precision_config;
1772 if (operand_precision) {
1773 *precision_config.mutable_operand_precision() = {
1774 operand_precision->begin(), operand_precision->end()};
1775 } else {
1776 precision_config.mutable_operand_precision()->Resize(
1777 operands.size(), PrecisionConfig::DEFAULT);
1778 }
1779 if (!maybe_infer_shape(
1780 [&] {
1781 return ShapeInference::InferConvolveShape(
1782 operands[0]->shape(), operands[1]->shape(),
1783 *feature_group_count, *batch_group_count, *window, *dnums,
1784 /*preferred_element_type=*/absl::nullopt);
1785 },
1786 &shape)) {
1787 return false;
1788 }
1789 instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
1790 shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
1791 feature_group_count.value(), batch_group_count.value(), *window,
1792 *dnums, precision_config));
1793 break;
1794 }
1795 case HloOpcode::kFft: {
1796 optional<FftType> fft_type;
1797 optional<std::vector<int64>> fft_length;
1798 attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
1799 attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
1800 &fft_length};
1801 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1802 !ParseAttributes(attrs)) {
1803 return false;
1804 }
1805 if (!maybe_infer_shape(
1806 [&] {
1807 return ShapeInference::InferFftShape(operands[0]->shape(),
1808 *fft_type, *fft_length);
1809 },
1810 &shape)) {
1811 return false;
1812 }
1813 instruction = builder->AddInstruction(HloInstruction::CreateFft(
1814 shape, operands[0], *fft_type, *fft_length));
1815 break;
1816 }
1817 case HloOpcode::kTriangularSolve: {
1818 TriangularSolveOptions options;
1819 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1820 !ParseAttributesAsProtoMessage(
1821 /*non_proto_attrs=*/attrs, &options)) {
1822 return false;
1823 }
1824 if (!maybe_infer_shape(
1825 [&] {
1826 return ShapeInference::InferTriangularSolveShape(
1827 operands[0]->shape(), operands[1]->shape(), options);
1828 },
1829 &shape)) {
1830 return false;
1831 }
1832 instruction =
1833 builder->AddInstruction(HloInstruction::CreateTriangularSolve(
1834 shape, operands[0], operands[1], options));
1835 break;
1836 }
1837 case HloOpcode::kCompare: {
1838 optional<ComparisonDirection> direction;
1839 optional<Comparison::Type> type;
1840 attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
1841 &direction};
1842 attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
1843 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1844 !ParseAttributes(attrs)) {
1845 return false;
1846 }
1847 if (!maybe_infer_shape(
1848 [&] {
1849 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1850 operands[1]);
1851 },
1852 &shape)) {
1853 return false;
1854 }
1855 instruction = builder->AddInstruction(HloInstruction::CreateCompare(
1856 shape, operands[0], operands[1], *direction, type));
1857 break;
1858 }
1859 case HloOpcode::kCholesky: {
1860 CholeskyOptions options;
1861 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1862 !ParseAttributesAsProtoMessage(
1863 /*non_proto_attrs=*/attrs, &options)) {
1864 return false;
1865 }
1866 if (!maybe_infer_shape(
1867 [&] {
1868 return ShapeInference::InferCholeskyShape(operands[0]->shape());
1869 },
1870 &shape)) {
1871 return false;
1872 }
1873 instruction = builder->AddInstruction(
1874 HloInstruction::CreateCholesky(shape, operands[0], options));
1875 break;
1876 }
1877 case HloOpcode::kBroadcast: {
1878 optional<std::vector<int64>> broadcast_dimensions;
1879 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1880 &broadcast_dimensions};
1881 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1882 !ParseAttributes(attrs)) {
1883 return false;
1884 }
1885 if (!maybe_infer_shape(
1886 [&] {
1887 return ShapeInference::InferBroadcastShape(
1888 operands[0]->shape(), *broadcast_dimensions);
1889 },
1890 &shape)) {
1891 return false;
1892 }
1893 instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
1894 shape, operands[0], *broadcast_dimensions));
1895 break;
1896 }
1897 case HloOpcode::kConcatenate: {
1898 optional<std::vector<int64>> dimensions;
1899 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1900 &dimensions};
1901 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
1902 dimensions->size() != 1) {
1903 return false;
1904 }
1905 if (!maybe_infer_shape(
1906 [&] {
1907 absl::InlinedVector<const Shape*, 2> arg_shapes;
1908 arg_shapes.reserve(operands.size());
1909 for (auto* operand : operands) {
1910 arg_shapes.push_back(&operand->shape());
1911 }
1912 return ShapeInference::InferConcatOpShape(arg_shapes,
1913 dimensions->at(0));
1914 },
1915 &shape)) {
1916 return false;
1917 }
1918 instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
1919 shape, operands, dimensions->at(0)));
1920 break;
1921 }
1922 case HloOpcode::kMap: {
1923 optional<HloComputation*> to_apply;
1924 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1925 &to_apply};
1926 optional<std::vector<int64>> dimensions;
1927 attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1928 &dimensions};
1929 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1930 return false;
1931 }
1932 if (!maybe_infer_shape(
1933 [&] {
1934 absl::InlinedVector<const Shape*, 2> arg_shapes;
1935 arg_shapes.reserve(operands.size());
1936 for (auto* operand : operands) {
1937 arg_shapes.push_back(&operand->shape());
1938 }
1939 return ShapeInference::InferMapShape(
1940 arg_shapes, to_apply.value()->ComputeProgramShape(),
1941 *dimensions);
1942 },
1943 &shape)) {
1944 return false;
1945 }
1946 instruction = builder->AddInstruction(
1947 HloInstruction::CreateMap(shape, operands, *to_apply));
1948 break;
1949 }
1950 case HloOpcode::kReduce: {
1951 auto loc = lexer_.GetLoc();
1952
1953 optional<HloComputation*> reduce_computation;
1954 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1955 &reduce_computation};
1956 optional<std::vector<int64>> dimensions_to_reduce;
1957 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1958 &dimensions_to_reduce};
1959 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1960 return false;
1961 }
1962 if (operands.size() % 2) {
1963 return Error(loc, StrCat("expects an even number of operands, but has ",
1964 operands.size(), " operands"));
1965 }
1966 if (!maybe_infer_shape(
1967 [&] {
1968 absl::InlinedVector<const Shape*, 2> arg_shapes;
1969 arg_shapes.reserve(operands.size());
1970 for (auto* operand : operands) {
1971 arg_shapes.push_back(&operand->shape());
1972 }
1973 return ShapeInference::InferReduceShape(
1974 arg_shapes, *dimensions_to_reduce,
1975 reduce_computation.value()->ComputeProgramShape());
1976 },
1977 &shape)) {
1978 return false;
1979 }
1980 instruction = builder->AddInstruction(HloInstruction::CreateReduce(
1981 shape, /*operands=*/
1982 absl::Span<HloInstruction* const>(operands).subspan(
1983 0, operands.size() / 2),
1984 /*init_values=*/
1985 absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1986 2),
1987 *dimensions_to_reduce, *reduce_computation));
1988 break;
1989 }
1990 case HloOpcode::kReverse: {
1991 optional<std::vector<int64>> dimensions;
1992 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1993 &dimensions};
1994 if (!ParseOperands(&operands, /*expected_size=*/1) ||
1995 !ParseAttributes(attrs)) {
1996 return false;
1997 }
1998 if (!maybe_infer_shape(
1999 [&] {
2000 return ShapeInference::InferReverseShape(operands[0]->shape(),
2001 *dimensions);
2002 },
2003 &shape)) {
2004 return false;
2005 }
2006 instruction = builder->AddInstruction(
2007 HloInstruction::CreateReverse(shape, operands[0], *dimensions));
2008 break;
2009 }
2010 case HloOpcode::kSelectAndScatter: {
2011 optional<HloComputation*> select;
2012 attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
2013 optional<HloComputation*> scatter;
2014 attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
2015 optional<Window> window;
2016 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2017 if (!ParseOperands(&operands, /*expected_size=*/3) ||
2018 !ParseAttributes(attrs)) {
2019 return false;
2020 }
2021 if (!window) {
2022 window.emplace();
2023 }
2024 if (!maybe_infer_shape(
2025 [&] {
2026 return ShapeInference::InferSelectAndScatterShape(
2027 operands[0]->shape(), select.value()->ComputeProgramShape(),
2028 *window, operands[1]->shape(), operands[2]->shape(),
2029 scatter.value()->ComputeProgramShape());
2030 },
2031 &shape)) {
2032 return false;
2033 }
2034 instruction =
2035 builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
2036 shape, /*operand=*/operands[0], *select, *window,
2037 /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
2038 break;
2039 }
2040 case HloOpcode::kSlice: {
2041 optional<SliceRanges> slice_ranges;
2042 attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
2043 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2044 !ParseAttributes(attrs)) {
2045 return false;
2046 }
2047 instruction = builder->AddInstruction(HloInstruction::CreateSlice(
2048 shape, operands[0], slice_ranges->starts, slice_ranges->limits,
2049 slice_ranges->strides));
2050 break;
2051 }
2052 case HloOpcode::kDynamicSlice: {
2053 optional<std::vector<int64>> dynamic_slice_sizes;
2054 attrs["dynamic_slice_sizes"] = {
2055 /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
2056 LocTy loc = lexer_.GetLoc();
2057 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2058 return false;
2059 }
2060 if (operands.empty()) {
2061 return Error(loc, "Expected at least one operand.");
2062 }
2063 if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
2064 operands.size() != 1 + operands[0]->shape().rank()) {
2065 return Error(loc, "Wrong number of operands.");
2066 }
2067 instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
2068 shape, /*operand=*/operands[0],
2069 /*start_indices=*/absl::MakeSpan(operands).subspan(1),
2070 *dynamic_slice_sizes));
2071 break;
2072 }
2073 case HloOpcode::kDynamicUpdateSlice: {
2074 LocTy loc = lexer_.GetLoc();
2075 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2076 return false;
2077 }
2078 if (operands.size() < 2) {
2079 return Error(loc, "Expected at least two operands.");
2080 }
2081 if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
2082 operands.size() != 2 + operands[0]->shape().rank()) {
2083 return Error(loc, "Wrong number of operands.");
2084 }
2085 instruction =
2086 builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2087 shape, /*operand=*/operands[0], /*update=*/operands[1],
2088 /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
2089 break;
2090 }
2091 case HloOpcode::kTranspose: {
2092 optional<std::vector<int64>> dimensions;
2093 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2094 &dimensions};
2095 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2096 !ParseAttributes(attrs)) {
2097 return false;
2098 }
2099 if (!maybe_infer_shape(
2100 [&] {
2101 return ShapeInference::InferTransposeShape(operands[0]->shape(),
2102 *dimensions);
2103 },
2104 &shape)) {
2105 return false;
2106 }
2107 instruction = builder->AddInstruction(
2108 HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
2109 break;
2110 }
2111 case HloOpcode::kBatchNormTraining: {
2112 optional<float> epsilon;
2113 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2114 optional<int64> feature_index;
2115 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2116 &feature_index};
2117 if (!ParseOperands(&operands, /*expected_size=*/3) ||
2118 !ParseAttributes(attrs)) {
2119 return false;
2120 }
2121 if (!maybe_infer_shape(
2122 [&] {
2123 return ShapeInference::InferBatchNormTrainingShape(
2124 operands[0]->shape(), operands[1]->shape(),
2125 operands[2]->shape(), *feature_index);
2126 },
2127 &shape)) {
2128 return false;
2129 }
2130 instruction =
2131 builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
2132 shape, /*operand=*/operands[0], /*scale=*/operands[1],
2133 /*offset=*/operands[2], *epsilon, *feature_index));
2134 break;
2135 }
2136 case HloOpcode::kBatchNormInference: {
2137 optional<float> epsilon;
2138 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2139 optional<int64> feature_index;
2140 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2141 &feature_index};
2142 if (!ParseOperands(&operands, /*expected_size=*/5) ||
2143 !ParseAttributes(attrs)) {
2144 return false;
2145 }
2146 if (!maybe_infer_shape(
2147 [&] {
2148 return ShapeInference::InferBatchNormInferenceShape(
2149 operands[0]->shape(), operands[1]->shape(),
2150 operands[2]->shape(), operands[3]->shape(),
2151 operands[4]->shape(), *feature_index);
2152 },
2153 &shape)) {
2154 return false;
2155 }
2156 instruction =
2157 builder->AddInstruction(HloInstruction::CreateBatchNormInference(
2158 shape, /*operand=*/operands[0], /*scale=*/operands[1],
2159 /*offset=*/operands[2], /*mean=*/operands[3],
2160 /*variance=*/operands[4], *epsilon, *feature_index));
2161 break;
2162 }
2163 case HloOpcode::kBatchNormGrad: {
2164 optional<float> epsilon;
2165 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2166 optional<int64> feature_index;
2167 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2168 &feature_index};
2169 if (!ParseOperands(&operands, /*expected_size=*/5) ||
2170 !ParseAttributes(attrs)) {
2171 return false;
2172 }
2173 if (!maybe_infer_shape(
2174 [&] {
2175 return ShapeInference::InferBatchNormGradShape(
2176 operands[0]->shape(), operands[1]->shape(),
2177 operands[2]->shape(), operands[3]->shape(),
2178 operands[4]->shape(), *feature_index);
2179 },
2180 &shape)) {
2181 return false;
2182 }
2183 instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
2184 shape, /*operand=*/operands[0], /*scale=*/operands[1],
2185 /*mean=*/operands[2], /*variance=*/operands[3],
2186 /*grad_output=*/operands[4], *epsilon, *feature_index));
2187 break;
2188 }
2189 case HloOpcode::kPad: {
2190 optional<PaddingConfig> padding;
2191 attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
2192 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2193 !ParseAttributes(attrs)) {
2194 return false;
2195 }
2196 if (!maybe_infer_shape(
2197 [&] {
2198 return ShapeInference::InferPadShape(
2199 operands[0]->shape(), operands[1]->shape(), *padding);
2200 },
2201 &shape)) {
2202 return false;
2203 }
2204 instruction = builder->AddInstruction(HloInstruction::CreatePad(
2205 shape, operands[0], /*padding_value=*/operands[1], *padding));
2206 break;
2207 }
2208 case HloOpcode::kFusion: {
2209 optional<HloComputation*> fusion_computation;
2210 attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
2211 &fusion_computation};
2212 optional<HloInstruction::FusionKind> fusion_kind;
2213 attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
2214 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2215 return false;
2216 }
2217 instruction = builder->AddInstruction(HloInstruction::CreateFusion(
2218 shape, *fusion_kind, operands, *fusion_computation));
2219 break;
2220 }
2221 case HloOpcode::kInfeed: {
2222 optional<std::string> config;
2223 attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2224 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2225 !ParseAttributes(attrs)) {
2226 return false;
2227 }
2228 // We need to know the infeed data shape to construct the infeed
2229 // instruction. This is the zero-th element of the tuple-shaped output of
2230 // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
2231 // if the shape is not a non-empty tuple, so add guard so an error message
2232 // can be emitted instead of a check fail
2233 if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
2234 return Error(lexer_.GetLoc(),
2235 "infeed must have a non-empty tuple shape");
2236 }
2237 instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
2238 ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
2239 config ? *config : ""));
2240 break;
2241 }
2242 case HloOpcode::kOutfeed: {
2243 optional<std::string> config;
2244 optional<Shape> outfeed_shape;
2245 attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2246 attrs["outfeed_shape"] = {/*required=*/false, AttrTy::kShape,
2247 &outfeed_shape};
2248 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2249 !ParseAttributes(attrs)) {
2250 return false;
2251 }
2252 HloInstruction* const outfeed_input = operands[0];
2253 HloInstruction* const outfeed_token = operands[1];
2254 const Shape shape =
2255 outfeed_shape.has_value() ? *outfeed_shape : outfeed_input->shape();
2256 instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
2257 shape, outfeed_input, outfeed_token, config ? *config : ""));
2258 break;
2259 }
2260 case HloOpcode::kRng: {
2261 optional<RandomDistribution> distribution;
2262 attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
2263 &distribution};
2264 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2265 return false;
2266 }
2267 instruction = builder->AddInstruction(
2268 HloInstruction::CreateRng(shape, *distribution, operands));
2269 break;
2270 }
2271 case HloOpcode::kRngGetAndUpdateState: {
2272 optional<int64> delta;
2273 attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
2274 if (!ParseOperands(&operands, /*expected_size=*/0) ||
2275 !ParseAttributes(attrs)) {
2276 return false;
2277 }
2278 instruction = builder->AddInstruction(
2279 HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
2280 break;
2281 }
2282 case HloOpcode::kRngBitGenerator: {
2283 optional<RandomAlgorithm> algorithm;
2284 attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
2285 &algorithm};
2286 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2287 return false;
2288 }
2289 instruction =
2290 builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
2291 shape, operands[0], *algorithm));
2292 break;
2293 }
2294 case HloOpcode::kReducePrecision: {
2295 optional<int64> exponent_bits;
2296 optional<int64> mantissa_bits;
2297 attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
2298 &exponent_bits};
2299 attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
2300 &mantissa_bits};
2301 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2302 !ParseAttributes(attrs)) {
2303 return false;
2304 }
2305 instruction =
2306 builder->AddInstruction(HloInstruction::CreateReducePrecision(
2307 shape, operands[0], static_cast<int>(*exponent_bits),
2308 static_cast<int>(*mantissa_bits)));
2309 break;
2310 }
2311 case HloOpcode::kConditional: {
2312 optional<HloComputation*> true_computation;
2313 optional<HloComputation*> false_computation;
2314 optional<std::vector<HloComputation*>> branch_computations;
2315 if (!ParseOperands(&operands)) {
2316 return false;
2317 }
2318 if (!ShapeUtil::IsScalar(operands[0]->shape())) {
2319 return Error(lexer_.GetLoc(), "The first operand must be a scalar");
2320 }
2321 const bool branch_index_is_bool =
2322 operands[0]->shape().element_type() == PRED;
2323 if (branch_index_is_bool) {
2324 attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
2325 &true_computation};
2326 attrs["false_computation"] = {
2327 /*required=*/true, AttrTy::kHloComputation, &false_computation};
2328 } else {
2329 if (operands[0]->shape().element_type() != S32) {
2330 return Error(lexer_.GetLoc(),
2331 "The first operand must be a scalar of PRED or S32");
2332 }
2333 attrs["branch_computations"] = {/*required=*/true,
2334 AttrTy::kBracedHloComputationList,
2335 &branch_computations};
2336 }
2337 if (!ParseAttributes(attrs)) {
2338 return false;
2339 }
2340 if (branch_index_is_bool) {
2341 branch_computations.emplace({*true_computation, *false_computation});
2342 }
2343 if (branch_computations->empty() ||
2344 operands.size() != branch_computations->size() + 1) {
2345 return false;
2346 }
2347 if (!maybe_infer_shape(
2348 [&] {
2349 absl::InlinedVector<ProgramShape, 2> branch_computation_shapes;
2350 branch_computation_shapes.reserve(branch_computations->size());
2351 for (auto* computation : *branch_computations) {
2352 branch_computation_shapes.push_back(
2353 computation->ComputeProgramShape());
2354 }
2355 absl::InlinedVector<Shape, 2> branch_operand_shapes;
2356 branch_operand_shapes.reserve(operands.size() - 1);
2357 for (int i = 1; i < operands.size(); ++i) {
2358 branch_operand_shapes.push_back(operands[i]->shape());
2359 }
2360 return ShapeInference::InferConditionalShape(
2361 operands[0]->shape(), branch_computation_shapes,
2362 branch_operand_shapes);
2363 },
2364 &shape)) {
2365 return false;
2366 }
2367 instruction = builder->AddInstruction(HloInstruction::CreateConditional(
2368 shape, /*branch_index=*/operands[0],
2369 absl::MakeSpan(*branch_computations),
2370 absl::MakeSpan(operands).subspan(1)));
2371 break;
2372 }
2373 case HloOpcode::kCustomCall: {
2374 optional<std::string> custom_call_target;
2375 optional<Window> window;
2376 optional<ConvolutionDimensionNumbers> dnums;
2377 optional<int64> feature_group_count;
2378 optional<int64> batch_group_count;
2379 optional<std::vector<Shape>> operand_layout_constraints;
2380 optional<bool> custom_call_has_side_effect;
2381 optional<HloComputation*> to_apply;
2382 optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
2383 output_to_operand_aliasing;
2384 optional<PaddingType> padding_type;
2385 optional<std::vector<HloComputation*>> called_computations;
2386 optional<CustomCallSchedule> custom_call_schedule;
2387 optional<CustomCallApiVersion> api_version;
2388 attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
2389 &custom_call_target};
2390 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2391 attrs["dim_labels"] = {/*required=*/false,
2392 AttrTy::kConvolutionDimensionNumbers, &dnums};
2393 attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
2394 &feature_group_count};
2395 attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
2396 &batch_group_count};
2397 attrs["operand_layout_constraints"] = {
2398 /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
2399 attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
2400 &custom_call_has_side_effect};
2401 attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
2402 &to_apply};
2403 attrs["called_computations"] = {/*required=*/false,
2404 AttrTy::kBracedHloComputationList,
2405 &called_computations};
2406 attrs["output_to_operand_aliasing"] = {/*required=*/false,
2407 AttrTy::kInstructionAliasing,
2408 &output_to_operand_aliasing};
2409
2410 attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
2411 &padding_type};
2412
2413 optional<Literal> literal;
2414 attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
2415 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2416 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2417 &operand_precision};
2418 if (called_computations.has_value() && to_apply.has_value()) {
2419 return Error(lexer_.GetLoc(),
2420 "A single instruction can't have both to_apply and "
2421 "calls field");
2422 }
2423 attrs["schedule"] = {/*required=*/false, AttrTy::kCustomCallSchedule,
2424 &custom_call_schedule};
2425 attrs["api_version"] = {/*required=*/false, AttrTy::kCustomCallApiVersion,
2426 &api_version};
2427 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
2428 return false;
2429 }
2430
2431 if (api_version.has_value() &&
2432 *api_version == CustomCallApiVersion::API_VERSION_UNSPECIFIED) {
2433 return Error(lexer_.GetLoc(),
2434 StrCat("Invalid API version: ",
2435 CustomCallApiVersion_Name(*api_version)));
2436 }
2437 if (operand_layout_constraints.has_value()) {
2438 if (!LayoutUtil::HasLayout(shape)) {
2439 return Error(lexer_.GetLoc(),
2440 "Layout must be set on layout-constrained custom call");
2441 }
2442 if (operands.size() != operand_layout_constraints->size()) {
2443 return Error(lexer_.GetLoc(),
2444 StrCat("Expected ", operands.size(),
2445 " operand layout constraints, ",
2446 operand_layout_constraints->size(), " given"));
2447 }
2448 for (int64_t i = 0; i < operands.size(); ++i) {
2449 const Shape& operand_shape_with_layout =
2450 (*operand_layout_constraints)[i];
2451 if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
2452 return Error(lexer_.GetLoc(),
2453 StrCat("Operand layout constraint shape ",
2454 ShapeUtil::HumanStringWithLayout(
2455 operand_shape_with_layout),
2456 " for operand ", i, " does not have a layout"));
2457 }
2458 if (!ShapeUtil::Compatible(operand_shape_with_layout,
2459 operands[i]->shape())) {
2460 return Error(
2461 lexer_.GetLoc(),
2462 StrCat(
2463 "Operand layout constraint shape ",
2464 ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
2465 " for operand ", i,
2466 " is not compatible with operand shape ",
2467 ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
2468 }
2469 }
2470 instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
2471 shape, operands, *custom_call_target, *operand_layout_constraints,
2472 backend_config ? *backend_config : ""));
2473 } else {
2474 if (to_apply.has_value()) {
2475 instruction =
2476 builder->AddInstruction(HloInstruction::CreateCustomCall(
2477 shape, operands, *to_apply, *custom_call_target,
2478 backend_config ? *backend_config : ""));
2479 } else if (called_computations.has_value()) {
2480 instruction =
2481 builder->AddInstruction(HloInstruction::CreateCustomCall(
2482 shape, operands, *called_computations, *custom_call_target,
2483 backend_config ? *backend_config : ""));
2484 } else {
2485 instruction =
2486 builder->AddInstruction(HloInstruction::CreateCustomCall(
2487 shape, operands, *custom_call_target,
2488 backend_config ? *backend_config : ""));
2489 }
2490 }
2491 auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
2492 if (window.has_value()) {
2493 custom_call_instr->set_window(*window);
2494 }
2495 if (dnums.has_value()) {
2496 custom_call_instr->set_convolution_dimension_numbers(*dnums);
2497 }
2498 if (feature_group_count.has_value()) {
2499 custom_call_instr->set_feature_group_count(*feature_group_count);
2500 }
2501 if (batch_group_count.has_value()) {
2502 custom_call_instr->set_batch_group_count(*batch_group_count);
2503 }
2504 if (padding_type.has_value()) {
2505 custom_call_instr->set_padding_type(*padding_type);
2506 }
2507 if (custom_call_has_side_effect.has_value()) {
2508 custom_call_instr->set_custom_call_has_side_effect(
2509 *custom_call_has_side_effect);
2510 }
2511 if (custom_call_schedule.has_value()) {
2512 custom_call_instr->set_custom_call_schedule(*custom_call_schedule);
2513 }
2514 if (api_version.has_value()) {
2515 custom_call_instr->set_api_version(*api_version);
2516 }
2517 if (output_to_operand_aliasing.has_value()) {
2518 custom_call_instr->set_output_to_operand_aliasing(
2519 std::move(*output_to_operand_aliasing));
2520 }
2521 if (literal.has_value()) {
2522 custom_call_instr->set_literal(std::move(*literal));
2523 }
2524 PrecisionConfig precision_config;
2525 if (operand_precision) {
2526 *precision_config.mutable_operand_precision() = {
2527 operand_precision->begin(), operand_precision->end()};
2528 } else {
2529 precision_config.mutable_operand_precision()->Resize(
2530 operands.size(), PrecisionConfig::DEFAULT);
2531 }
2532 *custom_call_instr->mutable_precision_config() = precision_config;
2533 break;
2534 }
2535 case HloOpcode::kDot: {
2536 optional<std::vector<int64>> lhs_contracting_dims;
2537 attrs["lhs_contracting_dims"] = {
2538 /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
2539 optional<std::vector<int64>> rhs_contracting_dims;
2540 attrs["rhs_contracting_dims"] = {
2541 /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
2542 optional<std::vector<int64>> lhs_batch_dims;
2543 attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2544 &lhs_batch_dims};
2545 optional<std::vector<int64>> rhs_batch_dims;
2546 attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2547 &rhs_batch_dims};
2548 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2549 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2550 &operand_precision};
2551
2552 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2553 !ParseAttributes(attrs)) {
2554 return false;
2555 }
2556
2557 DotDimensionNumbers dnum;
2558 if (lhs_contracting_dims) {
2559 *dnum.mutable_lhs_contracting_dimensions() = {
2560 lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
2561 }
2562 if (rhs_contracting_dims) {
2563 *dnum.mutable_rhs_contracting_dimensions() = {
2564 rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
2565 }
2566 if (lhs_batch_dims) {
2567 *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
2568 lhs_batch_dims->end()};
2569 }
2570 if (rhs_batch_dims) {
2571 *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
2572 rhs_batch_dims->end()};
2573 }
2574
2575 PrecisionConfig precision_config;
2576 if (operand_precision) {
2577 *precision_config.mutable_operand_precision() = {
2578 operand_precision->begin(), operand_precision->end()};
2579 } else {
2580 precision_config.mutable_operand_precision()->Resize(
2581 operands.size(), PrecisionConfig::DEFAULT);
2582 }
2583 if (!maybe_infer_shape(
2584 [&] {
2585 return ShapeInference::InferDotOpShape(
2586 operands[0]->shape(), operands[1]->shape(), dnum,
2587 /*preferred_element_type=*/absl::nullopt);
2588 },
2589 &shape)) {
2590 return false;
2591 }
2592 instruction = builder->AddInstruction(HloInstruction::CreateDot(
2593 shape, operands[0], operands[1], dnum, precision_config));
2594 break;
2595 }
2596 case HloOpcode::kGather: {
2597 optional<std::vector<int64>> offset_dims;
2598 attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
2599 &offset_dims};
2600 optional<std::vector<int64>> collapsed_slice_dims;
2601 attrs["collapsed_slice_dims"] = {
2602 /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
2603 optional<std::vector<int64>> start_index_map;
2604 attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
2605 &start_index_map};
2606 optional<int64> index_vector_dim;
2607 attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2608 &index_vector_dim};
2609 optional<std::vector<int64>> slice_sizes;
2610 attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
2611 &slice_sizes};
2612 optional<bool> indices_are_sorted = false;
2613 attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2614 &indices_are_sorted};
2615
2616 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2617 !ParseAttributes(attrs)) {
2618 return false;
2619 }
2620
2621 GatherDimensionNumbers dim_numbers =
2622 HloGatherInstruction::MakeGatherDimNumbers(
2623 /*offset_dims=*/*offset_dims,
2624 /*collapsed_slice_dims=*/*collapsed_slice_dims,
2625 /*start_index_map=*/*start_index_map,
2626 /*index_vector_dim=*/*index_vector_dim);
2627 if (!maybe_infer_shape(
2628 [&] {
2629 return ShapeInference::InferGatherShape(
2630 operands[0]->shape(), operands[1]->shape(), dim_numbers,
2631 *slice_sizes);
2632 },
2633 &shape)) {
2634 return false;
2635 }
2636 instruction = builder->AddInstruction(HloInstruction::CreateGather(
2637 shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
2638 dim_numbers, *slice_sizes, indices_are_sorted.value()));
2639 break;
2640 }
2641 case HloOpcode::kScatter: {
2642 optional<std::vector<int64>> update_window_dims;
2643 attrs["update_window_dims"] = {
2644 /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
2645 optional<std::vector<int64>> inserted_window_dims;
2646 attrs["inserted_window_dims"] = {
2647 /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
2648 optional<std::vector<int64>> scatter_dims_to_operand_dims;
2649 attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
2650 AttrTy::kBracedInt64List,
2651 &scatter_dims_to_operand_dims};
2652 optional<int64> index_vector_dim;
2653 attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2654 &index_vector_dim};
2655
2656 optional<HloComputation*> update_computation;
2657 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2658 &update_computation};
2659 optional<bool> indices_are_sorted = false;
2660 attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2661 &indices_are_sorted};
2662 optional<bool> unique_indices = false;
2663 attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
2664 &unique_indices};
2665
2666 if (!ParseOperands(&operands, /*expected_size=*/3) ||
2667 !ParseAttributes(attrs)) {
2668 return false;
2669 }
2670
2671 ScatterDimensionNumbers dim_numbers =
2672 HloScatterInstruction::MakeScatterDimNumbers(
2673 /*update_window_dims=*/*update_window_dims,
2674 /*inserted_window_dims=*/*inserted_window_dims,
2675 /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
2676 /*index_vector_dim=*/*index_vector_dim);
2677
2678 if (!maybe_infer_shape(
2679 [&] {
2680 return ShapeInference::InferScatterShape(
2681 operands[0]->shape(), operands[1]->shape(),
2682 operands[2]->shape(),
2683 update_computation.value()->ComputeProgramShape(),
2684 dim_numbers);
2685 },
2686 &shape)) {
2687 return false;
2688 }
2689 instruction = builder->AddInstruction(HloInstruction::CreateScatter(
2690 shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
2691 /*updates=*/operands[2], *update_computation, dim_numbers,
2692 indices_are_sorted.value(), unique_indices.value()));
2693 break;
2694 }
2695 case HloOpcode::kDomain: {
2696 DomainData domain;
2697 attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
2698 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2699 !ParseAttributes(attrs)) {
2700 return false;
2701 }
2702 if (!maybe_infer_shape(
2703 [&] {
2704 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
2705 },
2706 &shape)) {
2707 return false;
2708 }
2709 instruction = builder->AddInstruction(HloInstruction::CreateDomain(
2710 shape, operands[0], std::move(domain.exit_metadata),
2711 std::move(domain.entry_metadata)));
2712 break;
2713 }
2714 case HloOpcode::kTrace:
2715 return TokenError(StrCat("parsing not yet implemented for op: ",
2716 HloOpcodeString(opcode)));
2717 case HloOpcode::kGetDimensionSize: {
2718 optional<std::vector<int64>> dimensions;
2719 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2720 &dimensions};
2721 if (!ParseOperands(&operands, /*expected_size=*/1) ||
2722 !ParseAttributes(attrs)) {
2723 return false;
2724 }
2725 if (!maybe_infer_shape(
2726 [&] {
2727 return ShapeInference::InferGetDimensionSizeShape(
2728 operands[0]->shape(), dimensions->at(0));
2729 },
2730 &shape)) {
2731 return false;
2732 }
2733 instruction =
2734 builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
2735 shape, operands[0], (*dimensions)[0]));
2736 break;
2737 }
2738 case HloOpcode::kSetDimensionSize: {
2739 optional<std::vector<int64>> dimensions;
2740 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2741 &dimensions};
2742 if (!ParseOperands(&operands, /*expected_size=*/2) ||
2743 !ParseAttributes(attrs)) {
2744 return false;
2745 }
2746 if (!maybe_infer_shape(
2747 [&] {
2748 return ShapeInference::InferSetDimensionSizeShape(
2749 operands[0]->shape(), operands[1]->shape(),
2750 dimensions->at(0));
2751 },
2752 &shape)) {
2753 return false;
2754 }
2755 instruction =
2756 builder->AddInstruction(HloInstruction::CreateSetDimensionSize(
2757 shape, operands[0], operands[1], (*dimensions)[0]));
2758 break;
2759 }
2760 }
2761
2762 instruction->SetAndSanitizeName(name);
2763 if (instruction->name() != name) {
2764 return Error(name_loc,
2765 StrCat("illegal instruction name: ", name,
2766 "; suggest renaming to: ", instruction->name()));
2767 }
2768
2769 // Add shared attributes like metadata to the instruction, if they were seen.
2770 if (sharding) {
2771 instruction->set_sharding(
2772 HloSharding::FromProto(sharding.value()).ValueOrDie());
2773 }
2774 if (parameter_replication) {
2775 int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
2776 const auto& replicated =
2777 parameter_replication->replicated_at_leaf_buffers();
2778 if (leaf_count != replicated.size()) {
2779 return Error(lexer_.GetLoc(),
2780 StrCat("parameter has ", leaf_count,
2781 " leaf buffers, but parameter_replication has ",
2782 replicated.size(), " elements."));
2783 }
2784 instruction->set_parameter_replicated_at_leaf_buffers(replicated);
2785 }
2786 if (predecessors) {
2787 for (auto* pre : *predecessors) {
2788 Status status = pre->AddControlDependencyTo(instruction);
2789 if (!status.ok()) {
2790 return Error(name_loc, StrCat("error adding control dependency for: ",
2791 name, " status: ", status.ToString()));
2792 }
2793 }
2794 }
2795 if (metadata) {
2796 instruction->set_metadata(*metadata);
2797 }
2798 if (backend_config) {
2799 instruction->set_raw_backend_config_string(std::move(*backend_config));
2800 }
2801 if (outer_dimension_partitions) {
2802 instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
2803 }
2804 if (frontend_attributes) {
2805 instruction->set_frontend_attributes(*frontend_attributes);
2806 }
2807 return AddInstruction(name, instruction, name_loc);
2808 } // NOLINT(readability/fn_size)
2809
2810 // ::= '{' (single_sharding | tuple_sharding) '}'
2811 //
2812 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)2813 bool HloParserImpl::ParseSharding(OpSharding* sharding) {
2814 // A single sharding starts with '{' and is not followed by '{'.
2815 // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
2816 // an empty tuple.
2817 if (!ParseToken(TokKind::kLbrace,
2818 "expected '{' to start sharding attribute")) {
2819 return false;
2820 }
2821
2822 if (lexer_.GetKind() != TokKind::kLbrace &&
2823 lexer_.GetKind() != TokKind::kRbrace) {
2824 return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
2825 }
2826
2827 // Tuple sharding.
2828 // Allow empty tuple shardings.
2829 if (lexer_.GetKind() != TokKind::kRbrace) {
2830 do {
2831 if (!ParseSingleSharding(sharding->add_tuple_shardings(),
2832 /*lbrace_pre_lexed=*/false)) {
2833 return false;
2834 }
2835 } while (EatIfPresent(TokKind::kComma));
2836 }
2837 sharding->set_type(OpSharding::TUPLE);
2838
2839 return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
2840 }
2841
2842 // frontend_attributes ::= '{' attributes '}'
2843 // attributes
2844 // ::= /*empty*/
2845 // ::= attribute '=' value (',' attribute '=' value)*
ParseFrontendAttributes(FrontendAttributes * frontend_attributes)2846 bool HloParserImpl::ParseFrontendAttributes(
2847 FrontendAttributes* frontend_attributes) {
2848 CHECK(frontend_attributes != nullptr);
2849 if (!ParseToken(TokKind::kLbrace,
2850 "expected '{' to start frontend attributes")) {
2851 return false;
2852 }
2853 if (lexer_.GetKind() == TokKind::kRbrace) {
2854 // empty
2855 } else {
2856 do {
2857 std::string attribute;
2858 if (!ParseAttributeName(&attribute)) {
2859 return false;
2860 }
2861 if (lexer_.GetKind() != TokKind::kString) {
2862 return false;
2863 }
2864 (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
2865 lexer_.Lex();
2866 } while (EatIfPresent(TokKind::kComma));
2867 }
2868 return ParseToken(TokKind::kRbrace,
2869 "expects '}' at the end of frontend attributes");
2870 }
2871
2872 // ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
2873 // ('devices=' ('[' dims ']')* device_list)?
2874 // ('metadata=' metadata)* '}'
2875 //
2876 // dims ::= int_list device_list ::= int_list
2877 // metadata ::= single_metadata |
2878 // ('{' [single_metadata (',' single_metadata)*] '}')
2879 // last_tile_dims ::= sharding_type_list
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)2880 bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
2881 bool lbrace_pre_lexed) {
2882 if (!lbrace_pre_lexed &&
2883 !ParseToken(TokKind::kLbrace,
2884 "expected '{' to start sharding attribute")) {
2885 return false;
2886 }
2887
2888 LocTy loc = lexer_.GetLoc();
2889 bool maximal = false;
2890 bool replicated = false;
2891 bool manual = false;
2892 bool last_tile_dim_replicate = false;
2893 bool last_tile_dims = false;
2894 std::vector<int64> devices;
2895 std::vector<int64> tile_assignment_dimensions;
2896 std::vector<OpSharding::Type> sharding_types;
2897 while (lexer_.GetKind() != TokKind::kRbrace) {
2898 switch (lexer_.GetKind()) {
2899 case TokKind::kw_maximal:
2900 maximal = true;
2901 lexer_.Lex();
2902 break;
2903 case TokKind::kw_replicated:
2904 replicated = true;
2905 lexer_.Lex();
2906 break;
2907 case TokKind::kw_manual:
2908 manual = true;
2909 lexer_.Lex();
2910 break;
2911 case TokKind::kAttributeName: {
2912 if (lexer_.GetStrVal() == "device") {
2913 if (lexer_.Lex() != TokKind::kInt) {
2914 return TokenError("device= attribute must be an integer");
2915 }
2916 devices = {lexer_.GetInt64Val()};
2917 lexer_.Lex();
2918 } else if (lexer_.GetStrVal() == "devices") {
2919 lexer_.Lex();
2920 if (!ParseToken(TokKind::kLsquare,
2921 "expected '[' to start sharding devices shape")) {
2922 return false;
2923 }
2924
2925 do {
2926 int64_t dim;
2927 if (!ParseInt64(&dim)) {
2928 return false;
2929 }
2930 tile_assignment_dimensions.push_back(dim);
2931 } while (EatIfPresent(TokKind::kComma));
2932
2933 if (!ParseToken(TokKind::kRsquare,
2934 "expected ']' to start sharding devices shape")) {
2935 return false;
2936 }
2937 do {
2938 int64_t device;
2939 if (!ParseInt64(&device)) {
2940 return false;
2941 }
2942 devices.push_back(device);
2943 } while (EatIfPresent(TokKind::kComma));
2944 } else if (lexer_.GetStrVal() == "metadata") {
2945 lexer_.Lex();
2946 if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) {
2947 return false;
2948 }
2949 } else if (lexer_.GetStrVal() == "last_tile_dims") {
2950 last_tile_dims = true;
2951 lexer_.Lex();
2952 if (!ParseListShardingType(&sharding_types)) {
2953 return false;
2954 }
2955 } else {
2956 return TokenError(
2957 "unknown attribute in sharding: expected device=, devices= "
2958 "metadata= or last_tile_dims= ");
2959 }
2960 break;
2961 }
2962 case TokKind::kw_last_tile_dim_replicate:
2963 last_tile_dim_replicate = true;
2964 lexer_.Lex();
2965 break;
2966 case TokKind::kRbrace:
2967 break;
2968 default:
2969 return TokenError("unexpected token");
2970 }
2971 }
2972
2973 if (replicated) {
2974 if (!devices.empty()) {
2975 return Error(loc,
2976 "replicated shardings should not have any devices assigned");
2977 }
2978 sharding->set_type(OpSharding::REPLICATED);
2979 } else if (maximal) {
2980 if (devices.size() != 1) {
2981 return Error(loc,
2982 "maximal shardings should have exactly one device assigned");
2983 }
2984 sharding->set_type(OpSharding::MAXIMAL);
2985 sharding->add_tile_assignment_devices(devices[0]);
2986 } else if (manual) {
2987 if (!devices.empty()) {
2988 return Error(loc,
2989 "manual shardings should not have any devices assigned");
2990 }
2991 sharding->set_type(OpSharding::MANUAL);
2992 } else {
2993 if (devices.size() <= 1) {
2994 return Error(
2995 loc, "non-maximal shardings must have more than one device assigned");
2996 }
2997 if (tile_assignment_dimensions.empty()) {
2998 return Error(
2999 loc,
3000 "non-maximal shardings must have a tile assignment list including "
3001 "dimensions");
3002 }
3003 sharding->set_type(OpSharding::OTHER);
3004 for (int64_t dim : tile_assignment_dimensions) {
3005 sharding->add_tile_assignment_dimensions(dim);
3006 }
3007 for (int64_t device : devices) {
3008 sharding->add_tile_assignment_devices(device);
3009 }
3010
3011 if (last_tile_dims) {
3012 for (OpSharding::Type type : sharding_types) {
3013 sharding->add_last_tile_dims(type);
3014 }
3015 } else {
3016 sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate);
3017 }
3018 }
3019
3020 lexer_.Lex();
3021 return true;
3022 }
3023
3024 // parameter_replication ::=
3025 // '{' ('true' | 'false')* (',' ('true' | 'false'))* '}'
ParseParameterReplication(ParameterReplication * parameter_replication)3026 bool HloParserImpl::ParseParameterReplication(
3027 ParameterReplication* parameter_replication) {
3028 if (!ParseToken(TokKind::kLbrace,
3029 "expected '{' to start parameter_replication attribute")) {
3030 return false;
3031 }
3032
3033 if (lexer_.GetKind() != TokKind::kRbrace) {
3034 do {
3035 if (lexer_.GetKind() == TokKind::kw_true) {
3036 parameter_replication->add_replicated_at_leaf_buffers(true);
3037 } else if (lexer_.GetKind() == TokKind::kw_false) {
3038 parameter_replication->add_replicated_at_leaf_buffers(false);
3039 } else {
3040 return false;
3041 }
3042 lexer_.Lex();
3043 } while (EatIfPresent(TokKind::kComma));
3044 }
3045
3046 return ParseToken(TokKind::kRbrace,
3047 "expected '}' to end parameter_replication attribute");
3048 }
3049
3050 // replica_groups ::='{' int64list_elements '}'
3051 // int64list_elements
3052 // ::= /*empty*/
3053 // ::= int64list (',' int64list)*
3054 // int64list ::= '{' int64_elements '}'
3055 // int64_elements
3056 // ::= /*empty*/
3057 // ::= int64_val (',' int64_val)*
ParseReplicaGroupsOnly(std::vector<ReplicaGroup> * replica_groups)3058 bool HloParserImpl::ParseReplicaGroupsOnly(
3059 std::vector<ReplicaGroup>* replica_groups) {
3060 std::vector<std::vector<int64>> result;
3061 if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3062 &result)) {
3063 return false;
3064 }
3065 *replica_groups = CreateReplicaGroups(result);
3066 return true;
3067 }
3068
3069 // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
3070 // 'exit=' exit_sharding '}'
ParseDomain(DomainData * domain)3071 bool HloParserImpl::ParseDomain(DomainData* domain) {
3072 absl::flat_hash_map<std::string, AttrConfig> attrs;
3073 optional<std::string> kind;
3074 optional<OpSharding> entry_sharding;
3075 optional<OpSharding> exit_sharding;
3076 attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
3077 attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
3078 attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
3079 if (!ParseSubAttributes(attrs)) {
3080 return false;
3081 }
3082 if (*kind == ShardingMetadata::KindName()) {
3083 auto entry_sharding_ptr = absl::make_unique<HloSharding>(
3084 HloSharding::FromProto(*entry_sharding).ValueOrDie());
3085 auto exit_sharding_ptr = absl::make_unique<HloSharding>(
3086 HloSharding::FromProto(*exit_sharding).ValueOrDie());
3087 domain->entry_metadata =
3088 absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
3089 domain->exit_metadata =
3090 absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
3091 } else {
3092 return TokenError(StrCat("unsupported domain kind: ", *kind));
3093 }
3094 return true;
3095 }
3096
3097 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)3098 bool HloParserImpl::ParseInstructionNames(
3099 std::vector<HloInstruction*>* instructions) {
3100 if (!ParseToken(TokKind::kLbrace,
3101 "expects '{' at the beginning of instruction name list")) {
3102 return false;
3103 }
3104 LocTy loc = lexer_.GetLoc();
3105 do {
3106 std::string name;
3107 if (!ParseName(&name)) {
3108 return Error(loc, "expects a instruction name");
3109 }
3110 std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
3111 if (!instr) {
3112 return TokenError(StrFormat("instruction '%s' is not defined", name));
3113 }
3114 instructions->push_back(instr->first);
3115 } while (EatIfPresent(TokKind::kComma));
3116
3117 return ParseToken(TokKind::kRbrace,
3118 "expects '}' at the end of instruction name list");
3119 }
3120
SetValueInLiteral(LocTy loc,int64_t value,int64_t index,Literal * literal)3121 bool HloParserImpl::SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
3122 Literal* literal) {
3123 const Shape& shape = literal->shape();
3124 switch (shape.element_type()) {
3125 case S8:
3126 return SetValueInLiteralHelper<int8>(loc, value, index, literal);
3127 case S16:
3128 return SetValueInLiteralHelper<int16>(loc, value, index, literal);
3129 case S32:
3130 return SetValueInLiteralHelper<int32>(loc, value, index, literal);
3131 case S64:
3132 return SetValueInLiteralHelper<int64>(loc, value, index, literal);
3133 case U8:
3134 return SetValueInLiteralHelper<tensorflow::uint8>(loc, value, index,
3135 literal);
3136 case U16:
3137 return SetValueInLiteralHelper<tensorflow::uint16>(loc, value, index,
3138 literal);
3139 case U32:
3140 return SetValueInLiteralHelper<tensorflow::uint32>(loc, value, index,
3141 literal);
3142 case U64:
3143 return SetValueInLiteralHelper<tensorflow::uint64>(loc, value, index,
3144 literal);
3145 case PRED:
3146 // Bool type literals with rank >= 1 are printed in 0s and 1s.
3147 return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
3148 literal);
3149 default:
3150 LOG(FATAL) << "unknown integral primitive type "
3151 << PrimitiveType_Name(shape.element_type());
3152 }
3153 }
3154
SetValueInLiteral(LocTy loc,double value,int64_t index,Literal * literal)3155 bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64_t index,
3156 Literal* literal) {
3157 const Shape& shape = literal->shape();
3158 switch (shape.element_type()) {
3159 case F16:
3160 return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
3161 case BF16:
3162 return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
3163 literal);
3164 case F32:
3165 return SetValueInLiteralHelper<float>(loc, value, index, literal);
3166 case F64:
3167 return SetValueInLiteralHelper<double>(loc, value, index, literal);
3168 default:
3169 LOG(FATAL) << "unknown floating point primitive type "
3170 << PrimitiveType_Name(shape.element_type());
3171 }
3172 }
3173
SetValueInLiteral(LocTy loc,bool value,int64_t index,Literal * literal)3174 bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64_t index,
3175 Literal* literal) {
3176 const Shape& shape = literal->shape();
3177 switch (shape.element_type()) {
3178 case PRED:
3179 return SetValueInLiteralHelper<bool>(loc, value, index, literal);
3180 default:
3181 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3182 << " is not PRED type";
3183 }
3184 }
3185
SetValueInLiteral(LocTy loc,std::complex<double> value,int64_t index,Literal * literal)3186 bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
3187 int64_t index, Literal* literal) {
3188 const Shape& shape = literal->shape();
3189 switch (shape.element_type()) {
3190 case C64:
3191 return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
3192 literal);
3193 case C128:
3194 return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
3195 literal);
3196 default:
3197 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3198 << " is not a complex type";
3199 }
3200 }
3201
3202 template <typename T>
StringifyValue(T val)3203 std::string StringifyValue(T val) {
3204 return StrCat(val);
3205 }
3206 template <>
StringifyValue(std::complex<double> val)3207 std::string StringifyValue(std::complex<double> val) {
3208 return StrFormat("(%f, %f)", std::real(val), std::imag(val));
3209 }
3210
3211 // Evaluates to V when T == U.
3212 template <typename T, typename U, typename V>
3213 using EnableIfSameWithType = std::enable_if_t<std::is_same<T, U>::value, V>;
3214
3215 template <class T, EnableIfSameWithType<T, bool, bool> = false>
GetNanPayload(T val)3216 uint64_t GetNanPayload(T val) {
3217 return 0;
3218 }
3219
3220 template <class T, EnableIfSameWithType<T, int64_t, bool> = false>
GetNanPayload(T val)3221 uint64_t GetNanPayload(T val) {
3222 return 0;
3223 }
3224
3225 template <class T, EnableIfSameWithType<T, double, bool> = false>
GetNanPayload(T val)3226 uint64_t GetNanPayload(T val) {
3227 auto rep = absl::bit_cast<uint64_t>(val);
3228 if (auto payload = rep & NanPayloadBitMask<double>()) {
3229 return payload;
3230 }
3231 return QuietNanWithoutPayload<double>();
3232 }
3233
3234 template <typename LiteralNativeT, typename LiteralComponentT>
3235 EnableIfSameWithType<LiteralNativeT, LiteralComponentT, LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real,LiteralComponentT imag)3236 LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
3237 return real;
3238 }
3239
3240 template <typename LiteralNativeT, typename LiteralComponentT>
3241 EnableIfSameWithType<LiteralNativeT, std::complex<LiteralComponentT>,
3242 LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real,LiteralComponentT imag)3243 LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
3244 return LiteralNativeT(real, imag);
3245 }
3246
3247 template <typename T>
3248 struct ComponentType {
3249 using Type = T;
3250 };
3251
3252 template <typename T>
3253 struct ComponentType<std::complex<T>> {
3254 using Type = T;
3255 };
3256
3257 template <typename T>
GetReal(T value)3258 T GetReal(T value) {
3259 return value;
3260 }
3261
3262 template <typename T>
GetReal(std::complex<T> value)3263 T GetReal(std::complex<T> value) {
3264 return value.real();
3265 }
3266
3267 template <typename T>
GetImag(T value)3268 T GetImag(T value) {
3269 return 0;
3270 }
3271
3272 template <typename T>
GetImag(std::complex<T> value)3273 T GetImag(std::complex<T> value) {
3274 return value.imag();
3275 }
3276
3277 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(LocTy loc,ParsedElemT value,int64_t index,Literal * literal)3278 bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
3279 int64_t index, Literal* literal) {
3280 if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
3281 return false;
3282 }
3283
3284 // Check that the index is in range and assign into the literal
3285 if (index >= ShapeUtil::ElementsIn(literal->shape())) {
3286 return Error(loc, StrCat("tries to set value ", StringifyValue(value),
3287 " to a literal in shape ",
3288 ShapeUtil::HumanString(literal->shape()),
3289 " at linear index ", index,
3290 ", but the index is out of range"));
3291 }
3292 using ParsedElemComponentT = typename ComponentType<ParsedElemT>::Type;
3293 using LiteralNativeComponentT = typename ComponentType<LiteralNativeT>::Type;
3294 const auto handle_nan = [this, literal, index, loc](
3295 ParsedElemComponentT parsed_value_component,
3296 LiteralNativeComponentT*
3297 literal_value_component) {
3298 if (!std::isnan(static_cast<double>(parsed_value_component))) {
3299 return true;
3300 }
3301 auto nan_payload = GetNanPayload(parsed_value_component);
3302 if (nan_payload == QuietNanWithoutPayload<double>()) {
3303 nan_payload = QuietNanWithoutPayload<LiteralNativeComponentT>();
3304 }
3305 const auto kLargestPayload = NanPayloadBitMask<LiteralNativeComponentT>();
3306 if (nan_payload > kLargestPayload) {
3307 return Error(
3308 loc,
3309 StrCat("tries to set NaN payload 0x", absl::Hex(nan_payload),
3310 " to a literal in shape ",
3311 ShapeUtil::HumanString(literal->shape()), " at linear index ",
3312 index, ", but the NaN payload is out of range (0x",
3313 absl::Hex(kLargestPayload), ")"));
3314 }
3315 *literal_value_component = NanWithSignAndPayload<LiteralNativeComponentT>(
3316 /*sign=*/std::signbit(static_cast<double>(parsed_value_component)),
3317 /*nan_payload=*/nan_payload);
3318 return true;
3319 };
3320 const ParsedElemComponentT parsed_real_value = GetReal(value);
3321 auto literal_real_value =
3322 static_cast<LiteralNativeComponentT>(parsed_real_value);
3323 if (std::is_floating_point<ParsedElemT>::value ||
3324 std::is_same<ParsedElemT, std::complex<double>>::value) {
3325 if (!handle_nan(parsed_real_value, &literal_real_value)) {
3326 return false;
3327 }
3328 }
3329 const ParsedElemComponentT parsed_imag_value = GetImag(value);
3330 auto literal_imag_value =
3331 static_cast<LiteralNativeComponentT>(parsed_imag_value);
3332 if (std::is_same<ParsedElemT, std::complex<double>>::value) {
3333 if (!handle_nan(parsed_real_value, &literal_imag_value)) {
3334 return false;
3335 }
3336 }
3337 literal->data<LiteralNativeT>().at(index) =
3338 LiteralNativeFromRealImag<LiteralNativeT>(literal_real_value,
3339 literal_imag_value);
3340 return true;
3341 }
3342
3343 // Similar to ParseLiteral(Literal* literal, const Shape& shape), but parse the
3344 // shape instead of accepting one as argument.
ParseLiteral(Literal * literal)3345 bool HloParserImpl::ParseLiteral(Literal* literal) {
3346 if (lexer_.GetKind() == TokKind::kLparen) {
3347 // Consume Lparen
3348 lexer_.Lex();
3349 std::vector<Literal> elements;
3350 while (lexer_.GetKind() != TokKind::kRparen) {
3351 Literal element;
3352 if (!ParseLiteral(&element)) {
3353 return TokenError("Fails when parsing tuple element");
3354 }
3355 elements.emplace_back(std::move(element));
3356 if (lexer_.GetKind() != TokKind::kRparen) {
3357 ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3358 }
3359 }
3360
3361 *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3362 // Consume Rparen
3363 return ParseToken(TokKind::kRparen, "expects ')' to close a tuple literal");
3364 }
3365 Shape literal_shape;
3366 if (!ParseShape(&literal_shape)) {
3367 return false;
3368 }
3369 return ParseLiteral(literal, literal_shape);
3370 }
3371
3372 // literal
3373 // ::= tuple
3374 // ::= non_tuple
ParseLiteral(Literal * literal,const Shape & shape)3375 bool HloParserImpl::ParseLiteral(Literal* literal, const Shape& shape) {
3376 return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
3377 : ParseNonTupleLiteral(literal, shape);
3378 }
3379
3380 // tuple
3381 // ::= shape '(' literal_list ')'
3382 // literal_list
3383 // ::= /*empty*/
3384 // ::= literal (',' literal)*
ParseTupleLiteral(Literal * literal,const Shape & shape)3385 bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
3386 if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
3387 return false;
3388 }
3389 std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
3390
3391 if (lexer_.GetKind() == TokKind::kRparen) {
3392 // empty
3393 } else {
3394 // literal, (',' literal)*
3395 for (int i = 0; i < elements.size(); i++) {
3396 if (i > 0) {
3397 ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3398 }
3399 if (!ParseLiteral(&elements[i],
3400 ShapeUtil::GetTupleElementShape(shape, i))) {
3401 return TokenError(StrCat("expects the ", i, "th element"));
3402 }
3403 }
3404 }
3405 *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3406 return ParseToken(TokKind::kRparen,
3407 StrCat("expects ')' at the end of the tuple with ",
3408 ShapeUtil::TupleElementCount(shape), "elements"));
3409 }
3410
3411 // non_tuple
3412 // ::= rank01
3413 // ::= rank2345
3414 // rank2345 ::= shape nested_array
ParseNonTupleLiteral(Literal * literal,const Shape & shape)3415 bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
3416 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
3417 return ParseDenseLiteral(literal, shape);
3418 }
3419
ParseDenseLiteral(Literal * literal,const Shape & shape)3420 bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
3421 // Cast `rank` to int because we call shape.dimensions(int rank) below, and if
3422 // `rank` is an int64, that's an implicit narrowing conversion, which is
3423 // implementation-defined behavior.
3424 const int rank = static_cast<int>(shape.rank());
3425
3426 // Create a literal with the given shape in default layout.
3427 *literal = LiteralUtil::CreateFromDimensions(
3428 shape.element_type(), AsInt64Slice(shape.dimensions()));
3429 int64_t nest_level = 0;
3430 int64_t linear_index = 0;
3431 // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
3432 // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
3433 // when we are parsing the 2nd '{' (right before '1'), we are seeing a
3434 // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
3435 // the first '}' (right after '3'), it means the sub-array ends, and the
3436 // sub-array is supposed to contain exactly 3 elements, so check if
3437 // elems_seen_per_dim[1] is 3.
3438 std::vector<int64> elems_seen_per_dim(rank);
3439 auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string {
3440 std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
3441 elems_seen_per_dim.begin() + dim);
3442 return StrCat("[",
3443 StrJoin(elems_seen_until_dim, ",",
3444 [](std::string* out, const int64_t num_elems) {
3445 StrAppend(out, num_elems - 1);
3446 }),
3447 "]");
3448 };
3449
3450 auto add_one_elem_seen = [&] {
3451 if (rank > 0) {
3452 if (nest_level != rank) {
3453 return TokenError(absl::StrFormat(
3454 "expects nested array in rank %d, but sees %d", rank, nest_level));
3455 }
3456 elems_seen_per_dim[rank - 1]++;
3457 if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
3458 return TokenError(absl::StrFormat(
3459 "expects %d elements on the minor-most dimension, but "
3460 "sees more",
3461 shape.dimensions(rank - 1)));
3462 }
3463 }
3464 return true;
3465 };
3466
3467 do {
3468 switch (lexer_.GetKind()) {
3469 default:
3470 return TokenError("unexpected token type in a literal");
3471 case TokKind::kLbrace: {
3472 nest_level++;
3473 if (nest_level > rank) {
3474 return TokenError(absl::StrFormat(
3475 "expects nested array in rank %d, but sees larger", rank));
3476 }
3477 if (nest_level > 1) {
3478 elems_seen_per_dim[nest_level - 2]++;
3479 if (elems_seen_per_dim[nest_level - 2] >
3480 shape.dimensions(nest_level - 2)) {
3481 return TokenError(absl::StrFormat(
3482 "expects %d elements in the %sth element, but sees more",
3483 shape.dimensions(nest_level - 2),
3484 get_index_str(nest_level - 2)));
3485 }
3486 }
3487 lexer_.Lex();
3488 break;
3489 }
3490 case TokKind::kRbrace: {
3491 nest_level--;
3492 if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
3493 return TokenError(absl::StrFormat(
3494 "expects %d elements in the %sth element, but sees %d",
3495 shape.dimensions(nest_level), get_index_str(nest_level),
3496 elems_seen_per_dim[nest_level]));
3497 }
3498 elems_seen_per_dim[nest_level] = 0;
3499 lexer_.Lex();
3500 break;
3501 }
3502 case TokKind::kLparen: {
3503 if (!primitive_util::IsComplexType(shape.element_type())) {
3504 return TokenError(
3505 absl::StrFormat("unexpected '(' in literal. Parens are only "
3506 "valid for complex literals"));
3507 }
3508
3509 std::complex<double> value;
3510 LocTy loc = lexer_.GetLoc();
3511 if (!add_one_elem_seen() || !ParseComplex(&value) ||
3512 !SetValueInLiteral(loc, value, linear_index++, literal)) {
3513 return false;
3514 }
3515 break;
3516 }
3517 case TokKind::kDots: {
3518 if (nest_level != 1) {
3519 return TokenError(absl::StrFormat(
3520 "expects `...` at nest level 1, but sees it at nest level %d",
3521 nest_level));
3522 }
3523 elems_seen_per_dim[0] = shape.dimensions(0);
3524 lexer_.Lex();
3525 // Fill data with deterministic (garbage) values. Use static to avoid
3526 // creating identical constants which could potentially got CSE'ed
3527 // away. This is a best-effort approach to make sure replaying a HLO
3528 // gives us same optimized HLO graph.
3529 static uint32 data = 0;
3530 uint32* raw_data = static_cast<uint32*>(literal->untyped_data());
3531 for (int64_t i = 0; i < literal->size_bytes() / 4; ++i) {
3532 raw_data[i] = data++;
3533 }
3534 uint8* raw_data_int8 = static_cast<uint8*>(literal->untyped_data());
3535 static uint8 data_int8 = 0;
3536 for (int64_t i = 0; i < literal->size_bytes() % 4; ++i) {
3537 raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++;
3538 }
3539 break;
3540 }
3541 case TokKind::kComma:
3542 // Skip.
3543 lexer_.Lex();
3544 break;
3545 case TokKind::kw_true:
3546 case TokKind::kw_false:
3547 case TokKind::kInt:
3548 case TokKind::kDecimal:
3549 case TokKind::kw_inf:
3550 case TokKind::kNegInf: {
3551 add_one_elem_seen();
3552 if (lexer_.GetKind() == TokKind::kw_true ||
3553 lexer_.GetKind() == TokKind::kw_false) {
3554 if (!SetValueInLiteral(lexer_.GetLoc(),
3555 lexer_.GetKind() == TokKind::kw_true,
3556 linear_index++, literal)) {
3557 return false;
3558 }
3559 lexer_.Lex();
3560 } else if (primitive_util::IsIntegralType(shape.element_type()) ||
3561 shape.element_type() == PRED) {
3562 LocTy loc = lexer_.GetLoc();
3563 int64_t value;
3564 if (!ParseInt64(&value)) {
3565 return Error(loc, StrCat("expects integer for primitive type: ",
3566 PrimitiveType_Name(shape.element_type())));
3567 }
3568 if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3569 return false;
3570 }
3571 } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
3572 LocTy loc = lexer_.GetLoc();
3573 double value;
3574 if (!ParseDouble(&value)) {
3575 return Error(
3576 loc, StrCat("expect floating point value for primitive type: ",
3577 PrimitiveType_Name(shape.element_type())));
3578 }
3579 if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3580 return false;
3581 }
3582 } else {
3583 return TokenError(StrCat("unsupported primitive type ",
3584 PrimitiveType_Name(shape.element_type())));
3585 }
3586 break;
3587 }
3588 } // end of switch
3589 } while (nest_level > 0);
3590
3591 *literal = literal->Relayout(shape.layout());
3592 return true;
3593 }
3594
3595 // MaxFiniteValue is a type-traits helper used by
3596 // HloParserImpl::CheckParsedValueIsInRange.
3597 template <typename T>
3598 struct MinMaxFiniteValue {
maxxla::__anone3ea97620111::MinMaxFiniteValue3599 static T max() { return std::numeric_limits<T>::max(); }
minxla::__anone3ea97620111::MinMaxFiniteValue3600 static T min() { return std::numeric_limits<T>::lowest(); }
3601 };
3602
3603 template <>
3604 struct MinMaxFiniteValue<Eigen::half> {
maxxla::__anone3ea97620111::MinMaxFiniteValue3605 static double max() {
3606 // Sadly this is not constexpr, so this forces `value` to be a method.
3607 return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
3608 }
minxla::__anone3ea97620111::MinMaxFiniteValue3609 static double min() { return -max(); }
3610 };
3611
3612 template <>
3613 struct MinMaxFiniteValue<bfloat16> {
maxxla::__anone3ea97620111::MinMaxFiniteValue3614 static double max() {
3615 return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
3616 }
minxla::__anone3ea97620111::MinMaxFiniteValue3617 static double min() { return -max(); }
3618 };
3619
3620 // MSVC's standard C++ library does not define isnan/isfinite for integer types.
3621 // To work around that we will need to provide our own.
3622 template <typename T>
IsFinite(T val)3623 std::enable_if_t<std::is_floating_point<T>::value, bool> IsFinite(T val) {
3624 return std::isfinite(val);
3625 }
3626 template <typename T>
IsNaN(T val)3627 std::enable_if_t<std::is_floating_point<T>::value, bool> IsNaN(T val) {
3628 return std::isnan(val);
3629 }
3630 template <typename T>
IsFinite(T val)3631 std::enable_if_t<std::is_integral<T>::value, bool> IsFinite(T val) {
3632 return std::isfinite(static_cast<double>(val));
3633 }
3634 template <typename T>
IsNaN(T val)3635 std::enable_if_t<std::is_integral<T>::value, bool> IsNaN(T val) {
3636 return std::isnan(static_cast<double>(val));
3637 }
3638
3639 template <typename LiteralNativeT, typename ParsedElemT>
CheckParsedValueIsInRange(LocTy loc,ParsedElemT value)3640 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
3641 if (std::is_floating_point<ParsedElemT>::value) {
3642 auto value_as_native_t = static_cast<LiteralNativeT>(value);
3643 auto value_double_converted = static_cast<ParsedElemT>(value_as_native_t);
3644 if (!IsFinite(value) || IsFinite(value_double_converted)) {
3645 value = value_double_converted;
3646 }
3647 }
3648 PrimitiveType literal_ty =
3649 primitive_util::NativeToPrimitiveType<LiteralNativeT>();
3650 if (IsNaN(value) ||
3651 (std::numeric_limits<ParsedElemT>::has_infinity &&
3652 (std::numeric_limits<ParsedElemT>::infinity() == value ||
3653 -std::numeric_limits<ParsedElemT>::infinity() == value))) {
3654 // Skip range checking for non-finite value.
3655 } else if (std::is_unsigned<LiteralNativeT>::value) {
3656 CHECK((std::is_same<ParsedElemT, int64>::value ||
3657 std::is_same<ParsedElemT, bool>::value))
3658 << "Unimplemented checking for ParsedElemT";
3659
3660 const uint64 unsigned_value = value;
3661 const uint64 upper_bound =
3662 static_cast<uint64>(std::numeric_limits<LiteralNativeT>::max());
3663 if (unsigned_value > upper_bound) {
3664 // Value is out of range for LiteralNativeT.
3665 return Error(loc, StrCat("value ", value,
3666 " is out of range for literal's primitive type ",
3667 PrimitiveType_Name(literal_ty), " namely [0, ",
3668 upper_bound, "]."));
3669 }
3670 } else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
3671 value < MinMaxFiniteValue<LiteralNativeT>::min()) {
3672 // Value is out of range for LiteralNativeT.
3673 return Error(loc, StrCat("value ", value,
3674 " is out of range for literal's primitive type ",
3675 PrimitiveType_Name(literal_ty), " namely [",
3676 MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
3677 MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
3678 }
3679 return true;
3680 }
3681
3682 template <typename LiteralNativeT>
CheckParsedValueIsInRange(LocTy loc,std::complex<double> value)3683 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc,
3684 std::complex<double> value) {
3685 // e.g. `float` for std::complex<float>
3686 using LiteralComplexComponentT =
3687 decltype(std::real(std::declval<LiteralNativeT>()));
3688
3689 // We could do simply
3690 //
3691 // return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
3692 // CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
3693 //
3694 // but this would give bad error messages on failure.
3695
3696 auto check_component = [&](absl::string_view name, double v) {
3697 if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
3698 v == -std::numeric_limits<double>::infinity()) {
3699 // Skip range-checking for non-finite values.
3700 return true;
3701 }
3702
3703 double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
3704 double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
3705 if (v < min || v > max) {
3706 // Value is out of range for LitearlComplexComponentT.
3707 return Error(
3708 loc,
3709 StrCat(name, " part ", v,
3710 " is out of range for literal's primitive type ",
3711 PrimitiveType_Name(
3712 primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
3713 ", namely [", min, ", ", max, "]."));
3714 }
3715 return true;
3716 };
3717 return check_component("real", std::real(value)) &&
3718 check_component("imaginary", std::imag(value));
3719 }
3720
3721 // operands ::= '(' operands1 ')'
3722 // operands1
3723 // ::= /*empty*/
3724 // ::= operand (, operand)*
3725 // operand ::= (shape)? name
ParseOperands(std::vector<HloInstruction * > * operands)3726 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands) {
3727 CHECK(operands != nullptr);
3728 if (!ParseToken(TokKind::kLparen,
3729 "expects '(' at the beginning of operands")) {
3730 return false;
3731 }
3732 if (lexer_.GetKind() == TokKind::kRparen) {
3733 // empty
3734 } else {
3735 do {
3736 LocTy loc = lexer_.GetLoc();
3737 std::string name;
3738 optional<Shape> shape;
3739 if (CanBeShape()) {
3740 shape.emplace();
3741 if (!ParseShape(&shape.value())) {
3742 return false;
3743 }
3744 }
3745 if (!ParseName(&name)) {
3746 // When parsing a single instruction (as opposed to a whole module), an
3747 // HLO may have one or more operands with a shape but no name:
3748 //
3749 // foo = add(f32[10], f32[10])
3750 //
3751 // create_missing_instruction_ is always non-null when parsing a single
3752 // instruction, and is responsible for creating kParameter instructions
3753 // for these operands.
3754 if (shape.has_value() && create_missing_instruction_ != nullptr &&
3755 scoped_name_tables_.size() == 1) {
3756 name = "";
3757 } else {
3758 return false;
3759 }
3760 }
3761 std::pair<HloInstruction*, LocTy>* instruction =
3762 FindInstruction(name, shape);
3763 if (instruction == nullptr) {
3764 return Error(loc, StrCat("instruction does not exist: ", name));
3765 }
3766 operands->push_back(instruction->first);
3767 } while (EatIfPresent(TokKind::kComma));
3768 }
3769 return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
3770 }
3771
ParseOperands(std::vector<HloInstruction * > * operands,const int expected_size)3772 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
3773 const int expected_size) {
3774 CHECK(operands != nullptr);
3775 LocTy loc = lexer_.GetLoc();
3776 if (!ParseOperands(operands)) {
3777 return false;
3778 }
3779 if (expected_size != operands->size()) {
3780 return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
3781 operands->size(), " operands"));
3782 }
3783 return true;
3784 }
3785
3786 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)3787 bool HloParserImpl::ParseSubAttributes(
3788 const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
3789 LocTy loc = lexer_.GetLoc();
3790 if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
3791 return false;
3792 }
3793 absl::flat_hash_set<std::string> seen_attrs;
3794 if (lexer_.GetKind() == TokKind::kRbrace) {
3795 // empty
3796 } else {
3797 do {
3798 EatIfPresent(TokKind::kComma);
3799 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
3800 return false;
3801 }
3802 } while (lexer_.GetKind() != TokKind::kRbrace);
3803 }
3804 // Check that all required attrs were seen.
3805 for (const auto& attr_it : attrs) {
3806 if (attr_it.second.required &&
3807 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
3808 return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
3809 attr_it.first));
3810 }
3811 }
3812 return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
3813 }
3814
3815 // attributes ::= (',' attribute)*
ParseAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)3816 bool HloParserImpl::ParseAttributes(
3817 const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
3818 LocTy loc = lexer_.GetLoc();
3819 absl::flat_hash_set<std::string> seen_attrs;
3820 while (EatIfPresent(TokKind::kComma)) {
3821 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
3822 return false;
3823 }
3824 }
3825 // Check that all required attrs were seen.
3826 for (const auto& attr_it : attrs) {
3827 if (attr_it.second.required &&
3828 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
3829 return Error(loc, StrFormat("attribute %s is expected but not seen",
3830 attr_it.first));
3831 }
3832 }
3833 return true;
3834 }
3835
ParseAttributeHelper(const absl::flat_hash_map<std::string,AttrConfig> & attrs,absl::flat_hash_set<std::string> * seen_attrs)3836 bool HloParserImpl::ParseAttributeHelper(
3837 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
3838 absl::flat_hash_set<std::string>* seen_attrs) {
3839 LocTy loc = lexer_.GetLoc();
3840 std::string name;
3841 if (!ParseAttributeName(&name)) {
3842 return Error(loc, "error parsing attributes");
3843 }
3844 VLOG(3) << "Parsing attribute " << name;
3845 if (!seen_attrs->insert(name).second) {
3846 return Error(loc, StrFormat("attribute %s already exists", name));
3847 }
3848 auto attr_it = attrs.find(name);
3849 if (attr_it == attrs.end()) {
3850 std::string allowed_attrs;
3851 if (attrs.empty()) {
3852 allowed_attrs = "No attributes are allowed here.";
3853 } else {
3854 allowed_attrs =
3855 StrCat("Allowed attributes: ",
3856 StrJoin(attrs, ", ",
3857 [&](std::string* out,
3858 const std::pair<std::string, AttrConfig>& kv) {
3859 StrAppend(out, kv.first);
3860 }));
3861 }
3862 return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
3863 allowed_attrs));
3864 }
3865 AttrTy attr_type = attr_it->second.attr_type;
3866 void* attr_out_ptr = attr_it->second.result;
3867 bool success = [&] {
3868 LocTy attr_loc = lexer_.GetLoc();
3869 switch (attr_type) {
3870 case AttrTy::kBool: {
3871 bool result;
3872 if (!ParseBool(&result)) {
3873 return false;
3874 }
3875 static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
3876 return true;
3877 }
3878 case AttrTy::kInt64: {
3879 int64_t result;
3880 if (!ParseInt64(&result)) {
3881 return false;
3882 }
3883 static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
3884 return true;
3885 }
3886 case AttrTy::kInt32: {
3887 int64_t result;
3888 if (!ParseInt64(&result)) {
3889 return false;
3890 }
3891 if (result != static_cast<int32>(result)) {
3892 return Error(attr_loc, "value out of range for int32");
3893 }
3894 static_cast<optional<int32>*>(attr_out_ptr)
3895 ->emplace(static_cast<int32>(result));
3896 return true;
3897 }
3898 case AttrTy::kFloat: {
3899 double result;
3900 if (!ParseDouble(&result)) {
3901 return false;
3902 }
3903 if (result > std::numeric_limits<float>::max() ||
3904 result < std::numeric_limits<float>::lowest()) {
3905 return Error(attr_loc, "value out of range for float");
3906 }
3907 static_cast<optional<float>*>(attr_out_ptr)
3908 ->emplace(static_cast<float>(result));
3909 return true;
3910 }
3911 case AttrTy::kHloComputation: {
3912 HloComputation* result = nullptr;
3913 if (!ParseHloComputation(&result)) {
3914 return false;
3915 }
3916 static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
3917 return true;
3918 }
3919 case AttrTy::kBracedHloComputationList: {
3920 std::vector<HloComputation*> result;
3921 if (!ParseHloComputationList(&result)) {
3922 return false;
3923 }
3924 static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
3925 ->emplace(result);
3926 return true;
3927 }
3928 case AttrTy::kFftType: {
3929 FftType result;
3930 if (!ParseFftType(&result)) {
3931 return false;
3932 }
3933 static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
3934 return true;
3935 }
3936 case AttrTy::kPaddingType: {
3937 PaddingType result;
3938 if (!ParsePaddingType(&result)) {
3939 return false;
3940 }
3941 static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
3942 return true;
3943 }
3944 case AttrTy::kComparisonDirection: {
3945 ComparisonDirection result;
3946 if (!ParseComparisonDirection(&result)) {
3947 return false;
3948 }
3949 static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
3950 ->emplace(result);
3951 return true;
3952 }
3953 case AttrTy::kComparisonType: {
3954 Comparison::Type result;
3955 if (!ParseComparisonType(&result)) {
3956 return false;
3957 }
3958 static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
3959 return true;
3960 }
3961 case AttrTy::kEnum: {
3962 if (lexer_.GetKind() != TokKind::kIdent) {
3963 return TokenError("expects an enumeration value");
3964 }
3965 std::string result = lexer_.GetStrVal();
3966 lexer_.Lex();
3967 static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
3968 return true;
3969 }
3970 case AttrTy::kWindow: {
3971 Window result;
3972 if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
3973 return false;
3974 }
3975 static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
3976 return true;
3977 }
3978 case AttrTy::kConvolutionDimensionNumbers: {
3979 ConvolutionDimensionNumbers result;
3980 if (!ParseConvolutionDimensionNumbers(&result)) {
3981 return false;
3982 }
3983 static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
3984 ->emplace(result);
3985 return true;
3986 }
3987 case AttrTy::kSharding: {
3988 OpSharding sharding;
3989 if (!ParseSharding(&sharding)) {
3990 return false;
3991 }
3992 static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
3993 return true;
3994 }
3995 case AttrTy::kFrontendAttributes: {
3996 FrontendAttributes frontend_attributes;
3997 if (!ParseFrontendAttributes(&frontend_attributes)) {
3998 return false;
3999 }
4000 static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
4001 ->emplace(frontend_attributes);
4002 return true;
4003 }
4004 case AttrTy::kParameterReplication: {
4005 ParameterReplication parameter_replication;
4006 if (!ParseParameterReplication(¶meter_replication)) {
4007 return false;
4008 }
4009 static_cast<optional<ParameterReplication>*>(attr_out_ptr)
4010 ->emplace(parameter_replication);
4011 return true;
4012 }
4013 case AttrTy::kInstructionList: {
4014 std::vector<HloInstruction*> result;
4015 if (!ParseInstructionNames(&result)) {
4016 return false;
4017 }
4018 static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
4019 ->emplace(result);
4020 return true;
4021 }
4022 case AttrTy::kFusionKind: {
4023 HloInstruction::FusionKind result;
4024 if (!ParseFusionKind(&result)) {
4025 return false;
4026 }
4027 static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
4028 ->emplace(result);
4029 return true;
4030 }
4031 case AttrTy::kBracedInt64List: {
4032 std::vector<int64> result;
4033 if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4034 &result)) {
4035 return false;
4036 }
4037 static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
4038 ->emplace(result);
4039 return true;
4040 }
4041 case AttrTy::kBracedInt64ListList: {
4042 std::vector<std::vector<int64>> result;
4043 if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
4044 TokKind::kComma, &result)) {
4045 return false;
4046 }
4047 static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
4048 ->emplace(result);
4049 return true;
4050 }
4051 case AttrTy::kSliceRanges: {
4052 SliceRanges result;
4053 if (!ParseSliceRanges(&result)) {
4054 return false;
4055 }
4056 static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
4057 return true;
4058 }
4059 case AttrTy::kPaddingConfig: {
4060 PaddingConfig result;
4061 if (!ParsePaddingConfig(&result)) {
4062 return false;
4063 }
4064 static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
4065 return true;
4066 }
4067 case AttrTy::kString: {
4068 std::string result;
4069 if (!ParseString(&result)) {
4070 return false;
4071 }
4072 static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
4073 return true;
4074 }
4075 case AttrTy::kMetadata: {
4076 OpMetadata result;
4077 if (!ParseMetadata(&result)) {
4078 return false;
4079 }
4080 static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
4081 return true;
4082 }
4083 case AttrTy::kDistribution: {
4084 RandomDistribution result;
4085 if (!ParseRandomDistribution(&result)) {
4086 return false;
4087 }
4088 static_cast<optional<RandomDistribution>*>(attr_out_ptr)
4089 ->emplace(result);
4090 return true;
4091 }
4092 case AttrTy::kDomain: {
4093 return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
4094 }
4095 case AttrTy::kPrecisionList: {
4096 std::vector<PrecisionConfig::Precision> result;
4097 if (!ParsePrecisionList(&result)) {
4098 return false;
4099 }
4100 static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
4101 attr_out_ptr)
4102 ->emplace(result);
4103 return true;
4104 }
4105 case AttrTy::kShape: {
4106 Shape result;
4107 if (!ParseShape(&result)) {
4108 return false;
4109 }
4110 static_cast<optional<Shape>*>(attr_out_ptr)->emplace(result);
4111 return true;
4112 }
4113 case AttrTy::kShapeList: {
4114 std::vector<Shape> result;
4115 if (!ParseShapeList(&result)) {
4116 return false;
4117 }
4118 static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
4119 ->emplace(result);
4120 return true;
4121 }
4122 case AttrTy::kRandomAlgorithm: {
4123 RandomAlgorithm result;
4124 if (!ParseRandomAlgorithm(&result)) {
4125 return false;
4126 }
4127 static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
4128 return true;
4129 }
4130 case AttrTy::kAliasing: {
4131 AliasingData aliasing_data;
4132 if (!ParseAliasing(&aliasing_data)) {
4133 return false;
4134 }
4135 static_cast<optional<AliasingData>*>(attr_out_ptr)
4136 ->emplace(aliasing_data);
4137 return true;
4138 }
4139 case AttrTy::kInstructionAliasing: {
4140 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
4141 aliasing_output_operand_pairs;
4142 if (!ParseInstructionOutputOperandAliasing(
4143 &aliasing_output_operand_pairs)) {
4144 return false;
4145 }
4146 static_cast<optional<
4147 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
4148 attr_out_ptr)
4149 ->emplace(std::move(aliasing_output_operand_pairs));
4150 return true;
4151 }
4152 case AttrTy::kLiteral: {
4153 Literal result;
4154 if (!ParseLiteral(&result)) {
4155 return false;
4156 }
4157 static_cast<optional<Literal>*>(attr_out_ptr)
4158 ->emplace(std::move(result));
4159 return true;
4160 }
4161 case AttrTy::kCustomCallSchedule: {
4162 CustomCallSchedule result;
4163 if (!ParseCustomCallSchedule(&result)) {
4164 return false;
4165 }
4166 static_cast<optional<CustomCallSchedule>*>(attr_out_ptr)
4167 ->emplace(result);
4168 return true;
4169 }
4170 case AttrTy::kCustomCallApiVersion: {
4171 CustomCallApiVersion result;
4172 if (!ParseCustomCallApiVersion(&result)) {
4173 return false;
4174 }
4175 static_cast<optional<CustomCallApiVersion>*>(attr_out_ptr)
4176 ->emplace(result);
4177 return true;
4178 }
4179 }
4180 }();
4181 if (!success) {
4182 return Error(loc, StrFormat("error parsing attribute %s", name));
4183 }
4184 return true;
4185 }
4186
CopyAttributeToProtoMessage(absl::flat_hash_set<std::string> non_proto_attrs,const absl::flat_hash_map<std::string,AttrConfig> & attrs,tensorflow::protobuf::Message * message)4187 bool HloParserImpl::CopyAttributeToProtoMessage(
4188 absl::flat_hash_set<std::string> non_proto_attrs,
4189 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
4190 tensorflow::protobuf::Message* message) {
4191 const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
4192 const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
4193
4194 for (const auto& p : attrs) {
4195 const std::string& name = p.first;
4196 if (non_proto_attrs.find(name) != non_proto_attrs.end()) {
4197 continue;
4198 }
4199 const tensorflow::protobuf::FieldDescriptor* fd =
4200 descriptor->FindFieldByName(name);
4201 if (!fd) {
4202 std::string allowed_attrs = "Allowed attributes: ";
4203
4204 for (int i = 0; i < descriptor->field_count(); ++i) {
4205 if (i == 0) {
4206 absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
4207 } else {
4208 absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
4209 }
4210 }
4211 return TokenError(
4212 StrFormat("unexpected attribute \"%s\". %s", name, allowed_attrs));
4213 }
4214
4215 CHECK(!fd->is_repeated()); // Repeated fields not implemented.
4216 bool success = [&] {
4217 switch (fd->type()) {
4218 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
4219 auto attr_value = static_cast<optional<bool>*>(p.second.result);
4220 if (attr_value->has_value()) {
4221 reflection->SetBool(message, fd, **attr_value);
4222 }
4223 return true;
4224 }
4225 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
4226 auto attr_value =
4227 static_cast<optional<std::string>*>(p.second.result);
4228 if (attr_value->has_value()) {
4229 const tensorflow::protobuf::EnumValueDescriptor* evd =
4230 fd->enum_type()->FindValueByName(**attr_value);
4231 reflection->SetEnum(message, fd, evd);
4232 }
4233 return true;
4234 }
4235 default:
4236 return false;
4237 }
4238 }();
4239
4240 if (!success) {
4241 return TokenError(StrFormat("error parsing attribute %s", name));
4242 }
4243 }
4244
4245 return true;
4246 }
4247
4248 // attributes ::= (',' attribute)*
ParseAttributesAsProtoMessage(const absl::flat_hash_map<std::string,AttrConfig> & non_proto_attrs,tensorflow::protobuf::Message * message)4249 bool HloParserImpl::ParseAttributesAsProtoMessage(
4250 const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
4251 tensorflow::protobuf::Message* message) {
4252 const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
4253 absl::flat_hash_map<std::string, AttrConfig> attrs;
4254
4255 // Storage for attributes.
4256 std::vector<optional<bool>> bool_params;
4257 std::vector<optional<std::string>> string_params;
4258 // Reserve enough capacity to make sure that the vector is not growing, so we
4259 // can rely on the pointers to stay valid.
4260 bool_params.reserve(descriptor->field_count());
4261 string_params.reserve(descriptor->field_count());
4262
4263 // Populate the storage of expected attributes from the protobuf description.
4264 for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
4265 const tensorflow::protobuf::FieldDescriptor* fd =
4266 descriptor->field(field_idx);
4267 const std::string& field_name = fd->name();
4268 switch (fd->type()) {
4269 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
4270 bool_params.emplace_back(absl::nullopt);
4271 attrs[field_name] = {/*is_required*/ false, AttrTy::kBool,
4272 &bool_params.back()};
4273 break;
4274 }
4275 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
4276 string_params.emplace_back(absl::nullopt);
4277 attrs[field_name] = {/*is_required*/ false, AttrTy::kEnum,
4278 &string_params.back()};
4279 break;
4280 }
4281 default:
4282 return TokenError(absl::StrFormat(
4283 "Unexpected protocol buffer type: %s ", fd->DebugString()));
4284 }
4285 }
4286
4287 absl::flat_hash_set<std::string> non_proto_attrs_names;
4288 non_proto_attrs_names.reserve(non_proto_attrs.size());
4289 for (const auto& p : non_proto_attrs) {
4290 const std::string& attr_name = p.first;
4291 // If an attribute is both specified within 'non_proto_attrs' and an
4292 // attribute of the proto message, we prefer the attribute of the proto
4293 // message.
4294 if (attrs.find(attr_name) == attrs.end()) {
4295 non_proto_attrs_names.insert(attr_name);
4296 attrs[attr_name] = p.second;
4297 }
4298 }
4299
4300 if (!ParseAttributes(attrs)) {
4301 return false;
4302 }
4303
4304 return CopyAttributeToProtoMessage(non_proto_attrs_names, attrs, message);
4305 }
4306
ParseComputationName(HloComputation ** value)4307 bool HloParserImpl::ParseComputationName(HloComputation** value) {
4308 std::string name;
4309 LocTy loc = lexer_.GetLoc();
4310 if (!ParseName(&name)) {
4311 return Error(loc, "expects computation name");
4312 }
4313 std::pair<HloComputation*, LocTy>* computation =
4314 tensorflow::gtl::FindOrNull(computation_pool_, name);
4315 if (computation == nullptr) {
4316 return Error(loc, StrCat("computation does not exist: ", name));
4317 }
4318 *value = computation->first;
4319 return true;
4320 }
4321
4322 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
4323 // The subattributes can appear in any order. 'size=' is required, others are
4324 // optional.
ParseWindow(Window * window,bool expect_outer_curlies)4325 bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
4326 LocTy loc = lexer_.GetLoc();
4327 if (expect_outer_curlies &&
4328 !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
4329 return false;
4330 }
4331
4332 std::vector<int64> size;
4333 std::vector<int64> stride;
4334 std::vector<std::vector<int64>> pad;
4335 std::vector<int64> lhs_dilate;
4336 std::vector<int64> rhs_dilate;
4337 std::vector<int64> rhs_reversal;
4338 const auto end_token =
4339 expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
4340 while (lexer_.GetKind() != end_token) {
4341 LocTy attr_loc = lexer_.GetLoc();
4342 std::string field_name;
4343 if (!ParseAttributeName(&field_name)) {
4344 return Error(attr_loc, "expects sub-attributes in window");
4345 }
4346 bool ok = [&] {
4347 if (field_name == "size") {
4348 return ParseDxD("size", &size);
4349 }
4350 if (field_name == "stride") {
4351 return ParseDxD("stride", &stride);
4352 }
4353 if (field_name == "lhs_dilate") {
4354 return ParseDxD("lhs_dilate", &lhs_dilate);
4355 }
4356 if (field_name == "rhs_dilate") {
4357 return ParseDxD("rls_dilate", &rhs_dilate);
4358 }
4359 if (field_name == "pad") {
4360 return ParseWindowPad(&pad);
4361 }
4362 if (field_name == "rhs_reversal") {
4363 return ParseDxD("rhs_reversal", &rhs_reversal);
4364 }
4365 return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
4366 }();
4367 if (!ok) {
4368 return false;
4369 }
4370 }
4371
4372 if (!stride.empty() && stride.size() != size.size()) {
4373 return Error(loc, "expects 'stride=' has the same size as 'size='");
4374 }
4375 if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
4376 return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
4377 }
4378 if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
4379 return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
4380 }
4381 if (!pad.empty() && pad.size() != size.size()) {
4382 return Error(loc, "expects 'pad=' has the same size as 'size='");
4383 }
4384
4385 for (int i = 0; i < size.size(); i++) {
4386 window->add_dimensions()->set_size(size[i]);
4387 if (!pad.empty()) {
4388 window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
4389 window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
4390 }
4391 // If some field is not present, it has the default value.
4392 window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
4393 window->mutable_dimensions(i)->set_base_dilation(
4394 lhs_dilate.empty() ? 1 : lhs_dilate[i]);
4395 window->mutable_dimensions(i)->set_window_dilation(
4396 rhs_dilate.empty() ? 1 : rhs_dilate[i]);
4397 window->mutable_dimensions(i)->set_window_reversal(
4398 rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
4399 }
4400 return !expect_outer_curlies ||
4401 ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
4402 }
4403
4404 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
4405 // The string looks like "dim_labels=0bf_0io->0bf".
4406 //
4407 // '?' dims don't appear in ConvolutionDimensionNumbers. There can be more than
4408 // one '?' dim.
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)4409 bool HloParserImpl::ParseConvolutionDimensionNumbers(
4410 ConvolutionDimensionNumbers* dnums) {
4411 if (lexer_.GetKind() != TokKind::kDimLabels) {
4412 return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
4413 }
4414 std::string str = lexer_.GetStrVal();
4415
4416 // The str is expected to have 3 items, lhs, rhs, out, and it must look like
4417 // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
4418 std::vector<std::string> split1 = absl::StrSplit(str, '_');
4419 if (split1.size() != 2) {
4420 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4421 << str;
4422 }
4423 std::vector<std::string> split2 = absl::StrSplit(split1[1], "->");
4424 if (split2.size() != 2) {
4425 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4426 << str;
4427 }
4428 absl::string_view lhs = split1[0];
4429 absl::string_view rhs = split2[0];
4430 absl::string_view out = split2[1];
4431
4432 auto is_unique = [](absl::string_view str) -> bool {
4433 absl::flat_hash_set<char> chars;
4434 for (char c : str) {
4435 // '?' dims are skipped.
4436 if (c == '?') {
4437 continue;
4438 }
4439 if (!chars.insert(c).second) {
4440 return false;
4441 }
4442 }
4443 return true;
4444 };
4445
4446 // lhs
4447 {
4448 if (!is_unique(lhs)) {
4449 return TokenError(
4450 StrCat("expects unique lhs dimension numbers, but sees ", lhs));
4451 }
4452 // Count number of spatial dimensions.
4453 for (char c : lhs) {
4454 if (c != 'b' && c != 'f' && c != '?') {
4455 dnums->add_input_spatial_dimensions(-1);
4456 }
4457 }
4458 for (int i = 0; i < lhs.size(); i++) {
4459 char c = lhs[i];
4460 if (c == '?') {
4461 continue;
4462 } else if (c == 'b') {
4463 dnums->set_input_batch_dimension(i);
4464 } else if (c == 'f') {
4465 dnums->set_input_feature_dimension(i);
4466 } else if (c < '0' + lhs.size() && c >= '0') {
4467 dnums->set_input_spatial_dimensions(c - '0', i);
4468 } else {
4469 return TokenError(StrFormat(
4470 "expects [0-%dbf?] in lhs dimension numbers", lhs.size() - 1));
4471 }
4472 }
4473 }
4474 // rhs
4475 {
4476 if (!is_unique(rhs)) {
4477 return TokenError(
4478 StrCat("expects unique rhs dimension numbers, but sees ", rhs));
4479 }
4480 // Count number of spatial dimensions.
4481 for (char c : rhs) {
4482 if (c != 'i' && c != 'o' && c != '?') {
4483 dnums->add_kernel_spatial_dimensions(-1);
4484 }
4485 }
4486 for (int i = 0; i < rhs.size(); i++) {
4487 char c = rhs[i];
4488 if (c == '?') {
4489 continue;
4490 } else if (c == 'i') {
4491 dnums->set_kernel_input_feature_dimension(i);
4492 } else if (c == 'o') {
4493 dnums->set_kernel_output_feature_dimension(i);
4494 } else if (c < '0' + rhs.size() && c >= '0') {
4495 dnums->set_kernel_spatial_dimensions(c - '0', i);
4496 } else {
4497 return TokenError(StrFormat(
4498 "expects [0-%dio?] in rhs dimension numbers", rhs.size() - 1));
4499 }
4500 }
4501 }
4502 // output
4503 {
4504 if (!is_unique(out)) {
4505 return TokenError(
4506 StrCat("expects unique output dimension numbers, but sees ", out));
4507 }
4508 // Count number of spatial dimensions.
4509 for (char c : out) {
4510 if (c != 'b' && c != 'f' && c != '?') {
4511 dnums->add_output_spatial_dimensions(-1);
4512 }
4513 }
4514 for (int i = 0; i < out.size(); i++) {
4515 char c = out[i];
4516 if (c == '?') {
4517 continue;
4518 } else if (c == 'b') {
4519 dnums->set_output_batch_dimension(i);
4520 } else if (c == 'f') {
4521 dnums->set_output_feature_dimension(i);
4522 } else if (c < '0' + out.size() && c >= '0') {
4523 dnums->set_output_spatial_dimensions(c - '0', i);
4524 } else {
4525 return TokenError(StrFormat(
4526 "expects [0-%dbf?] in output dimension numbers", out.size() - 1));
4527 }
4528 }
4529 }
4530
4531 // lhs, rhs, and output should have the same number of spatial dimensions.
4532 if (dnums->input_spatial_dimensions_size() !=
4533 dnums->output_spatial_dimensions_size() ||
4534 dnums->input_spatial_dimensions_size() !=
4535 dnums->kernel_spatial_dimensions_size()) {
4536 return TokenError(
4537 StrFormat("input, kernel, and output must have same number of spatial "
4538 "dimensions, but got %d, %d, %d, respectively.",
4539 dnums->input_spatial_dimensions_size(),
4540 dnums->kernel_spatial_dimensions_size(),
4541 dnums->output_spatial_dimensions_size()));
4542 }
4543
4544 lexer_.Lex();
4545 return true;
4546 }
4547
4548 // ::= '{' ranges '}'
4549 // ::= /*empty*/
4550 // ::= range (',' range)*
4551 // range ::= '[' start ':' limit (':' stride)? ']'
4552 //
4553 // The slice ranges are printed as:
4554 //
4555 // {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
4556 //
4557 // This function extracts the starts, limits, and strides as 3 vectors to the
4558 // result. If stride is not present, stride is 1. For example, if the slice
4559 // ranges is printed as:
4560 //
4561 // {[2:3:4], [5:6:7], [8:9]}
4562 //
4563 // The parsed result will be:
4564 //
4565 // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
4566 //
ParseSliceRanges(SliceRanges * result)4567 bool HloParserImpl::ParseSliceRanges(SliceRanges* result) {
4568 if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
4569 return false;
4570 }
4571 std::vector<std::vector<int64>> ranges;
4572 if (lexer_.GetKind() == TokKind::kRbrace) {
4573 // empty
4574 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4575 }
4576 do {
4577 LocTy loc = lexer_.GetLoc();
4578 ranges.emplace_back();
4579 if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
4580 &ranges.back())) {
4581 return false;
4582 }
4583 const auto& range = ranges.back();
4584 if (range.size() != 2 && range.size() != 3) {
4585 return Error(loc,
4586 StrFormat("expects [start:limit:step] or [start:limit], "
4587 "but sees %d elements.",
4588 range.size()));
4589 }
4590 } while (EatIfPresent(TokKind::kComma));
4591
4592 for (const auto& range : ranges) {
4593 result->starts.push_back(range[0]);
4594 result->limits.push_back(range[1]);
4595 result->strides.push_back(range.size() == 3 ? range[2] : 1);
4596 }
4597 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4598 }
4599
4600 // precisionlist ::= start precision_elements end
4601 // precision_elements
4602 // ::= /*empty*/
4603 // ::= precision_val (delim precision_val)*
ParsePrecisionList(std::vector<PrecisionConfig::Precision> * result)4604 bool HloParserImpl::ParsePrecisionList(
4605 std::vector<PrecisionConfig::Precision>* result) {
4606 auto parse_and_add_item = [&]() {
4607 PrecisionConfig::Precision item;
4608 if (!ParsePrecision(&item)) {
4609 return false;
4610 }
4611 result->push_back(item);
4612 return true;
4613 };
4614 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4615 parse_and_add_item);
4616 }
4617
ParseHloComputation(HloComputation ** result)4618 bool HloParserImpl::ParseHloComputation(HloComputation** result) {
4619 if (lexer_.GetKind() == TokKind::kLbrace) {
4620 // This means it is a nested computation.
4621 return ParseInstructionList(result, /*computation_name=*/"_");
4622 }
4623 // This means it is a computation name.
4624 return ParseComputationName(result);
4625 }
4626
ParseHloComputationList(std::vector<HloComputation * > * result)4627 bool HloParserImpl::ParseHloComputationList(
4628 std::vector<HloComputation*>* result) {
4629 auto parse_and_add_item = [&]() {
4630 HloComputation* computation;
4631 if (!ParseHloComputation(&computation)) {
4632 return false;
4633 }
4634 VLOG(3) << "parsed computation " << computation->name();
4635 result->push_back(computation);
4636 return true;
4637 };
4638 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4639 parse_and_add_item);
4640 }
4641
4642 // shapelist ::= '{' shapes '}'
4643 // precision_elements
4644 // ::= /*empty*/
4645 // ::= shape (',' shape)*
ParseShapeList(std::vector<Shape> * result)4646 bool HloParserImpl::ParseShapeList(std::vector<Shape>* result) {
4647 auto parse_and_add_item = [&]() {
4648 Shape shape;
4649 if (!ParseShape(&shape)) {
4650 return false;
4651 }
4652 result->push_back(std::move(shape));
4653 return true;
4654 };
4655 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4656 parse_and_add_item);
4657 }
4658
4659 // int64list ::= start int64_elements end
4660 // int64_elements
4661 // ::= /*empty*/
4662 // ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64> * result)4663 bool HloParserImpl::ParseInt64List(const TokKind start, const TokKind end,
4664 const TokKind delim,
4665 std::vector<int64>* result) {
4666 auto parse_and_add_item = [&]() {
4667 int64_t i;
4668 if (!ParseInt64(&i)) {
4669 return false;
4670 }
4671 result->push_back(i);
4672 return true;
4673 };
4674 return ParseList(start, end, delim, parse_and_add_item);
4675 }
4676
4677 // int64listlist ::= start int64list_elements end
4678 // int64list_elements
4679 // ::= /*empty*/
4680 // ::= int64list (delim int64list)*
4681 // int64list ::= start int64_elements end
4682 // int64_elements
4683 // ::= /*empty*/
4684 // ::= int64_val (delim int64_val)*
ParseInt64ListList(const TokKind start,const TokKind end,const TokKind delim,std::vector<std::vector<int64>> * result)4685 bool HloParserImpl::ParseInt64ListList(
4686 const TokKind start, const TokKind end, const TokKind delim,
4687 std::vector<std::vector<int64>>* result) {
4688 auto parse_and_add_item = [&]() {
4689 std::vector<int64> item;
4690 if (!ParseInt64List(start, end, delim, &item)) {
4691 return false;
4692 }
4693 result->push_back(item);
4694 return true;
4695 };
4696 return ParseList(start, end, delim, parse_and_add_item);
4697 }
4698
ParseList(const TokKind start,const TokKind end,const TokKind delim,const std::function<bool ()> & parse_and_add_item)4699 bool HloParserImpl::ParseList(const TokKind start, const TokKind end,
4700 const TokKind delim,
4701 const std::function<bool()>& parse_and_add_item) {
4702 if (!ParseToken(start, StrCat("expects a list starting with ",
4703 TokKindToString(start)))) {
4704 return false;
4705 }
4706 if (lexer_.GetKind() == end) {
4707 // empty
4708 } else {
4709 do {
4710 if (!parse_and_add_item()) {
4711 return false;
4712 }
4713 } while (EatIfPresent(delim));
4714 }
4715 return ParseToken(
4716 end, StrCat("expects a list to end with ", TokKindToString(end)));
4717 }
4718
4719 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)4720 bool HloParserImpl::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
4721 if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
4722 return false;
4723 }
4724 *shape_loc = lexer_.GetLoc();
4725 return ParseShape(shape);
4726 }
4727
CanBeParamListToShape()4728 bool HloParserImpl::CanBeParamListToShape() {
4729 return lexer_.GetKind() == TokKind::kLparen;
4730 }
4731
4732 // param_list ::= '(' param_list1 ')'
4733 // param_list1
4734 // ::= /*empty*/
4735 // ::= param (',' param)*
4736 // param ::= name shape
ParseParamList()4737 bool HloParserImpl::ParseParamList() {
4738 if (!ParseToken(TokKind::kLparen,
4739 "expects '(' at the beginning of param list")) {
4740 return false;
4741 }
4742
4743 if (lexer_.GetKind() == TokKind::kRparen) {
4744 // empty
4745 } else {
4746 do {
4747 Shape shape;
4748 std::string name;
4749 if (!ParseName(&name) || !ParseShape(&shape)) {
4750 return false;
4751 }
4752 } while (EatIfPresent(TokKind::kComma));
4753 }
4754 return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
4755 }
4756
4757 // dimension_sizes ::= '[' dimension_list ']'
4758 // dimension_list
4759 // ::= /*empty*/
4760 // ::= <=? int64 (',' param)*
4761 // param ::= name shape
ParseDimensionSizes(std::vector<int64> * dimension_sizes,std::vector<bool> * dynamic_dimensions)4762 bool HloParserImpl::ParseDimensionSizes(std::vector<int64>* dimension_sizes,
4763 std::vector<bool>* dynamic_dimensions) {
4764 auto parse_and_add_item = [&]() {
4765 int64_t i;
4766 bool is_dynamic = false;
4767 if (lexer_.GetKind() == TokKind::kLeq) {
4768 is_dynamic = true;
4769 lexer_.Lex();
4770 }
4771 if (!ParseInt64(&i)) {
4772 return false;
4773 }
4774 dimension_sizes->push_back(i);
4775 dynamic_dimensions->push_back(is_dynamic);
4776 return true;
4777 };
4778 return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
4779 parse_and_add_item);
4780 }
4781
4782 // tiles
4783 // ::= /*empty*/
4784 // ::= 'T' '(' dim_list ')'
4785 // dim_list
4786 // ::= /*empty*/
4787 // ::= (int64 | '*') (',' (int64 | '*'))*
ParseTiles(std::vector<Tile> * tiles)4788 bool HloParserImpl::ParseTiles(std::vector<Tile>* tiles) {
4789 auto parse_and_add_tile_dimension = [&]() {
4790 int64_t i;
4791 if (ParseInt64(&i)) {
4792 tiles->back().add_dimensions(i);
4793 return true;
4794 }
4795 if (lexer_.GetKind() == TokKind::kAsterisk) {
4796 tiles->back().add_dimensions(Tile::kCombineDimension);
4797 lexer_.Lex();
4798 return true;
4799 }
4800 return false;
4801 };
4802
4803 do {
4804 tiles->push_back(Tile());
4805 if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
4806 parse_and_add_tile_dimension)) {
4807 return false;
4808 }
4809 } while (lexer_.GetKind() == TokKind::kLparen);
4810 return true;
4811 }
4812
4813 // int_attribute
4814 // ::= /*empty*/
4815 // ::= attr_token '(' attr_value ')'
4816 // attr_token
4817 // ::= 'E' | 'S'
4818 // attr_value
4819 // ::= int64
ParseLayoutIntAttribute(int64 * attr_value,absl::string_view attr_description)4820 bool HloParserImpl::ParseLayoutIntAttribute(
4821 int64* attr_value, absl::string_view attr_description) {
4822 if (!ParseToken(TokKind::kLparen,
4823 StrCat("expects ", attr_description, " to start with ",
4824 TokKindToString(TokKind::kLparen)))) {
4825 return false;
4826 }
4827 if (!ParseInt64(attr_value)) {
4828 return false;
4829 }
4830 if (!ParseToken(TokKind::kRparen,
4831 StrCat("expects ", attr_description, " to end with ",
4832 TokKindToString(TokKind::kRparen)))) {
4833 return false;
4834 }
4835 return true;
4836 }
4837
4838 // layout ::= '{' int64_list (':' tiles element_size_in_bits memory_space)? '}'
4839 // element_size_in_bits
4840 // ::= /*empty*/
4841 // ::= 'E' '(' int64 ')'
4842 // memory_space
4843 // ::= /*empty*/
4844 // ::= 'S' '(' int64 ')'
ParseLayout(Layout * layout)4845 bool HloParserImpl::ParseLayout(Layout* layout) {
4846 std::vector<int64> minor_to_major;
4847 std::vector<Tile> tiles;
4848 int64_t element_size_in_bits = 0;
4849 int64_t memory_space = 0;
4850
4851 auto parse_and_add_item = [&]() {
4852 int64_t i;
4853 if (!ParseInt64(&i)) {
4854 return false;
4855 }
4856 minor_to_major.push_back(i);
4857 return true;
4858 };
4859
4860 if (!ParseToken(TokKind::kLbrace,
4861 StrCat("expects layout to start with ",
4862 TokKindToString(TokKind::kLbrace)))) {
4863 return false;
4864 }
4865 if (lexer_.GetKind() != TokKind::kRbrace) {
4866 if (lexer_.GetKind() == TokKind::kInt) {
4867 // Parse minor to major.
4868 do {
4869 if (!parse_and_add_item()) {
4870 return false;
4871 }
4872 } while (EatIfPresent(TokKind::kComma));
4873 }
4874
4875 if (lexer_.GetKind() == TokKind::kColon) {
4876 lexer_.Lex();
4877 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
4878 lexer_.Lex();
4879 ParseTiles(&tiles);
4880 }
4881
4882 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
4883 lexer_.Lex();
4884 ParseLayoutIntAttribute(&element_size_in_bits, "element size in bits");
4885 }
4886
4887 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "S") {
4888 lexer_.Lex();
4889 ParseLayoutIntAttribute(&memory_space, "memory space");
4890 }
4891 }
4892 }
4893 if (!ParseToken(TokKind::kRbrace,
4894 StrCat("expects layout to end with ",
4895 TokKindToString(TokKind::kRbrace)))) {
4896 return false;
4897 }
4898
4899 std::vector<Tile> vec_tiles(tiles.size());
4900 for (int i = 0; i < tiles.size(); i++) {
4901 vec_tiles[i] = Tile(tiles[i]);
4902 }
4903 *layout = LayoutUtil::MakeLayout(minor_to_major, vec_tiles,
4904 element_size_in_bits, memory_space);
4905 return true;
4906 }
4907
4908 // shape ::= shape_val_
4909 // shape ::= '(' tuple_elements ')'
4910 // tuple_elements
4911 // ::= /*empty*/
4912 // ::= shape (',' shape)*
ParseShape(Shape * result)4913 bool HloParserImpl::ParseShape(Shape* result) {
4914 if (EatIfPresent(TokKind::kLparen)) { // Tuple
4915 std::vector<Shape> shapes;
4916 if (lexer_.GetKind() == TokKind::kRparen) {
4917 /*empty*/
4918 } else {
4919 // shape (',' shape)*
4920 do {
4921 shapes.emplace_back();
4922 if (!ParseShape(&shapes.back())) {
4923 return false;
4924 }
4925 } while (EatIfPresent(TokKind::kComma));
4926 }
4927 *result = ShapeUtil::MakeTupleShape(shapes);
4928 return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
4929 }
4930
4931 if (lexer_.GetKind() != TokKind::kPrimitiveType) {
4932 return TokenError(absl::StrCat("expected primitive type, saw ",
4933 TokKindToString(lexer_.GetKind())));
4934 }
4935 PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
4936 lexer_.Lex();
4937
4938 // Each element contains a dimension size and a bool indicating whether this
4939 // is a dynamic dimension.
4940 std::vector<int64> dimension_sizes;
4941 std::vector<bool> dynamic_dimensions;
4942 if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
4943 return false;
4944 }
4945 result->set_element_type(primitive_type);
4946 for (int i = 0; i < dimension_sizes.size(); ++i) {
4947 result->add_dimensions(dimension_sizes[i]);
4948 result->set_dynamic_dimension(i, dynamic_dimensions[i]);
4949 }
4950 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "invalid") {
4951 lexer_.Lex();
4952 if (lexer_.GetKind() != TokKind::kLbrace) {
4953 return false;
4954 }
4955 lexer_.Lex();
4956 if (lexer_.GetKind() != TokKind::kRbrace) {
4957 return false;
4958 }
4959 lexer_.Lex();
4960 result->mutable_layout()->Clear();
4961 return true;
4962 }
4963 LayoutUtil::SetToDefaultLayout(result);
4964 // We need to lookahead to see if a following open brace is the start of a
4965 // layout. The specific problematic case is:
4966 //
4967 // ENTRY %foo (x: f32[42]) -> f32[123] {
4968 // ...
4969 // }
4970 //
4971 // The open brace could either be the start of a computation or the start of a
4972 // layout for the f32[123] shape. We consider it the start of a layout if the
4973 // next token after the open brace is an integer or a colon.
4974 if (lexer_.GetKind() == TokKind::kLbrace &&
4975 (lexer_.LookAhead() == TokKind::kInt ||
4976 lexer_.LookAhead() == TokKind::kColon)) {
4977 Layout layout;
4978 if (!ParseLayout(&layout)) {
4979 return false;
4980 }
4981 if (layout.minor_to_major_size() != result->rank()) {
4982 return Error(
4983 lexer_.GetLoc(),
4984 StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
4985 result->rank(), layout.minor_to_major_size()));
4986 }
4987 *result->mutable_layout() = layout;
4988 }
4989 return true;
4990 }
4991
CanBeShape()4992 bool HloParserImpl::CanBeShape() {
4993 // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
4994 // with '('.
4995 return lexer_.GetKind() == TokKind::kPrimitiveType ||
4996 lexer_.GetKind() == TokKind::kLparen;
4997 }
4998
ParseName(std::string * result)4999 bool HloParserImpl::ParseName(std::string* result) {
5000 VLOG(3) << "ParseName";
5001 if (lexer_.GetKind() != TokKind::kIdent &&
5002 lexer_.GetKind() != TokKind::kName) {
5003 return TokenError("expects name");
5004 }
5005 *result = lexer_.GetStrVal();
5006 lexer_.Lex();
5007 return true;
5008 }
5009
ParseAttributeName(std::string * result)5010 bool HloParserImpl::ParseAttributeName(std::string* result) {
5011 if (lexer_.GetKind() != TokKind::kAttributeName) {
5012 return TokenError("expects attribute name");
5013 }
5014 *result = lexer_.GetStrVal();
5015 lexer_.Lex();
5016 return true;
5017 }
5018
ParseString(std::string * result)5019 bool HloParserImpl::ParseString(std::string* result) {
5020 VLOG(3) << "ParseString";
5021 if (lexer_.GetKind() != TokKind::kString) {
5022 return TokenError("expects string");
5023 }
5024 *result = lexer_.GetStrVal();
5025 lexer_.Lex();
5026 return true;
5027 }
5028
ParseDxD(const std::string & name,std::vector<int64> * result)5029 bool HloParserImpl::ParseDxD(const std::string& name,
5030 std::vector<int64>* result) {
5031 LocTy loc = lexer_.GetLoc();
5032 if (!result->empty()) {
5033 return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
5034 }
5035 // 1D
5036 if (lexer_.GetKind() == TokKind::kInt) {
5037 int64_t number;
5038 if (!ParseInt64(&number)) {
5039 return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
5040 }
5041 result->push_back(number);
5042 return true;
5043 }
5044 // 2D or higher.
5045 if (lexer_.GetKind() == TokKind::kDxD) {
5046 std::string str = lexer_.GetStrVal();
5047 if (!SplitToInt64s(str, 'x', result)) {
5048 return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
5049 }
5050 lexer_.Lex();
5051 return true;
5052 }
5053 return TokenError("expects token type kInt or kDxD");
5054 }
5055
ParseWindowPad(std::vector<std::vector<int64>> * pad)5056 bool HloParserImpl::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
5057 LocTy loc = lexer_.GetLoc();
5058 if (!pad->empty()) {
5059 return Error(loc, "sub-attribute 'pad=' already exists");
5060 }
5061 if (lexer_.GetKind() != TokKind::kPad) {
5062 return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
5063 }
5064 std::string str = lexer_.GetStrVal();
5065 for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
5066 std::vector<int64> low_high;
5067 if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
5068 low_high.size() != 2) {
5069 return Error(loc,
5070 "expects padding_low and padding_high separated by '_'");
5071 }
5072 pad->push_back(low_high);
5073 }
5074 lexer_.Lex();
5075 return true;
5076 }
5077
5078 // This is the inverse xla::ToString(PaddingConfig). The padding config string
5079 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
5080 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
5081 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)5082 bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) {
5083 if (lexer_.GetKind() != TokKind::kPad) {
5084 return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
5085 }
5086 LocTy loc = lexer_.GetLoc();
5087 std::string str = lexer_.GetStrVal();
5088 for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
5089 std::vector<int64> padding_dim;
5090 if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
5091 (padding_dim.size() != 2 && padding_dim.size() != 3)) {
5092 return Error(loc,
5093 "expects padding config pattern like 'low_high_interior' or "
5094 "'low_high'");
5095 }
5096 auto* dim = padding->add_dimensions();
5097 dim->set_edge_padding_low(padding_dim[0]);
5098 dim->set_edge_padding_high(padding_dim[1]);
5099 dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
5100 }
5101 lexer_.Lex();
5102 return true;
5103 }
5104
5105 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)5106 bool HloParserImpl::ParseMetadata(OpMetadata* metadata) {
5107 absl::flat_hash_map<std::string, AttrConfig> attrs;
5108 optional<std::string> op_type;
5109 optional<std::string> op_name;
5110 optional<std::string> source_file;
5111 optional<int32> source_line;
5112 optional<std::vector<int64>> profile_type;
5113 attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
5114 attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
5115 attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
5116 attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
5117 attrs["profile_type"] = {/*required=*/false, AttrTy::kBracedInt64List,
5118 &profile_type};
5119 if (!ParseSubAttributes(attrs)) {
5120 return false;
5121 }
5122 if (op_type) {
5123 metadata->set_op_type(*op_type);
5124 }
5125 if (op_name) {
5126 metadata->set_op_name(*op_name);
5127 }
5128 if (source_file) {
5129 metadata->set_source_file(*source_file);
5130 }
5131 if (source_line) {
5132 metadata->set_source_line(*source_line);
5133 }
5134 if (profile_type) {
5135 for (const auto& type : *profile_type) {
5136 if (!ProfileType_IsValid(type)) {
5137 return false;
5138 }
5139 metadata->add_profile_type(static_cast<ProfileType>(type));
5140 }
5141 }
5142 return true;
5143 }
5144
5145 // ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}')
ParseSingleOrListMetadata(tensorflow::protobuf::RepeatedPtrField<OpMetadata> * metadata)5146 bool HloParserImpl::ParseSingleOrListMetadata(
5147 tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata) {
5148 if (lexer_.GetKind() == TokKind::kLbrace &&
5149 lexer_.LookAhead() == TokKind::kLbrace) {
5150 if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) {
5151 return false;
5152 }
5153
5154 if (lexer_.GetKind() != TokKind::kRbrace) {
5155 do {
5156 if (!ParseMetadata(metadata->Add())) {
5157 return false;
5158 }
5159 } while (EatIfPresent(TokKind::kComma));
5160 }
5161
5162 return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list");
5163 }
5164
5165 return ParseMetadata(metadata->Add());
5166 }
5167
ParseOpShardingType(OpSharding::Type * type)5168 bool HloParserImpl::ParseOpShardingType(OpSharding::Type* type) {
5169 switch (lexer_.GetKind()) {
5170 case TokKind::kw_maximal:
5171 *type = OpSharding::MAXIMAL;
5172 lexer_.Lex();
5173 break;
5174 case TokKind::kw_replicated:
5175 *type = OpSharding::REPLICATED;
5176 lexer_.Lex();
5177 break;
5178 case TokKind::kw_manual:
5179 *type = OpSharding::MANUAL;
5180 lexer_.Lex();
5181 break;
5182 default:
5183 return false;
5184 }
5185 return true;
5186 }
5187
ParseListShardingType(std::vector<OpSharding::Type> * types)5188 bool HloParserImpl::ParseListShardingType(
5189 std::vector<OpSharding::Type>* types) {
5190 if (!ParseToken(TokKind::kLbrace,
5191 "expected '{' to start sharding type list")) {
5192 return false;
5193 }
5194
5195 if (lexer_.GetKind() != TokKind::kRbrace) {
5196 do {
5197 OpSharding::Type type;
5198 if (!ParseOpShardingType(&type)) {
5199 return false;
5200 }
5201 types->emplace_back(type);
5202 } while (EatIfPresent(TokKind::kComma));
5203 }
5204
5205 return ParseToken(TokKind::kRbrace, "expected '}' to end sharding type list");
5206 }
5207
ParseOpcode(HloOpcode * result)5208 bool HloParserImpl::ParseOpcode(HloOpcode* result) {
5209 VLOG(3) << "ParseOpcode";
5210 if (lexer_.GetKind() != TokKind::kIdent) {
5211 return TokenError("expects opcode");
5212 }
5213 std::string val = lexer_.GetStrVal();
5214 auto status_or_result = StringToHloOpcode(val);
5215 if (!status_or_result.ok()) {
5216 return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
5217 status_or_result.status().error_message()));
5218 }
5219 *result = status_or_result.ValueOrDie();
5220 lexer_.Lex();
5221 return true;
5222 }
5223
ParseFftType(FftType * result)5224 bool HloParserImpl::ParseFftType(FftType* result) {
5225 VLOG(3) << "ParseFftType";
5226 if (lexer_.GetKind() != TokKind::kIdent) {
5227 return TokenError("expects fft type");
5228 }
5229 std::string val = lexer_.GetStrVal();
5230 if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
5231 return TokenError(StrFormat("expects fft type but sees: %s", val));
5232 }
5233 lexer_.Lex();
5234 return true;
5235 }
5236
ParsePaddingType(PaddingType * result)5237 bool HloParserImpl::ParsePaddingType(PaddingType* result) {
5238 VLOG(3) << "ParsePaddingType";
5239 if (lexer_.GetKind() != TokKind::kIdent) {
5240 return TokenError("expects padding type");
5241 }
5242 std::string val = lexer_.GetStrVal();
5243 if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
5244 return TokenError(StrFormat("expects padding type but sees: %s", val));
5245 }
5246 lexer_.Lex();
5247 return true;
5248 }
5249
ParseComparisonDirection(ComparisonDirection * result)5250 bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
5251 VLOG(3) << "ParseComparisonDirection";
5252 if (lexer_.GetKind() != TokKind::kIdent) {
5253 return TokenError("expects comparison direction");
5254 }
5255 std::string val = lexer_.GetStrVal();
5256 auto status_or_result = StringToComparisonDirection(val);
5257 if (!status_or_result.ok()) {
5258 return TokenError(
5259 StrFormat("expects comparison direction but sees: %s", val));
5260 }
5261 *result = status_or_result.ValueOrDie();
5262 lexer_.Lex();
5263 return true;
5264 }
5265
ParseComparisonType(Comparison::Type * result)5266 bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
5267 VLOG(1) << "ParseComparisonType";
5268 if (lexer_.GetKind() != TokKind::kIdent) {
5269 return TokenError("expects comparison type");
5270 }
5271 std::string val = lexer_.GetStrVal();
5272 auto status_or_result = StringToComparisonType(val);
5273 if (!status_or_result.ok()) {
5274 return TokenError(StrFormat("expects comparison type but sees: %s", val));
5275 }
5276 *result = status_or_result.ValueOrDie();
5277 lexer_.Lex();
5278 return true;
5279 }
5280
ParseFusionKind(HloInstruction::FusionKind * result)5281 bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
5282 VLOG(3) << "ParseFusionKind";
5283 if (lexer_.GetKind() != TokKind::kIdent) {
5284 return TokenError("expects fusion kind");
5285 }
5286 std::string val = lexer_.GetStrVal();
5287 auto status_or_result = StringToFusionKind(val);
5288 if (!status_or_result.ok()) {
5289 return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
5290 val,
5291 status_or_result.status().error_message()));
5292 }
5293 *result = status_or_result.ValueOrDie();
5294 lexer_.Lex();
5295 return true;
5296 }
5297
ParseRandomDistribution(RandomDistribution * result)5298 bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
5299 VLOG(3) << "ParseRandomDistribution";
5300 if (lexer_.GetKind() != TokKind::kIdent) {
5301 return TokenError("expects random distribution");
5302 }
5303 std::string val = lexer_.GetStrVal();
5304 auto status_or_result = StringToRandomDistribution(val);
5305 if (!status_or_result.ok()) {
5306 return TokenError(
5307 StrFormat("expects random distribution but sees: %s, error: %s", val,
5308 status_or_result.status().error_message()));
5309 }
5310 *result = status_or_result.ValueOrDie();
5311 lexer_.Lex();
5312 return true;
5313 }
5314
ParseRandomAlgorithm(RandomAlgorithm * result)5315 bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
5316 VLOG(3) << "ParseRandomAlgorithm";
5317 if (lexer_.GetKind() != TokKind::kIdent) {
5318 return TokenError("expects random algorithm");
5319 }
5320 std::string val = lexer_.GetStrVal();
5321 auto status_or_result = StringToRandomAlgorithm(val);
5322 if (!status_or_result.ok()) {
5323 return TokenError(
5324 StrFormat("expects random algorithm but sees: %s, error: %s", val,
5325 status_or_result.status().error_message()));
5326 }
5327 *result = status_or_result.ValueOrDie();
5328 lexer_.Lex();
5329 return true;
5330 }
5331
ParsePrecision(PrecisionConfig::Precision * result)5332 bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
5333 VLOG(3) << "ParsePrecision";
5334 if (lexer_.GetKind() != TokKind::kIdent) {
5335 return TokenError("expects random distribution");
5336 }
5337 std::string val = lexer_.GetStrVal();
5338 auto status_or_result = StringToPrecision(val);
5339 if (!status_or_result.ok()) {
5340 return TokenError(StrFormat("expects precision but sees: %s, error: %s",
5341 val,
5342 status_or_result.status().error_message()));
5343 }
5344 *result = status_or_result.ValueOrDie();
5345 lexer_.Lex();
5346 return true;
5347 }
5348
ParseInt64(int64 * result)5349 bool HloParserImpl::ParseInt64(int64* result) {
5350 VLOG(3) << "ParseInt64";
5351 if (lexer_.GetKind() != TokKind::kInt) {
5352 return TokenError("expects integer");
5353 }
5354 *result = lexer_.GetInt64Val();
5355 lexer_.Lex();
5356 return true;
5357 }
5358
ParseDouble(double * result)5359 bool HloParserImpl::ParseDouble(double* result) {
5360 switch (lexer_.GetKind()) {
5361 case TokKind::kDecimal: {
5362 double val = lexer_.GetDecimalVal();
5363 // If GetDecimalVal returns +/-inf, that means that we overflowed
5364 // `double`.
5365 if (std::isinf(val)) {
5366 return TokenError(StrCat("Constant is out of range for double (+/-",
5367 std::numeric_limits<double>::max(),
5368 ") and so is unparsable."));
5369 }
5370 *result = val;
5371 break;
5372 }
5373 case TokKind::kInt:
5374 *result = static_cast<double>(lexer_.GetInt64Val());
5375 break;
5376 case TokKind::kw_inf:
5377 *result = std::numeric_limits<double>::infinity();
5378 break;
5379 case TokKind::kNegInf:
5380 *result = -std::numeric_limits<double>::infinity();
5381 break;
5382 default:
5383 return TokenError("expects decimal or integer");
5384 }
5385 lexer_.Lex();
5386 return true;
5387 }
5388
ParseComplex(std::complex<double> * result)5389 bool HloParserImpl::ParseComplex(std::complex<double>* result) {
5390 if (lexer_.GetKind() != TokKind::kLparen) {
5391 return TokenError("expects '(' before complex number");
5392 }
5393 lexer_.Lex();
5394
5395 double real;
5396 LocTy loc = lexer_.GetLoc();
5397 if (!ParseDouble(&real)) {
5398 return Error(loc,
5399 "expect floating-point value for real part of complex number");
5400 }
5401
5402 if (lexer_.GetKind() != TokKind::kComma) {
5403 return TokenError(
5404 absl::StrFormat("expect comma after real part of complex literal"));
5405 }
5406 lexer_.Lex();
5407
5408 double imag;
5409 loc = lexer_.GetLoc();
5410 if (!ParseDouble(&imag)) {
5411 return Error(
5412 loc,
5413 "expect floating-point value for imaginary part of complex number");
5414 }
5415
5416 if (lexer_.GetKind() != TokKind::kRparen) {
5417 return TokenError(absl::StrFormat("expect ')' after complex number"));
5418 }
5419
5420 *result = std::complex<double>(real, imag);
5421 lexer_.Lex();
5422 return true;
5423 }
5424
ParseBool(bool * result)5425 bool HloParserImpl::ParseBool(bool* result) {
5426 if (lexer_.GetKind() != TokKind::kw_true &&
5427 lexer_.GetKind() != TokKind::kw_false) {
5428 return TokenError("expects true or false");
5429 }
5430 *result = lexer_.GetKind() == TokKind::kw_true;
5431 lexer_.Lex();
5432 return true;
5433 }
5434
ParseToken(TokKind kind,const std::string & msg)5435 bool HloParserImpl::ParseToken(TokKind kind, const std::string& msg) {
5436 VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
5437 if (lexer_.GetKind() != kind) {
5438 return TokenError(msg);
5439 }
5440 lexer_.Lex();
5441 return true;
5442 }
5443
EatIfPresent(TokKind kind)5444 bool HloParserImpl::EatIfPresent(TokKind kind) {
5445 if (lexer_.GetKind() != kind) {
5446 return false;
5447 }
5448 lexer_.Lex();
5449 return true;
5450 }
5451
AddInstruction(const std::string & name,HloInstruction * instruction,LocTy name_loc)5452 bool HloParserImpl::AddInstruction(const std::string& name,
5453 HloInstruction* instruction,
5454 LocTy name_loc) {
5455 auto result = current_name_table().insert({name, {instruction, name_loc}});
5456 if (!result.second) {
5457 Error(name_loc, StrCat("instruction already exists: ", name));
5458 return Error(/*loc=*/result.first->second.second,
5459 "instruction previously defined here");
5460 }
5461 return true;
5462 }
5463
AddComputation(const std::string & name,HloComputation * computation,LocTy name_loc)5464 bool HloParserImpl::AddComputation(const std::string& name,
5465 HloComputation* computation,
5466 LocTy name_loc) {
5467 auto result = computation_pool_.insert({name, {computation, name_loc}});
5468 if (!result.second) {
5469 Error(name_loc, StrCat("computation already exists: ", name));
5470 return Error(/*loc=*/result.first->second.second,
5471 "computation previously defined here");
5472 }
5473 return true;
5474 }
5475
ParseShapeOnly()5476 StatusOr<Shape> HloParserImpl::ParseShapeOnly() {
5477 lexer_.Lex();
5478 Shape shape;
5479 if (!ParseShape(&shape)) {
5480 return InvalidArgument("Syntax error:\n%s", GetError());
5481 }
5482 if (lexer_.GetKind() != TokKind::kEof) {
5483 return InvalidArgument("Syntax error:\nExtra content after shape");
5484 }
5485 return shape;
5486 }
5487
ParseShardingOnly()5488 StatusOr<HloSharding> HloParserImpl::ParseShardingOnly() {
5489 lexer_.Lex();
5490 OpSharding op_sharding;
5491 if (!ParseSharding(&op_sharding)) {
5492 return InvalidArgument("Syntax error:\n%s", GetError());
5493 }
5494 if (lexer_.GetKind() != TokKind::kEof) {
5495 return InvalidArgument("Syntax error:\nExtra content after sharding");
5496 }
5497 return HloSharding::FromProto(op_sharding);
5498 }
5499
ParseFrontendAttributesOnly()5500 StatusOr<FrontendAttributes> HloParserImpl::ParseFrontendAttributesOnly() {
5501 lexer_.Lex();
5502 FrontendAttributes attributes;
5503 if (!ParseFrontendAttributes(&attributes)) {
5504 return InvalidArgument("Syntax error:\n%s", GetError());
5505 }
5506 if (lexer_.GetKind() != TokKind::kEof) {
5507 return InvalidArgument(
5508 "Syntax error:\nExtra content after frontend attributes");
5509 }
5510 return attributes;
5511 }
5512
ParseParameterReplicationOnly()5513 StatusOr<std::vector<bool>> HloParserImpl::ParseParameterReplicationOnly() {
5514 lexer_.Lex();
5515 ParameterReplication parameter_replication;
5516 if (!ParseParameterReplication(¶meter_replication)) {
5517 return InvalidArgument("Syntax error:\n%s", GetError());
5518 }
5519 if (lexer_.GetKind() != TokKind::kEof) {
5520 return InvalidArgument(
5521 "Syntax error:\nExtra content after parameter replication");
5522 }
5523 return std::vector<bool>(
5524 parameter_replication.replicated_at_leaf_buffers().begin(),
5525 parameter_replication.replicated_at_leaf_buffers().end());
5526 }
5527
ParseReplicaGroupsOnly()5528 StatusOr<std::vector<ReplicaGroup>> HloParserImpl::ParseReplicaGroupsOnly() {
5529 lexer_.Lex();
5530 std::vector<ReplicaGroup> replica_groups;
5531 if (!ParseReplicaGroupsOnly(&replica_groups)) {
5532 return InvalidArgument("Syntax error:\n%s", GetError());
5533 }
5534 if (lexer_.GetKind() != TokKind::kEof) {
5535 return InvalidArgument("Syntax error:\nExtra content after replica groups");
5536 }
5537 return replica_groups;
5538 }
5539
ParseWindowOnly()5540 StatusOr<Window> HloParserImpl::ParseWindowOnly() {
5541 lexer_.Lex();
5542 Window window;
5543 if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
5544 return InvalidArgument("Syntax error:\n%s", GetError());
5545 }
5546 if (lexer_.GetKind() != TokKind::kEof) {
5547 return InvalidArgument("Syntax error:\nExtra content after window");
5548 }
5549 return window;
5550 }
5551
5552 StatusOr<ConvolutionDimensionNumbers>
ParseConvolutionDimensionNumbersOnly()5553 HloParserImpl::ParseConvolutionDimensionNumbersOnly() {
5554 lexer_.Lex();
5555 ConvolutionDimensionNumbers dnums;
5556 if (!ParseConvolutionDimensionNumbers(&dnums)) {
5557 return InvalidArgument("Syntax error:\n%s", GetError());
5558 }
5559 if (lexer_.GetKind() != TokKind::kEof) {
5560 return InvalidArgument(
5561 "Syntax error:\nExtra content after convolution dnums");
5562 }
5563 return dnums;
5564 }
5565
ParsePaddingConfigOnly()5566 StatusOr<PaddingConfig> HloParserImpl::ParsePaddingConfigOnly() {
5567 lexer_.Lex();
5568 PaddingConfig padding_config;
5569 if (!ParsePaddingConfig(&padding_config)) {
5570 return InvalidArgument("Syntax error:\n%s", GetError());
5571 }
5572 if (lexer_.GetKind() != TokKind::kEof) {
5573 return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
5574 }
5575 return padding_config;
5576 }
5577
ParseSingleInstruction(HloModule * module)5578 bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
5579 if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
5580 LOG(FATAL) << "Parser state is not clean. Please do not call any other "
5581 "methods before calling ParseSingleInstruction.";
5582 }
5583 HloComputation::Builder builder(module->name());
5584
5585 // The missing instruction hook we register creates the shaped instruction on
5586 // the fly as a parameter and returns it.
5587 int64_t parameter_count = 0;
5588 create_missing_instruction_ =
5589 [this, &builder, ¶meter_count](
5590 const std::string& name,
5591 const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
5592 std::string new_name = name.empty() ? StrCat("_", parameter_count) : name;
5593 HloInstruction* parameter = builder.AddInstruction(
5594 HloInstruction::CreateParameter(parameter_count++, shape, new_name));
5595 current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
5596 return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
5597 };
5598
5599 // Parse the instruction with the registered hook.
5600 Scope scope(&scoped_name_tables_);
5601 if (CanBeShape()) {
5602 // This means that the instruction's left-hand side is probably omitted,
5603 // e.g.
5604 //
5605 // f32[10] fusion(...), calls={...}
5606 if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
5607 return false;
5608 }
5609 } else {
5610 // This means that the instruction's left-hand side might exist, e.g.
5611 //
5612 // foo = f32[10] fusion(...), calls={...}
5613 std::string root_name;
5614 if (!ParseInstruction(&builder, &root_name)) {
5615 return false;
5616 }
5617 }
5618
5619 if (lexer_.GetKind() != TokKind::kEof) {
5620 Error(
5621 lexer_.GetLoc(),
5622 "Syntax error:\nExpected eof after parsing single instruction. Did "
5623 "you mean to write an HLO module and forget the \"HloModule\" header?");
5624 return false;
5625 }
5626
5627 module->AddEntryComputation(builder.Build());
5628 for (auto& comp : computations_) {
5629 module->AddEmbeddedComputation(std::move(comp));
5630 }
5631 TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
5632 return true;
5633 }
5634
5635 } // namespace
5636
ParseAndReturnUnverifiedModule(absl::string_view str,const HloModuleConfig & config)5637 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5638 absl::string_view str, const HloModuleConfig& config) {
5639 auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
5640 HloParserImpl parser(str);
5641 TF_RETURN_IF_ERROR(parser.Run(module.get()));
5642 return std::move(module);
5643 }
5644
ParseAndReturnUnverifiedModule(absl::string_view str)5645 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5646 absl::string_view str) {
5647 return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
5648 }
5649
ParseSharding(absl::string_view str)5650 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
5651 HloParserImpl parser(str);
5652 return parser.ParseShardingOnly();
5653 }
5654
ParseFrontendAttributes(absl::string_view str)5655 StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
5656 HloParserImpl parser(str);
5657 return parser.ParseFrontendAttributesOnly();
5658 }
5659
ParseParameterReplication(absl::string_view str)5660 StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
5661 HloParserImpl parser(str);
5662 return parser.ParseParameterReplicationOnly();
5663 }
5664
ParseReplicaGroupsOnly(absl::string_view str)5665 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
5666 absl::string_view str) {
5667 HloParserImpl parser(str);
5668 return parser.ParseReplicaGroupsOnly();
5669 }
5670
ParseWindow(absl::string_view str)5671 StatusOr<Window> ParseWindow(absl::string_view str) {
5672 HloParserImpl parser(str);
5673 return parser.ParseWindowOnly();
5674 }
5675
ParseConvolutionDimensionNumbers(absl::string_view str)5676 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
5677 absl::string_view str) {
5678 HloParserImpl parser(str);
5679 return parser.ParseConvolutionDimensionNumbersOnly();
5680 }
5681
ParsePaddingConfig(absl::string_view str)5682 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
5683 HloParserImpl parser(str);
5684 return parser.ParsePaddingConfigOnly();
5685 }
5686
ParseShape(absl::string_view str)5687 StatusOr<Shape> ParseShape(absl::string_view str) {
5688 HloParserImpl parser(str);
5689 return parser.ParseShapeOnly();
5690 }
5691
CreateHloParserForTests(absl::string_view str)5692 std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
5693 absl::string_view str) {
5694 return absl::make_unique<HloParserImpl>(str);
5695 }
5696
5697 } // namespace xla
5698