1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
17
18 #include <unordered_map>
19
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/escaping.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/platform/regexp.h"
30
31 namespace xla {
32 namespace {
33
34 using absl::string_view;
35
36 constexpr int kEOF = -1;
37 constexpr int kError = -2;
38
39 // [a-zA-Z0-9_.-]
IsIdentifierChar(char c)40 bool IsIdentifierChar(char c) {
41 return absl::ascii_isalnum(static_cast<unsigned char>(c)) || c == '-' ||
42 c == '.' || c == '_';
43 }
44
45 } // namespace
46
GetNextChar()47 int HloLexer::GetNextChar() {
48 int current_char = PeekCurrentChar();
49 if (current_char != kEOF && current_char != kError) {
50 current_ptr_++;
51 }
52 return current_char;
53 }
54
PeekCurrentChar() const55 int HloLexer::PeekCurrentChar() const {
56 if (current_ptr_ == buf_.end()) {
57 return kEOF;
58 }
59 char current_char = *current_ptr_;
60 if (current_char == 0) {
61 // '\0' should not appear in the middle of the string.
62 return kError;
63 }
64 return static_cast<unsigned char>(current_char);
65 }
66
CanDereference(const char * ptr) const67 bool HloLexer::CanDereference(const char* ptr) const {
68 return ptr < buf_.end() && ptr >= buf_.begin();
69 }
70
StringPieceFromPointers(const char * begin,const char * end) const71 absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
72 const char* end) const {
73 CHECK(begin <= end);
74 CHECK(begin == buf_.end() || CanDereference(begin));
75 CHECK(end == buf_.end() || CanDereference(end));
76 return absl::string_view(begin, end - begin);
77 }
78
RegexpStringPieceFromPointers(const char * begin,const char * end) const79 tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
80 const char* begin, const char* end) const {
81 CHECK(begin <= end);
82 CHECK(begin == buf_.end() || CanDereference(begin));
83 CHECK(end == buf_.end() || CanDereference(end));
84 return tensorflow::RegexpStringPiece(begin, end - begin);
85 }
86
LookAhead()87 TokKind HloLexer::LookAhead() {
88 if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
89 return GetKind();
90 }
91
92 const char* old_current_ptr = current_ptr_;
93 TokenState old_token_state = token_state_;
94 Lex();
95 TokKind kind = GetKind();
96 token_state_ = old_token_state;
97 current_ptr_ = old_current_ptr;
98 return kind;
99 }
100
LexToken()101 TokKind HloLexer::LexToken() {
102 while (true) {
103 token_state_.token_start = current_ptr_;
104
105 int current_char = GetNextChar();
106 switch (current_char) {
107 default:
108 // [a-zA-Z_]
109 if (absl::ascii_isalpha(static_cast<unsigned char>(current_char)) ||
110 current_char == '_') {
111 return LexIdentifier();
112 }
113 return TokKind::kError;
114 case kEOF:
115 // Hit the end of the input buffer.
116 return TokKind::kEof;
117 case kError:
118 // Hit an invalid character in the input buffer.
119 return TokKind::kError;
120 case ' ':
121 case '\t':
122 case '\n':
123 case '\r':
124 // Ignore whitespace.
125 continue;
126 case '0':
127 case '1':
128 case '2':
129 case '3':
130 case '4':
131 case '5':
132 case '6':
133 case '7':
134 case '8':
135 case '9':
136 case '-':
137 if (current_char == '-' && PeekCurrentChar() == '>') {
138 current_ptr_++;
139 return TokKind::kArrow;
140 }
141 return LexNumberOrPattern();
142 case '=':
143 return TokKind::kEqual;
144 case '<':
145 if (current_char == '<' && PeekCurrentChar() == '=') {
146 current_ptr_++;
147 return TokKind::kLeq;
148 }
149 return TokKind::kError;
150 case ',':
151 return TokKind::kComma;
152 case '%':
153 return LexPercent();
154 case ':':
155 return TokKind::kColon;
156 case '*':
157 return TokKind::kAsterisk;
158 case '[':
159 return TokKind::kLsquare;
160 case ']':
161 return TokKind::kRsquare;
162 case '{':
163 return TokKind::kLbrace;
164 case '}':
165 return TokKind::kRbrace;
166 case '(':
167 return TokKind::kLparen;
168 case ')':
169 return TokKind::kRparen;
170 case '/': {
171 if (PeekCurrentChar() == '*') {
172 // This is the start of a /*...*/ delimited comment. Save the current
173 // location in case the comment is unterminated so the error message
174 // will point to the beginning of the comment.
175 const char* comment_start = current_ptr_;
176 current_ptr_++;
177 // Advance until '*/' is found.
178 while (true) {
179 int current = GetNextChar();
180 if (current == '*' && PeekCurrentChar() == '/') {
181 // End of comment.
182 current_ptr_++;
183 break;
184 }
185 if (current == kEOF) {
186 // Unterminated comment.
187 current_ptr_ = comment_start;
188 return TokKind::kError;
189 }
190 if (current == kError) {
191 return TokKind::kError;
192 }
193 }
194 // Return no token for the comment. Keep lexing.
195 continue;
196 } else if (PeekCurrentChar() == '/') {
197 // This is the start of a '//' delimited comment. Throw away
198 // everything until end of line or file. The end-of-line character(s)
199 // are left unlexed in the buffer which is harmless because these are
200 // skipped later by the lexer. This approach enables support for
201 // different end-of-line encodings.
202 while (true) {
203 int current = PeekCurrentChar();
204 if (current == kEOF || current == '\n' || current == '\r') {
205 break;
206 }
207 if (current == kError) {
208 return TokKind::kError;
209 }
210 current_ptr_++;
211 }
212 continue;
213 }
214 // A lone '/' is an error.
215 return TokKind::kError;
216 }
217 case '.':
218 if (PeekCurrentChar() == '.') {
219 current_ptr_++;
220 if (PeekCurrentChar() == '.') {
221 current_ptr_++;
222 return TokKind::kDots;
223 }
224 }
225 return TokKind::kError;
226 case '"':
227 return LexString();
228 }
229 }
230 }
231
232 // Lex a shape, name, keyword, attribute name, the dim labels pattern, and
233 // other identifiers.
234 //
235 // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
236 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
237 // keyword ::= HloModule, ENTRY, ...
238 // attribute_name ::= condition, body, dimensions, ...
239 // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
240 // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
LexIdentifier()241 TokKind HloLexer::LexIdentifier() {
242 while (IsIdentifierChar(PeekCurrentChar())) {
243 current_ptr_++;
244 }
245
246 // If followed by ':', it's a name.
247 if (PeekCurrentChar() == ':') {
248 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
249 current_ptr_++; // skip ':'
250 return TokKind::kName;
251 }
252
253 // If followed by '=', it's a attribute name.
254 if (PeekCurrentChar() == '=') {
255 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
256 current_ptr_++; // skip '='
257 return TokKind::kAttributeName;
258 }
259
260 absl::string_view identifier =
261 StringPieceFromPointers(token_state_.token_start, current_ptr_);
262
263 // Primitive type strings are reserved words. The exception is 'tuple' whose
264 // type is represented using nested parentheses without the string 'tuple'.
265 if (primitive_util::IsPrimitiveTypeName(identifier)) {
266 PrimitiveType primitive_type =
267 primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
268 if (primitive_type != TUPLE) {
269 token_state_.primitive_type_val = primitive_type;
270 return TokKind::kPrimitiveType;
271 }
272 }
273
274 // See if this is a keyword.
275 #define KEYWORD(STR) \
276 do { \
277 if (identifier == #STR) { \
278 return TokKind::kw_##STR; \
279 } \
280 } while (false)
281
282 KEYWORD(true);
283 KEYWORD(false);
284 KEYWORD(inf);
285 KEYWORD(nan);
286 KEYWORD(HloModule);
287 KEYWORD(ENTRY);
288 KEYWORD(ROOT);
289 KEYWORD(maximal);
290 KEYWORD(replicated);
291 KEYWORD(sparse);
292
293 #undef KEYWORD
294
295 {
296 auto consumable =
297 RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
298 static LazyRE2 dim_labels_pattern = {
299 R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
300 if (RE2::Consume(&consumable, *dim_labels_pattern)) {
301 current_ptr_ = consumable.begin();
302 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
303 return TokKind::kDimLabels;
304 }
305 }
306
307 token_state_.str_val = string(identifier);
308 return TokKind::kIdent;
309 }
310
311 // Lex names after a % character.
312 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
LexPercent()313 TokKind HloLexer::LexPercent() {
314 const char* name_start = current_ptr_;
315 if (absl::ascii_isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
316 PeekCurrentChar() == '_') {
317 current_ptr_++;
318 while (IsIdentifierChar(PeekCurrentChar())) {
319 current_ptr_++;
320 }
321 token_state_.str_val.assign(name_start, current_ptr_);
322 return TokKind::kName;
323 }
324 return TokKind::kError;
325 }
326
327 // Lex integer and floating-point values, -inf, and patterns for dim labels,
328 // dxd (e.g. 1x2x3), and pad.
329 //
330 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
331 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
332 // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
333 // dxd_pattern ::= [0-9]+(x[0-9]+)+
334 // pad_pattern ::=
335 // [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
336 // int ::= [-]?[0-9]+
337 // negative inf ::= '-inf'
LexNumberOrPattern()338 TokKind HloLexer::LexNumberOrPattern() {
339 auto consumable =
340 RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
341 static LazyRE2 float_pattern = {
342 R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
343 if (RE2::Consume(&consumable, *float_pattern)) {
344 current_ptr_ = consumable.begin();
345 CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_),
346 &token_state_.decimal_val));
347 return TokKind::kDecimal;
348 }
349
350 static LazyRE2 dim_labels_pattern = {
351 R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
352 static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
353 static LazyRE2 pad_pattern = {
354 R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
355
356 if (RE2::Consume(&consumable, *dim_labels_pattern)) {
357 current_ptr_ = consumable.begin();
358 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
359 return TokKind::kDimLabels;
360 }
361
362 if (RE2::Consume(&consumable, *dxd_pattern)) {
363 current_ptr_ = consumable.begin();
364 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
365 return TokKind::kDxD;
366 }
367
368 if (RE2::Consume(&consumable, *pad_pattern)) {
369 current_ptr_ = consumable.begin();
370 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
371 return TokKind::kPad;
372 }
373
374 static LazyRE2 int_pattern = {R"([-]?\d+)"};
375 if (RE2::Consume(&consumable, *int_pattern)) {
376 current_ptr_ = consumable.begin();
377 auto slice =
378 StringPieceFromPointers(token_state_.token_start, current_ptr_);
379 if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
380 return TokKind::kInt;
381 }
382 LOG(ERROR) << "Failed to parse int literal: " << slice;
383 return TokKind::kError;
384 }
385
386 static LazyRE2 neg_inf = {"-inf"};
387 if (RE2::Consume(&consumable, *neg_inf)) {
388 current_ptr_ = consumable.begin();
389 return TokKind::kNegInf;
390 }
391
392 return TokKind::kError;
393 }
394
GetLineAndColumn(LocTy location) const395 std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
396 unsigned line_no = 1;
397 const char* start = buf_.begin();
398 const char* ptr = start;
399 if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) &&
400 line_no_cache_.last_query <= location) {
401 ptr = line_no_cache_.last_query;
402 line_no = line_no_cache_.line_no_of_query;
403 }
404 for (; ptr != location; ptr++) {
405 CHECK_LT(ptr, buf_.end());
406 if (*ptr == '\n') {
407 line_no++;
408 }
409 }
410
411 // Update the line number cache.
412 line_no_cache_.last_query = ptr;
413 line_no_cache_.line_no_of_query = line_no;
414 size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
415 if (line_offset == absl::string_view::npos) {
416 line_offset = 0;
417 }
418 return {line_no, ptr - start - line_offset};
419 }
420
GetLine(LocTy loc) const421 absl::string_view HloLexer::GetLine(LocTy loc) const {
422 if (!CanDereference(loc)) {
423 return "LINE OUT OF RANGE";
424 }
425 size_t line_start =
426 StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
427 const char* start = line_start == absl::string_view::npos
428 ? buf_.begin()
429 : buf_.begin() + line_start + 1;
430 size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
431 const char* end =
432 line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
433
434 return StringPieceFromPointers(start, end);
435 }
436
437 // Lexes quoted string with escaping characters. If matched, the quoted string
438 // will be unescaped and stored to token_state_.str_val.
LexString()439 TokKind HloLexer::LexString() {
440 auto consumable =
441 RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
442 static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
443 if (RE2::Consume(&consumable, *escaping_pattern)) {
444 current_ptr_ = consumable.begin();
445 absl::string_view raw =
446 StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
447 string error;
448 if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
449 LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
450 return TokKind::kError;
451 }
452 return TokKind::kString;
453 }
454 return TokKind::kError;
455 }
456
457 string TokKindToString(TokKind kind) {
458 switch (kind) {
459 case TokKind::kEof:
460 return "kEof";
461 case TokKind::kError:
462 return "kError";
463 case TokKind::kEqual:
464 return "kEqaul";
465 case TokKind::kComma:
466 return "kComma";
467 case TokKind::kColon:
468 return "kColon";
469 case TokKind::kAsterisk:
470 return "kAsterisk";
471 case TokKind::kLsquare:
472 return "kLsquare";
473 case TokKind::kRsquare:
474 return "kRsquare";
475 case TokKind::kLbrace:
476 return "kLbrace";
477 case TokKind::kRbrace:
478 return "kRbrace";
479 case TokKind::kLparen:
480 return "kLparen";
481 case TokKind::kRparen:
482 return "kRparen";
483 case TokKind::kArrow:
484 return "kArrow";
485 case TokKind::kLeq:
486 return "kLeq";
487 case TokKind::kw_HloModule:
488 return "kw_HloModule";
489 case TokKind::kw_ENTRY:
490 return "kw_ENTRY";
491 case TokKind::kw_ROOT:
492 return "kw_ROOT";
493 case TokKind::kw_true:
494 return "kw_true";
495 case TokKind::kw_false:
496 return "kw_false";
497 case TokKind::kw_maximal:
498 return "kw_maximal";
499 case TokKind::kw_replicated:
500 return "kw_replicated";
501 case TokKind::kw_nan:
502 return "kw_nan";
503 case TokKind::kw_inf:
504 return "kw_inf";
505 case TokKind::kNegInf:
506 return "kNegInf";
507 case TokKind::kw_sparse:
508 return "kw_sparse";
509 case TokKind::kPrimitiveType:
510 return "kPrimitiveType";
511 case TokKind::kName:
512 return "kName";
513 case TokKind::kAttributeName:
514 return "kAttributeName";
515 case TokKind::kDimLabels:
516 return "kDimLabels";
517 case TokKind::kDxD:
518 return "kDxD";
519 case TokKind::kPad:
520 return "kPad";
521 case TokKind::kIdent:
522 return "kIdent";
523 case TokKind::kString:
524 return "kString";
525 case TokKind::kInt:
526 return "kInt";
527 case TokKind::kDecimal:
528 return "kDecimal";
529 case TokKind::kDots:
530 return "kDots";
531 }
532 }
533
534 } // namespace xla
535