1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "absl/strings/ascii.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/strings/strip.h"
27 #include "tensorflow/core/platform/regexp.h"
28
29 namespace tensorflow {
30 namespace profiler {
31 namespace {
32
33 const absl::string_view kIterator = "Iterator";
34 const absl::string_view kSeparator = "::";
35 constexpr char kNameScopeSeparator = '/';
36
37 } // namespace
38
39 const absl::string_view kUnknownOp = ""; // op types are non-empty strings
40 const absl::string_view kDatasetOp = "Dataset";
41 const absl::string_view kMemcpyHToDOp = "MemcpyHToD";
42 const absl::string_view kMemcpyDToHOp = "MemcpyDToH";
43
IsTfOpName(absl::string_view op_name)44 bool IsTfOpName(absl::string_view op_name) {
45 // TODO(b/177602927): Confirm the naming convention with the TF team.
46 static const LazyRE2 kTfOpNameRegEx = {"[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*"};
47 return RE2::FullMatch(op_name, *kTfOpNameRegEx);
48 }
49
IsTfOpType(absl::string_view op_type)50 bool IsTfOpType(absl::string_view op_type) {
51 static const LazyRE2 kTfOpTypeRegEx = {"[A-Z_][a-zA-Z0-9_]*"};
52 return RE2::FullMatch(op_type, *kTfOpTypeRegEx);
53 }
54
IsJaxOpType(absl::string_view op_type)55 bool IsJaxOpType(absl::string_view op_type) {
56 static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_][a-z0-9_]*"};
57 return RE2::FullMatch(op_type, *kJaxOpTypeRegEx);
58 }
59
IsJaxOpNameAndType(absl::string_view op_name,absl::string_view op_type)60 bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type) {
61 if (op_name.empty() || !IsJaxOpType(op_type)) return false;
62 std::vector<absl::string_view> split_result =
63 absl::StrSplit(op_name, kNameScopeSeparator);
64 return absl::StrContains(split_result.back(), op_type);
65 }
66
ParseTfOpFullname(absl::string_view tf_op_fullname)67 TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
68 // TF Op names have the format "name:type".
69 TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp};
70 std::vector<absl::string_view> parts =
71 absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1));
72 if (parts.size() != 2) {
73 // GPU-related Ops that need to be tracked.
74 if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) {
75 tf_op.category = Category::kMemcpyHToD;
76 tf_op.type = kMemcpyHToDOp;
77 } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) {
78 tf_op.category = Category::kMemcpyDToH;
79 tf_op.type = kMemcpyDToHOp;
80 }
81 // TODO(ckluk): Include the corresponding Ops on TPU.
82 } else if (parts[0] == kIterator) {
83 // Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
84 // format of TF Op names. But we still want to capture them for
85 // input-pipeline analysis.
86 tf_op.category = Category::kTfData;
87 tf_op.type = kDatasetOp;
88 } else if (IsTfOpType(parts[1]) && IsTfOpName(parts[0])) {
89 tf_op = {Category::kTensorFlow, parts[0], parts[1]};
90 } else if (IsJaxOpType(parts[1])) {
91 tf_op = {Category::kJax, parts[0], parts[1]};
92 } else if (parts[1].empty()) {
93 tf_op.name = parts[0]; // remove trailing ':'
94 }
95 return tf_op;
96 }
97
ParseTfNameScopes(const TfOp & tf_op)98 std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op) {
99 std::vector<absl::string_view> name_scopes =
100 absl::StrSplit(tf_op.name, kNameScopeSeparator);
101 // The last element is an op name not TF name scope.
102 if (!name_scopes.empty()) name_scopes.pop_back();
103 return name_scopes;
104 }
105
TfOpEventName(const TfOp & tf_op)106 std::string TfOpEventName(const TfOp& tf_op) {
107 std::string event_name;
108 if (tf_op.category == Category::kUnknown) {
109 // Some TraceMe names contain trailing whitespace, remove it.
110 event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name));
111 } else if (tf_op.category == Category::kTfData) {
112 event_name = DatasetOpEventName(tf_op.name);
113 } else {
114 event_name = std::string(tf_op.type);
115 }
116 return event_name;
117 }
118
TfOpEventName(absl::string_view tf_op_fullname)119 std::string TfOpEventName(absl::string_view tf_op_fullname) {
120 return TfOpEventName(ParseTfOpFullname(tf_op_fullname));
121 }
122
DatasetOpEventName(absl::string_view full_name)123 std::string DatasetOpEventName(absl::string_view full_name) {
124 std::vector<absl::string_view> split_result =
125 absl::StrSplit(full_name, kSeparator);
126 return absl::StrCat(kIterator, kSeparator, split_result.back());
127 }
128
IteratorName(absl::string_view full_name)129 std::string IteratorName(absl::string_view full_name) {
130 std::vector<absl::string_view> split_result =
131 absl::StrSplit(full_name, kSeparator);
132 return std::string(split_result.back());
133 }
134
ParseTensorShapes(absl::string_view tensor_shapes)135 std::vector<absl::string_view> ParseTensorShapes(
136 absl::string_view tensor_shapes) {
137 absl::ConsumePrefix(&tensor_shapes, "(");
138 absl::ConsumeSuffix(&tensor_shapes, ")");
139 return absl::StrSplit(tensor_shapes, ';');
140 }
141
142 } // namespace profiler
143 } // namespace tensorflow
144