• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.learn.python.learn import ops
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import random_seed
27from tensorflow.python.ops import variables
28from tensorflow.python.ops import array_ops
29from tensorflow.python.platform import test
30
31
32class OpsTest(test.TestCase):
33  """Ops tests."""
34
35  def test_softmax_classifier(self):
36    with self.cached_session() as session:
37      features = array_ops.placeholder(dtypes.float32, [None, 3])
38      labels = array_ops.placeholder(dtypes.float32, [None, 2])
39      weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]])
40      biases = constant_op.constant([0.2, 0.3])
41      class_weight = constant_op.constant([0.1, 0.9])
42      prediction, loss = ops.softmax_classifier(features, labels, weights,
43                                                biases, class_weight)
44      self.assertEqual(prediction.get_shape()[1], 2)
45      self.assertEqual(loss.get_shape(), [])
46      value = session.run(loss, {features: [[0.2, 0.3, 0.2]], labels: [[0, 1]]})
47      self.assertAllClose(value, 0.55180627)
48
49  def test_embedding_lookup(self):
50    d_embed = 5
51    n_embed = 10
52    ids_shape = (2, 3, 4)
53    embeds = np.random.randn(n_embed, d_embed)
54    ids = np.random.randint(0, n_embed, ids_shape)
55    with self.cached_session():
56      embed_np = embeds[ids]
57      embed_tf = ops.embedding_lookup(embeds, ids).eval()
58    self.assertEqual(embed_np.shape, embed_tf.shape)
59    self.assertAllClose(embed_np, embed_tf)
60
61  def test_categorical_variable(self):
62    random_seed.set_random_seed(42)
63    with self.cached_session() as sess:
64      cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
65      embeddings = ops.categorical_variable(
66          cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
67      sess.run(variables.global_variables_initializer())
68      emb1 = sess.run(embeddings,
69                      feed_dict={cat_var_idx.name: [[0, 1], [2, 3]]})
70      emb2 = sess.run(embeddings,
71                      feed_dict={cat_var_idx.name: [[0, 2], [1, 3]]})
72    self.assertEqual(emb1.shape, emb2.shape)
73    self.assertAllEqual(np.transpose(emb2, axes=[1, 0, 2]), emb1)
74
75
76if __name__ == "__main__":
77  test.main()
78