• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
17 #define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
18 
19 #include "tensorflow/c/conversion_macros.h"
20 #include "tensorflow/c/eager/abstract_context.h"
21 #include "tensorflow/c/eager/abstract_operation.h"
22 #include "tensorflow/c/eager/abstract_tensor_handle.h"
23 #include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
24 #include "tensorflow/core/platform/refcount.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace tensorflow {
28 
29 class OpHandlerOperation;
30 
31 // Op handlers are a convenient way to intercept and transform computation.
32 //
33 // The implementation is currently experimental and incomplete, but aims
34 // eventually to support tracing and replay of function bodies, gradients
35 // through copy operations, and a variety of hooks for things like debug
36 // strings. A public C API for op handlers is planned.
37 class OpHandler : public core::RefCounted {
38  public:
39   // Called on operation->Execute when operation->get_handler() == this.
40   //
41   // Allows the handler to customize or inspect `operation`'s execution.
42   virtual Status Execute(OpHandlerOperation* operation,
43                          absl::Span<AbstractTensorHandle*> retvals,
44                          int* num_retvals) = 0;
45   // Creates a new handler by merging this handler with `next_handler`.
46   //
47   // The new handler is expected to transform operations first with this handler
48   // and then execute the resulting operations on `next_handler` (by calling
49   // `OpHandlerOperation::set_handler` and passing `next_handler`). If this is
50   // not possible then the merge operation should fail.
51   virtual Status Merge(OpHandler* next_handler,
52                        core::RefCountPtr<OpHandler>& merged_handler) = 0;
53 };
54 
55 // Keeps some handler-specific metadata, but otherwise wraps a single
56 // AbstractOperation in the underlying context. The operation is created, its
57 // attributes set, etc., and at execution time it is presented to its handler,
58 // which may choose to execute it or simply inspect it and do something else.
59 //
60 // This is somewhat different than the Context approach, where the operation's
61 // construction is streamed through each layered Context. The streaming approach
62 // would require a much larger op handler public API, one function pointer per
63 // attribute type, and there is some ambiguity before an op is finalized about
64 // whether it should be presented as-is to handlers (regular operations) or
65 // replayed (function calls and control flow operations).
66 class OpHandlerOperation : public WrapperOperation {
67  public:
68   explicit OpHandlerOperation(AbstractOperation*);
69   OpHandler* get_handler();
70   void set_handler(OpHandler* handler);
71   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
72                  int* num_retvals) override;
73 
74  protected:
75   core::RefCountPtr<OpHandler> handler_;
76 };
77 
78 // A context which allows a default handler to be set for new operations. It
79 // otherwise defers to the context it wraps.
80 //
81 // TODO(allenl): A stack of contexts and a stack of handlers look pretty similar
82 // in some ways. Having each handler be its own context seems almost doable,
83 // with things like copy operations and function/control flow replay being
84 // somewhat tricky (since they should be generated at the top of the handler
85 // stack and "caught" at the bottom). After handlers have evolved for a bit we
86 // should re-evaluate whether the handler+context concepts can be merged.
87 class OpHandlerContext : public AbstractContext {
88  public:
89   explicit OpHandlerContext(AbstractContext*);
90   void Release() override;
91   OpHandlerOperation* CreateOperation() override;
92   Status RegisterFunction(AbstractFunction*) override;
93   Status RemoveFunction(const string&) override;
94   // For LLVM style RTTI.
classof(const AbstractContext * ptr)95   static bool classof(const AbstractContext* ptr) {
96     return ptr->getKind() == kOpHandler;
97   }
98   ~OpHandlerContext() override;
99 
100   void set_default_handler(OpHandler* handler);
101 
102  private:
103   AbstractContext* parent_ctx_;  // Not owned.
104   core::RefCountPtr<OpHandler> default_handler_;
105 };
106 
107 class ReleaseOpHandlerOperation {
108  public:
operator()109   void operator()(OpHandlerOperation* operation) { operation->Release(); }
110 };
111 
112 typedef std::unique_ptr<OpHandlerOperation, ReleaseOpHandlerOperation>
113     OpHandlerOperationPtr;
114 
115 }  // namespace tensorflow
116 
117 #endif  // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
118