• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 
20 namespace tensorflow {
XlaResourceOpKindToString(XlaResourceOpKind op_kind)21 /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
22     XlaResourceOpKind op_kind) {
23   switch (op_kind) {
24     case XlaResourceOpKind::kRead:
25       return "Read";
26     case XlaResourceOpKind::kWrite:
27       return "Write";
28     case XlaResourceOpKind::kReadWrite:
29       return "Modify";
30   }
31 }
32 
33 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap()34 CreateResourceOpInfoMap() {
35   auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
36 
37   auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
38                  XlaResourceKind resource_kind) {
39     auto insert_result =
40         result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
41     CHECK(insert_result.second);
42   };
43 
44   auto kRead = XlaResourceOpKind::kRead;
45   auto kWrite = XlaResourceOpKind::kWrite;
46   auto kReadWrite = XlaResourceOpKind::kReadWrite;
47 
48   auto kVariable = XlaResourceKind::kVariable;
49   auto kStack = XlaResourceKind::kStack;
50   auto kTensorArray = XlaResourceKind::kTensorArray;
51 
52   // clang-format off
53   add("AssignAddVariableOp"                  , kReadWrite, kVariable);
54   add("AssignSubVariableOp"                  , kReadWrite, kVariable);
55   add("AssignVariableOp"                     , kWrite,     kVariable);
56   add("ReadVariableOp"                       , kRead,      kVariable);
57   add("ResourceApplyAdaMax"                  , kReadWrite, kVariable);
58   add("ResourceApplyAdadelta"                , kReadWrite, kVariable);
59   add("ResourceApplyAdagrad"                 , kReadWrite, kVariable);
60   add("ResourceApplyAdagradDA"               , kReadWrite, kVariable);
61   add("ResourceApplyAdam"                    , kReadWrite, kVariable);
62   add("ResourceApplyAddSign"                 , kReadWrite, kVariable);
63   add("ResourceApplyCenteredRMSProp"         , kReadWrite, kVariable);
64   add("ResourceApplyFtrl"                    , kReadWrite, kVariable);
65   add("ResourceApplyFtrlV2"                  , kReadWrite, kVariable);
66   add("ResourceApplyGradientDescent"         , kReadWrite, kVariable);
67   add("ResourceApplyMomentum"                , kReadWrite, kVariable);
68   add("ResourceApplyKerasMomentum"           , kReadWrite, kVariable);
69   add("ResourceApplyPowerSign"               , kReadWrite, kVariable);
70   add("ResourceApplyProximalAdagrad"         , kReadWrite, kVariable);
71   add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
72   add("ResourceApplyRMSProp"                 , kReadWrite, kVariable);
73   add("ResourceGather"                       , kRead,      kVariable);
74   add("ResourceScatterAdd"                   , kReadWrite, kVariable);
75   add("ResourceScatterDiv"                   , kReadWrite, kVariable);
76   add("ResourceScatterMax"                   , kReadWrite, kVariable);
77   add("ResourceScatterMin"                   , kReadWrite, kVariable);
78   add("ResourceScatterMul"                   , kReadWrite, kVariable);
79   add("ResourceScatterNdAdd"                 , kReadWrite, kVariable);
80   add("ResourceScatterNdSub"                 , kReadWrite, kVariable);
81   add("ResourceScatterNdUpdate"              , kReadWrite, kVariable);
82   add("ResourceScatterSub"                   , kReadWrite, kVariable);
83   add("ResourceScatterUpdate"                , kReadWrite, kVariable);
84   add("ResourceStridedSliceAssign"           , kReadWrite, kVariable);
85   add("StatefulStandardNormalV2"             , kReadWrite, kVariable);
86   add("StatefulUniformFullInt"               , kReadWrite, kVariable);
87   add("StatefulUniformInt"                   , kReadWrite, kVariable);
88   add("VarIsInitializedOp"                   , kRead,      kVariable);
89   add("VariableShape"                        , kRead,      kVariable);
90 
91   add("StackV2"                              , kWrite,     kStack);
92   add("StackCloseV2"                         , kRead,      kStack);
93   add("StackPopV2"                           , kReadWrite, kStack);
94   add("StackPushV2"                          , kReadWrite, kStack);
95 
96   add("TensorArrayV3"                        , kWrite,     kTensorArray);
97   add("TensorArrayConcatV3"                  , kRead,      kTensorArray);
98   add("TensorArrayGatherV3"                  , kRead,      kTensorArray);
99   add("TensorArrayScatterV3"                 , kWrite,     kTensorArray);
100   add("TensorArrayGradV3"                    , kRead,      kTensorArray);
101   add("TensorArrayCloseV3"                   , kRead,      kTensorArray);
102   add("TensorArrayReadV3"                    , kRead,      kTensorArray);
103   add("TensorArraySizeV3"                    , kRead,      kTensorArray);
104   add("TensorArraySplitV3"                   , kWrite,     kTensorArray);
105   add("TensorArrayWriteV3"                   , kWrite,     kTensorArray);
106   // clang-format on
107 
108   return result;
109 }
110 
111 static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap()112 GetStaticResourceOpInfoMap() {
113   static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
114       op_info_map = CreateResourceOpInfoMap();
115   return *op_info_map;
116 }
117 
GetResourceOpInfoForOp(absl::string_view op)118 const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
119   const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
120       GetStaticResourceOpInfoMap();
121   auto it = op_infos.find(op);
122   return it == op_infos.end() ? nullptr : &it->second;
123 }
124 
125 namespace resource_op_table_internal {
GetKnownResourceOps()126 std::vector<absl::string_view> GetKnownResourceOps() {
127   std::vector<absl::string_view> result;
128   for (const auto& p : GetStaticResourceOpInfoMap()) {
129     result.push_back(p.first);
130   }
131   absl::c_sort(result);
132   return result;
133 }
134 }  // namespace resource_op_table_internal
135 }  // namespace tensorflow
136