1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17// Tests for the generated code of some operations. 18 19package op 20 21import ( 22 "strings" 23 "testing" 24 25 tf "github.com/tensorflow/tensorflow/tensorflow/go" 26) 27 28func TestPlaceholder(t *testing.T) { 29 s := NewScope() 30 Placeholder(s.SubScope("x"), tf.Float, PlaceholderShape(tf.MakeShape(-1, 10))) 31 Placeholder(s.SubScope("y"), tf.Float, PlaceholderShape(tf.ScalarShape())) 32 Placeholder(s.SubScope("z"), tf.Float, PlaceholderShape(tf.Shape{})) 33 if _, err := s.Finalize(); err != nil { 34 t.Fatal(err) 35 } 36} 37 38func TestAddOperationFailure(t *testing.T) { 39 // Inspired from https://github.com/tensorflow/tensorflow/issues/9931 40 s := NewScope() 41 42 resize := ResizeArea(s, Placeholder(s, tf.Float), Const(s, []int64{80, 80})) 43 if err := s.Err(); err == nil { 44 t.Fatal("ResizeArea expects an int32 Tensor for size, should fail when an int64 is provided") 45 } 46 // And any use of resize should panic with an error message more informative than SIGSEGV 47 defer func() { 48 r := recover() 49 if r == nil { 50 return 51 } 52 s, ok := r.(string) 53 if ok && strings.Contains(s, "see Scope.Err() for details") { 54 return 55 } 56 t.Errorf("Expected panic string to Scope.Err(), found %T: %q", r, r) 57 }() 58 _ = resize.Shape() 59 t.Errorf("resize.Shape() should have paniced since the underlying Operation was not created") 60} 61 62func TestShapeAttribute(t *testing.T) { 63 s := NewScope() 64 x := Placeholder(s.SubScope("x"), tf.Int32, PlaceholderShape(tf.MakeShape(1))) 65 y := Placeholder(s.SubScope("y"), tf.Int32, PlaceholderShape(tf.Shape{})) 66 z := Add(s, x, y) 67 graph, err := s.Finalize() 68 if err != nil { 69 t.Fatal(err) 70 } 71 sess, err := tf.NewSession(graph, nil) 72 if err != nil { 73 t.Fatal(err) 74 } 75 76 value, err := tf.NewTensor([]int32{7}) 77 if err != nil { 78 t.Fatal(err) 79 } 80 feeds := map[tf.Output]*tf.Tensor{ 81 x: value, 82 y: value, 83 } 84 fetched, err := sess.Run(feeds, []tf.Output{z}, nil) 85 if err != nil { 86 t.Fatal(err) 87 } 88 if got, want := len(fetched), 1; got != want { 89 t.Fatalf("Fetched %d tensors, expected %d", got, want) 90 } 91 if got, want := fetched[0].Value().([]int32), []int32{14}; len(got) != len(want) || len(got) != 1 || got[0] != want[0] { 92 t.Fatalf("Got %v, want %v", got, want) 93 } 94} 95 96func TestDataset(t *testing.T) { 97 var ( 98 s = NewScope() 99 100 // The use of a non-scalar here is inspired by 101 // https://github.com/tensorflow/tensorflow/issues/14891 102 c = Const(s, []int32{21718, 31415}) 103 types = []tf.DataType{c.DataType()} 104 shapes = []tf.Shape{c.Shape()} 105 dataset = TensorDataset(s, []tf.Output{c}, shapes) 106 107 iterator = Iterator(s, "", "", types, shapes) 108 next = IteratorGetNext(s, iterator, types, shapes) 109 init = MakeIterator(s, dataset, iterator) 110 ) 111 graph, err := s.Finalize() 112 if err != nil { 113 t.Fatal(err) 114 } 115 sess, err := tf.NewSession(graph, nil) 116 if err != nil { 117 t.Fatal(err) 118 } 119 if _, err := sess.Run(nil, nil, []*tf.Operation{init}); err != nil { 120 t.Fatal(err) 121 } 122 results, err := sess.Run(nil, next, nil) 123 if err != nil { 124 t.Fatal(err) 125 } 126 got := results[0].Value().([]int32) 127 if len(got) != 2 || got[0] != 21718 || got[1] != 31415 { 128 t.Errorf("Got %v, want {21718, 31415}", got) 129 } 130 if _, err := sess.Run(nil, next, nil); err == nil { 131 t.Errorf("Expected sess.Run() to fail since the iterator should have reached the end of the dataset") 132 } 133} 134