• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
17 
18 #include <unordered_map>
19 
20 #include "absl/base/casts.h"
21 #include "absl/strings/ascii.h"
22 #include "absl/strings/escaping.h"
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/platform/regexp.h"
31 
32 namespace xla {
33 namespace {
34 
35 using absl::string_view;
36 
37 constexpr int kEOF = -1;
38 constexpr int kError = -2;
39 
40 // [a-zA-Z0-9_.-]
IsIdentifierChar(char c)41 bool IsIdentifierChar(char c) {
42   return absl::ascii_isalnum(static_cast<unsigned char>(c)) || c == '-' ||
43          c == '.' || c == '_';
44 }
45 
46 }  // namespace
47 
GetNextChar()48 int HloLexer::GetNextChar() {
49   int current_char = PeekCurrentChar();
50   if (current_char != kEOF && current_char != kError) {
51     current_ptr_++;
52   }
53   return current_char;
54 }
55 
PeekCurrentChar() const56 int HloLexer::PeekCurrentChar() const {
57   if (current_ptr_ == buf_.end()) {
58     return kEOF;
59   }
60   char current_char = *current_ptr_;
61   if (current_char == 0) {
62     // '\0' should not appear in the middle of the string.
63     return kError;
64   }
65   return static_cast<unsigned char>(current_char);
66 }
67 
CanDereference(const char * ptr) const68 bool HloLexer::CanDereference(const char* ptr) const {
69   return ptr < buf_.end() && ptr >= buf_.begin();
70 }
71 
StringPieceFromPointers(const char * begin,const char * end) const72 absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
73                                                     const char* end) const {
74   CHECK(begin <= end);
75   CHECK(begin == buf_.end() || CanDereference(begin));
76   CHECK(end == buf_.end() || CanDereference(end));
77   return absl::string_view(begin, end - begin);
78 }
79 
LookAhead()80 TokKind HloLexer::LookAhead() {
81   if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
82     return GetKind();
83   }
84 
85   const char* old_current_ptr = current_ptr_;
86   TokenState old_token_state = token_state_;
87   Lex();
88   TokKind kind = GetKind();
89   token_state_ = old_token_state;
90   current_ptr_ = old_current_ptr;
91   return kind;
92 }
93 
LexToken()94 TokKind HloLexer::LexToken() {
95   while (true) {
96     token_state_.token_start = current_ptr_;
97 
98     int current_char = GetNextChar();
99     switch (current_char) {
100       default:
101         // [a-zA-Z_]
102         if (absl::ascii_isalpha(static_cast<unsigned char>(current_char)) ||
103             current_char == '_') {
104           return LexIdentifier();
105         }
106         return TokKind::kError;
107       case kEOF:
108         // Hit the end of the input buffer.
109         return TokKind::kEof;
110       case kError:
111         // Hit an invalid character in the input buffer.
112         return TokKind::kError;
113       case ' ':
114       case '\t':
115       case '\n':
116       case '\r':
117         // Ignore whitespace.
118         continue;
119       case '0':
120       case '1':
121       case '2':
122       case '3':
123       case '4':
124       case '5':
125       case '6':
126       case '7':
127       case '8':
128       case '9':
129       case '-':
130         if (current_char == '-' && PeekCurrentChar() == '>') {
131           current_ptr_++;
132           return TokKind::kArrow;
133         }
134         return LexNumberOrPattern();
135       case '=':
136         return TokKind::kEqual;
137       case '<':
138         if (current_char == '<' && PeekCurrentChar() == '=') {
139           current_ptr_++;
140           return TokKind::kLeq;
141         }
142         return TokKind::kError;
143       case ',':
144         return TokKind::kComma;
145       case '%':
146         return LexPercent();
147       case ':':
148         return TokKind::kColon;
149       case '*':
150         return TokKind::kAsterisk;
151       case '[':
152         return TokKind::kLsquare;
153       case ']':
154         return TokKind::kRsquare;
155       case '{':
156         return TokKind::kLbrace;
157       case '}':
158         return TokKind::kRbrace;
159       case '(':
160         return TokKind::kLparen;
161       case ')':
162         return TokKind::kRparen;
163       case '/': {
164         if (PeekCurrentChar() == '*') {
165           // This is the start of a /*...*/ delimited comment. Save the current
166           // location in case the comment is unterminated so the error message
167           // will point to the beginning of the comment.
168           const char* comment_start = current_ptr_;
169           current_ptr_++;
170           // Advance until '*/' is found.
171           while (true) {
172             int current = GetNextChar();
173             if (current == '*' && PeekCurrentChar() == '/') {
174               // End of comment.
175               current_ptr_++;
176               break;
177             }
178             if (current == kEOF) {
179               // Unterminated comment.
180               current_ptr_ = comment_start;
181               return TokKind::kError;
182             }
183             if (current == kError) {
184               return TokKind::kError;
185             }
186           }
187           // Return no token for the comment. Keep lexing.
188           continue;
189         } else if (PeekCurrentChar() == '/') {
190           // This is the start of a '//' delimited comment. Throw away
191           // everything until end of line or file. The end-of-line character(s)
192           // are left unlexed in the buffer which is harmless because these are
193           // skipped later by the lexer. This approach enables support for
194           // different end-of-line encodings.
195           while (true) {
196             int current = PeekCurrentChar();
197             if (current == kEOF || current == '\n' || current == '\r') {
198               break;
199             }
200             if (current == kError) {
201               return TokKind::kError;
202             }
203             current_ptr_++;
204           }
205           continue;
206         }
207         // A lone '/' is an error.
208         return TokKind::kError;
209       }
210       case '.':
211         if (PeekCurrentChar() == '.') {
212           current_ptr_++;
213           if (PeekCurrentChar() == '.') {
214             current_ptr_++;
215             return TokKind::kDots;
216           }
217         }
218         return TokKind::kError;
219       case '"':
220         return LexString();
221     }
222   }
223 }
224 
225 // Lex a shape, name, keyword, attribute name, the dim labels pattern, and
226 // other identifiers.
227 //
228 // shape    ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
229 // name     ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
230 // keyword  ::= HloModule, ENTRY, ...
231 // attribute_name ::= condition, body, dimensions, ...
232 // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
233 // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
LexIdentifier()234 TokKind HloLexer::LexIdentifier() {
235   while (IsIdentifierChar(PeekCurrentChar())) {
236     current_ptr_++;
237   }
238 
239   // If followed by ':', it's a name.
240   if (PeekCurrentChar() == ':') {
241     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
242     current_ptr_++;  // skip ':'
243     return TokKind::kName;
244   }
245 
246   // If followed by '=', it's a attribute name.
247   if (PeekCurrentChar() == '=') {
248     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
249     current_ptr_++;  // skip '='
250     return TokKind::kAttributeName;
251   }
252 
253   absl::string_view identifier =
254       StringPieceFromPointers(token_state_.token_start, current_ptr_);
255 
256   // Primitive type strings are reserved words. The exception is 'tuple' whose
257   // type is represented using nested parentheses without the string 'tuple'.
258   if (primitive_util::IsPrimitiveTypeName(identifier)) {
259     PrimitiveType primitive_type =
260         primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
261     if (primitive_type != TUPLE) {
262       token_state_.primitive_type_val = primitive_type;
263       return TokKind::kPrimitiveType;
264     }
265   }
266 
267   // See if this is a keyword.
268 #define KEYWORD(STR)            \
269   do {                          \
270     if (identifier == #STR) {   \
271       return TokKind::kw_##STR; \
272     }                           \
273   } while (false)
274 
275   KEYWORD(true);
276   KEYWORD(false);
277   KEYWORD(inf);
278   KEYWORD(nan);
279   KEYWORD(HloModule);
280   KEYWORD(ENTRY);
281   KEYWORD(ROOT);
282   KEYWORD(maximal);
283   KEYWORD(replicated);
284   KEYWORD(manual);
285   KEYWORD(last_tile_dim_replicate);
286 
287 #undef KEYWORD
288 
289   {
290     absl::string_view consumable =
291         StringPieceFromPointers(token_state_.token_start, buf_.end());
292     static LazyRE2 dim_labels_pattern = {
293         R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
294     if (RE2::Consume(&consumable, *dim_labels_pattern)) {
295       current_ptr_ = consumable.begin();
296       token_state_.str_val.assign(token_state_.token_start, current_ptr_);
297       return TokKind::kDimLabels;
298     }
299   }
300 
301   token_state_.str_val = string(identifier);
302   return TokKind::kIdent;
303 }
304 
305 // Lex names after a % character.
306 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
LexPercent()307 TokKind HloLexer::LexPercent() {
308   const char* name_start = current_ptr_;
309   if (absl::ascii_isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
310       PeekCurrentChar() == '_') {
311     current_ptr_++;
312     while (IsIdentifierChar(PeekCurrentChar())) {
313       current_ptr_++;
314     }
315     token_state_.str_val.assign(name_start, current_ptr_);
316     return TokKind::kName;
317   }
318   return TokKind::kError;
319 }
320 
321 // Lex integer and floating-point values, -inf, and patterns for dim labels,
322 // dxd (e.g. 1x2x3), and pad.
323 //
324 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
325 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
326 // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
327 // dxd_pattern ::= [0-9]+(x[0-9]+)+
328 // pad_pattern ::=
329 //   [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
330 // int ::=  [-]?[0-9]+
331 // negative inf ::= '-inf'
LexNumberOrPattern()332 TokKind HloLexer::LexNumberOrPattern() {
333   absl::string_view consumable =
334       StringPieceFromPointers(token_state_.token_start, buf_.end());
335   static LazyRE2 float_pattern = {
336       R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
337   if (RE2::Consume(&consumable, *float_pattern)) {
338     current_ptr_ = consumable.begin();
339     CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_),
340                            &token_state_.decimal_val));
341     return TokKind::kDecimal;
342   }
343 
344   static LazyRE2 dim_labels_pattern = {
345       R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
346   static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
347   static LazyRE2 pad_pattern = {
348       R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
349 
350   if (RE2::Consume(&consumable, *dim_labels_pattern)) {
351     current_ptr_ = consumable.begin();
352     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
353     return TokKind::kDimLabels;
354   }
355 
356   if (RE2::Consume(&consumable, *dxd_pattern)) {
357     current_ptr_ = consumable.begin();
358     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
359     return TokKind::kDxD;
360   }
361 
362   if (RE2::Consume(&consumable, *pad_pattern)) {
363     current_ptr_ = consumable.begin();
364     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
365     return TokKind::kPad;
366   }
367 
368   static LazyRE2 int_pattern = {R"([-]?\d+)"};
369   if (RE2::Consume(&consumable, *int_pattern)) {
370     current_ptr_ = consumable.begin();
371     auto slice =
372         StringPieceFromPointers(token_state_.token_start, current_ptr_);
373     if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
374       return TokKind::kInt;
375     }
376     uint64 uint64_val;
377     if (absl::SimpleAtoi(slice, &uint64_val)) {
378       token_state_.int64_val = absl::bit_cast<int64>(uint64_val);
379       return TokKind::kInt;
380     }
381     LOG(ERROR) << "Failed to parse int literal: " << slice;
382     return TokKind::kError;
383   }
384 
385   static LazyRE2 neg_inf = {"-inf"};
386   if (RE2::Consume(&consumable, *neg_inf)) {
387     current_ptr_ = consumable.begin();
388     return TokKind::kNegInf;
389   }
390 
391   static LazyRE2 neg_nan = {"-nan"};
392   if (RE2::Consume(&consumable, *neg_nan)) {
393     current_ptr_ = consumable.begin();
394     return TokKind::kNegNan;
395   }
396 
397   return TokKind::kError;
398 }
399 
GetLineAndColumn(LocTy location) const400 std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
401   unsigned line_no = 1;
402   const char* start = buf_.begin();
403   const char* ptr = start;
404   if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) &&
405       line_no_cache_.last_query <= location) {
406     ptr = line_no_cache_.last_query;
407     line_no = line_no_cache_.line_no_of_query;
408   }
409   for (; ptr != location; ptr++) {
410     CHECK_LT(ptr, buf_.end());
411     if (*ptr == '\n') {
412       line_no++;
413     }
414   }
415 
416   // Update the line number cache.
417   line_no_cache_.last_query = ptr;
418   line_no_cache_.line_no_of_query = line_no;
419   size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
420   if (line_offset == absl::string_view::npos) {
421     line_offset = 0;
422   }
423   return {line_no, ptr - start - line_offset};
424 }
425 
GetLine(LocTy loc) const426 absl::string_view HloLexer::GetLine(LocTy loc) const {
427   if (!CanDereference(loc)) {
428     return "LINE OUT OF RANGE";
429   }
430   size_t line_start =
431       StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
432   const char* start = line_start == absl::string_view::npos
433                           ? buf_.begin()
434                           : buf_.begin() + line_start + 1;
435   size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
436   const char* end =
437       line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
438 
439   return StringPieceFromPointers(start, end);
440 }
441 
442 // Lexes quoted string with escaping characters. If matched, the quoted string
443 // will be unescaped and stored to token_state_.str_val.
LexString()444 TokKind HloLexer::LexString() {
445   absl::string_view consumable =
446       StringPieceFromPointers(token_state_.token_start, buf_.end());
447   static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
448   if (RE2::Consume(&consumable, *escaping_pattern)) {
449     current_ptr_ = consumable.begin();
450     absl::string_view raw =
451         StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
452     string error;
453     if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
454       LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
455       return TokKind::kError;
456     }
457     return TokKind::kString;
458   }
459   return TokKind::kError;
460 }
461 
462 string TokKindToString(TokKind kind) {
463   switch (kind) {
464     case TokKind::kEof:
465       return "kEof";
466     case TokKind::kError:
467       return "kError";
468     case TokKind::kEqual:
469       return "kEqaul";
470     case TokKind::kComma:
471       return "kComma";
472     case TokKind::kColon:
473       return "kColon";
474     case TokKind::kAsterisk:
475       return "kAsterisk";
476     case TokKind::kLsquare:
477       return "kLsquare";
478     case TokKind::kRsquare:
479       return "kRsquare";
480     case TokKind::kLbrace:
481       return "kLbrace";
482     case TokKind::kRbrace:
483       return "kRbrace";
484     case TokKind::kLparen:
485       return "kLparen";
486     case TokKind::kRparen:
487       return "kRparen";
488     case TokKind::kArrow:
489       return "kArrow";
490     case TokKind::kLeq:
491       return "kLeq";
492     case TokKind::kw_HloModule:
493       return "kw_HloModule";
494     case TokKind::kw_ENTRY:
495       return "kw_ENTRY";
496     case TokKind::kw_ROOT:
497       return "kw_ROOT";
498     case TokKind::kw_true:
499       return "kw_true";
500     case TokKind::kw_false:
501       return "kw_false";
502     case TokKind::kw_maximal:
503       return "kw_maximal";
504     case TokKind::kw_replicated:
505       return "kw_replicated";
506     case TokKind::kw_manual:
507       return "kw_manual";
508     case TokKind::kw_last_tile_dim_replicate:
509       return "kw_last_tile_dim_replicate";
510     case TokKind::kw_nan:
511       return "kw_nan";
512     case TokKind::kw_inf:
513       return "kw_inf";
514     case TokKind::kNegNan:
515       return "kNegNan";
516     case TokKind::kNegInf:
517       return "kNegInf";
518     case TokKind::kPrimitiveType:
519       return "kPrimitiveType";
520     case TokKind::kName:
521       return "kName";
522     case TokKind::kAttributeName:
523       return "kAttributeName";
524     case TokKind::kDimLabels:
525       return "kDimLabels";
526     case TokKind::kDxD:
527       return "kDxD";
528     case TokKind::kPad:
529       return "kPad";
530     case TokKind::kIdent:
531       return "kIdent";
532     case TokKind::kString:
533       return "kString";
534     case TokKind::kInt:
535       return "kInt";
536     case TokKind::kDecimal:
537       return "kDecimal";
538     case TokKind::kDots:
539       return "kDots";
540   }
541 }
542 
543 }  // namespace xla
544