• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2
3# Copyright Jim Bosch & Ankit Daftery 2010-2012.
4# Distributed under the Boost Software License, Version 1.0.
5#    (See accompanying file LICENSE_1_0.txt or copy at
6#          http://www.boost.org/LICENSE_1_0.txt)
7
8import ndarray_ext
9import unittest
10import numpy
11
12class TestNdarray(unittest.TestCase):
13
14    def testNdzeros(self):
15        for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
16            v = numpy.zeros(60, dtype=dtp)
17            dt = numpy.dtype(dtp)
18            for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
19                a1 = ndarray_ext.zeros(shape,dt)
20                a2 = v.reshape(a1.shape)
21                self.assertEqual(shape,a1.shape)
22                self.assert_((a1 == a2).all())
23
24    def testNdzeros_matrix(self):
25        for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
26            dt = numpy.dtype(dtp)
27            shape = (6, 10)
28            a1 = ndarray_ext.zeros_matrix(shape, dt)
29            a2 = numpy.matrix(numpy.zeros(shape, dtype=dtp))
30            self.assertEqual(shape,a1.shape)
31            self.assert_((a1 == a2).all())
32            self.assertEqual(type(a1), type(a2))
33
34    def testNdarray(self):
35        a = range(0,60)
36        for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
37            v = numpy.array(a, dtype=dtp)
38            dt = numpy.dtype(dtp)
39            a1 = ndarray_ext.array(a)
40            a2 = ndarray_ext.array(a,dt)
41            self.assert_((a1 == v).all())
42            self.assert_((a2 == v).all())
43            for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
44                a1 = a1.reshape(shape)
45                self.assertEqual(shape,a1.shape)
46                a2 = a2.reshape(shape)
47                self.assertEqual(shape,a2.shape)
48
49    def testNdempty(self):
50        for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
51            dt = numpy.dtype(dtp)
52            for shape in ((60,),(6,10),(4,3,5),(2,2,3,5)):
53                a1 = ndarray_ext.empty(shape,dt)
54                a2 = ndarray_ext.c_empty(shape,dt)
55                self.assertEqual(shape,a1.shape)
56                self.assertEqual(shape,a2.shape)
57
58    def testTranspose(self):
59        for dtp in (numpy.int16, numpy.int32, numpy.float32, numpy.complex128):
60            dt = numpy.dtype(dtp)
61            for shape in ((6,10),(4,3,5),(2,2,3,5)):
62                a1 = numpy.empty(shape,dt)
63                a2 = a1.transpose()
64                a1 = ndarray_ext.transpose(a1)
65                self.assertEqual(a1.shape,a2.shape)
66
67    def testSqueeze(self):
68        a1 = numpy.array([[[3,4,5]]])
69        a2 = a1.squeeze()
70        a1 = ndarray_ext.squeeze(a1)
71        self.assertEqual(a1.shape,a2.shape)
72
73    def testReshape(self):
74        a1 = numpy.empty((2,2))
75        a2 = ndarray_ext.reshape(a1,(1,4))
76        self.assertEqual(a2.shape,(1,4))
77
78    def testShapeIndex(self):
79        a = numpy.arange(24)
80        a.shape = (1,2,3,4)
81        def shape_check(i):
82            print(i)
83            self.assertEqual(ndarray_ext.shape_index(a,i) ,a.shape[i] )
84        for i in range(4):
85            shape_check(i)
86        for i in range(-1,-5,-1):
87            shape_check(i)
88        try:
89            ndarray_ext.shape_index(a,4) # out of bounds -- should raise IndexError
90            self.assertTrue(False)
91        except IndexError:
92            pass
93
94    def testStridesIndex(self):
95        a = numpy.arange(24)
96        a.shape = (1,2,3,4)
97        def strides_check(i):
98            print(i)
99            self.assertEqual(ndarray_ext.strides_index(a,i) ,a.strides[i] )
100        for i in range(4):
101            strides_check(i)
102        for i in range(-1,-5,-1):
103            strides_check(i)
104        try:
105            ndarray_ext.strides_index(a,4) # out of bounds -- should raise IndexError
106            self.assertTrue(False)
107        except IndexError:
108            pass
109
110
111if __name__=="__main__":
112    unittest.main()
113