• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/mnist_gradients_testutil.h"
16 
17 #include <memory>
18 
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/c/eager/abstract_tensor_handle.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api_unified_experimental.h"
24 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
25 #include "tensorflow/c/eager/gradients.h"
26 #include "tensorflow/c/eager/gradients_internal.h"
27 #include "tensorflow/c/eager/gradients_util.h"
28 #include "tensorflow/c/experimental/gradients/tape/tape_context.h"
29 #include "tensorflow/c/experimental/ops/array_ops.h"
30 #include "tensorflow/c/experimental/ops/math_ops.h"
31 #include "tensorflow/c/experimental/ops/nn_ops.h"
32 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
33 
34 namespace tensorflow {
35 namespace gradients {
36 namespace internal {
37 
38 using std::vector;
39 
40 //===================== Test Models to run =========================
41 
42 // Computes
43 // y = inputs[0] + inputs[1]
44 // return grad(y, {inputs[0], inputs[1]})
AddGradModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)45 Status AddGradModel(AbstractContext* ctx,
46                     absl::Span<AbstractTensorHandle* const> inputs,
47                     absl::Span<AbstractTensorHandle*> outputs,
48                     const GradientRegistry& registry) {
49   auto tape = new Tape(/*persistent=*/false);
50   tape->Watch(inputs[0]);  // Watch x.
51   tape->Watch(inputs[1]);  // Watch y.
52   std::vector<AbstractTensorHandle*> add_outputs(1);
53   AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
54   TF_RETURN_IF_ERROR(
55       ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
56   TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
57                                            /*sources=*/inputs,
58                                            /*output_gradients=*/{}, outputs));
59   for (auto add_output : add_outputs) {
60     add_output->Unref();
61   }
62   delete tape;
63   return Status::OK();
64 }
65 
66 // Computes
67 // y = inputs[0] * inputs[1]
68 // return grad(y, {inputs[0], inputs[1]})
MatMulGradModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)69 Status MatMulGradModel(AbstractContext* ctx,
70                        absl::Span<AbstractTensorHandle* const> inputs,
71                        absl::Span<AbstractTensorHandle*> outputs,
72                        const GradientRegistry& registry) {
73   auto tape = new Tape(/*persistent=*/false);
74   tape->Watch(inputs[0]);  // Watch x.
75   tape->Watch(inputs[1]);  // Watch y.
76   vector<AbstractTensorHandle*> mm_outputs(1);
77   AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
78   TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
79                                  absl::MakeSpan(mm_outputs), "matmul0",
80                                  /*transpose_a=*/false,
81                                  /*transpose_b=*/false));  // Compute x*y.
82 
83   TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mm_outputs,
84                                            /*sources=*/inputs,
85                                            /*output_gradients=*/{}, outputs));
86   for (auto mm_output : mm_outputs) {
87     mm_output->Unref();
88   }
89   delete tape;
90   return Status::OK();
91 }
92 
93 // Model to run 2-layer net
MNISTForwardModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)94 Status MNISTForwardModel(AbstractContext* ctx,
95                          absl::Span<AbstractTensorHandle* const> inputs,
96                          absl::Span<AbstractTensorHandle*> outputs,
97                          const GradientRegistry& registry) {
98   /**
99    * We will trace a 2-layer fully connected network for an MNIST model:
100    *
101    *   def mnist_forward(X, W1, W2, y_labels):
102    *     mm_out_1 = tf.matmul(X,W1)
103    *     hidden_layer = tf.nn.relu(mm_out_1)
104    *     scores = tf.matmul(hidden_layer,W2)
105    *     softmax =
106    *        tf.nn.sparse_softmax_cross_entropy_with_logits(scores,
107    *                                                       y_labels)
108    *     return scores, softmax
109    *
110    * Use this convention for inputs:
111    *
112    *   inputs = [X, W1, W2, y_labels]
113    *
114    */
115   AbstractTensorHandle* X = inputs[0];
116   AbstractTensorHandle* W1 = inputs[1];
117   AbstractTensorHandle* W2 = inputs[2];
118   AbstractTensorHandle* y_labels = inputs[3];
119 
120   vector<AbstractTensorHandle*> temp_outputs(1);
121 
122   TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, absl::MakeSpan(temp_outputs),
123                                  "matmul0",
124                                  /*transpose_a=*/false,
125                                  /*transpose_b=*/false));  // Compute X*W1
126 
127   TF_RETURN_IF_ERROR(ops::Relu(ctx, {temp_outputs[0]},
128                                absl::MakeSpan(temp_outputs),
129                                "relu"));  // Compute Relu(X*W1)
130 
131   TF_RETURN_IF_ERROR(ops::MatMul(
132       ctx, {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), "matmul1",
133       /*transpose_a=*/false, /*transpose_b=*/false));  // Compute W2*Relu(X*W1)
134 
135   AbstractTensorHandle* scores = temp_outputs[0];
136 
137   temp_outputs.resize(2);
138   TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
139       ctx, {scores, y_labels}, absl::MakeSpan(temp_outputs),
140       "softmax_loss"));  // Compute Softmax(Scores,labels)
141 
142   AbstractTensorHandle* loss_vals = temp_outputs[0];
143 
144   outputs[0] = scores;
145   outputs[1] = loss_vals;
146   return Status::OK();
147 }
148 
MatMulTransposeModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)149 Status MatMulTransposeModel(AbstractContext* ctx,
150                             absl::Span<AbstractTensorHandle* const> inputs,
151                             absl::Span<AbstractTensorHandle*> outputs,
152                             const GradientRegistry& registry) {
153   AbstractTensorHandle* X = inputs[0];
154   AbstractTensorHandle* W1 = inputs[1];
155 
156   TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, outputs, "matmul0",
157                                  /*transpose_a=*/true,
158                                  /*transpose_b=*/false));  // Compute X*W1
159   return Status::OK();
160 }
161 
MNISTGradModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)162 Status MNISTGradModel(AbstractContext* ctx,
163                       absl::Span<AbstractTensorHandle* const> inputs,
164                       absl::Span<AbstractTensorHandle*> outputs,
165                       const GradientRegistry& registry) {
166   AbstractTensorHandle* X = inputs[0];
167   AbstractTensorHandle* W1 = inputs[1];
168   AbstractTensorHandle* W2 = inputs[2];
169   AbstractTensorHandle* y_labels = inputs[3];
170 
171   auto tape = new Tape(/*persistent=*/true);
172   tape->Watch(X);   // Watch X.
173   tape->Watch(W1);  // Watch W1.
174   tape->Watch(W2);  // Watch W1.
175   vector<AbstractTensorHandle*> temp_outputs(1);
176   AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
177   TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
178                                  absl::MakeSpan(temp_outputs), "matmul0",
179                                  /*transpose_a=*/false,
180                                  /*transpose_b=*/false));  // Compute X*W1
181 
182   AbstractTensorHandle* mm = temp_outputs[0];
183 
184   TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
185                                absl::MakeSpan(temp_outputs),  // Relu(X*W1)
186                                "relu0"));
187 
188   AbstractTensorHandle* hidden = temp_outputs[0];
189 
190   TF_RETURN_IF_ERROR(ops::MatMul(
191       tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
192       /*transpose_a=*/false, /*transpose_b=*/false));  // W2*Relu(X*W1)
193 
194   AbstractTensorHandle* scores = temp_outputs[0];
195 
196   temp_outputs.resize(2);
197   TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
198       tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
199       "softmaxloss"));  // W2*Relu(X*W1)
200 
201   AbstractTensorHandle* loss = temp_outputs[0];
202 
203   TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/{loss},
204                                            /*sources=*/{W1, W2},
205                                            /*output_gradients=*/{},
206                                            outputs.subspan(0, 2)));
207 
208   // Only release 2nd temp output as first holds loss values.
209   temp_outputs[1]->Unref();
210   outputs[2] = loss;
211   delete tape;
212   return Status::OK();
213 }
214 
ScalarMulModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)215 Status ScalarMulModel(AbstractContext* ctx,
216                       absl::Span<AbstractTensorHandle* const> inputs,
217                       absl::Span<AbstractTensorHandle*> outputs,
218                       const GradientRegistry& registry) {
219   return ops::Mul(ctx, inputs, outputs,
220                   "scalarMul0");  // Compute eta*A
221 }
222 
MatMulModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)223 Status MatMulModel(AbstractContext* ctx,
224                    absl::Span<AbstractTensorHandle* const> inputs,
225                    absl::Span<AbstractTensorHandle*> outputs,
226                    const GradientRegistry& registry) {
227   return ops::MatMul(ctx, inputs, outputs, "matmul0",
228                      /*transpose_a=*/false,
229                      /*transpose_b=*/false);  // Compute X*W1
230 }
231 
MulModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)232 Status MulModel(AbstractContext* ctx,
233                 absl::Span<AbstractTensorHandle* const> inputs,
234                 absl::Span<AbstractTensorHandle*> outputs,
235                 const GradientRegistry& registry) {
236   return ops::Mul(ctx, inputs, outputs,
237                   "mul0");  // Compute x*y
238 }
239 
240 // ============================= End Models ================================
241 
242 }  // namespace internal
243 }  // namespace gradients
244 }  // namespace tensorflow
245