• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2
3# Copyright Jim Bosch & Ankit Daftery 2010-2012.
4# Distributed under the Boost Software License, Version 1.0.
5#    (See accompanying file LICENSE_1_0.txt or copy at
6#          http://www.boost.org/LICENSE_1_0.txt)
7
8import ufunc_ext
9import unittest
10import numpy
11from numpy.testing.utils import assert_array_almost_equal
12
13class TestUnary(unittest.TestCase):
14
15    def testScalar(self):
16        f = ufunc_ext.UnaryCallable()
17        assert_array_almost_equal(f(1.0), 2.0)
18        assert_array_almost_equal(f(3.0), 6.0)
19
20    def testArray(self):
21        f = ufunc_ext.UnaryCallable()
22        a = numpy.arange(5, dtype=float)
23        b = f(a)
24        assert_array_almost_equal(b, a*2.0)
25        c = numpy.zeros(5, dtype=float)
26        d = f(a,output=c)
27        self.assert_(c is d)
28        assert_array_almost_equal(d, a*2.0)
29
30    def testList(self):
31        f = ufunc_ext.UnaryCallable()
32        a = range(5)
33        b = f(a)
34        assert_array_almost_equal(b/2.0, a)
35
36class TestBinary(unittest.TestCase):
37
38    def testScalar(self):
39        f = ufunc_ext.BinaryCallable()
40        assert_array_almost_equal(f(1.0, 3.0), 11.0)
41        assert_array_almost_equal(f(3.0, 2.0), 12.0)
42
43    def testArray(self):
44        f = ufunc_ext.BinaryCallable()
45        a = numpy.random.randn(5)
46        b = numpy.random.randn(5)
47        assert_array_almost_equal(f(a,b), (a*2+b*3))
48        c = numpy.zeros(5, dtype=float)
49        d = f(a,b,output=c)
50        self.assert_(c is d)
51        assert_array_almost_equal(d, a*2 + b*3)
52        assert_array_almost_equal(f(a, 2.0), a*2 + 6.0)
53        assert_array_almost_equal(f(1.0, b), 2.0 + b*3)
54
55
56if __name__=="__main__":
57    unittest.main()
58