• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/python_api.h"
17 
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
20 
21 namespace tensorflow {
22 
AddControlInput(TF_Graph * graph,TF_Operation * op,TF_Operation * input)23 void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
24   mutex_lock l(graph->mu);
25   graph->graph.AddControlEdge(&input->node, &op->node);
26   RecordMutation(graph, *op, "adding control input");
27 }
28 
SetAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Buffer * attr_value_proto,TF_Status * status)29 void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
30              TF_Buffer* attr_value_proto, TF_Status* status) {
31   AttrValue attr_val;
32   if (!attr_val.ParseFromArray(attr_value_proto->data,
33                                attr_value_proto->length)) {
34     status->status =
35         tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
36     return;
37   }
38 
39   mutex_lock l(graph->mu);
40   op->node.AddAttr(attr_name, attr_val);
41   RecordMutation(graph, *op, "setting attribute");
42 }
43 
ClearAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Status * status)44 void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
45                TF_Status* status) {
46 
47   mutex_lock l(graph->mu);
48   op->node.ClearAttr(attr_name);
49   RecordMutation(graph, *op, "clearing attribute");
50 }
51 
SetRequestedDevice(TF_Graph * graph,TF_Operation * op,const char * device)52 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
53   mutex_lock l(graph->mu);
54   op->node.set_requested_device(device);
55   RecordMutation(graph, *op, "setting device");
56 }
57 
UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)58 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
59                 TF_Status* status) {
60   TF_UpdateEdge(graph, new_src, dst, status);
61 }
62 
RemoveAllControlInputs(TF_Graph * graph,TF_Operation * op)63 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
64   mutex_lock l(graph->mu);
65   std::vector<const Edge*> control_edges;
66   for (const Edge* edge : op->node.in_edges()) {
67     if (!edge->IsControlEdge()) continue;
68     control_edges.push_back(edge);
69   }
70   for (const Edge* edge : control_edges) {
71     graph->graph.RemoveControlEdge(edge);
72   }
73 }
74 
SetRequireShapeInferenceFns(TF_Graph * graph,bool require)75 void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
76   mutex_lock l(graph->mu);
77   graph->refiner.set_require_shape_inference_fns(require);
78 }
79 
ExtendSession(TF_Session * session,TF_Status * status)80 void ExtendSession(TF_Session* session, TF_Status* status) {
81   ExtendSessionGraphHelper(session, status);
82   session->extend_before_run = false;
83 }
84 
GetHandleShapeAndType(TF_Graph * graph,TF_Output output)85 std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
86   Node* node = &output.oper->node;
87   CppShapeInferenceResult::HandleData handle_data;
88   handle_data.set_is_set(true);
89   {
90     mutex_lock l(graph->mu);
91     tensorflow::shape_inference::InferenceContext* ic =
92         graph->refiner.GetContext(node);
93     CHECK(ic != nullptr);
94     CHECK_LT(output.index, ic->num_outputs());
95     const auto* shapes_and_types =
96         ic->output_handle_shapes_and_types(output.index);
97     if (shapes_and_types == nullptr) return "";
98 
99     for (const auto& p : *shapes_and_types) {
100       auto* out_shape_and_type = handle_data.add_shape_and_type();
101       ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
102       out_shape_and_type->set_dtype(p.dtype);
103       out_shape_and_type->set_specialized_type(p.specialized_type);
104     }
105   }
106   string result;
107   handle_data.SerializeToString(&result);
108   return result;
109 }
110 
SetHandleShapeAndType(TF_Graph * graph,TF_Output output,const void * proto,size_t proto_len,TF_Status * status)111 void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
112                            size_t proto_len, TF_Status* status) {
113   tensorflow::CppShapeInferenceResult::HandleData handle_data;
114   if (!handle_data.ParseFromArray(proto, proto_len)) {
115     status->status = tensorflow::errors::InvalidArgument(
116         "Couldn't deserialize HandleData proto");
117     return;
118   }
119   DCHECK(handle_data.is_set());
120 
121   tensorflow::mutex_lock l(graph->mu);
122   tensorflow::shape_inference::InferenceContext* ic =
123       graph->refiner.GetContext(&output.oper->node);
124 
125   std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
126   for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
127     tensorflow::shape_inference::ShapeHandle shape;
128     status->status =
129         ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
130     if (TF_GetCode(status) != TF_OK) return;
131     shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
132                                   shape_and_type_proto.specialized_type());
133   }
134   ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
135 }
136 
AddWhileInputHack(TF_Graph * graph,TF_Output new_src,TF_Operation * dst,TF_Status * status)137 void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
138                        TF_Status* status) {
139   mutex_lock l(graph->mu);
140   status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
141                                                   new_src.index, &dst->node);
142   if (TF_GetCode(status) == TF_OK) {
143     // This modification only updates the destination node for
144     // the purposes of running this graph in a session. Thus, we don't
145     // record the source node as being modified.
146     RecordMutation(graph, *dst, "adding input tensor");
147   }
148 }
149 
150 }  // namespace tensorflow
151