• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
17 
18 #include <limits>
19 #include <string>
20 #include <unordered_map>
21 
22 #include "absl/base/casts.h"
23 #include "absl/strings/ascii.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_split.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/lib/strings/numbers.h"
32 #include "tensorflow/core/platform/regexp.h"
33 
34 namespace xla {
35 namespace {
36 
37 using absl::string_view;
38 
39 constexpr int kEOF = -1;
40 constexpr int kError = -2;
41 
42 // [a-zA-Z0-9_.-]
IsIdentifierChar(char c)43 bool IsIdentifierChar(char c) {
44   return absl::ascii_isalnum(static_cast<unsigned char>(c)) || c == '-' ||
45          c == '.' || c == '_';
46 }
47 
48 }  // namespace
49 
GetNextChar()50 int HloLexer::GetNextChar() {
51   int current_char = PeekCurrentChar();
52   if (current_char != kEOF && current_char != kError) {
53     current_ptr_++;
54   }
55   return current_char;
56 }
57 
PeekCurrentChar() const58 int HloLexer::PeekCurrentChar() const {
59   if (current_ptr_ == buf_.end()) {
60     return kEOF;
61   }
62   char current_char = *current_ptr_;
63   if (current_char == 0) {
64     // '\0' should not appear in the middle of the string.
65     return kError;
66   }
67   return static_cast<unsigned char>(current_char);
68 }
69 
CanDereference(const char * ptr) const70 bool HloLexer::CanDereference(const char* ptr) const {
71   return ptr < buf_.end() && ptr >= buf_.begin();
72 }
73 
StringPieceFromPointers(const char * begin,const char * end) const74 absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
75                                                     const char* end) const {
76   CHECK(begin <= end);
77   CHECK(begin == buf_.end() || CanDereference(begin));
78   CHECK(end == buf_.end() || CanDereference(end));
79   return absl::string_view(begin, end - begin);
80 }
81 
LookAhead()82 TokKind HloLexer::LookAhead() {
83   if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
84     return GetKind();
85   }
86 
87   const char* old_current_ptr = current_ptr_;
88   TokenState old_token_state = token_state_;
89   Lex();
90   TokKind kind = GetKind();
91   token_state_ = old_token_state;
92   current_ptr_ = old_current_ptr;
93   return kind;
94 }
95 
LexToken()96 TokKind HloLexer::LexToken() {
97   while (true) {
98     token_state_.token_start = current_ptr_;
99 
100     int current_char = GetNextChar();
101     switch (current_char) {
102       default:
103         // [a-zA-Z_]
104         if (absl::ascii_isalpha(static_cast<unsigned char>(current_char)) ||
105             current_char == '_') {
106           return LexIdentifier();
107         }
108         return TokKind::kError;
109       case kEOF:
110         // Hit the end of the input buffer.
111         return TokKind::kEof;
112       case kError:
113         // Hit an invalid character in the input buffer.
114         return TokKind::kError;
115       case ' ':
116       case '\t':
117       case '\n':
118       case '\r':
119         // Ignore whitespace.
120         continue;
121       case '0':
122       case '1':
123       case '2':
124       case '3':
125       case '4':
126       case '5':
127       case '6':
128       case '7':
129       case '8':
130       case '9':
131       case '-':
132       case '?':
133         if (current_char == '-' && PeekCurrentChar() == '>') {
134           current_ptr_++;
135           return TokKind::kArrow;
136         }
137         return LexNumberOrPattern();
138       case '=':
139         return TokKind::kEqual;
140       case '<':
141         if (current_char == '<' && PeekCurrentChar() == '=') {
142           current_ptr_++;
143           return TokKind::kLeq;
144         }
145         return TokKind::kError;
146       case ',':
147         return TokKind::kComma;
148       case '%':
149         return LexPercent();
150       case ':':
151         return TokKind::kColon;
152       case '*':
153         return TokKind::kAsterisk;
154       case '[':
155         return TokKind::kLsquare;
156       case ']':
157         return TokKind::kRsquare;
158       case '{':
159         return TokKind::kLbrace;
160       case '}':
161         return TokKind::kRbrace;
162       case '(':
163         return TokKind::kLparen;
164       case ')':
165         return TokKind::kRparen;
166       case '/': {
167         if (PeekCurrentChar() == '*') {
168           // This is the start of a /*...*/ delimited comment. Save the current
169           // location in case the comment is unterminated so the error message
170           // will point to the beginning of the comment.
171           const char* comment_start = current_ptr_;
172           current_ptr_++;
173           // Advance until '*/' is found.
174           while (true) {
175             int current = GetNextChar();
176             if (current == '*' && PeekCurrentChar() == '/') {
177               // End of comment.
178               current_ptr_++;
179               break;
180             }
181             if (current == kEOF) {
182               // Unterminated comment.
183               current_ptr_ = comment_start;
184               return TokKind::kError;
185             }
186             if (current == kError) {
187               return TokKind::kError;
188             }
189           }
190           // Return no token for the comment. Keep lexing.
191           continue;
192         } else if (PeekCurrentChar() == '/') {
193           // This is the start of a '//' delimited comment. Throw away
194           // everything until end of line or file. The end-of-line character(s)
195           // are left unlexed in the buffer which is harmless because these are
196           // skipped later by the lexer. This approach enables support for
197           // different end-of-line encodings.
198           while (true) {
199             int current = PeekCurrentChar();
200             if (current == kEOF || current == '\n' || current == '\r') {
201               break;
202             }
203             if (current == kError) {
204               return TokKind::kError;
205             }
206             current_ptr_++;
207           }
208           continue;
209         }
210         // A lone '/' is an error.
211         return TokKind::kError;
212       }
213       case '.':
214         if (PeekCurrentChar() == '.') {
215           current_ptr_++;
216           if (PeekCurrentChar() == '.') {
217             current_ptr_++;
218             return TokKind::kDots;
219           }
220         }
221         return TokKind::kError;
222       case '"':
223         return LexString();
224     }
225   }
226 }
227 
LexNanPayload(absl::string_view & consumable)228 absl::optional<int64> HloLexer::LexNanPayload(absl::string_view& consumable) {
229   static LazyRE2 payload_pattern = {R"(\(0x[0-9a-fA-F]+\))"};
230   if (!RE2::Consume(&consumable, *payload_pattern)) {
231     return absl::nullopt;
232   }
233   auto slice = StringPieceFromPointers(current_ptr_, consumable.begin());
234   current_ptr_ = consumable.begin();
235   CHECK(absl::StartsWith(slice, "(0x"));
236   slice.remove_prefix(std::strlen("(0x"));
237   CHECK(absl::EndsWith(slice, ")"));
238   slice.remove_suffix(std::strlen(")"));
239   uint64 payload_value;
240   if (tensorflow::strings::HexStringToUint64(slice, &payload_value)) {
241     if (payload_value <= 0 || payload_value > NanPayloadBitMask<double>()) {
242       LOG(ERROR) << "NaN payload out of range: " << payload_value;
243       return absl::nullopt;
244     }
245     return payload_value;
246   }
247   return absl::nullopt;
248 }
249 
250 // Lex a shape, name, keyword, attribute name, the dim labels pattern, and
251 // other identifiers.
252 //
253 // shape    ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
254 // name     ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
255 // keyword  ::= HloModule, ENTRY, ...
256 // attribute_name ::= condition, body, dimensions, ...
257 // dim_labels_pattern ::= [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,}
258 // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
LexIdentifier()259 TokKind HloLexer::LexIdentifier() {
260   while (IsIdentifierChar(PeekCurrentChar())) {
261     current_ptr_++;
262   }
263 
264   // If followed by ':', it's a name.
265   if (PeekCurrentChar() == ':') {
266     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
267     current_ptr_++;  // skip ':'
268     return TokKind::kName;
269   }
270 
271   // If followed by '=', it's a attribute name.
272   if (PeekCurrentChar() == '=') {
273     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
274     current_ptr_++;  // skip '='
275     return TokKind::kAttributeName;
276   }
277 
278   absl::string_view identifier =
279       StringPieceFromPointers(token_state_.token_start, current_ptr_);
280 
281   // Primitive type strings are reserved words. The exception is 'tuple' whose
282   // type is represented using nested parentheses without the string 'tuple'.
283   if (primitive_util::IsPrimitiveTypeName(identifier)) {
284     PrimitiveType primitive_type =
285         primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
286     if (primitive_type != TUPLE) {
287       token_state_.primitive_type_val = primitive_type;
288       return TokKind::kPrimitiveType;
289     }
290   }
291 
292   if (identifier == "nan") {
293     absl::optional<int64_t> payload;
294     if (PeekCurrentChar() == '(') {
295       absl::string_view consumable =
296           StringPieceFromPointers(current_ptr_, buf_.end());
297       payload = LexNanPayload(consumable);
298       if (!payload.has_value()) {
299         return TokKind::kError;
300       }
301     }
302     token_state_.decimal_val = NanWithSignAndPayload<double>(
303         /*sign=*/false, payload.value_or(QuietNanWithoutPayload<double>()));
304     return TokKind::kDecimal;
305   }
306 
307   // See if this is a keyword.
308 #define KEYWORD(STR)            \
309   do {                          \
310     if (identifier == #STR) {   \
311       return TokKind::kw_##STR; \
312     }                           \
313   } while (false)
314 
315   KEYWORD(true);
316   KEYWORD(false);
317   KEYWORD(inf);
318   KEYWORD(HloModule);
319   KEYWORD(ENTRY);
320   KEYWORD(ROOT);
321   KEYWORD(maximal);
322   KEYWORD(replicated);
323   KEYWORD(manual);
324   KEYWORD(last_tile_dim_replicate);
325 
326 #undef KEYWORD
327 
328   {
329     absl::string_view consumable =
330         StringPieceFromPointers(token_state_.token_start, buf_.end());
331     static LazyRE2 dim_labels_pattern = {
332         R"([0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,})"};
333     if (RE2::Consume(&consumable, *dim_labels_pattern)) {
334       current_ptr_ = consumable.begin();
335       token_state_.str_val.assign(token_state_.token_start, current_ptr_);
336       return TokKind::kDimLabels;
337     }
338   }
339 
340   token_state_.str_val = string(identifier);
341   return TokKind::kIdent;
342 }
343 
344 // Lex names after a % character.
345 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
LexPercent()346 TokKind HloLexer::LexPercent() {
347   const char* name_start = current_ptr_;
348   if (absl::ascii_isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
349       PeekCurrentChar() == '_') {
350     current_ptr_++;
351     while (IsIdentifierChar(PeekCurrentChar())) {
352       current_ptr_++;
353     }
354     token_state_.str_val.assign(name_start, current_ptr_);
355     return TokKind::kName;
356   }
357   return TokKind::kError;
358 }
359 
360 // Lex integer and floating-point values, -inf, and patterns for dim labels,
361 // dxd (e.g. 1x2x3), and pad.
362 //
363 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
364 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
365 // dim_labels_pattern ::= [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,}
366 // dxd_pattern ::= [0-9]+(x[0-9]+)+
367 // pad_pattern ::=
368 //   [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
369 // int ::=  [-]?[0-9]+
370 // negative inf ::= '-inf'
LexNumberOrPattern()371 TokKind HloLexer::LexNumberOrPattern() {
372   absl::string_view consumable =
373       StringPieceFromPointers(token_state_.token_start, buf_.end());
374   static LazyRE2 float_pattern = {
375       R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
376   if (RE2::Consume(&consumable, *float_pattern)) {
377     current_ptr_ = consumable.begin();
378     CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_),
379                            &token_state_.decimal_val));
380     return TokKind::kDecimal;
381   }
382 
383   static LazyRE2 dim_labels_pattern = {
384       R"([0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,})"};
385   static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
386   static LazyRE2 pad_pattern = {
387       R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
388 
389   if (RE2::Consume(&consumable, *dim_labels_pattern)) {
390     current_ptr_ = consumable.begin();
391     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
392     return TokKind::kDimLabels;
393   }
394 
395   if (RE2::Consume(&consumable, *dxd_pattern)) {
396     current_ptr_ = consumable.begin();
397     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
398     return TokKind::kDxD;
399   }
400 
401   if (RE2::Consume(&consumable, *pad_pattern)) {
402     current_ptr_ = consumable.begin();
403     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
404     return TokKind::kPad;
405   }
406 
407   static LazyRE2 int_pattern = {R"([-]?\d+)"};
408   if (RE2::Consume(&consumable, *int_pattern)) {
409     current_ptr_ = consumable.begin();
410     auto slice =
411         StringPieceFromPointers(token_state_.token_start, current_ptr_);
412     if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
413       return TokKind::kInt;
414     }
415     uint64 uint64_val;
416     if (absl::SimpleAtoi(slice, &uint64_val)) {
417       token_state_.int64_val = absl::bit_cast<int64>(uint64_val);
418       return TokKind::kInt;
419     }
420     LOG(ERROR) << "Failed to parse int literal: " << slice;
421     return TokKind::kError;
422   }
423 
424   static LazyRE2 neg_inf = {"-inf"};
425   if (RE2::Consume(&consumable, *neg_inf)) {
426     current_ptr_ = consumable.begin();
427     return TokKind::kNegInf;
428   }
429 
430   static LazyRE2 neg_nan = {"-nan"};
431   if (RE2::Consume(&consumable, *neg_nan)) {
432     current_ptr_ = consumable.begin();
433 
434     absl::optional<int64_t> payload;
435     if (PeekCurrentChar() == '(') {
436       payload = LexNanPayload(consumable);
437       if (!payload.has_value()) {
438         return TokKind::kError;
439       }
440     }
441     token_state_.decimal_val = NanWithSignAndPayload<double>(
442         /*sign=*/true, payload.value_or(QuietNanWithoutPayload<double>()));
443     return TokKind::kDecimal;
444   }
445 
446   return TokKind::kError;
447 }
448 
GetLineAndColumn(LocTy location) const449 std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
450   unsigned line_no = 1;
451   const char* start = buf_.begin();
452   const char* ptr = start;
453   if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) &&
454       line_no_cache_.last_query <= location) {
455     ptr = line_no_cache_.last_query;
456     line_no = line_no_cache_.line_no_of_query;
457   }
458   for (; ptr != location; ptr++) {
459     CHECK_LT(ptr, buf_.end());
460     if (*ptr == '\n') {
461       line_no++;
462     }
463   }
464 
465   // Update the line number cache.
466   line_no_cache_.last_query = ptr;
467   line_no_cache_.line_no_of_query = line_no;
468   size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
469   if (line_offset == absl::string_view::npos) {
470     line_offset = 0;
471   }
472   return {line_no, ptr - start - line_offset};
473 }
474 
GetLine(LocTy loc) const475 absl::string_view HloLexer::GetLine(LocTy loc) const {
476   if (!CanDereference(loc)) {
477     return "LINE OUT OF RANGE";
478   }
479   size_t line_start =
480       StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
481   const char* start = line_start == absl::string_view::npos
482                           ? buf_.begin()
483                           : buf_.begin() + line_start + 1;
484   size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
485   const char* end =
486       line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
487 
488   return StringPieceFromPointers(start, end);
489 }
490 
491 // Lexes quoted string with escaping characters. If matched, the quoted string
492 // will be unescaped and stored to token_state_.str_val.
LexString()493 TokKind HloLexer::LexString() {
494   absl::string_view consumable =
495       StringPieceFromPointers(token_state_.token_start, buf_.end());
496   static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
497   if (RE2::Consume(&consumable, *escaping_pattern)) {
498     current_ptr_ = consumable.begin();
499     absl::string_view raw =
500         StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
501     string error;
502     if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
503       LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
504       return TokKind::kError;
505     }
506     return TokKind::kString;
507   }
508   return TokKind::kError;
509 }
510 
511 string TokKindToString(TokKind kind) {
512   switch (kind) {
513     case TokKind::kEof:
514       return "kEof";
515     case TokKind::kError:
516       return "kError";
517     case TokKind::kEqual:
518       return "kEqaul";
519     case TokKind::kComma:
520       return "kComma";
521     case TokKind::kColon:
522       return "kColon";
523     case TokKind::kAsterisk:
524       return "kAsterisk";
525     case TokKind::kLsquare:
526       return "kLsquare";
527     case TokKind::kRsquare:
528       return "kRsquare";
529     case TokKind::kLbrace:
530       return "kLbrace";
531     case TokKind::kRbrace:
532       return "kRbrace";
533     case TokKind::kLparen:
534       return "kLparen";
535     case TokKind::kRparen:
536       return "kRparen";
537     case TokKind::kArrow:
538       return "kArrow";
539     case TokKind::kLeq:
540       return "kLeq";
541     case TokKind::kw_HloModule:
542       return "kw_HloModule";
543     case TokKind::kw_ENTRY:
544       return "kw_ENTRY";
545     case TokKind::kw_ROOT:
546       return "kw_ROOT";
547     case TokKind::kw_true:
548       return "kw_true";
549     case TokKind::kw_false:
550       return "kw_false";
551     case TokKind::kw_maximal:
552       return "kw_maximal";
553     case TokKind::kw_replicated:
554       return "kw_replicated";
555     case TokKind::kw_manual:
556       return "kw_manual";
557     case TokKind::kw_last_tile_dim_replicate:
558       return "kw_last_tile_dim_replicate";
559     case TokKind::kw_inf:
560       return "kw_inf";
561     case TokKind::kNegInf:
562       return "kNegInf";
563     case TokKind::kPrimitiveType:
564       return "kPrimitiveType";
565     case TokKind::kName:
566       return "kName";
567     case TokKind::kAttributeName:
568       return "kAttributeName";
569     case TokKind::kDimLabels:
570       return "kDimLabels";
571     case TokKind::kDxD:
572       return "kDxD";
573     case TokKind::kPad:
574       return "kPad";
575     case TokKind::kIdent:
576       return "kIdent";
577     case TokKind::kString:
578       return "kString";
579     case TokKind::kInt:
580       return "kInt";
581     case TokKind::kDecimal:
582       return "kDecimal";
583     case TokKind::kDots:
584       return "kDots";
585   }
586 }
587 
588 }  // namespace xla
589