1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19
20 namespace tensorflow {
XlaResourceOpKindToString(XlaResourceOpKind op_kind)21 /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
22 XlaResourceOpKind op_kind) {
23 switch (op_kind) {
24 case XlaResourceOpKind::kRead:
25 return "Read";
26 case XlaResourceOpKind::kWrite:
27 return "Write";
28 case XlaResourceOpKind::kReadWrite:
29 return "Modify";
30 }
31 }
32
33 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap()34 CreateResourceOpInfoMap() {
35 auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
36
37 auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
38 XlaResourceKind resource_kind) {
39 auto insert_result =
40 result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
41 CHECK(insert_result.second);
42 };
43
44 auto kRead = XlaResourceOpKind::kRead;
45 auto kWrite = XlaResourceOpKind::kWrite;
46 auto kReadWrite = XlaResourceOpKind::kReadWrite;
47
48 auto kVariable = XlaResourceKind::kVariable;
49 auto kStack = XlaResourceKind::kStack;
50 auto kTensorArray = XlaResourceKind::kTensorArray;
51
52 // clang-format off
53 add("AssignAddVariableOp" , kReadWrite, kVariable);
54 add("AssignSubVariableOp" , kReadWrite, kVariable);
55 add("AssignVariableOp" , kWrite, kVariable);
56 add("ReadVariableOp" , kRead, kVariable);
57 add("ResourceApplyAdaMax" , kReadWrite, kVariable);
58 add("ResourceApplyAdadelta" , kReadWrite, kVariable);
59 add("ResourceApplyAdagrad" , kReadWrite, kVariable);
60 add("ResourceApplyAdagradDA" , kReadWrite, kVariable);
61 add("ResourceApplyAdam" , kReadWrite, kVariable);
62 add("ResourceApplyAddSign" , kReadWrite, kVariable);
63 add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable);
64 add("ResourceApplyFtrl" , kReadWrite, kVariable);
65 add("ResourceApplyFtrlV2" , kReadWrite, kVariable);
66 add("ResourceApplyGradientDescent" , kReadWrite, kVariable);
67 add("ResourceApplyMomentum" , kReadWrite, kVariable);
68 add("ResourceApplyKerasMomentum" , kReadWrite, kVariable);
69 add("ResourceApplyPowerSign" , kReadWrite, kVariable);
70 add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable);
71 add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
72 add("ResourceApplyRMSProp" , kReadWrite, kVariable);
73 add("ResourceGather" , kRead, kVariable);
74 add("ResourceScatterAdd" , kReadWrite, kVariable);
75 add("ResourceScatterDiv" , kReadWrite, kVariable);
76 add("ResourceScatterMax" , kReadWrite, kVariable);
77 add("ResourceScatterMin" , kReadWrite, kVariable);
78 add("ResourceScatterMul" , kReadWrite, kVariable);
79 add("ResourceScatterNdAdd" , kReadWrite, kVariable);
80 add("ResourceScatterNdSub" , kReadWrite, kVariable);
81 add("ResourceScatterNdUpdate" , kReadWrite, kVariable);
82 add("ResourceScatterSub" , kReadWrite, kVariable);
83 add("ResourceScatterUpdate" , kReadWrite, kVariable);
84 add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
85 add("StatefulStandardNormalV2" , kReadWrite, kVariable);
86 add("StatefulUniformFullInt" , kReadWrite, kVariable);
87 add("StatefulUniformInt" , kReadWrite, kVariable);
88 add("VarIsInitializedOp" , kRead, kVariable);
89 add("VariableShape" , kRead, kVariable);
90
91 add("StackV2" , kWrite, kStack);
92 add("StackCloseV2" , kRead, kStack);
93 add("StackPopV2" , kReadWrite, kStack);
94 add("StackPushV2" , kReadWrite, kStack);
95
96 add("TensorArrayV3" , kWrite, kTensorArray);
97 add("TensorArrayConcatV3" , kRead, kTensorArray);
98 add("TensorArrayGatherV3" , kRead, kTensorArray);
99 add("TensorArrayScatterV3" , kWrite, kTensorArray);
100 add("TensorArrayGradV3" , kRead, kTensorArray);
101 add("TensorArrayCloseV3" , kRead, kTensorArray);
102 add("TensorArrayReadV3" , kRead, kTensorArray);
103 add("TensorArraySizeV3" , kRead, kTensorArray);
104 add("TensorArraySplitV3" , kWrite, kTensorArray);
105 add("TensorArrayWriteV3" , kWrite, kTensorArray);
106 // clang-format on
107
108 return result;
109 }
110
111 static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap()112 GetStaticResourceOpInfoMap() {
113 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
114 op_info_map = CreateResourceOpInfoMap();
115 return *op_info_map;
116 }
117
GetResourceOpInfoForOp(absl::string_view op)118 const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
119 const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
120 GetStaticResourceOpInfoMap();
121 auto it = op_infos.find(op);
122 return it == op_infos.end() ? nullptr : &it->second;
123 }
124
125 namespace resource_op_table_internal {
GetKnownResourceOps()126 std::vector<absl::string_view> GetKnownResourceOps() {
127 std::vector<absl::string_view> result;
128 for (const auto& p : GetStaticResourceOpInfoMap()) {
129 result.push_back(p.first);
130 }
131 absl::c_sort(result);
132 return result;
133 }
134 } // namespace resource_op_table_internal
135 } // namespace tensorflow
136