@@ -67,7 +67,46 @@ def test_trace_with_incompatible_seeds(self):
67
67
with self .assertRaises (AttributeError ):
68
68
trace (self .f2 , seeds = np .array ([[1 ,3 ,2 ],[1 ,3 ,2 ],[2 ,3 ,3 ]]), x = np .array (
69
69
[[2 , 3 ]]), y = np .array ([[1 , 2 ]]))
70
- #def test_trace_with_vector_inputs_with_seeds(self):
70
+
71
+ def test_trace_with_incompatible_seeds_reverse (self ):
72
+ with self .assertRaises (AttributeError ):
73
+ trace (self .f2 , seeds = np .array ([[1 ,3 ,2 ],[1 ,3 ,2 ],[2 ,3 ,3 ]]), mode = "reverse" , x = np .array (
74
+ [[2 , 3 ]]), y = np .array ([[1 , 2 ]]))
75
+
76
+ def test_trace_with_scalar_inputs_seeds (self ):
77
+ with self .assertRaises (TypeError ):
78
+ trace (self .f1 , seeds = np .array ([[1 ,0 ]]), x = 1 )
79
+ self .assertEqual (trace (self .f1 , seeds = 2 , x = 1 ), (3 ,2 ))
80
+ self .assertEqual (trace (self .f2 , seeds = np .array ([[1 ,0 ],[0 ,1 ]]), x = 2 , y = 999 )[1 ], [[1 ], [1 ]])
81
+ with self .assertRaises (TypeError ):
82
+ trace (self .f2 , seeds = 1 , x = 2 , y = 999 )
83
+ with self .assertRaises (TypeError ):
84
+ trace (self .f2 , seeds = "seed" , x = 2 , y = 999 )
85
+ with self .assertRaises (TypeError ):
86
+ trace (self .f2 , seeds = np .array ([[1 ,0 ],[0 ,1 ]]), x = "2" , y = 999 )
87
+ with self .assertRaises (AttributeError ):
88
+ trace (self .f2 , seeds = np .array ([[1 ],[0 ]]), x = 2 , y = 999 )
89
+
90
+ def test_trace_with_scalar_inputs_seeds_reverse (self ):
91
+ with self .assertRaises (TypeError ):
92
+ trace (self .f1 , seeds = np .array ([[1 ,0 ]]), mode = 'reverse' , x = 1 )
93
+ self .assertEqual (trace (self .f1 , seeds = 2 , mode = 'reverse' , x = 1 ), (3 ,2 ))
94
+ self .assertEqual (trace (self .f2 , seeds = np .array ([[1 ,0 ],[0 ,1 ]]), mode = 'reverse' , x = 2 , y = 999 )[1 ], [[1 ], [1 ]])
95
+ with self .assertRaises (TypeError ):
96
+ trace (self .f2 , seeds = 1 , mode = 'reverse' , x = 2 , y = 999 )
97
+ with self .assertRaises (TypeError ):
98
+ trace (self .f2 , seeds = "seed" , mode = 'reverse' , x = 2 , y = 999 )
99
+ with self .assertRaises (TypeError ):
100
+ trace (self .f2 , seeds = np .array ([[1 ,0 ],[0 ,1 ]]), mode = 'reverse' , x = "2" , y = 999 )
101
+ with self .assertRaises (AttributeError ):
102
+ trace (self .f2 , seeds = np .array ([[1 ],[0 ]]), mode = 'reverse' , x = 2 , y = 999 )
103
+
104
+ def test_trace_with_vector_inputs_seeds (self ):
105
+ self .assertEqual (trace (self .f2 , seeds = np .array ([[1 ,0 ],[0 ,1 ]]), x = np .array ([[2 , 3 ]]), y = np .array ([[1 , 2 ]]))[1 ], [[1 , 1 ], [1 , 1 ]])
106
+
107
+ def test_trace_with_vector_inputs_seeds_reverse (self ):
108
+ self .assertEqual (trace (self .f2 , seeds = np .array ([[1 ,0 ],[0 ,1 ]]), mode = 'reverse' , x = np .array ([[2 , 3 ]]), y = np .array ([[1 , 2 ]]))[1 ], [[1 , 1 ], [1 , 1 ]])
109
+
71
110
def test_trace_with_different_moded (self ):
72
111
self .assertEqual (trace (self .f1 , x = 2 ), (4 , 1 ))
73
112
self .assertEqual (trace (self .f1 , mode = 'forward' , x = 2 ), (4 , 1 ))
@@ -111,10 +150,18 @@ def test_trace_multiple_vector_inputs(self):
111
150
self .assertEqual (trace (self .f3 , mode = 'reverse' , x = np .array (
112
151
[[2 , 2 ]]), y = np .array ([[4 , 4 ]]))[1 ], [[2. , 2. ], [0.25 , 0.25 ]])
113
152
114
- def test_trace_single_vector_inputs (self ):
153
+ def test_trace_single_vector_input (self ):
115
154
self .assertEqual (trace (self .f1 , x = np .array ([[2 , 2 ]]))[1 ], [[1 , 1 ]])
155
+ with self .assertRaises (TypeError ):
156
+ trace (self .f1 , x = np .array ([]))
116
157
158
+ def test_trace_non_lambda_function (self ):
159
+ with self .assertRaises (TypeError ):
160
+ trace ("Function" , x = 1 )
117
161
162
+ def test_trace_vector_functions (self ):
163
+ self .assertEqual (trace ([self .f2 ,self .f3 ], x = 2 , y = 4 )[1 ].tolist (), [[[1.0 ], [1.0 ]], [[2.0 ], [0.25 ]]])
164
+
118
165
def test_mixed_inputs (self ):
119
166
self .assertEqual (trace (self .f3 , x = np .array ([[2 , 2 ]]), y = 4 )[
120
167
0 ].tolist (), [[6 , 6 ]])
@@ -125,7 +172,8 @@ def test_mixed_inputs(self):
125
172
1 ], [[2. , 2. ], [0.25 , 0.25 ]])
126
173
self .assertEqual (trace (self .f3 , mode = 'reverse' , x = np .array ([[2 , 2 ]]), y = 4 )[
127
174
1 ], [[2. , 2. ], [0.25 , 0.25 ]])
128
-
175
+ self .assertEqual (trace (self .f3 , mode = 'reverse' , x = np .array ([[2 , 2 ]]), y = np .array ([[4 , 4 ]]))[
176
+ 1 ], [[2. , 2. ], [0.25 , 0.25 ]])
129
177
130
178
if __name__ == "__main__" :
131
179
unittest .main ()
0 commit comments