• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/python_api.h"
17 
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
20 
21 namespace tensorflow {
22 
AddControlInput(TF_Graph * graph,TF_Operation * op,TF_Operation * input)23 void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
24   mutex_lock l(graph->mu);
25   graph->graph.AddControlEdge(&input->node, &op->node);
26   RecordMutation(graph, *op, "adding control input");
27 }
28 
SetAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Buffer * attr_value_proto,TF_Status * status)29 void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
30              TF_Buffer* attr_value_proto, TF_Status* status) {
31   AttrValue attr_val;
32   if (!attr_val.ParseFromArray(attr_value_proto->data,
33                                attr_value_proto->length)) {
34     status->status =
35         tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
36     return;
37   }
38 
39   mutex_lock l(graph->mu);
40   op->node.AddAttr(attr_name, attr_val);
41   RecordMutation(graph, *op, "setting attribute");
42 }
43 
ClearAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Status * status)44 void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
45                TF_Status* status) {
46 
47   mutex_lock l(graph->mu);
48   op->node.ClearAttr(attr_name);
49   RecordMutation(graph, *op, "clearing attribute");
50 }
51 
SetRequestedDevice(TF_Graph * graph,TF_Operation * op,const char * device)52 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
53   mutex_lock l(graph->mu);
54   op->node.set_requested_device(device);
55   RecordMutation(graph, *op, "setting device");
56 }
57 
UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)58 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
59                 TF_Status* status) {
60   mutex_lock l(graph->mu);
61   tensorflow::shape_inference::InferenceContext* ic =
62       graph->refiner.GetContext(&new_src.oper->node);
63 
64   if (ic->num_outputs() <= new_src.index) {
65     status->status = tensorflow::errors::OutOfRange(
66         "Cannot update edge. Output index [", new_src.index,
67         "] is greater than the number of total outputs [", ic->num_outputs(),
68         "].");
69     return;
70   }
71   tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
72 
73   tensorflow::shape_inference::InferenceContext* ic_dst =
74       graph->refiner.GetContext(&dst.oper->node);
75   if (ic_dst->num_inputs() <= dst.index) {
76     status->status = tensorflow::errors::OutOfRange(
77         "Cannot update edge. Input index [", dst.index,
78         "] is greater than the number of total inputs [", ic_dst->num_inputs(),
79         "].");
80     return;
81   }
82   if (!ic_dst->MergeInput(dst.index, shape)) {
83     status->status = tensorflow::errors::InvalidArgument(
84         "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
85         " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
86     return;
87   }
88   status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
89                                            &dst.oper->node, dst.index);
90 
91   if (TF_GetCode(status) == TF_OK) {
92     // This modification only updates the destination node for
93     // the purposes of running this graph in a session. Thus, we don't
94     // record the source node as being modified.
95     RecordMutation(graph, *dst.oper, "updating input tensor");
96   }
97 }
98 
RemoveAllControlInputs(TF_Graph * graph,TF_Operation * op)99 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
100   mutex_lock l(graph->mu);
101   std::vector<const Edge*> control_edges;
102   for (const Edge* edge : op->node.in_edges()) {
103     if (!edge->IsControlEdge()) continue;
104     control_edges.push_back(edge);
105   }
106   for (const Edge* edge : control_edges) {
107     graph->graph.RemoveControlEdge(edge);
108   }
109 }
110 
SetRequireShapeInferenceFns(TF_Graph * graph,bool require)111 void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
112   mutex_lock l(graph->mu);
113   graph->refiner.set_require_shape_inference_fns(require);
114 }
115 
ExtendSession(TF_Session * session,TF_Status * status)116 void ExtendSession(TF_Session* session, TF_Status* status) {
117   ExtendSessionGraphHelper(session, status);
118   session->extend_before_run = false;
119 }
120 
GetHandleShapeAndType(TF_Graph * graph,TF_Output output)121 std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
122   Node* node = &output.oper->node;
123   CppShapeInferenceResult::HandleData handle_data;
124   handle_data.set_is_set(true);
125   {
126     mutex_lock l(graph->mu);
127     tensorflow::shape_inference::InferenceContext* ic =
128         graph->refiner.GetContext(node);
129     CHECK(ic != nullptr);
130     CHECK_LT(output.index, ic->num_outputs());
131     const auto* shapes_and_types =
132         ic->output_handle_shapes_and_types(output.index);
133     if (shapes_and_types == nullptr) return "";
134 
135     for (const auto& p : *shapes_and_types) {
136       auto* out_shape_and_type = handle_data.add_shape_and_type();
137       ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
138       out_shape_and_type->set_dtype(p.dtype);
139     }
140   }
141   string result;
142   handle_data.SerializeToString(&result);
143   return result;
144 }
145 
SetHandleShapeAndType(TF_Graph * graph,TF_Output output,const void * proto,size_t proto_len,TF_Status * status)146 void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
147                            size_t proto_len, TF_Status* status) {
148   tensorflow::CppShapeInferenceResult::HandleData handle_data;
149   if (!handle_data.ParseFromArray(proto, proto_len)) {
150     status->status = tensorflow::errors::InvalidArgument(
151         "Couldn't deserialize HandleData proto");
152     return;
153   }
154   DCHECK(handle_data.is_set());
155 
156   tensorflow::mutex_lock l(graph->mu);
157   tensorflow::shape_inference::InferenceContext* ic =
158       graph->refiner.GetContext(&output.oper->node);
159 
160   std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
161   for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
162     tensorflow::shape_inference::ShapeHandle shape;
163     status->status =
164         ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
165     if (TF_GetCode(status) != TF_OK) return;
166     shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
167   }
168   ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
169 }
170 
AddWhileInputHack(TF_Graph * graph,TF_Output new_src,TF_Operation * dst,TF_Status * status)171 void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
172                        TF_Status* status) {
173   mutex_lock l(graph->mu);
174   status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
175                                                   new_src.index, &dst->node);
176   if (TF_GetCode(status) == TF_OK) {
177     // This modification only updates the destination node for
178     // the purposes of running this graph in a session. Thus, we don't
179     // record the source node as being modified.
180     RecordMutation(graph, *dst, "adding input tensor");
181   }
182 }
183 
184 }  // namespace tensorflow
185