• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
17 #define TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
18 
19 #include "tensorflow/core/graph/graph.h"
20 #include "tensorflow/core/lib/core/status.h"
21 #include "tensorflow/core/lib/gtl/array_slice.h"
22 
23 namespace tensorflow {
24 
25 // Represents the output of 'node' at 'index'.
26 struct NodeOut {
27   Node* node;
28   int index;
29 
30   // Returns the string name that represents the output of this node.
31   string name() const;
32   // Returns the data type of the output of this node.
33   DataType dtype() const;
34 };
35 
36 // NOTE: This API is a work in progress and will likely be changing frequently.
37 //
38 // Given initial gradient-node outputs 'y_grad_node_outputs' (which compute the
39 // symbolic partial derivatives of some loss function 'L' w.r.t the node outputs
40 // 'y_node_outputs'), adds gradient nodes to 'graph' that compute the symbolic
41 // partial derivatives of 'L' w.r.t the node outputs 'x_node_outputs'.
42 //
43 // REQUIRES: Each node in 'x_node_outputs' to be unique, and so to have a single
44 // output (this restriction will be removed in a subsequent change).
45 
46 // TODO(andydavis) Add symbolic gradient support for general graphs (the current
47 // implementation only supports gradients for functions). In particular,
48 // the nodes in 'x_nodes' are currently restricted to have one output.
49 
50 Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
51                             gtl::ArraySlice<NodeOut> x_node_outputs,
52                             gtl::ArraySlice<NodeOut> y_grad_node_outputs,
53                             std::vector<NodeOut>* x_grad_node_outputs,
54                             Graph* graph);
55 
56 }  // namespace tensorflow
57 
58 #endif  // TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
59