1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_PYTHON_API_H_ 17 #define TENSORFLOW_C_PYTHON_API_H_ 18 19 #include <string> 20 21 #include "tensorflow/c/c_api.h" 22 23 // These functions can be removed without notice. They exist to facilitate some 24 // refactoring of graph construction code in the Python API. 25 26 namespace tensorflow { 27 28 void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); 29 30 // Changes an attr value in the node_def Protocol Buffer and sets a status upon 31 // completion. 32 void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, 33 TF_Buffer* attr_value_proto, TF_Status* status); 34 35 // Clears the attr in the node_def Protocol Buffer and sets a status upon 36 // completion. 37 void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, 38 TF_Status* status); 39 40 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); 41 42 // Updates 'dst' to consume 'new_src'. 43 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, 44 TF_Status* status); 45 46 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); 47 48 // Sets whether ops missing a shape inference function should trigger an 49 // error. The default is true. 50 void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); 51 52 // Extends `session` with any new operations added to its associated graph. 53 // Usually this happens automatically in TF_SessionRun. After this is called, 54 // TF_SessionRun will no longer extend the session on every call. 55 // 56 // We expose this here to allow fine-grained synchronization in multi-threaded 57 // workloads, which is required since the Python implementation depends on the 58 // above mutation methods. This allows us to prevent modifications to nodes in 59 // the graph after the session has been made aware of them. 60 void ExtendSession(TF_Session* session, TF_Status* status); 61 62 // Returns the serialized CppShapeInferenceResult::HandleData proto for 63 // `output` if its a resource or variant tensor, or otherwise returns the empty 64 // string. 65 std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); 66 67 // Sets `output` based on `proto`, which should be a serialized 68 // CppShapeInferenceResult::HandleData proto. `output` should be a resource 69 // or variant tensor. 70 // NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string 71 // because I couldn't get SWIG to work otherwise. 72 void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, 73 size_t proto_len, TF_Status* status); 74 75 // This method is used to add a new input edge to 'dst', which must be a While 76 // op. The While op's "T" attribute must have already been updated to include 77 // the new edge. This is used to construct tf.while_loop gradients. 78 void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, 79 TF_Status* status); 80 81 } // namespace tensorflow 82 83 #endif // TENSORFLOW_C_PYTHON_API_H_ 84