1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19
20 namespace tensorflow {
XlaResourceOpKindToString(XlaResourceOpKind op_kind)21 /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
22 XlaResourceOpKind op_kind) {
23 switch (op_kind) {
24 case XlaResourceOpKind::kRead:
25 return "Read";
26 case XlaResourceOpKind::kWrite:
27 return "Write";
28 case XlaResourceOpKind::kReadWrite:
29 return "Modify";
30 }
31 }
32
33 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap()34 CreateResourceOpInfoMap() {
35 auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
36
37 auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
38 XlaResourceKind resource_kind) {
39 auto insert_result =
40 result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
41 CHECK(insert_result.second);
42 };
43
44 auto kRead = XlaResourceOpKind::kRead;
45 auto kWrite = XlaResourceOpKind::kWrite;
46 auto kReadWrite = XlaResourceOpKind::kReadWrite;
47
48 auto kVariable = XlaResourceKind::kVariable;
49 auto kStack = XlaResourceKind::kStack;
50 auto kTensorArray = XlaResourceKind::kTensorArray;
51
52 // clang-format off
53 add("AssignAddVariableOp" , kReadWrite, kVariable);
54 add("AssignSubVariableOp" , kReadWrite, kVariable);
55 add("AssignVariableOp" , kWrite, kVariable);
56 add("ReadVariableOp" , kRead, kVariable);
57 add("ResourceApplyAdaMax" , kReadWrite, kVariable);
58 add("ResourceApplyAdadelta" , kReadWrite, kVariable);
59 add("ResourceApplyAdagrad" , kReadWrite, kVariable);
60 add("ResourceApplyAdagradV2" , kReadWrite, kVariable),
61 add("ResourceApplyAdagradDA" , kReadWrite, kVariable);
62 add("ResourceApplyAdam" , kReadWrite, kVariable);
63 add("ResourceApplyAddSign" , kReadWrite, kVariable);
64 add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable);
65 add("ResourceApplyFtrl" , kReadWrite, kVariable);
66 add("ResourceApplyFtrlV2" , kReadWrite, kVariable);
67 add("ResourceApplyGradientDescent" , kReadWrite, kVariable);
68 add("ResourceApplyMomentum" , kReadWrite, kVariable);
69 add("ResourceApplyKerasMomentum" , kReadWrite, kVariable);
70 add("ResourceApplyPowerSign" , kReadWrite, kVariable);
71 add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable);
72 add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
73 add("ResourceApplyRMSProp" , kReadWrite, kVariable);
74 add("ResourceGather" , kRead, kVariable);
75 add("ResourceScatterAdd" , kReadWrite, kVariable);
76 add("ResourceScatterDiv" , kReadWrite, kVariable);
77 add("ResourceScatterMax" , kReadWrite, kVariable);
78 add("ResourceScatterMin" , kReadWrite, kVariable);
79 add("ResourceScatterMul" , kReadWrite, kVariable);
80 add("ResourceScatterNdAdd" , kReadWrite, kVariable);
81 add("ResourceScatterNdSub" , kReadWrite, kVariable);
82 add("ResourceScatterNdUpdate" , kReadWrite, kVariable);
83 add("ResourceScatterSub" , kReadWrite, kVariable);
84 add("ResourceScatterUpdate" , kReadWrite, kVariable);
85 add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
86 add("RngReadAndSkip" , kReadWrite, kVariable);
87 add("RngSkip" , kReadWrite, kVariable);
88 add("StatefulStandardNormalV2" , kReadWrite, kVariable);
89 add("StatefulTruncatedNormal" , kReadWrite, kVariable);
90 add("StatefulUniform" , kReadWrite, kVariable);
91 add("StatefulUniformFullInt" , kReadWrite, kVariable);
92 add("StatefulUniformInt" , kReadWrite, kVariable);
93 add("VarIsInitializedOp" , kRead, kVariable);
94 add("VariableShape" , kRead, kVariable);
95
96 add("StackV2" , kWrite, kStack);
97 add("StackCloseV2" , kRead, kStack);
98 add("StackPopV2" , kReadWrite, kStack);
99 add("StackPushV2" , kReadWrite, kStack);
100
101 add("TensorArrayV3" , kWrite, kTensorArray);
102 add("TensorArrayConcatV3" , kRead, kTensorArray);
103 add("TensorArrayGatherV3" , kRead, kTensorArray);
104 add("TensorArrayScatterV3" , kWrite, kTensorArray);
105 add("TensorArrayGradV3" , kRead, kTensorArray);
106 add("TensorArrayCloseV3" , kRead, kTensorArray);
107 add("TensorArrayReadV3" , kRead, kTensorArray);
108 add("TensorArraySizeV3" , kRead, kTensorArray);
109 add("TensorArraySplitV3" , kWrite, kTensorArray);
110 add("TensorArrayWriteV3" , kWrite, kTensorArray);
111 // clang-format on
112
113 return result;
114 }
115
116 static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap()117 GetStaticResourceOpInfoMap() {
118 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
119 op_info_map = CreateResourceOpInfoMap();
120 return *op_info_map;
121 }
122
GetResourceOpInfoForOp(absl::string_view op)123 const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
124 const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
125 GetStaticResourceOpInfoMap();
126 auto it = op_infos.find(op);
127 return it == op_infos.end() ? nullptr : &it->second;
128 }
129
130 namespace resource_op_table_internal {
GetKnownResourceOps()131 std::vector<absl::string_view> GetKnownResourceOps() {
132 std::vector<absl::string_view> result;
133 for (const auto& p : GetStaticResourceOpInfoMap()) {
134 result.push_back(p.first);
135 }
136 absl::c_sort(result);
137 return result;
138 }
139 } // namespace resource_op_table_internal
140 } // namespace tensorflow
141