1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/python_api.h"
17
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
20
21 namespace tensorflow {
22
AddControlInput(TF_Graph * graph,TF_Operation * op,TF_Operation * input)23 void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
24 mutex_lock l(graph->mu);
25 graph->graph.AddControlEdge(&input->node, &op->node);
26 RecordMutation(graph, *op, "adding control input");
27 }
28
SetAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Buffer * attr_value_proto,TF_Status * status)29 void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
30 TF_Buffer* attr_value_proto, TF_Status* status) {
31 AttrValue attr_val;
32 if (!attr_val.ParseFromArray(attr_value_proto->data,
33 attr_value_proto->length)) {
34 status->status =
35 tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
36 return;
37 }
38
39 mutex_lock l(graph->mu);
40 op->node.AddAttr(attr_name, attr_val);
41 RecordMutation(graph, *op, "setting attribute");
42 }
43
ClearAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Status * status)44 void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
45 TF_Status* status) {
46
47 mutex_lock l(graph->mu);
48 op->node.ClearAttr(attr_name);
49 RecordMutation(graph, *op, "clearing attribute");
50 }
51
SetRequestedDevice(TF_Graph * graph,TF_Operation * op,const char * device)52 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
53 mutex_lock l(graph->mu);
54 op->node.set_requested_device(device);
55 RecordMutation(graph, *op, "setting device");
56 }
57
UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)58 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
59 TF_Status* status) {
60 mutex_lock l(graph->mu);
61 tensorflow::shape_inference::InferenceContext* ic =
62 graph->refiner.GetContext(&new_src.oper->node);
63
64 if (ic->num_outputs() <= new_src.index) {
65 status->status = tensorflow::errors::OutOfRange(
66 "Cannot update edge. Output index [", new_src.index,
67 "] is greater than the number of total outputs [", ic->num_outputs(),
68 "].");
69 return;
70 }
71 tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
72
73 tensorflow::shape_inference::InferenceContext* ic_dst =
74 graph->refiner.GetContext(&dst.oper->node);
75 if (ic_dst->num_inputs() <= dst.index) {
76 status->status = tensorflow::errors::OutOfRange(
77 "Cannot update edge. Input index [", dst.index,
78 "] is greater than the number of total inputs [", ic_dst->num_inputs(),
79 "].");
80 return;
81 }
82 if (!ic_dst->MergeInput(dst.index, shape)) {
83 status->status = tensorflow::errors::InvalidArgument(
84 "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
85 " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
86 return;
87 }
88 status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
89 &dst.oper->node, dst.index);
90
91 if (TF_GetCode(status) == TF_OK) {
92 // This modification only updates the destination node for
93 // the purposes of running this graph in a session. Thus, we don't
94 // record the source node as being modified.
95 RecordMutation(graph, *dst.oper, "updating input tensor");
96 }
97 }
98
RemoveAllControlInputs(TF_Graph * graph,TF_Operation * op)99 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
100 mutex_lock l(graph->mu);
101 std::vector<const Edge*> control_edges;
102 for (const Edge* edge : op->node.in_edges()) {
103 if (!edge->IsControlEdge()) continue;
104 control_edges.push_back(edge);
105 }
106 for (const Edge* edge : control_edges) {
107 graph->graph.RemoveControlEdge(edge);
108 }
109 }
110
SetRequireShapeInferenceFns(TF_Graph * graph,bool require)111 void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
112 mutex_lock l(graph->mu);
113 graph->refiner.set_require_shape_inference_fns(require);
114 }
115
ExtendSession(TF_Session * session,TF_Status * status)116 void ExtendSession(TF_Session* session, TF_Status* status) {
117 ExtendSessionGraphHelper(session, status);
118 session->extend_before_run = false;
119 }
120
GetHandleShapeAndType(TF_Graph * graph,TF_Output output)121 std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
122 Node* node = &output.oper->node;
123 CppShapeInferenceResult::HandleData handle_data;
124 handle_data.set_is_set(true);
125 {
126 mutex_lock l(graph->mu);
127 tensorflow::shape_inference::InferenceContext* ic =
128 graph->refiner.GetContext(node);
129 CHECK(ic != nullptr);
130 CHECK_LT(output.index, ic->num_outputs());
131 const auto* shapes_and_types =
132 ic->output_handle_shapes_and_types(output.index);
133 if (shapes_and_types == nullptr) return "";
134
135 for (const auto& p : *shapes_and_types) {
136 auto* out_shape_and_type = handle_data.add_shape_and_type();
137 ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
138 out_shape_and_type->set_dtype(p.dtype);
139 }
140 }
141 string result;
142 handle_data.SerializeToString(&result);
143 return result;
144 }
145
SetHandleShapeAndType(TF_Graph * graph,TF_Output output,const void * proto,size_t proto_len,TF_Status * status)146 void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
147 size_t proto_len, TF_Status* status) {
148 tensorflow::CppShapeInferenceResult::HandleData handle_data;
149 if (!handle_data.ParseFromArray(proto, proto_len)) {
150 status->status = tensorflow::errors::InvalidArgument(
151 "Couldn't deserialize HandleData proto");
152 return;
153 }
154 DCHECK(handle_data.is_set());
155
156 tensorflow::mutex_lock l(graph->mu);
157 tensorflow::shape_inference::InferenceContext* ic =
158 graph->refiner.GetContext(&output.oper->node);
159
160 std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
161 for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
162 tensorflow::shape_inference::ShapeHandle shape;
163 status->status =
164 ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
165 if (TF_GetCode(status) != TF_OK) return;
166 shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
167 }
168 ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
169 }
170
AddWhileInputHack(TF_Graph * graph,TF_Output new_src,TF_Operation * dst,TF_Status * status)171 void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
172 TF_Status* status) {
173 mutex_lock l(graph->mu);
174 status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
175 new_src.index, &dst->node);
176 if (TF_GetCode(status) == TF_OK) {
177 // This modification only updates the destination node for
178 // the purposes of running this graph in a session. Thus, we don't
179 // record the source node as being modified.
180 RecordMutation(graph, *dst, "adding input tensor");
181 }
182 }
183
184 } // namespace tensorflow
185