1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib.learn.python.learn import ops 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import random_seed 27from tensorflow.python.ops import variables 28from tensorflow.python.ops import array_ops 29from tensorflow.python.platform import test 30 31 32class OpsTest(test.TestCase): 33 """Ops tests.""" 34 35 def test_softmax_classifier(self): 36 with self.cached_session() as session: 37 features = array_ops.placeholder(dtypes.float32, [None, 3]) 38 labels = array_ops.placeholder(dtypes.float32, [None, 2]) 39 weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]]) 40 biases = constant_op.constant([0.2, 0.3]) 41 class_weight = constant_op.constant([0.1, 0.9]) 42 prediction, loss = ops.softmax_classifier(features, labels, weights, 43 biases, class_weight) 44 self.assertEqual(prediction.get_shape()[1], 2) 45 self.assertEqual(loss.get_shape(), []) 46 value = session.run(loss, {features: [[0.2, 0.3, 0.2]], labels: [[0, 1]]}) 47 self.assertAllClose(value, 0.55180627) 48 49 def test_embedding_lookup(self): 50 d_embed = 5 51 n_embed = 10 52 ids_shape = (2, 3, 4) 53 embeds = np.random.randn(n_embed, d_embed) 54 ids = np.random.randint(0, n_embed, ids_shape) 55 with self.cached_session(): 56 embed_np = embeds[ids] 57 embed_tf = ops.embedding_lookup(embeds, ids).eval() 58 self.assertEqual(embed_np.shape, embed_tf.shape) 59 self.assertAllClose(embed_np, embed_tf) 60 61 def test_categorical_variable(self): 62 random_seed.set_random_seed(42) 63 with self.cached_session() as sess: 64 cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2]) 65 embeddings = ops.categorical_variable( 66 cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var") 67 sess.run(variables.global_variables_initializer()) 68 emb1 = sess.run(embeddings, 69 feed_dict={cat_var_idx.name: [[0, 1], [2, 3]]}) 70 emb2 = sess.run(embeddings, 71 feed_dict={cat_var_idx.name: [[0, 2], [1, 3]]}) 72 self.assertEqual(emb1.shape, emb2.shape) 73 self.assertAllEqual(np.transpose(emb2, axes=[1, 0, 2]), emb1) 74 75 76if __name__ == "__main__": 77 test.main() 78