• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
17 
18 #include <map>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 
23 namespace tensorflow {
24 
25 // Eager op rewrites should inherit from this class and
26 // implement the Run method.
27 class EagerOpRewrite {
28  public:
EagerOpRewrite(string name,string file,string line)29   EagerOpRewrite(string name, string file, string line) {
30     debug_info_.name = name;
31     debug_info_.file = file;
32     debug_info_.line = line;
33   }
34 
~EagerOpRewrite()35   virtual ~EagerOpRewrite() {}
36 
37   // To be implemented by an Eager op rewrite pass.
38   virtual Status Run(EagerOperation* orig_op,
39                      std::unique_ptr<tensorflow::EagerOperation>* out_op) = 0;
40 
41   // Holds information about the rewrite registration.
42   struct DebugInfo {
43     string name, file, line;
44   };
45 
46   // Returns information about the registered Eager op rewrite.
GetDebugInfo()47   DebugInfo GetDebugInfo() const { return debug_info_; }
48 
49  private:
50   DebugInfo debug_info_;
51 };
52 
53 class EagerOpRewriteRegistry {
54  public:
55   // Phases at which the Eager op rewrite pass should run.
56   // For now we only added PRE_EXECUTION. Expand as needed.
57   enum Phase {
58     PRE_EXECUTION = 0,  // right before executing an eager op
59     POST_PLACEMENT = 1  // after device placement
60   };
61 
62   // Add a rewrite pass to the registry.
63   // Only one rewrite pass is allowed per phase.
64   void Register(Phase phase, std::unique_ptr<EagerOpRewrite> pass);
65 
66   // Run the rewrite pass registered for a given phase.
67   Status RunRewrite(Phase phase, EagerOperation* orig_op,
68                     std::unique_ptr<tensorflow::EagerOperation>* out_op);
69 
70   // Returns the global registry of rewrite passes.
71   static EagerOpRewriteRegistry* Global();
72 
73  private:
74   static constexpr int32 kNumPhases = 2;
75   // Holds all the registered Eager op rewrites.
76   std::array<std::unique_ptr<EagerOpRewrite>, kNumPhases> rewrites_;
77 };
78 
79 namespace eager_rewrite_registration {
80 
81 // This class is used to register a new Eager Op rewrite.
82 class EagerRewriteRegistration {
83  public:
EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase,std::unique_ptr<EagerOpRewrite> pass)84   EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase,
85                            std::unique_ptr<EagerOpRewrite> pass) {
86     EagerOpRewriteRegistry::Global()->Register(phase, std::move(pass));
87   }
88 };
89 
90 }  // namespace eager_rewrite_registration
91 
92 #define REGISTER_REWRITE(phase, rewrite) \
93   REGISTER_REWRITE_UNIQ_HELPER(__COUNTER__, __FILE__, __LINE__, phase, rewrite)
94 
95 #define REGISTER_REWRITE_UNIQ_HELPER(ctr, file, line, phase, rewrite) \
96   REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite)
97 
98 #define REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite)                \
99   static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration   \
100       register_rewrite_##ctr(phase,                                           \
101                              ::std::unique_ptr<::tensorflow::EagerOpRewrite>( \
102                                  new rewrite(#rewrite, file, #line)))
103 
104 }  // namespace tensorflow
105 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_
106