1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.op.core; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertNotNull; 20 import static org.junit.Assert.assertTrue; 21 22 import java.util.Arrays; 23 import org.junit.Test; 24 import org.junit.runner.RunWith; 25 import org.junit.runners.JUnit4; 26 import org.tensorflow.Graph; 27 import org.tensorflow.Output; 28 import org.tensorflow.Session; 29 import org.tensorflow.Tensor; 30 import org.tensorflow.Tensors; 31 import org.tensorflow.TestUtil; 32 import org.tensorflow.op.Scope; 33 34 @RunWith(JUnit4.class) 35 public class GradientsTest { 36 37 @Test createGradients()38 public void createGradients() { 39 try (Graph g = new Graph(); 40 Session sess = new Session(g)) { 41 Scope scope = new Scope(g); 42 43 Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); 44 Output<Float> y0 = TestUtil.square(g, "y0", x); 45 Output<Float> y1 = TestUtil.square(g, "y1", y0); 46 47 Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0)); 48 49 assertNotNull(grads); 50 assertNotNull(grads.dy()); 51 assertEquals(2, grads.dy().size()); 52 53 try (Tensor<Float> c = Tensors.create(3.0f); 54 TestUtil.AutoCloseableList<Tensor<?>> outputs = 55 new TestUtil.AutoCloseableList<>( 56 sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { 57 58 assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); 59 assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f); 60 } 61 } 62 } 63 64 @Test createGradientsWithSum()65 public void createGradientsWithSum() { 66 try (Graph g = new Graph(); 67 Session sess = new Session(g)) { 68 Scope scope = new Scope(g); 69 70 Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); 71 Output<Float> y0 = TestUtil.square(g, "y0", x); 72 Output<Float> y1 = TestUtil.square(g, "y1", y0); 73 74 Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x)); 75 76 assertNotNull(grads); 77 assertNotNull(grads.dy()); 78 assertEquals(1, grads.dy().size()); 79 80 try (Tensor<Float> c = Tensors.create(3.0f); 81 TestUtil.AutoCloseableList<Tensor<?>> outputs = 82 new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { 83 84 assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); 85 } 86 } 87 } 88 89 @Test createGradientsWithInitialValues()90 public void createGradientsWithInitialValues() { 91 try (Graph g = new Graph(); 92 Session sess = new Session(g)) { 93 Scope scope = new Scope(g); 94 95 Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); 96 Output<Float> y0 = TestUtil.square(g, "y0", x); 97 Output<Float> y1 = TestUtil.square(g, "y1", y0); 98 99 Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0)); 100 Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy())); 101 102 assertNotNull(grads1); 103 assertNotNull(grads1.dy()); 104 assertEquals(1, grads1.dy().size()); 105 106 try (Tensor<Float> c = Tensors.create(3.0f); 107 TestUtil.AutoCloseableList<Tensor<?>> outputs = 108 new TestUtil.AutoCloseableList<>( 109 sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { 110 111 assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); 112 } 113 } 114 } 115 116 @Test validateGradientsNames()117 public void validateGradientsNames() { 118 try (Graph g = new Graph()) { 119 Scope scope = new Scope(g).withSubScope("sub"); 120 121 Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); 122 Output<Float> y = TestUtil.square(g, "y", x); 123 124 Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x)); 125 assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/")); 126 127 Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x)); 128 assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/")); 129 } 130 } 131 } 132