22from typing import Tuple
33from unittest .mock import patch
44
5- import numpy as np
65import pytest
76import sklearn
87import torch
@@ -28,85 +27,97 @@ def test_no_sklearn(mock_no_sklearn):
2827 pr_curve .compute ()
2928
3029
31- def test_precision_recall_curve ():
30+ def test_precision_recall_curve (available_device ):
3231 size = 100
33- np_y_pred = np .random .rand (size , 1 )
34- np_y = np .zeros ((size ,))
35- np_y [size // 2 :] = 1
36- sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y , np_y_pred )
32+ y_pred = torch .rand (size , 1 , dtype = torch .float32 , device = available_device )
33+ y_true = torch .zeros (size , dtype = torch .float32 , device = available_device )
34+ y_true [size // 2 :] = 1.0
35+ expected_precision , expected_recall , expected_thresholds = precision_recall_curve (
36+ y_true .cpu ().numpy (), y_pred .cpu ().numpy ()
37+ )
3738
38- precision_recall_curve_metric = PrecisionRecallCurve ()
39- y_pred = torch .from_numpy (np_y_pred )
40- y = torch .from_numpy (np_y )
39+ precision_recall_curve_metric = PrecisionRecallCurve (device = available_device )
40+ assert precision_recall_curve_metric ._device == torch .device (available_device )
4141
42- precision_recall_curve_metric .update ((y_pred , y ))
42+ precision_recall_curve_metric .update ((y_pred , y_true ))
4343 precision , recall , thresholds = precision_recall_curve_metric .compute ()
44- precision = precision .numpy ()
45- recall = recall .numpy ()
46- thresholds = thresholds .numpy ()
4744
48- assert pytest .approx (precision ) == sk_precision
49- assert pytest .approx (recall ) == sk_recall
50- # assert thresholds almost equal, due to numpy->torch->numpy conversion
51- np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
45+ precision = precision .cpu ().numpy ()
46+ recall = recall .cpu ().numpy ()
47+ thresholds = thresholds .cpu ().numpy ()
48+
49+ assert pytest .approx (precision ) == expected_precision
50+ assert pytest .approx (recall ) == expected_recall
51+ assert thresholds == pytest .approx (expected_thresholds , rel = 1e-6 )
5252
5353
54- def test_integration_precision_recall_curve_with_output_transform ():
55- np .random .seed (1 )
54+ def test_integration_precision_recall_curve_with_output_transform (available_device ):
5655 size = 100
57- np_y_pred = np .random .rand (size , 1 )
58- np_y = np .zeros ((size ,))
59- np_y [size // 2 :] = 1
60- np .random .shuffle (np_y )
56+ y_pred = torch .rand (size , 1 , dtype = torch .float32 , device = available_device )
57+ y_true = torch .zeros (size , dtype = torch .float32 , device = available_device )
58+ y_true [size // 2 :] = 1.0
59+ perm = torch .randperm (size )
60+ y_pred = y_pred [perm ]
61+ y_true = y_true [perm ]
6162
62- sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y , np_y_pred )
63+ expected_precision , expected_recall , expected_thresholds = precision_recall_curve (
64+ y_true .cpu ().numpy (), y_pred .cpu ().numpy ()
65+ )
6366
6467 batch_size = 10
6568
6669 def update_fn (engine , batch ):
6770 idx = (engine .state .iteration - 1 ) * batch_size
68- y_true_batch = np_y [idx : idx + batch_size ]
69- y_pred_batch = np_y_pred [idx : idx + batch_size ]
70- return idx , torch . from_numpy ( y_pred_batch ), torch . from_numpy ( y_true_batch )
71+ y_true_batch = y_true [idx : idx + batch_size ]
72+ y_pred_batch = y_pred [idx : idx + batch_size ]
73+ return idx , y_pred_batch , y_true_batch
7174
7275 engine = Engine (update_fn )
7376
74- precision_recall_curve_metric = PrecisionRecallCurve (output_transform = lambda x : (x [1 ], x [2 ]))
77+ precision_recall_curve_metric = PrecisionRecallCurve (
78+ output_transform = lambda x : (x [1 ], x [2 ]), device = available_device
79+ )
80+ assert precision_recall_curve_metric ._device == torch .device (available_device )
7581 precision_recall_curve_metric .attach (engine , "precision_recall_curve" )
7682
7783 data = list (range (size // batch_size ))
7884 precision , recall , thresholds = engine .run (data , max_epochs = 1 ).metrics ["precision_recall_curve" ]
79- precision = precision .numpy ()
80- recall = recall .numpy ()
81- thresholds = thresholds .numpy ()
82- assert pytest .approx (precision ) == sk_precision
83- assert pytest .approx (recall ) == sk_recall
84- # assert thresholds almost equal, due to numpy->torch->numpy conversion
85- np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
85+ precision = precision .cpu ().numpy ()
86+ recall = recall .cpu ().numpy ()
87+ thresholds = thresholds .cpu ().numpy ()
88+ assert pytest .approx (precision ) == expected_precision
89+ assert pytest .approx (recall ) == expected_recall
90+ assert thresholds == pytest .approx (expected_thresholds , rel = 1e-6 )
8691
8792
88- def test_integration_precision_recall_curve_with_activated_output_transform ():
89- np .random .seed (1 )
93+ def test_integration_precision_recall_curve_with_activated_output_transform (available_device ):
9094 size = 100
91- np_y_pred = np .random .rand (size , 1 )
92- np_y_pred_sigmoid = torch .sigmoid (torch .from_numpy (np_y_pred )).numpy ()
93- np_y = np .zeros ((size ,))
94- np_y [size // 2 :] = 1
95- np .random .shuffle (np_y )
96-
97- sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y , np_y_pred_sigmoid )
95+ y_pred = torch .rand (size , 1 , dtype = torch .float32 , device = available_device )
96+ y_true = torch .zeros (size , dtype = torch .float32 , device = available_device )
97+ y_true [size // 2 :] = 1.0
98+ perm = torch .randperm (size )
99+ y_pred = y_pred [perm ]
100+ y_true = y_true [perm ]
101+
102+ sigmoid_y_pred = torch .sigmoid (y_pred ).cpu ().numpy ()
103+ expected_precision , expected_recall , expected_thresholds = precision_recall_curve (
104+ y_true .cpu ().numpy (), sigmoid_y_pred
105+ )
98106
99107 batch_size = 10
100108
101109 def update_fn (engine , batch ):
102110 idx = (engine .state .iteration - 1 ) * batch_size
103- y_true_batch = np_y [idx : idx + batch_size ]
104- y_pred_batch = np_y_pred [idx : idx + batch_size ]
105- return idx , torch . from_numpy ( y_pred_batch ), torch . from_numpy ( y_true_batch )
111+ y_true_batch = y_true [idx : idx + batch_size ]
112+ y_pred_batch = y_pred [idx : idx + batch_size ]
113+ return idx , y_pred_batch , y_true_batch
106114
107115 engine = Engine (update_fn )
108116
109- precision_recall_curve_metric = PrecisionRecallCurve (output_transform = lambda x : (torch .sigmoid (x [1 ]), x [2 ]))
117+ precision_recall_curve_metric = PrecisionRecallCurve (
118+ output_transform = lambda x : (torch .sigmoid (x [1 ]), x [2 ]), device = available_device
119+ )
120+ assert precision_recall_curve_metric ._device == torch .device (available_device )
110121 precision_recall_curve_metric .attach (engine , "precision_recall_curve" )
111122
112123 data = list (range (size // batch_size ))
@@ -115,25 +126,26 @@ def update_fn(engine, batch):
115126 recall = recall .cpu ().numpy ()
116127 thresholds = thresholds .cpu ().numpy ()
117128
118- assert pytest .approx (precision ) == sk_precision
119- assert pytest .approx (recall ) == sk_recall
120- # assert thresholds almost equal, due to numpy->torch->numpy conversion
121- np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
129+ assert pytest .approx (precision ) == expected_precision
130+ assert pytest .approx (recall ) == expected_recall
131+ assert thresholds == pytest .approx (expected_thresholds , rel = 1e-6 )
122132
123133
124- def test_check_compute_fn ():
134+ def test_check_compute_fn (available_device ):
125135 y_pred = torch .zeros ((8 , 13 ))
126136 y_pred [:, 1 ] = 1
127137 y_true = torch .zeros_like (y_pred )
128138 output = (y_pred , y_true )
129139
130- em = PrecisionRecallCurve (check_compute_fn = True )
140+ em = PrecisionRecallCurve (check_compute_fn = True , device = available_device )
141+ assert em ._device == torch .device (available_device )
131142
132143 em .reset ()
133144 with pytest .warns (EpochMetricWarning , match = r"Probably, there can be a problem with `compute_fn`" ):
134145 em .update (output )
135146
136- em = PrecisionRecallCurve (check_compute_fn = False )
147+ em = PrecisionRecallCurve (check_compute_fn = False , device = available_device )
148+ assert em ._device == torch .device (available_device )
137149 em .update (output )
138150
139151
@@ -225,14 +237,14 @@ def update(engine, i):
225237 np_y_true = y_true .cpu ().numpy ().ravel ()
226238 np_y_preds = y_preds .cpu ().numpy ().ravel ()
227239
228- sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y_true , np_y_preds )
240+ expected_precision , expected_recall , expected_thresholds = precision_recall_curve (np_y_true , np_y_preds )
229241
230- assert precision .shape == sk_precision .shape
231- assert recall .shape == sk_recall .shape
232- assert thresholds .shape == sk_thresholds .shape
233- assert pytest .approx (precision .cpu ().numpy ()) == sk_precision
234- assert pytest .approx (recall .cpu ().numpy ()) == sk_recall
235- assert pytest .approx (thresholds .cpu ().numpy ()) == sk_thresholds
242+ assert precision .shape == expected_precision .shape
243+ assert recall .shape == expected_recall .shape
244+ assert thresholds .shape == expected_thresholds .shape
245+ assert pytest .approx (precision .cpu ().numpy ()) == expected_precision
246+ assert pytest .approx (recall .cpu ().numpy ()) == expected_recall
247+ assert pytest .approx (thresholds .cpu ().numpy ()) == expected_thresholds
236248
237249 metric_devices = ["cpu" ]
238250 if device .type != "xla" :
0 commit comments