1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
17
18 #include <limits>
19 #include <string>
20 #include <unordered_map>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/ascii.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_split.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/lib/strings/numbers.h"
32 #include "tensorflow/core/platform/regexp.h"
33
34 namespace xla {
35 namespace {
36
37 using absl::string_view;
38
39 constexpr int kEOF = -1;
40 constexpr int kError = -2;
41
42 // [a-zA-Z0-9_.-]
IsIdentifierChar(char c)43 bool IsIdentifierChar(char c) {
44 return absl::ascii_isalnum(static_cast<unsigned char>(c)) || c == '-' ||
45 c == '.' || c == '_';
46 }
47
48 } // namespace
49
GetNextChar()50 int HloLexer::GetNextChar() {
51 int current_char = PeekCurrentChar();
52 if (current_char != kEOF && current_char != kError) {
53 current_ptr_++;
54 }
55 return current_char;
56 }
57
PeekCurrentChar() const58 int HloLexer::PeekCurrentChar() const {
59 if (current_ptr_ == buf_.end()) {
60 return kEOF;
61 }
62 char current_char = *current_ptr_;
63 if (current_char == 0) {
64 // '\0' should not appear in the middle of the string.
65 return kError;
66 }
67 return static_cast<unsigned char>(current_char);
68 }
69
CanDereference(const char * ptr) const70 bool HloLexer::CanDereference(const char* ptr) const {
71 return ptr < buf_.end() && ptr >= buf_.begin();
72 }
73
StringPieceFromPointers(const char * begin,const char * end) const74 absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
75 const char* end) const {
76 CHECK(begin <= end);
77 CHECK(begin == buf_.end() || CanDereference(begin));
78 CHECK(end == buf_.end() || CanDereference(end));
79 return absl::string_view(begin, end - begin);
80 }
81
LookAhead()82 TokKind HloLexer::LookAhead() {
83 if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
84 return GetKind();
85 }
86
87 const char* old_current_ptr = current_ptr_;
88 TokenState old_token_state = token_state_;
89 Lex();
90 TokKind kind = GetKind();
91 token_state_ = old_token_state;
92 current_ptr_ = old_current_ptr;
93 return kind;
94 }
95
LexToken()96 TokKind HloLexer::LexToken() {
97 while (true) {
98 token_state_.token_start = current_ptr_;
99
100 int current_char = GetNextChar();
101 switch (current_char) {
102 default:
103 // [a-zA-Z_]
104 if (absl::ascii_isalpha(static_cast<unsigned char>(current_char)) ||
105 current_char == '_') {
106 return LexIdentifier();
107 }
108 return TokKind::kError;
109 case kEOF:
110 // Hit the end of the input buffer.
111 return TokKind::kEof;
112 case kError:
113 // Hit an invalid character in the input buffer.
114 return TokKind::kError;
115 case ' ':
116 case '\t':
117 case '\n':
118 case '\r':
119 // Ignore whitespace.
120 continue;
121 case '0':
122 case '1':
123 case '2':
124 case '3':
125 case '4':
126 case '5':
127 case '6':
128 case '7':
129 case '8':
130 case '9':
131 case '-':
132 case '?':
133 if (current_char == '-' && PeekCurrentChar() == '>') {
134 current_ptr_++;
135 return TokKind::kArrow;
136 }
137 return LexNumberOrPattern();
138 case '=':
139 return TokKind::kEqual;
140 case '<':
141 if (current_char == '<' && PeekCurrentChar() == '=') {
142 current_ptr_++;
143 return TokKind::kLeq;
144 }
145 return TokKind::kError;
146 case ',':
147 return TokKind::kComma;
148 case '%':
149 return LexPercent();
150 case ':':
151 return TokKind::kColon;
152 case '*':
153 return TokKind::kAsterisk;
154 case '[':
155 return TokKind::kLsquare;
156 case ']':
157 return TokKind::kRsquare;
158 case '{':
159 return TokKind::kLbrace;
160 case '}':
161 return TokKind::kRbrace;
162 case '(':
163 return TokKind::kLparen;
164 case ')':
165 return TokKind::kRparen;
166 case '/': {
167 if (PeekCurrentChar() == '*') {
168 // This is the start of a /*...*/ delimited comment. Save the current
169 // location in case the comment is unterminated so the error message
170 // will point to the beginning of the comment.
171 const char* comment_start = current_ptr_;
172 current_ptr_++;
173 // Advance until '*/' is found.
174 while (true) {
175 int current = GetNextChar();
176 if (current == '*' && PeekCurrentChar() == '/') {
177 // End of comment.
178 current_ptr_++;
179 break;
180 }
181 if (current == kEOF) {
182 // Unterminated comment.
183 current_ptr_ = comment_start;
184 return TokKind::kError;
185 }
186 if (current == kError) {
187 return TokKind::kError;
188 }
189 }
190 // Return no token for the comment. Keep lexing.
191 continue;
192 } else if (PeekCurrentChar() == '/') {
193 // This is the start of a '//' delimited comment. Throw away
194 // everything until end of line or file. The end-of-line character(s)
195 // are left unlexed in the buffer which is harmless because these are
196 // skipped later by the lexer. This approach enables support for
197 // different end-of-line encodings.
198 while (true) {
199 int current = PeekCurrentChar();
200 if (current == kEOF || current == '\n' || current == '\r') {
201 break;
202 }
203 if (current == kError) {
204 return TokKind::kError;
205 }
206 current_ptr_++;
207 }
208 continue;
209 }
210 // A lone '/' is an error.
211 return TokKind::kError;
212 }
213 case '.':
214 if (PeekCurrentChar() == '.') {
215 current_ptr_++;
216 if (PeekCurrentChar() == '.') {
217 current_ptr_++;
218 return TokKind::kDots;
219 }
220 }
221 return TokKind::kError;
222 case '"':
223 return LexString();
224 }
225 }
226 }
227
LexNanPayload(absl::string_view & consumable)228 absl::optional<int64> HloLexer::LexNanPayload(absl::string_view& consumable) {
229 static LazyRE2 payload_pattern = {R"(\(0x[0-9a-fA-F]+\))"};
230 if (!RE2::Consume(&consumable, *payload_pattern)) {
231 return absl::nullopt;
232 }
233 auto slice = StringPieceFromPointers(current_ptr_, consumable.begin());
234 current_ptr_ = consumable.begin();
235 CHECK(absl::StartsWith(slice, "(0x"));
236 slice.remove_prefix(std::strlen("(0x"));
237 CHECK(absl::EndsWith(slice, ")"));
238 slice.remove_suffix(std::strlen(")"));
239 uint64 payload_value;
240 if (tensorflow::strings::HexStringToUint64(slice, &payload_value)) {
241 if (payload_value <= 0 || payload_value > NanPayloadBitMask<double>()) {
242 LOG(ERROR) << "NaN payload out of range: " << payload_value;
243 return absl::nullopt;
244 }
245 return payload_value;
246 }
247 return absl::nullopt;
248 }
249
250 // Lex a shape, name, keyword, attribute name, the dim labels pattern, and
251 // other identifiers.
252 //
253 // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
254 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
255 // keyword ::= HloModule, ENTRY, ...
256 // attribute_name ::= condition, body, dimensions, ...
257 // dim_labels_pattern ::= [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,}
258 // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
LexIdentifier()259 TokKind HloLexer::LexIdentifier() {
260 while (IsIdentifierChar(PeekCurrentChar())) {
261 current_ptr_++;
262 }
263
264 // If followed by ':', it's a name.
265 if (PeekCurrentChar() == ':') {
266 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
267 current_ptr_++; // skip ':'
268 return TokKind::kName;
269 }
270
271 // If followed by '=', it's a attribute name.
272 if (PeekCurrentChar() == '=') {
273 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
274 current_ptr_++; // skip '='
275 return TokKind::kAttributeName;
276 }
277
278 absl::string_view identifier =
279 StringPieceFromPointers(token_state_.token_start, current_ptr_);
280
281 // Primitive type strings are reserved words. The exception is 'tuple' whose
282 // type is represented using nested parentheses without the string 'tuple'.
283 if (primitive_util::IsPrimitiveTypeName(identifier)) {
284 PrimitiveType primitive_type =
285 primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
286 if (primitive_type != TUPLE) {
287 token_state_.primitive_type_val = primitive_type;
288 return TokKind::kPrimitiveType;
289 }
290 }
291
292 if (identifier == "nan") {
293 absl::optional<int64_t> payload;
294 if (PeekCurrentChar() == '(') {
295 absl::string_view consumable =
296 StringPieceFromPointers(current_ptr_, buf_.end());
297 payload = LexNanPayload(consumable);
298 if (!payload.has_value()) {
299 return TokKind::kError;
300 }
301 }
302 token_state_.decimal_val = NanWithSignAndPayload<double>(
303 /*sign=*/false, payload.value_or(QuietNanWithoutPayload<double>()));
304 return TokKind::kDecimal;
305 }
306
307 // See if this is a keyword.
308 #define KEYWORD(STR) \
309 do { \
310 if (identifier == #STR) { \
311 return TokKind::kw_##STR; \
312 } \
313 } while (false)
314
315 KEYWORD(true);
316 KEYWORD(false);
317 KEYWORD(inf);
318 KEYWORD(HloModule);
319 KEYWORD(ENTRY);
320 KEYWORD(ROOT);
321 KEYWORD(maximal);
322 KEYWORD(replicated);
323 KEYWORD(manual);
324 KEYWORD(last_tile_dim_replicate);
325
326 #undef KEYWORD
327
328 {
329 absl::string_view consumable =
330 StringPieceFromPointers(token_state_.token_start, buf_.end());
331 static LazyRE2 dim_labels_pattern = {
332 R"([0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,})"};
333 if (RE2::Consume(&consumable, *dim_labels_pattern)) {
334 current_ptr_ = consumable.begin();
335 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
336 return TokKind::kDimLabels;
337 }
338 }
339
340 token_state_.str_val = string(identifier);
341 return TokKind::kIdent;
342 }
343
344 // Lex names after a % character.
345 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
LexPercent()346 TokKind HloLexer::LexPercent() {
347 const char* name_start = current_ptr_;
348 if (absl::ascii_isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
349 PeekCurrentChar() == '_') {
350 current_ptr_++;
351 while (IsIdentifierChar(PeekCurrentChar())) {
352 current_ptr_++;
353 }
354 token_state_.str_val.assign(name_start, current_ptr_);
355 return TokKind::kName;
356 }
357 return TokKind::kError;
358 }
359
360 // Lex integer and floating-point values, -inf, and patterns for dim labels,
361 // dxd (e.g. 1x2x3), and pad.
362 //
363 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
364 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
365 // dim_labels_pattern ::= [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,}
366 // dxd_pattern ::= [0-9]+(x[0-9]+)+
367 // pad_pattern ::=
368 // [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
369 // int ::= [-]?[0-9]+
370 // negative inf ::= '-inf'
LexNumberOrPattern()371 TokKind HloLexer::LexNumberOrPattern() {
372 absl::string_view consumable =
373 StringPieceFromPointers(token_state_.token_start, buf_.end());
374 static LazyRE2 float_pattern = {
375 R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
376 if (RE2::Consume(&consumable, *float_pattern)) {
377 current_ptr_ = consumable.begin();
378 CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_),
379 &token_state_.decimal_val));
380 return TokKind::kDecimal;
381 }
382
383 static LazyRE2 dim_labels_pattern = {
384 R"([0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,})"};
385 static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
386 static LazyRE2 pad_pattern = {
387 R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
388
389 if (RE2::Consume(&consumable, *dim_labels_pattern)) {
390 current_ptr_ = consumable.begin();
391 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
392 return TokKind::kDimLabels;
393 }
394
395 if (RE2::Consume(&consumable, *dxd_pattern)) {
396 current_ptr_ = consumable.begin();
397 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
398 return TokKind::kDxD;
399 }
400
401 if (RE2::Consume(&consumable, *pad_pattern)) {
402 current_ptr_ = consumable.begin();
403 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
404 return TokKind::kPad;
405 }
406
407 static LazyRE2 int_pattern = {R"([-]?\d+)"};
408 if (RE2::Consume(&consumable, *int_pattern)) {
409 current_ptr_ = consumable.begin();
410 auto slice =
411 StringPieceFromPointers(token_state_.token_start, current_ptr_);
412 if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
413 return TokKind::kInt;
414 }
415 uint64 uint64_val;
416 if (absl::SimpleAtoi(slice, &uint64_val)) {
417 token_state_.int64_val = absl::bit_cast<int64>(uint64_val);
418 return TokKind::kInt;
419 }
420 LOG(ERROR) << "Failed to parse int literal: " << slice;
421 return TokKind::kError;
422 }
423
424 static LazyRE2 neg_inf = {"-inf"};
425 if (RE2::Consume(&consumable, *neg_inf)) {
426 current_ptr_ = consumable.begin();
427 return TokKind::kNegInf;
428 }
429
430 static LazyRE2 neg_nan = {"-nan"};
431 if (RE2::Consume(&consumable, *neg_nan)) {
432 current_ptr_ = consumable.begin();
433
434 absl::optional<int64_t> payload;
435 if (PeekCurrentChar() == '(') {
436 payload = LexNanPayload(consumable);
437 if (!payload.has_value()) {
438 return TokKind::kError;
439 }
440 }
441 token_state_.decimal_val = NanWithSignAndPayload<double>(
442 /*sign=*/true, payload.value_or(QuietNanWithoutPayload<double>()));
443 return TokKind::kDecimal;
444 }
445
446 return TokKind::kError;
447 }
448
GetLineAndColumn(LocTy location) const449 std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
450 unsigned line_no = 1;
451 const char* start = buf_.begin();
452 const char* ptr = start;
453 if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) &&
454 line_no_cache_.last_query <= location) {
455 ptr = line_no_cache_.last_query;
456 line_no = line_no_cache_.line_no_of_query;
457 }
458 for (; ptr != location; ptr++) {
459 CHECK_LT(ptr, buf_.end());
460 if (*ptr == '\n') {
461 line_no++;
462 }
463 }
464
465 // Update the line number cache.
466 line_no_cache_.last_query = ptr;
467 line_no_cache_.line_no_of_query = line_no;
468 size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
469 if (line_offset == absl::string_view::npos) {
470 line_offset = 0;
471 }
472 return {line_no, ptr - start - line_offset};
473 }
474
GetLine(LocTy loc) const475 absl::string_view HloLexer::GetLine(LocTy loc) const {
476 if (!CanDereference(loc)) {
477 return "LINE OUT OF RANGE";
478 }
479 size_t line_start =
480 StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
481 const char* start = line_start == absl::string_view::npos
482 ? buf_.begin()
483 : buf_.begin() + line_start + 1;
484 size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
485 const char* end =
486 line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
487
488 return StringPieceFromPointers(start, end);
489 }
490
491 // Lexes quoted string with escaping characters. If matched, the quoted string
492 // will be unescaped and stored to token_state_.str_val.
LexString()493 TokKind HloLexer::LexString() {
494 absl::string_view consumable =
495 StringPieceFromPointers(token_state_.token_start, buf_.end());
496 static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
497 if (RE2::Consume(&consumable, *escaping_pattern)) {
498 current_ptr_ = consumable.begin();
499 absl::string_view raw =
500 StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
501 string error;
502 if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
503 LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
504 return TokKind::kError;
505 }
506 return TokKind::kString;
507 }
508 return TokKind::kError;
509 }
510
511 string TokKindToString(TokKind kind) {
512 switch (kind) {
513 case TokKind::kEof:
514 return "kEof";
515 case TokKind::kError:
516 return "kError";
517 case TokKind::kEqual:
518 return "kEqaul";
519 case TokKind::kComma:
520 return "kComma";
521 case TokKind::kColon:
522 return "kColon";
523 case TokKind::kAsterisk:
524 return "kAsterisk";
525 case TokKind::kLsquare:
526 return "kLsquare";
527 case TokKind::kRsquare:
528 return "kRsquare";
529 case TokKind::kLbrace:
530 return "kLbrace";
531 case TokKind::kRbrace:
532 return "kRbrace";
533 case TokKind::kLparen:
534 return "kLparen";
535 case TokKind::kRparen:
536 return "kRparen";
537 case TokKind::kArrow:
538 return "kArrow";
539 case TokKind::kLeq:
540 return "kLeq";
541 case TokKind::kw_HloModule:
542 return "kw_HloModule";
543 case TokKind::kw_ENTRY:
544 return "kw_ENTRY";
545 case TokKind::kw_ROOT:
546 return "kw_ROOT";
547 case TokKind::kw_true:
548 return "kw_true";
549 case TokKind::kw_false:
550 return "kw_false";
551 case TokKind::kw_maximal:
552 return "kw_maximal";
553 case TokKind::kw_replicated:
554 return "kw_replicated";
555 case TokKind::kw_manual:
556 return "kw_manual";
557 case TokKind::kw_last_tile_dim_replicate:
558 return "kw_last_tile_dim_replicate";
559 case TokKind::kw_inf:
560 return "kw_inf";
561 case TokKind::kNegInf:
562 return "kNegInf";
563 case TokKind::kPrimitiveType:
564 return "kPrimitiveType";
565 case TokKind::kName:
566 return "kName";
567 case TokKind::kAttributeName:
568 return "kAttributeName";
569 case TokKind::kDimLabels:
570 return "kDimLabels";
571 case TokKind::kDxD:
572 return "kDxD";
573 case TokKind::kPad:
574 return "kPad";
575 case TokKind::kIdent:
576 return "kIdent";
577 case TokKind::kString:
578 return "kString";
579 case TokKind::kInt:
580 return "kInt";
581 case TokKind::kDecimal:
582 return "kDecimal";
583 case TokKind::kDots:
584 return "kDots";
585 }
586 }
587
588 } // namespace xla
589