• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.op.core;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertNotNull;
20 import static org.junit.Assert.assertTrue;
21 
22 import java.util.Arrays;
23 import org.junit.Test;
24 import org.junit.runner.RunWith;
25 import org.junit.runners.JUnit4;
26 import org.tensorflow.Graph;
27 import org.tensorflow.Output;
28 import org.tensorflow.Session;
29 import org.tensorflow.Tensor;
30 import org.tensorflow.Tensors;
31 import org.tensorflow.TestUtil;
32 import org.tensorflow.op.Scope;
33 
34 @RunWith(JUnit4.class)
35 public class GradientsTest {
36 
37   @Test
createGradients()38   public void createGradients() {
39     try (Graph g = new Graph();
40         Session sess = new Session(g)) {
41       Scope scope = new Scope(g);
42 
43       Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
44       Output<Float> y0 = TestUtil.square(g, "y0", x);
45       Output<Float> y1 = TestUtil.square(g, "y1", y0);
46 
47       Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0));
48 
49       assertNotNull(grads);
50       assertNotNull(grads.dy());
51       assertEquals(2, grads.dy().size());
52 
53       try (Tensor<Float> c = Tensors.create(3.0f);
54           TestUtil.AutoCloseableList<Tensor<?>> outputs =
55               new TestUtil.AutoCloseableList<>(
56                   sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) {
57 
58         assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
59         assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f);
60       }
61     }
62   }
63 
64   @Test
createGradientsWithSum()65   public void createGradientsWithSum() {
66     try (Graph g = new Graph();
67         Session sess = new Session(g)) {
68       Scope scope = new Scope(g);
69 
70       Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
71       Output<Float> y0 = TestUtil.square(g, "y0", x);
72       Output<Float> y1 = TestUtil.square(g, "y1", y0);
73 
74       Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x));
75 
76       assertNotNull(grads);
77       assertNotNull(grads.dy());
78       assertEquals(1, grads.dy().size());
79 
80       try (Tensor<Float> c = Tensors.create(3.0f);
81           TestUtil.AutoCloseableList<Tensor<?>> outputs =
82               new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) {
83 
84         assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f);
85       }
86     }
87   }
88 
89   @Test
createGradientsWithInitialValues()90   public void createGradientsWithInitialValues() {
91     try (Graph g = new Graph();
92         Session sess = new Session(g)) {
93       Scope scope = new Scope(g);
94 
95       Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
96       Output<Float> y0 = TestUtil.square(g, "y0", x);
97       Output<Float> y1 = TestUtil.square(g, "y1", y0);
98 
99       Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0));
100       Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy()));
101 
102       assertNotNull(grads1);
103       assertNotNull(grads1.dy());
104       assertEquals(1, grads1.dy().size());
105 
106       try (Tensor<Float> c = Tensors.create(3.0f);
107           TestUtil.AutoCloseableList<Tensor<?>> outputs =
108               new TestUtil.AutoCloseableList<>(
109                   sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) {
110 
111         assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
112       }
113     }
114   }
115 
116   @Test
validateGradientsNames()117   public void validateGradientsNames() {
118     try (Graph g = new Graph()) {
119       Scope scope = new Scope(g).withSubScope("sub");
120 
121       Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
122       Output<Float> y = TestUtil.square(g, "y", x);
123 
124       Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x));
125       assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/"));
126 
127       Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x));
128       assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/"));
129     }
130   }
131 }
132