1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/python_api.h"
17
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
20
21 namespace tensorflow {
22
AddControlInput(TF_Graph * graph,TF_Operation * op,TF_Operation * input)23 void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
24 mutex_lock l(graph->mu);
25 graph->graph.AddControlEdge(&input->node, &op->node);
26 RecordMutation(graph, *op, "adding control input");
27 }
28
SetAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Buffer * attr_value_proto,TF_Status * status)29 void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
30 TF_Buffer* attr_value_proto, TF_Status* status) {
31 AttrValue attr_val;
32 if (!attr_val.ParseFromArray(attr_value_proto->data,
33 attr_value_proto->length)) {
34 status->status =
35 tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
36 return;
37 }
38
39 mutex_lock l(graph->mu);
40 op->node.AddAttr(attr_name, attr_val);
41 RecordMutation(graph, *op, "setting attribute");
42 }
43
ClearAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Status * status)44 void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
45 TF_Status* status) {
46 AttrValue attr_val;
47
48 mutex_lock l(graph->mu);
49 op->node.ClearAttr(attr_name);
50 RecordMutation(graph, *op, "clearing attribute");
51 }
52
SetRequestedDevice(TF_Graph * graph,TF_Operation * op,const char * device)53 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
54 mutex_lock l(graph->mu);
55 op->node.set_requested_device(device);
56 RecordMutation(graph, *op, "setting device");
57 }
58
UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)59 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
60 TF_Status* status) {
61 mutex_lock l(graph->mu);
62 tensorflow::shape_inference::InferenceContext* ic =
63 graph->refiner.GetContext(&new_src.oper->node);
64
65 if (ic->num_outputs() <= new_src.index) {
66 status->status = tensorflow::errors::OutOfRange(
67 "Cannot update edge. Output index [", new_src.index,
68 "] is greater than the number of total outputs [", ic->num_outputs(),
69 "].");
70 return;
71 }
72 tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
73
74 tensorflow::shape_inference::InferenceContext* ic_dst =
75 graph->refiner.GetContext(&dst.oper->node);
76 if (ic_dst->num_inputs() <= dst.index) {
77 status->status = tensorflow::errors::OutOfRange(
78 "Cannot update edge. Input index [", dst.index,
79 "] is greater than the number of total inputs [", ic_dst->num_inputs(),
80 "].");
81 return;
82 }
83 if (!ic_dst->MergeInput(dst.index, shape)) {
84 status->status = tensorflow::errors::InvalidArgument(
85 "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
86 " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
87 return;
88 }
89 status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
90 &dst.oper->node, dst.index);
91
92 if (TF_GetCode(status) == TF_OK) {
93 // This modification only updates the destination node for
94 // the purposes of running this graph in a session. Thus, we don't
95 // record the source node as being modified.
96 RecordMutation(graph, *dst.oper, "updating input tensor");
97 }
98 }
99
RemoveAllControlInputs(TF_Graph * graph,TF_Operation * op)100 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
101 mutex_lock l(graph->mu);
102 std::vector<const Edge*> control_edges;
103 for (const Edge* edge : op->node.in_edges()) {
104 if (!edge->IsControlEdge()) continue;
105 control_edges.push_back(edge);
106 }
107 for (const Edge* edge : control_edges) {
108 graph->graph.RemoveControlEdge(edge);
109 }
110 }
111
SetRequireShapeInferenceFns(TF_Graph * graph,bool require)112 void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
113 mutex_lock l(graph->mu);
114 graph->refiner.set_require_shape_inference_fns(require);
115 }
116
ExtendSession(TF_Session * session,TF_Status * status)117 void ExtendSession(TF_Session* session, TF_Status* status) {
118 ExtendSessionGraphHelper(session, status);
119 session->extend_before_run = false;
120 }
121
GetHandleShapeAndType(TF_Graph * graph,TF_Output output)122 std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
123 Node* node = &output.oper->node;
124 CppShapeInferenceResult::HandleData handle_data;
125 handle_data.set_is_set(true);
126 {
127 mutex_lock l(graph->mu);
128 tensorflow::shape_inference::InferenceContext* ic =
129 graph->refiner.GetContext(node);
130 CHECK(ic != nullptr);
131 CHECK_LT(output.index, ic->num_outputs());
132 const auto* shapes_and_types =
133 ic->output_handle_shapes_and_types(output.index);
134 if (shapes_and_types == nullptr) return "";
135
136 for (const auto& p : *shapes_and_types) {
137 auto* out_shape_and_type = handle_data.add_shape_and_type();
138 ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
139 out_shape_and_type->set_dtype(p.dtype);
140 }
141 }
142 string result;
143 handle_data.SerializeToString(&result);
144 return result;
145 }
146
SetHandleShapeAndType(TF_Graph * graph,TF_Output output,const void * proto,size_t proto_len,TF_Status * status)147 void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
148 size_t proto_len, TF_Status* status) {
149 tensorflow::CppShapeInferenceResult::HandleData handle_data;
150 if (!handle_data.ParseFromArray(proto, proto_len)) {
151 status->status = tensorflow::errors::InvalidArgument(
152 "Couldn't deserialize HandleData proto");
153 return;
154 }
155 DCHECK(handle_data.is_set());
156
157 tensorflow::mutex_lock l(graph->mu);
158 tensorflow::shape_inference::InferenceContext* ic =
159 graph->refiner.GetContext(&output.oper->node);
160
161 std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
162 for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
163 tensorflow::shape_inference::ShapeHandle shape;
164 status->status =
165 ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
166 if (TF_GetCode(status) != TF_OK) return;
167 shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
168 }
169 ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
170 }
171
AddWhileInputHack(TF_Graph * graph,TF_Output new_src,TF_Operation * dst,TF_Status * status)172 void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
173 TF_Status* status) {
174 mutex_lock l(graph->mu);
175 status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
176 new_src.index, &dst->node);
177 if (TF_GetCode(status) == TF_OK) {
178 // This modification only updates the destination node for
179 // the purposes of running this graph in a session. Thus, we don't
180 // record the source node as being modified.
181 RecordMutation(graph, *dst, "adding input tensor");
182 }
183 }
184
185 } // namespace tensorflow
186