1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ 17 #define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ 18 19 #include "tensorflow/c/conversion_macros.h" 20 #include "tensorflow/c/eager/abstract_context.h" 21 #include "tensorflow/c/eager/abstract_operation.h" 22 #include "tensorflow/c/eager/abstract_tensor_handle.h" 23 #include "tensorflow/c/experimental/op_handler/wrapper_operation.h" 24 #include "tensorflow/core/platform/refcount.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace tensorflow { 28 29 class OpHandlerOperation; 30 31 // Op handlers are a convenient way to intercept and transform computation. 32 // 33 // The implementation is currently experimental and incomplete, but aims 34 // eventually to support tracing and replay of function bodies, gradients 35 // through copy operations, and a variety of hooks for things like debug 36 // strings. A public C API for op handlers is planned. 37 class OpHandler : public core::RefCounted { 38 public: 39 // Called on operation->Execute when operation->get_handler() == this. 40 // 41 // Allows the handler to customize or inspect `operation`'s execution. 42 virtual Status Execute(OpHandlerOperation* operation, 43 absl::Span<AbstractTensorHandle*> retvals, 44 int* num_retvals) = 0; 45 // Creates a new handler by merging this handler with `next_handler`. 46 // 47 // The new handler is expected to transform operations first with this handler 48 // and then execute the resulting operations on `next_handler` (by calling 49 // `OpHandlerOperation::set_handler` and passing `next_handler`). If this is 50 // not possible then the merge operation should fail. 51 virtual Status Merge(OpHandler* next_handler, 52 core::RefCountPtr<OpHandler>& merged_handler) = 0; 53 }; 54 55 // Keeps some handler-specific metadata, but otherwise wraps a single 56 // AbstractOperation in the underlying context. The operation is created, its 57 // attributes set, etc., and at execution time it is presented to its handler, 58 // which may choose to execute it or simply inspect it and do something else. 59 // 60 // This is somewhat different than the Context approach, where the operation's 61 // construction is streamed through each layered Context. The streaming approach 62 // would require a much larger op handler public API, one function pointer per 63 // attribute type, and there is some ambiguity before an op is finalized about 64 // whether it should be presented as-is to handlers (regular operations) or 65 // replayed (function calls and control flow operations). 66 class OpHandlerOperation : public WrapperOperation { 67 public: 68 explicit OpHandlerOperation(AbstractOperation*); 69 OpHandler* get_handler(); 70 void set_handler(OpHandler* handler); 71 Status Execute(absl::Span<AbstractTensorHandle*> retvals, 72 int* num_retvals) override; 73 74 protected: 75 core::RefCountPtr<OpHandler> handler_; 76 }; 77 78 // A context which allows a default handler to be set for new operations. It 79 // otherwise defers to the context it wraps. 80 // 81 // TODO(allenl): A stack of contexts and a stack of handlers look pretty similar 82 // in some ways. Having each handler be its own context seems almost doable, 83 // with things like copy operations and function/control flow replay being 84 // somewhat tricky (since they should be generated at the top of the handler 85 // stack and "caught" at the bottom). After handlers have evolved for a bit we 86 // should re-evaluate whether the handler+context concepts can be merged. 87 class OpHandlerContext : public AbstractContext { 88 public: 89 explicit OpHandlerContext(AbstractContext*); 90 void Release() override; 91 OpHandlerOperation* CreateOperation() override; 92 Status RegisterFunction(AbstractFunction*) override; 93 Status RemoveFunction(const string&) override; 94 // For LLVM style RTTI. classof(const AbstractContext * ptr)95 static bool classof(const AbstractContext* ptr) { 96 return ptr->getKind() == kOpHandler; 97 } 98 ~OpHandlerContext() override; 99 100 void set_default_handler(OpHandler* handler); 101 102 private: 103 AbstractContext* parent_ctx_; // Not owned. 104 core::RefCountPtr<OpHandler> default_handler_; 105 }; 106 107 class ReleaseOpHandlerOperation { 108 public: operator()109 void operator()(OpHandlerOperation* operation) { operation->Release(); } 110 }; 111 112 typedef std::unique_ptr<OpHandlerOperation, ReleaseOpHandlerOperation> 113 OpHandlerOperationPtr; 114 115 } // namespace tensorflow 116 117 #endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ 118