1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/python_api.h"
17
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
20
21 namespace tensorflow {
22
AddControlInput(TF_Graph * graph,TF_Operation * op,TF_Operation * input)23 void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
24 mutex_lock l(graph->mu);
25 graph->graph.AddControlEdge(&input->node, &op->node);
26 RecordMutation(graph, *op, "adding control input");
27 }
28
SetAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Buffer * attr_value_proto,TF_Status * status)29 void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
30 TF_Buffer* attr_value_proto, TF_Status* status) {
31 AttrValue attr_val;
32 if (!attr_val.ParseFromArray(attr_value_proto->data,
33 attr_value_proto->length)) {
34 status->status =
35 tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
36 return;
37 }
38
39 mutex_lock l(graph->mu);
40 op->node.AddAttr(attr_name, attr_val);
41 RecordMutation(graph, *op, "setting attribute");
42 }
43
ClearAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Status * status)44 void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
45 TF_Status* status) {
46
47 mutex_lock l(graph->mu);
48 op->node.ClearAttr(attr_name);
49 RecordMutation(graph, *op, "clearing attribute");
50 }
51
SetRequestedDevice(TF_Graph * graph,TF_Operation * op,const char * device)52 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
53 mutex_lock l(graph->mu);
54 op->node.set_requested_device(device);
55 RecordMutation(graph, *op, "setting device");
56 }
57
UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)58 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
59 TF_Status* status) {
60 TF_UpdateEdge(graph, new_src, dst, status);
61 }
62
RemoveAllControlInputs(TF_Graph * graph,TF_Operation * op)63 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
64 mutex_lock l(graph->mu);
65 std::vector<const Edge*> control_edges;
66 for (const Edge* edge : op->node.in_edges()) {
67 if (!edge->IsControlEdge()) continue;
68 control_edges.push_back(edge);
69 }
70 for (const Edge* edge : control_edges) {
71 graph->graph.RemoveControlEdge(edge);
72 }
73 }
74
SetRequireShapeInferenceFns(TF_Graph * graph,bool require)75 void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
76 mutex_lock l(graph->mu);
77 graph->refiner.set_require_shape_inference_fns(require);
78 }
79
ExtendSession(TF_Session * session,TF_Status * status)80 void ExtendSession(TF_Session* session, TF_Status* status) {
81 ExtendSessionGraphHelper(session, status);
82 session->extend_before_run = false;
83 }
84
GetHandleShapeAndType(TF_Graph * graph,TF_Output output)85 std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
86 Node* node = &output.oper->node;
87 CppShapeInferenceResult::HandleData handle_data;
88 handle_data.set_is_set(true);
89 {
90 mutex_lock l(graph->mu);
91 tensorflow::shape_inference::InferenceContext* ic =
92 graph->refiner.GetContext(node);
93 CHECK(ic != nullptr);
94 CHECK_LT(output.index, ic->num_outputs());
95 const auto* shapes_and_types =
96 ic->output_handle_shapes_and_types(output.index);
97 if (shapes_and_types == nullptr) return "";
98
99 for (const auto& p : *shapes_and_types) {
100 auto* out_shape_and_type = handle_data.add_shape_and_type();
101 ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
102 out_shape_and_type->set_dtype(p.dtype);
103 out_shape_and_type->set_specialized_type(p.specialized_type);
104 }
105 }
106 string result;
107 handle_data.SerializeToString(&result);
108 return result;
109 }
110
SetHandleShapeAndType(TF_Graph * graph,TF_Output output,const void * proto,size_t proto_len,TF_Status * status)111 void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
112 size_t proto_len, TF_Status* status) {
113 tensorflow::CppShapeInferenceResult::HandleData handle_data;
114 if (!handle_data.ParseFromArray(proto, proto_len)) {
115 status->status = tensorflow::errors::InvalidArgument(
116 "Couldn't deserialize HandleData proto");
117 return;
118 }
119 DCHECK(handle_data.is_set());
120
121 tensorflow::mutex_lock l(graph->mu);
122 tensorflow::shape_inference::InferenceContext* ic =
123 graph->refiner.GetContext(&output.oper->node);
124
125 std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
126 for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
127 tensorflow::shape_inference::ShapeHandle shape;
128 status->status =
129 ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
130 if (TF_GetCode(status) != TF_OK) return;
131 shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
132 shape_and_type_proto.specialized_type());
133 }
134 ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
135 }
136
AddWhileInputHack(TF_Graph * graph,TF_Output new_src,TF_Operation * dst,TF_Status * status)137 void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
138 TF_Status* status) {
139 mutex_lock l(graph->mu);
140 status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
141 new_src.index, &dst->node);
142 if (TF_GetCode(status) == TF_OK) {
143 // This modification only updates the destination node for
144 // the purposes of running this graph in a session. Thus, we don't
145 // record the source node as being modified.
146 RecordMutation(graph, *dst, "adding input tensor");
147 }
148 }
149
150 } // namespace tensorflow
151