1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
17
18 #include <unordered_map>
19
20 #include "absl/base/casts.h"
21 #include "absl/strings/ascii.h"
22 #include "absl/strings/escaping.h"
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/platform/regexp.h"
31
32 namespace xla {
33 namespace {
34
35 using absl::string_view;
36
37 constexpr int kEOF = -1;
38 constexpr int kError = -2;
39
40 // [a-zA-Z0-9_.-]
IsIdentifierChar(char c)41 bool IsIdentifierChar(char c) {
42 return absl::ascii_isalnum(static_cast<unsigned char>(c)) || c == '-' ||
43 c == '.' || c == '_';
44 }
45
46 } // namespace
47
GetNextChar()48 int HloLexer::GetNextChar() {
49 int current_char = PeekCurrentChar();
50 if (current_char != kEOF && current_char != kError) {
51 current_ptr_++;
52 }
53 return current_char;
54 }
55
PeekCurrentChar() const56 int HloLexer::PeekCurrentChar() const {
57 if (current_ptr_ == buf_.end()) {
58 return kEOF;
59 }
60 char current_char = *current_ptr_;
61 if (current_char == 0) {
62 // '\0' should not appear in the middle of the string.
63 return kError;
64 }
65 return static_cast<unsigned char>(current_char);
66 }
67
CanDereference(const char * ptr) const68 bool HloLexer::CanDereference(const char* ptr) const {
69 return ptr < buf_.end() && ptr >= buf_.begin();
70 }
71
StringPieceFromPointers(const char * begin,const char * end) const72 absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
73 const char* end) const {
74 CHECK(begin <= end);
75 CHECK(begin == buf_.end() || CanDereference(begin));
76 CHECK(end == buf_.end() || CanDereference(end));
77 return absl::string_view(begin, end - begin);
78 }
79
LookAhead()80 TokKind HloLexer::LookAhead() {
81 if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
82 return GetKind();
83 }
84
85 const char* old_current_ptr = current_ptr_;
86 TokenState old_token_state = token_state_;
87 Lex();
88 TokKind kind = GetKind();
89 token_state_ = old_token_state;
90 current_ptr_ = old_current_ptr;
91 return kind;
92 }
93
LexToken()94 TokKind HloLexer::LexToken() {
95 while (true) {
96 token_state_.token_start = current_ptr_;
97
98 int current_char = GetNextChar();
99 switch (current_char) {
100 default:
101 // [a-zA-Z_]
102 if (absl::ascii_isalpha(static_cast<unsigned char>(current_char)) ||
103 current_char == '_') {
104 return LexIdentifier();
105 }
106 return TokKind::kError;
107 case kEOF:
108 // Hit the end of the input buffer.
109 return TokKind::kEof;
110 case kError:
111 // Hit an invalid character in the input buffer.
112 return TokKind::kError;
113 case ' ':
114 case '\t':
115 case '\n':
116 case '\r':
117 // Ignore whitespace.
118 continue;
119 case '0':
120 case '1':
121 case '2':
122 case '3':
123 case '4':
124 case '5':
125 case '6':
126 case '7':
127 case '8':
128 case '9':
129 case '-':
130 if (current_char == '-' && PeekCurrentChar() == '>') {
131 current_ptr_++;
132 return TokKind::kArrow;
133 }
134 return LexNumberOrPattern();
135 case '=':
136 return TokKind::kEqual;
137 case '<':
138 if (current_char == '<' && PeekCurrentChar() == '=') {
139 current_ptr_++;
140 return TokKind::kLeq;
141 }
142 return TokKind::kError;
143 case ',':
144 return TokKind::kComma;
145 case '%':
146 return LexPercent();
147 case ':':
148 return TokKind::kColon;
149 case '*':
150 return TokKind::kAsterisk;
151 case '[':
152 return TokKind::kLsquare;
153 case ']':
154 return TokKind::kRsquare;
155 case '{':
156 return TokKind::kLbrace;
157 case '}':
158 return TokKind::kRbrace;
159 case '(':
160 return TokKind::kLparen;
161 case ')':
162 return TokKind::kRparen;
163 case '/': {
164 if (PeekCurrentChar() == '*') {
165 // This is the start of a /*...*/ delimited comment. Save the current
166 // location in case the comment is unterminated so the error message
167 // will point to the beginning of the comment.
168 const char* comment_start = current_ptr_;
169 current_ptr_++;
170 // Advance until '*/' is found.
171 while (true) {
172 int current = GetNextChar();
173 if (current == '*' && PeekCurrentChar() == '/') {
174 // End of comment.
175 current_ptr_++;
176 break;
177 }
178 if (current == kEOF) {
179 // Unterminated comment.
180 current_ptr_ = comment_start;
181 return TokKind::kError;
182 }
183 if (current == kError) {
184 return TokKind::kError;
185 }
186 }
187 // Return no token for the comment. Keep lexing.
188 continue;
189 } else if (PeekCurrentChar() == '/') {
190 // This is the start of a '//' delimited comment. Throw away
191 // everything until end of line or file. The end-of-line character(s)
192 // are left unlexed in the buffer which is harmless because these are
193 // skipped later by the lexer. This approach enables support for
194 // different end-of-line encodings.
195 while (true) {
196 int current = PeekCurrentChar();
197 if (current == kEOF || current == '\n' || current == '\r') {
198 break;
199 }
200 if (current == kError) {
201 return TokKind::kError;
202 }
203 current_ptr_++;
204 }
205 continue;
206 }
207 // A lone '/' is an error.
208 return TokKind::kError;
209 }
210 case '.':
211 if (PeekCurrentChar() == '.') {
212 current_ptr_++;
213 if (PeekCurrentChar() == '.') {
214 current_ptr_++;
215 return TokKind::kDots;
216 }
217 }
218 return TokKind::kError;
219 case '"':
220 return LexString();
221 }
222 }
223 }
224
225 // Lex a shape, name, keyword, attribute name, the dim labels pattern, and
226 // other identifiers.
227 //
228 // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
229 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
230 // keyword ::= HloModule, ENTRY, ...
231 // attribute_name ::= condition, body, dimensions, ...
232 // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
233 // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
LexIdentifier()234 TokKind HloLexer::LexIdentifier() {
235 while (IsIdentifierChar(PeekCurrentChar())) {
236 current_ptr_++;
237 }
238
239 // If followed by ':', it's a name.
240 if (PeekCurrentChar() == ':') {
241 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
242 current_ptr_++; // skip ':'
243 return TokKind::kName;
244 }
245
246 // If followed by '=', it's a attribute name.
247 if (PeekCurrentChar() == '=') {
248 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
249 current_ptr_++; // skip '='
250 return TokKind::kAttributeName;
251 }
252
253 absl::string_view identifier =
254 StringPieceFromPointers(token_state_.token_start, current_ptr_);
255
256 // Primitive type strings are reserved words. The exception is 'tuple' whose
257 // type is represented using nested parentheses without the string 'tuple'.
258 if (primitive_util::IsPrimitiveTypeName(identifier)) {
259 PrimitiveType primitive_type =
260 primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
261 if (primitive_type != TUPLE) {
262 token_state_.primitive_type_val = primitive_type;
263 return TokKind::kPrimitiveType;
264 }
265 }
266
267 // See if this is a keyword.
268 #define KEYWORD(STR) \
269 do { \
270 if (identifier == #STR) { \
271 return TokKind::kw_##STR; \
272 } \
273 } while (false)
274
275 KEYWORD(true);
276 KEYWORD(false);
277 KEYWORD(inf);
278 KEYWORD(nan);
279 KEYWORD(HloModule);
280 KEYWORD(ENTRY);
281 KEYWORD(ROOT);
282 KEYWORD(maximal);
283 KEYWORD(replicated);
284 KEYWORD(manual);
285 KEYWORD(last_tile_dim_replicate);
286
287 #undef KEYWORD
288
289 {
290 absl::string_view consumable =
291 StringPieceFromPointers(token_state_.token_start, buf_.end());
292 static LazyRE2 dim_labels_pattern = {
293 R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
294 if (RE2::Consume(&consumable, *dim_labels_pattern)) {
295 current_ptr_ = consumable.begin();
296 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
297 return TokKind::kDimLabels;
298 }
299 }
300
301 token_state_.str_val = string(identifier);
302 return TokKind::kIdent;
303 }
304
305 // Lex names after a % character.
306 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
LexPercent()307 TokKind HloLexer::LexPercent() {
308 const char* name_start = current_ptr_;
309 if (absl::ascii_isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
310 PeekCurrentChar() == '_') {
311 current_ptr_++;
312 while (IsIdentifierChar(PeekCurrentChar())) {
313 current_ptr_++;
314 }
315 token_state_.str_val.assign(name_start, current_ptr_);
316 return TokKind::kName;
317 }
318 return TokKind::kError;
319 }
320
321 // Lex integer and floating-point values, -inf, and patterns for dim labels,
322 // dxd (e.g. 1x2x3), and pad.
323 //
324 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
325 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
326 // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
327 // dxd_pattern ::= [0-9]+(x[0-9]+)+
328 // pad_pattern ::=
329 // [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
330 // int ::= [-]?[0-9]+
331 // negative inf ::= '-inf'
LexNumberOrPattern()332 TokKind HloLexer::LexNumberOrPattern() {
333 absl::string_view consumable =
334 StringPieceFromPointers(token_state_.token_start, buf_.end());
335 static LazyRE2 float_pattern = {
336 R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
337 if (RE2::Consume(&consumable, *float_pattern)) {
338 current_ptr_ = consumable.begin();
339 CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_),
340 &token_state_.decimal_val));
341 return TokKind::kDecimal;
342 }
343
344 static LazyRE2 dim_labels_pattern = {
345 R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
346 static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
347 static LazyRE2 pad_pattern = {
348 R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
349
350 if (RE2::Consume(&consumable, *dim_labels_pattern)) {
351 current_ptr_ = consumable.begin();
352 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
353 return TokKind::kDimLabels;
354 }
355
356 if (RE2::Consume(&consumable, *dxd_pattern)) {
357 current_ptr_ = consumable.begin();
358 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
359 return TokKind::kDxD;
360 }
361
362 if (RE2::Consume(&consumable, *pad_pattern)) {
363 current_ptr_ = consumable.begin();
364 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
365 return TokKind::kPad;
366 }
367
368 static LazyRE2 int_pattern = {R"([-]?\d+)"};
369 if (RE2::Consume(&consumable, *int_pattern)) {
370 current_ptr_ = consumable.begin();
371 auto slice =
372 StringPieceFromPointers(token_state_.token_start, current_ptr_);
373 if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
374 return TokKind::kInt;
375 }
376 uint64 uint64_val;
377 if (absl::SimpleAtoi(slice, &uint64_val)) {
378 token_state_.int64_val = absl::bit_cast<int64>(uint64_val);
379 return TokKind::kInt;
380 }
381 LOG(ERROR) << "Failed to parse int literal: " << slice;
382 return TokKind::kError;
383 }
384
385 static LazyRE2 neg_inf = {"-inf"};
386 if (RE2::Consume(&consumable, *neg_inf)) {
387 current_ptr_ = consumable.begin();
388 return TokKind::kNegInf;
389 }
390
391 static LazyRE2 neg_nan = {"-nan"};
392 if (RE2::Consume(&consumable, *neg_nan)) {
393 current_ptr_ = consumable.begin();
394 return TokKind::kNegNan;
395 }
396
397 return TokKind::kError;
398 }
399
GetLineAndColumn(LocTy location) const400 std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
401 unsigned line_no = 1;
402 const char* start = buf_.begin();
403 const char* ptr = start;
404 if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) &&
405 line_no_cache_.last_query <= location) {
406 ptr = line_no_cache_.last_query;
407 line_no = line_no_cache_.line_no_of_query;
408 }
409 for (; ptr != location; ptr++) {
410 CHECK_LT(ptr, buf_.end());
411 if (*ptr == '\n') {
412 line_no++;
413 }
414 }
415
416 // Update the line number cache.
417 line_no_cache_.last_query = ptr;
418 line_no_cache_.line_no_of_query = line_no;
419 size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
420 if (line_offset == absl::string_view::npos) {
421 line_offset = 0;
422 }
423 return {line_no, ptr - start - line_offset};
424 }
425
GetLine(LocTy loc) const426 absl::string_view HloLexer::GetLine(LocTy loc) const {
427 if (!CanDereference(loc)) {
428 return "LINE OUT OF RANGE";
429 }
430 size_t line_start =
431 StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
432 const char* start = line_start == absl::string_view::npos
433 ? buf_.begin()
434 : buf_.begin() + line_start + 1;
435 size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
436 const char* end =
437 line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
438
439 return StringPieceFromPointers(start, end);
440 }
441
442 // Lexes quoted string with escaping characters. If matched, the quoted string
443 // will be unescaped and stored to token_state_.str_val.
LexString()444 TokKind HloLexer::LexString() {
445 absl::string_view consumable =
446 StringPieceFromPointers(token_state_.token_start, buf_.end());
447 static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
448 if (RE2::Consume(&consumable, *escaping_pattern)) {
449 current_ptr_ = consumable.begin();
450 absl::string_view raw =
451 StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
452 string error;
453 if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
454 LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
455 return TokKind::kError;
456 }
457 return TokKind::kString;
458 }
459 return TokKind::kError;
460 }
461
462 string TokKindToString(TokKind kind) {
463 switch (kind) {
464 case TokKind::kEof:
465 return "kEof";
466 case TokKind::kError:
467 return "kError";
468 case TokKind::kEqual:
469 return "kEqaul";
470 case TokKind::kComma:
471 return "kComma";
472 case TokKind::kColon:
473 return "kColon";
474 case TokKind::kAsterisk:
475 return "kAsterisk";
476 case TokKind::kLsquare:
477 return "kLsquare";
478 case TokKind::kRsquare:
479 return "kRsquare";
480 case TokKind::kLbrace:
481 return "kLbrace";
482 case TokKind::kRbrace:
483 return "kRbrace";
484 case TokKind::kLparen:
485 return "kLparen";
486 case TokKind::kRparen:
487 return "kRparen";
488 case TokKind::kArrow:
489 return "kArrow";
490 case TokKind::kLeq:
491 return "kLeq";
492 case TokKind::kw_HloModule:
493 return "kw_HloModule";
494 case TokKind::kw_ENTRY:
495 return "kw_ENTRY";
496 case TokKind::kw_ROOT:
497 return "kw_ROOT";
498 case TokKind::kw_true:
499 return "kw_true";
500 case TokKind::kw_false:
501 return "kw_false";
502 case TokKind::kw_maximal:
503 return "kw_maximal";
504 case TokKind::kw_replicated:
505 return "kw_replicated";
506 case TokKind::kw_manual:
507 return "kw_manual";
508 case TokKind::kw_last_tile_dim_replicate:
509 return "kw_last_tile_dim_replicate";
510 case TokKind::kw_nan:
511 return "kw_nan";
512 case TokKind::kw_inf:
513 return "kw_inf";
514 case TokKind::kNegNan:
515 return "kNegNan";
516 case TokKind::kNegInf:
517 return "kNegInf";
518 case TokKind::kPrimitiveType:
519 return "kPrimitiveType";
520 case TokKind::kName:
521 return "kName";
522 case TokKind::kAttributeName:
523 return "kAttributeName";
524 case TokKind::kDimLabels:
525 return "kDimLabels";
526 case TokKind::kDxD:
527 return "kDxD";
528 case TokKind::kPad:
529 return "kPad";
530 case TokKind::kIdent:
531 return "kIdent";
532 case TokKind::kString:
533 return "kString";
534 case TokKind::kInt:
535 return "kInt";
536 case TokKind::kDecimal:
537 return "kDecimal";
538 case TokKind::kDots:
539 return "kDots";
540 }
541 }
542
543 } // namespace xla
544