• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/data_flow_ops.h"
17 #include "tensorflow/cc/ops/data_flow_ops_internal.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 
20 #include "tensorflow/cc/framework/grad_op_registry.h"
21 #include "tensorflow/cc/framework/gradients.h"
22 
23 namespace tensorflow {
24 namespace ops {
25 namespace {
26 
27 REGISTER_NO_GRADIENT_OP("Queue");
28 REGISTER_NO_GRADIENT_OP("QueueEnqueue");
29 REGISTER_NO_GRADIENT_OP("QueueEnqueueMany");
30 REGISTER_NO_GRADIENT_OP("QueueDequeue");
31 REGISTER_NO_GRADIENT_OP("QueueDequeueMany");
32 REGISTER_NO_GRADIENT_OP("QueueDequeueUpTo");
33 REGISTER_NO_GRADIENT_OP("QueueClose");
34 REGISTER_NO_GRADIENT_OP("QueueSize");
35 REGISTER_NO_GRADIENT_OP("Stack");
36 REGISTER_NO_GRADIENT_OP("StackPush");
37 REGISTER_NO_GRADIENT_OP("StackPop");
38 REGISTER_NO_GRADIENT_OP("StackClose");
39 REGISTER_NO_GRADIENT_OP("GetSessionHandle");
40 REGISTER_NO_GRADIENT_OP("GetSessionHandleV2");
41 REGISTER_NO_GRADIENT_OP("GetSessionTensor");
42 REGISTER_NO_GRADIENT_OP("DeleteSessionTensor");
43 
DynamicPartitionGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)44 Status DynamicPartitionGrad(const Scope& scope, const Operation& op,
45                             const std::vector<Output>& grad_inputs,
46                             std::vector<Output>* grad_outputs) {
47   // DynamicPartition only moves input values into various positions
48   // in the output, so the gradient operation only has to map incoming
49   // gradients into their input source locations.
50   // running example:
51   // data = [10, 20, 30, 40, 50]
52   // partitions = [0, 0, 1, 1, 0]
53   // num_partitions = 2
54   // dynamic_partition(data, partitions, num_partitions) = {
55   //   [10, 20, 50],
56   //   [30, 40]
57   // }
58   // grads = {
59   //   [g1, g2, g3],
60   //   [g4, g5]
61   // }
62   // The desired propagation of the gradients back to the data inputs is:
63   // [g1, g2, g4, g5, g3]
64   auto data = op.input(0);
65   auto partitions = op.input(1);
66   int32 num_partitions;
67   TF_RETURN_IF_ERROR(
68       GetNodeAttr(op.node()->attrs(), "num_partitions", &num_partitions));
69 
70   // Note: the shape of the partitions is a prefix of the data shape.
71   // shape(partitions) = [5]
72   auto partitions_shape = Shape(scope, partitions);
73   // We now create a partitions-shaped tensor with integers from
74   // [0..size(partitions)) This will be dynamic_partitioned with the
75   // input parameters, providing the destination index for a given
76   // source item.
77   // partitions_size = prod([5]) = 5
78   // reshape(range(partitions_size), [5]) = [0, 1, 2, 3, 4]
79   auto zero = Const(scope, 0);
80   auto one = Const(scope, 1);
81   auto original_indices = Reshape(
82       scope, Range(scope, zero, Prod(scope, partitions_shape, zero), one),
83       partitions_shape);
84   // dynamic_partition(
85   //   [0, 1, 2, 3, 4],
86   //   [0, 0, 1, 1, 0], 2)
87   //  = { [0, 1, 4],
88   //      [2, 3] }
89   auto partitioned_indices =
90       DynamicPartition(scope, original_indices, partitions, num_partitions);
91 
92   // Invert these indices with dynamic_stitch to map the incoming
93   // gradients to their source inputs.
94   // dynamic_stitch(
95   //   { [0, 1, 4], [2, 3] },
96   //   { [g1, g2, g3], [g4, g5] })
97   // = [g1, g2, g4, g5, g3]
98   auto reconstructed =
99       DynamicStitch(scope, partitioned_indices.outputs, grad_inputs);
100   // reshape back into a data-shaped tensor to propagate gradients for the data
101   // input.
102   grad_outputs->push_back(Reshape(scope, reconstructed, Shape(scope, data)));
103   // Stop propagation along the partitions input
104   grad_outputs->push_back(NoGradient());
105   return scope.status();
106 }
107 REGISTER_GRADIENT_OP("DynamicPartition", DynamicPartitionGrad);
108 
DynamicStitchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)109 Status DynamicStitchGrad(const Scope& scope, const Operation& op,
110                          const std::vector<Output>& grad_inputs,
111                          std::vector<Output>* grad_outputs) {
112   // Running example:
113   // indices = {2, [1, 0]}
114   // data = {[d_1, d_2], [[d_3, d_4], [d_5, d_6]]}
115   // out = [[d_5, d_6], [d_3, d_4], [d_1, d_2]]
116   // grad = [[g_1, g_2], [g_3, g_4], [g_5, g_6]]
117 
118   // indices and data are two equal-sized lists passed
119   // into DynamicStitch.
120   // num_values = 2
121   int32 num_values = op.num_inputs() / 2;
122 
123   // Stop propagation along the indices list
124   for (int32 i = 0; i < num_values; i++) {
125     grad_outputs->push_back(NoGradient());
126   }
127 
128   // DynamicStitch shuffles its data to the output (using items in
129   // indices) so the gradient propagated to a given data input simply
130   // selects the gradient for its output position.
131   for (int32 i = 0; i < num_values; i++) {
132     // index has the destination positions for the i'th data
133     // element. We cast it into an int32 if necessary, so we can use
134     // it from a Gather op.
135     // i = 0: index = 2
136     // i = 1: index = [1, 0]
137     auto index = op.input(i);
138     if (index.type() != DT_INT32) {
139       index = Cast(scope, index, DT_INT32);
140     }
141     // Gather the index specified locations in the gradient and
142     // propagate it as the gradient for the i'th data item.
143     // i = 0: gather(grad, 2) = [g_5, g_6]
144     // i = 1: gather(grad, [1, 0]) = [[g_3, g_4], [g_1, g_2]]
145     grad_outputs->push_back(Gather(scope, grad_inputs[0], index));
146   }
147 
148   return scope.status();
149 }
150 REGISTER_GRADIENT_OP("DynamicStitch", DynamicStitchGrad);
151 REGISTER_GRADIENT_OP("ParallelDynamicStitch", DynamicStitchGrad);
152 
153 }  // anonymous namespace
154 }  // namespace ops
155 }  // namespace tensorflow
156