• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2import pytest
3from pybind11_tests import operators as m
4from pybind11_tests import ConstructorStats
5
6
7def test_operator_overloading():
8    v1 = m.Vector2(1, 2)
9    v2 = m.Vector(3, -1)
10    v3 = m.Vector2(1, 2)  # Same value as v1, but different instance.
11    assert v1 is not v3
12
13    assert str(v1) == "[1.000000, 2.000000]"
14    assert str(v2) == "[3.000000, -1.000000]"
15
16    assert str(-v2) == "[-3.000000, 1.000000]"
17
18    assert str(v1 + v2) == "[4.000000, 1.000000]"
19    assert str(v1 - v2) == "[-2.000000, 3.000000]"
20    assert str(v1 - 8) == "[-7.000000, -6.000000]"
21    assert str(v1 + 8) == "[9.000000, 10.000000]"
22    assert str(v1 * 8) == "[8.000000, 16.000000]"
23    assert str(v1 / 8) == "[0.125000, 0.250000]"
24    assert str(8 - v1) == "[7.000000, 6.000000]"
25    assert str(8 + v1) == "[9.000000, 10.000000]"
26    assert str(8 * v1) == "[8.000000, 16.000000]"
27    assert str(8 / v1) == "[8.000000, 4.000000]"
28    assert str(v1 * v2) == "[3.000000, -2.000000]"
29    assert str(v2 / v1) == "[3.000000, -0.500000]"
30
31    assert v1 == v3
32    assert v1 != v2
33    assert hash(v1) == 4
34    # TODO(eric.cousineau): Make this work.
35    # assert abs(v1) == "abs(Vector2)"
36
37    v1 += 2 * v2
38    assert str(v1) == "[7.000000, 0.000000]"
39    v1 -= v2
40    assert str(v1) == "[4.000000, 1.000000]"
41    v1 *= 2
42    assert str(v1) == "[8.000000, 2.000000]"
43    v1 /= 16
44    assert str(v1) == "[0.500000, 0.125000]"
45    v1 *= v2
46    assert str(v1) == "[1.500000, -0.125000]"
47    v2 /= v1
48    assert str(v2) == "[2.000000, 8.000000]"
49
50    cstats = ConstructorStats.get(m.Vector2)
51    assert cstats.alive() == 3
52    del v1
53    assert cstats.alive() == 2
54    del v2
55    assert cstats.alive() == 1
56    del v3
57    assert cstats.alive() == 0
58    assert cstats.values() == [
59        "[1.000000, 2.000000]",
60        "[3.000000, -1.000000]",
61        "[1.000000, 2.000000]",
62        "[-3.000000, 1.000000]",
63        "[4.000000, 1.000000]",
64        "[-2.000000, 3.000000]",
65        "[-7.000000, -6.000000]",
66        "[9.000000, 10.000000]",
67        "[8.000000, 16.000000]",
68        "[0.125000, 0.250000]",
69        "[7.000000, 6.000000]",
70        "[9.000000, 10.000000]",
71        "[8.000000, 16.000000]",
72        "[8.000000, 4.000000]",
73        "[3.000000, -2.000000]",
74        "[3.000000, -0.500000]",
75        "[6.000000, -2.000000]",
76    ]
77    assert cstats.default_constructions == 0
78    assert cstats.copy_constructions == 0
79    assert cstats.move_constructions >= 10
80    assert cstats.copy_assignments == 0
81    assert cstats.move_assignments == 0
82
83
84def test_operators_notimplemented():
85    """#393: need to return NotSupported to ensure correct arithmetic operator behavior"""
86
87    c1, c2 = m.C1(), m.C2()
88    assert c1 + c1 == 11
89    assert c2 + c2 == 22
90    assert c2 + c1 == 21
91    assert c1 + c2 == 12
92
93
94def test_nested():
95    """#328: first member in a class can't be used in operators"""
96
97    a = m.NestA()
98    b = m.NestB()
99    c = m.NestC()
100
101    a += 10
102    assert m.get_NestA(a) == 13
103    b.a += 100
104    assert m.get_NestA(b.a) == 103
105    c.b.a += 1000
106    assert m.get_NestA(c.b.a) == 1003
107    b -= 1
108    assert m.get_NestB(b) == 3
109    c.b -= 3
110    assert m.get_NestB(c.b) == 1
111    c *= 7
112    assert m.get_NestC(c) == 35
113
114    abase = a.as_base()
115    assert abase.value == -2
116    a.as_base().value += 44
117    assert abase.value == 42
118    assert c.b.a.as_base().value == -2
119    c.b.a.as_base().value += 44
120    assert c.b.a.as_base().value == 42
121
122    del c
123    pytest.gc_collect()
124    del a  # Shouldn't delete while abase is still alive
125    pytest.gc_collect()
126
127    assert abase.value == 42
128    del abase, b
129    pytest.gc_collect()
130
131
132def test_overriding_eq_reset_hash():
133
134    assert m.Comparable(15) is not m.Comparable(15)
135    assert m.Comparable(15) == m.Comparable(15)
136
137    with pytest.raises(TypeError):
138        hash(m.Comparable(15))  # TypeError: unhashable type: 'm.Comparable'
139
140    for hashable in (m.Hashable, m.Hashable2):
141        assert hashable(15) is not hashable(15)
142        assert hashable(15) == hashable(15)
143
144        assert hash(hashable(15)) == 15
145        assert hash(hashable(15)) == hash(hashable(15))
146