• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 
20 namespace tensorflow {
XlaResourceOpKindToString(XlaResourceOpKind op_kind)21 /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
22     XlaResourceOpKind op_kind) {
23   switch (op_kind) {
24     case XlaResourceOpKind::kRead:
25       return "Read";
26     case XlaResourceOpKind::kWrite:
27       return "Write";
28     case XlaResourceOpKind::kReadWrite:
29       return "Modify";
30   }
31 }
32 
33 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap()34 CreateResourceOpInfoMap() {
35   auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
36 
37   auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
38                  XlaResourceKind resource_kind) {
39     auto insert_result =
40         result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
41     CHECK(insert_result.second);
42   };
43 
44   auto kRead = XlaResourceOpKind::kRead;
45   auto kWrite = XlaResourceOpKind::kWrite;
46   auto kReadWrite = XlaResourceOpKind::kReadWrite;
47 
48   auto kVariable = XlaResourceKind::kVariable;
49   auto kStack = XlaResourceKind::kStack;
50   auto kTensorArray = XlaResourceKind::kTensorArray;
51 
52   // clang-format off
53   add("AssignAddVariableOp"                  , kReadWrite, kVariable);
54   add("AssignSubVariableOp"                  , kReadWrite, kVariable);
55   add("AssignVariableOp"                     , kWrite,     kVariable);
56   add("CollectiveReduceV2"                   , kRead,      kVariable);
57   add("ReadVariableOp"                       , kRead,      kVariable);
58   add("ResourceApplyAdaMax"                  , kReadWrite, kVariable);
59   add("ResourceApplyAdadelta"                , kReadWrite, kVariable);
60   add("ResourceApplyAdagrad"                 , kReadWrite, kVariable);
61   add("ResourceApplyAdagradV2"               , kReadWrite, kVariable),
62   add("ResourceApplyAdagradDA"               , kReadWrite, kVariable);
63   add("ResourceApplyAdam"                    , kReadWrite, kVariable);
64   add("ResourceApplyAddSign"                 , kReadWrite, kVariable);
65   add("ResourceApplyCenteredRMSProp"         , kReadWrite, kVariable);
66   add("ResourceApplyFtrl"                    , kReadWrite, kVariable);
67   add("ResourceApplyFtrlV2"                  , kReadWrite, kVariable);
68   add("ResourceApplyGradientDescent"         , kReadWrite, kVariable);
69   add("ResourceApplyMomentum"                , kReadWrite, kVariable);
70   add("ResourceApplyKerasMomentum"           , kReadWrite, kVariable);
71   add("ResourceApplyPowerSign"               , kReadWrite, kVariable);
72   add("ResourceApplyProximalAdagrad"         , kReadWrite, kVariable);
73   add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
74   add("ResourceApplyRMSProp"                 , kReadWrite, kVariable);
75   add("ResourceGather"                       , kRead,      kVariable);
76   add("ResourceScatterAdd"                   , kReadWrite, kVariable);
77   add("ResourceScatterDiv"                   , kReadWrite, kVariable);
78   add("ResourceScatterMax"                   , kReadWrite, kVariable);
79   add("ResourceScatterMin"                   , kReadWrite, kVariable);
80   add("ResourceScatterMul"                   , kReadWrite, kVariable);
81   add("ResourceScatterNdAdd"                 , kReadWrite, kVariable);
82   add("ResourceScatterNdSub"                 , kReadWrite, kVariable);
83   add("ResourceScatterNdUpdate"              , kReadWrite, kVariable);
84   add("ResourceScatterSub"                   , kReadWrite, kVariable);
85   add("ResourceScatterUpdate"                , kReadWrite, kVariable);
86   add("ResourceStridedSliceAssign"           , kReadWrite, kVariable);
87   add("RngReadAndSkip"                       , kReadWrite, kVariable);
88   add("RngSkip"                              , kReadWrite, kVariable);
89   add("StatefulStandardNormalV2"             , kReadWrite, kVariable);
90   add("StatefulTruncatedNormal"              , kReadWrite, kVariable);
91   add("StatefulUniform"                      , kReadWrite, kVariable);
92   add("StatefulUniformFullInt"               , kReadWrite, kVariable);
93   add("StatefulUniformInt"                   , kReadWrite, kVariable);
94   add("VarIsInitializedOp"                   , kRead,      kVariable);
95   add("VariableShape"                        , kRead,      kVariable);
96 
97   add("StackV2"                              , kWrite,     kStack);
98   add("StackCloseV2"                         , kRead,      kStack);
99   add("StackPopV2"                           , kReadWrite, kStack);
100   add("StackPushV2"                          , kReadWrite, kStack);
101 
102   add("TensorArrayV3"                        , kWrite,     kTensorArray);
103   add("TensorArrayConcatV3"                  , kRead,      kTensorArray);
104   add("TensorArrayGatherV3"                  , kRead,      kTensorArray);
105   add("TensorArrayScatterV3"                 , kWrite,     kTensorArray);
106   add("TensorArrayGradV3"                    , kRead,      kTensorArray);
107   add("TensorArrayCloseV3"                   , kRead,      kTensorArray);
108   add("TensorArrayReadV3"                    , kRead,      kTensorArray);
109   add("TensorArraySizeV3"                    , kRead,      kTensorArray);
110   add("TensorArraySplitV3"                   , kWrite,     kTensorArray);
111   add("TensorArrayWriteV3"                   , kWrite,     kTensorArray);
112   // clang-format on
113 
114   return result;
115 }
116 
117 static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap()118 GetStaticResourceOpInfoMap() {
119   static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
120       op_info_map = CreateResourceOpInfoMap();
121   return *op_info_map;
122 }
123 
GetResourceOpInfoForOp(absl::string_view op)124 const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
125   const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
126       GetStaticResourceOpInfoMap();
127   auto it = op_infos.find(op);
128   return it == op_infos.end() ? nullptr : &it->second;
129 }
130 
131 namespace resource_op_table_internal {
GetKnownResourceOps()132 std::vector<absl::string_view> GetKnownResourceOps() {
133   std::vector<absl::string_view> result;
134   for (const auto& p : GetStaticResourceOpInfoMap()) {
135     result.push_back(p.first);
136   }
137   absl::c_sort(result);
138   return result;
139 }
140 }  // namespace resource_op_table_internal
141 }  // namespace tensorflow
142