1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/python_api.h"
17
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/core/framework/full_type.pb.h"
20 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
21
22 namespace tensorflow {
23
AddControlInput(TF_Graph * graph,TF_Operation * op,TF_Operation * input)24 void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
25 mutex_lock l(graph->mu);
26 graph->graph.AddControlEdge(&input->node, &op->node);
27 RecordMutation(graph, *op, "adding control input");
28 }
29
SetAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Buffer * attr_value_proto,TF_Status * status)30 void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
31 TF_Buffer* attr_value_proto, TF_Status* status) {
32 AttrValue attr_val;
33 if (!attr_val.ParseFromArray(attr_value_proto->data,
34 attr_value_proto->length)) {
35 status->status =
36 tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
37 return;
38 }
39
40 mutex_lock l(graph->mu);
41 op->node.AddAttr(attr_name, attr_val);
42 RecordMutation(graph, *op, "setting attribute");
43 }
44
ClearAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Status * status)45 void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
46 TF_Status* status) {
47 mutex_lock l(graph->mu);
48 op->node.ClearAttr(attr_name);
49 RecordMutation(graph, *op, "clearing attribute");
50 }
51
SetFullType(TF_Graph * graph,TF_Operation * op,const FullTypeDef & full_type)52 void SetFullType(TF_Graph* graph, TF_Operation* op,
53 const FullTypeDef& full_type) {
54 mutex_lock l(graph->mu);
55 *op->node.mutable_def()->mutable_experimental_type() = full_type;
56 RecordMutation(graph, *op, "setting fulltype");
57 }
58
SetRequestedDevice(TF_Graph * graph,TF_Operation * op,const char * device)59 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
60 mutex_lock l(graph->mu);
61 op->node.set_requested_device(device);
62 RecordMutation(graph, *op, "setting device");
63 }
64
UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)65 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
66 TF_Status* status) {
67 TF_UpdateEdge(graph, new_src, dst, status);
68 }
69
RemoveAllControlInputs(TF_Graph * graph,TF_Operation * op)70 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
71 mutex_lock l(graph->mu);
72 std::vector<const Edge*> control_edges;
73 for (const Edge* edge : op->node.in_edges()) {
74 if (!edge->IsControlEdge()) continue;
75 control_edges.push_back(edge);
76 }
77 for (const Edge* edge : control_edges) {
78 graph->graph.RemoveControlEdge(edge);
79 }
80 }
81
SetRequireShapeInferenceFns(TF_Graph * graph,bool require)82 void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
83 mutex_lock l(graph->mu);
84 graph->refiner.set_require_shape_inference_fns(require);
85 }
86
ExtendSession(TF_Session * session,TF_Status * status)87 void ExtendSession(TF_Session* session, TF_Status* status) {
88 ExtendSessionGraphHelper(session, status);
89 session->extend_before_run = false;
90 }
91
GetHandleShapeAndType(TF_Graph * graph,TF_Output output)92 std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
93 Node* node = &output.oper->node;
94 CppShapeInferenceResult::HandleData handle_data;
95 handle_data.set_is_set(true);
96 {
97 mutex_lock l(graph->mu);
98 tensorflow::shape_inference::InferenceContext* ic =
99 graph->refiner.GetContext(node);
100 CHECK(ic != nullptr);
101 CHECK_LT(output.index, ic->num_outputs());
102 const auto* shapes_and_types =
103 ic->output_handle_shapes_and_types(output.index);
104 if (shapes_and_types == nullptr) return "";
105
106 for (const auto& p : *shapes_and_types) {
107 auto* out_shape_and_type = handle_data.add_shape_and_type();
108 ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
109 out_shape_and_type->set_dtype(p.dtype);
110 *out_shape_and_type->mutable_type() = p.type;
111 }
112 }
113 string result;
114 handle_data.SerializeToString(&result);
115 return result;
116 }
117
SetHandleShapeAndType(TF_Graph * graph,TF_Output output,const void * proto,size_t proto_len,TF_Status * status)118 void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
119 size_t proto_len, TF_Status* status) {
120 tensorflow::CppShapeInferenceResult::HandleData handle_data;
121 if (!handle_data.ParseFromArray(proto, proto_len)) {
122 status->status = tensorflow::errors::InvalidArgument(
123 "Couldn't deserialize HandleData proto");
124 return;
125 }
126 DCHECK(handle_data.is_set());
127
128 tensorflow::mutex_lock l(graph->mu);
129 tensorflow::shape_inference::InferenceContext* ic =
130 graph->refiner.GetContext(&output.oper->node);
131
132 std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
133 for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
134 tensorflow::shape_inference::ShapeHandle shape;
135 status->status =
136 ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
137 if (TF_GetCode(status) != TF_OK) return;
138 shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
139 shape_and_type_proto.type());
140 }
141 ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
142 }
143
AddWhileInputHack(TF_Graph * graph,TF_Output new_src,TF_Operation * dst,TF_Status * status)144 void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
145 TF_Status* status) {
146 mutex_lock l(graph->mu);
147 status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
148 new_src.index, &dst->node);
149 if (TF_GetCode(status) == TF_OK) {
150 // This modification only updates the destination node for
151 // the purposes of running this graph in a session. Thus, we don't
152 // record the source node as being modified.
153 RecordMutation(graph, *dst, "adding input tensor");
154 }
155 }
156
157 } // namespace tensorflow
158