1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19
20 namespace tensorflow {
XlaResourceOpKindToString(XlaResourceOpKind op_kind)21 /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
22 XlaResourceOpKind op_kind) {
23 switch (op_kind) {
24 case XlaResourceOpKind::kRead:
25 return "Read";
26 case XlaResourceOpKind::kWrite:
27 return "Write";
28 case XlaResourceOpKind::kReadWrite:
29 return "Modify";
30 }
31 }
32
33 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap()34 CreateResourceOpInfoMap() {
35 auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
36
37 auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
38 XlaResourceKind resource_kind) {
39 auto insert_result =
40 result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
41 CHECK(insert_result.second);
42 };
43
44 auto kRead = XlaResourceOpKind::kRead;
45 auto kWrite = XlaResourceOpKind::kWrite;
46 auto kReadWrite = XlaResourceOpKind::kReadWrite;
47
48 auto kVariable = XlaResourceKind::kVariable;
49 auto kStack = XlaResourceKind::kStack;
50 auto kTensorArray = XlaResourceKind::kTensorArray;
51
52 // clang-format off
53 add("AssignAddVariableOp" , kReadWrite, kVariable);
54 add("AssignSubVariableOp" , kReadWrite, kVariable);
55 add("AssignVariableOp" , kWrite, kVariable);
56 add("CollectiveReduceV2" , kRead, kVariable);
57 add("ReadVariableOp" , kRead, kVariable);
58 add("ResourceApplyAdaMax" , kReadWrite, kVariable);
59 add("ResourceApplyAdadelta" , kReadWrite, kVariable);
60 add("ResourceApplyAdagrad" , kReadWrite, kVariable);
61 add("ResourceApplyAdagradV2" , kReadWrite, kVariable),
62 add("ResourceApplyAdagradDA" , kReadWrite, kVariable);
63 add("ResourceApplyAdam" , kReadWrite, kVariable);
64 add("ResourceApplyAddSign" , kReadWrite, kVariable);
65 add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable);
66 add("ResourceApplyFtrl" , kReadWrite, kVariable);
67 add("ResourceApplyFtrlV2" , kReadWrite, kVariable);
68 add("ResourceApplyGradientDescent" , kReadWrite, kVariable);
69 add("ResourceApplyMomentum" , kReadWrite, kVariable);
70 add("ResourceApplyKerasMomentum" , kReadWrite, kVariable);
71 add("ResourceApplyPowerSign" , kReadWrite, kVariable);
72 add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable);
73 add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
74 add("ResourceApplyRMSProp" , kReadWrite, kVariable);
75 add("ResourceGather" , kRead, kVariable);
76 add("ResourceScatterAdd" , kReadWrite, kVariable);
77 add("ResourceScatterDiv" , kReadWrite, kVariable);
78 add("ResourceScatterMax" , kReadWrite, kVariable);
79 add("ResourceScatterMin" , kReadWrite, kVariable);
80 add("ResourceScatterMul" , kReadWrite, kVariable);
81 add("ResourceScatterNdAdd" , kReadWrite, kVariable);
82 add("ResourceScatterNdSub" , kReadWrite, kVariable);
83 add("ResourceScatterNdUpdate" , kReadWrite, kVariable);
84 add("ResourceScatterSub" , kReadWrite, kVariable);
85 add("ResourceScatterUpdate" , kReadWrite, kVariable);
86 add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
87 add("RngReadAndSkip" , kReadWrite, kVariable);
88 add("RngSkip" , kReadWrite, kVariable);
89 add("StatefulStandardNormalV2" , kReadWrite, kVariable);
90 add("StatefulTruncatedNormal" , kReadWrite, kVariable);
91 add("StatefulUniform" , kReadWrite, kVariable);
92 add("StatefulUniformFullInt" , kReadWrite, kVariable);
93 add("StatefulUniformInt" , kReadWrite, kVariable);
94 add("VarIsInitializedOp" , kRead, kVariable);
95 add("VariableShape" , kRead, kVariable);
96
97 add("StackV2" , kWrite, kStack);
98 add("StackCloseV2" , kRead, kStack);
99 add("StackPopV2" , kReadWrite, kStack);
100 add("StackPushV2" , kReadWrite, kStack);
101
102 add("TensorArrayV3" , kWrite, kTensorArray);
103 add("TensorArrayConcatV3" , kRead, kTensorArray);
104 add("TensorArrayGatherV3" , kRead, kTensorArray);
105 add("TensorArrayScatterV3" , kWrite, kTensorArray);
106 add("TensorArrayGradV3" , kRead, kTensorArray);
107 add("TensorArrayCloseV3" , kRead, kTensorArray);
108 add("TensorArrayReadV3" , kRead, kTensorArray);
109 add("TensorArraySizeV3" , kRead, kTensorArray);
110 add("TensorArraySplitV3" , kWrite, kTensorArray);
111 add("TensorArrayWriteV3" , kWrite, kTensorArray);
112 // clang-format on
113
114 return result;
115 }
116
117 static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap()118 GetStaticResourceOpInfoMap() {
119 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
120 op_info_map = CreateResourceOpInfoMap();
121 return *op_info_map;
122 }
123
GetResourceOpInfoForOp(absl::string_view op)124 const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
125 const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
126 GetStaticResourceOpInfoMap();
127 auto it = op_infos.find(op);
128 return it == op_infos.end() ? nullptr : &it->second;
129 }
130
131 namespace resource_op_table_internal {
GetKnownResourceOps()132 std::vector<absl::string_view> GetKnownResourceOps() {
133 std::vector<absl::string_view> result;
134 for (const auto& p : GetStaticResourceOpInfoMap()) {
135 result.push_back(p.first);
136 }
137 absl::c_sort(result);
138 return result;
139 }
140 } // namespace resource_op_table_internal
141 } // namespace tensorflow
142