@@ -232,51 +232,33 @@ class average_recall(layer):
232
232
"""Create a metric for the average recall calculation.
233
233
"""
234
234
235
- def __init__ (self , name = "average_recall" , classes = 2 , ** kwargs ):
235
+ def __init__ (self , name = "average_recall" , labels = 1 , ** kwargs ):
236
236
super (average_recall , self ).__init__ (name = name , ** kwargs )
237
237
238
- if classes < 2 :
239
- raise ValueError ('argument classes must >= 2' )
238
+ self .labels = labels
240
239
241
- self .classes = classes
242
-
243
- self .true = K .zeros (classes , dtype = "int32" )
244
- self .pred = K .zeros (classes , dtype = "int32" )
240
+ self .tp = K .zeros (labels , dtype = "int32" )
241
+ self .fn = K .zeros (labels , dtype = "int32" )
245
242
246
243
def reset_states (self ):
247
- K .set_value (self .true , [0 for v in range ( self .classes )] )
248
- K .set_value (self .pred , [0 for v in range ( self .classes )] )
244
+ K .set_value (self .tp , [0 ] * self .labels )
245
+ K .set_value (self .fn , [0 ] * self .labels )
249
246
250
247
def __call__ (self , y_true , y_pred ):
251
- # Cast input
252
- t , p = self .cast (y_true , y_pred , dtype = "float64" )
253
-
254
- # Init a bias matrix
255
- b = K .variable ([truediv (1 , (v + 1 )) for v in range (self .classes )],
256
- dtype = "float64" )
257
-
258
- # Simulate to_categorical operation
259
- t , p = K .expand_dims (t , axis = - 1 ), K .expand_dims (p , axis = - 1 )
260
- t , p = (t + 1 ) * b - 1 , (p + 1 ) * b - 1
261
-
262
- # Make correct position filled with 1
263
- t , p = K .cast (t , "bool" ), K .cast (p , "bool" )
264
- t , p = 1 - K .cast (t , "int32" ), 1 - K .cast (p , "int32" )
265
-
266
- t , p = K .transpose (t ), K .transpose (p )
248
+ y_true = K .cast (K .round (y_true ), "int32" )
249
+ y_pred = K .cast (K .round (y_pred ), "int32" )
250
+ neg_y_pred = 1 - y_pred
267
251
268
- # Results for current batch
269
- batch_t = K .sum (t , axis = - 1 )
270
- batch_p = K .sum (t * p , axis = - 1 )
252
+ tp = K .sum (K .transpose (y_true * y_pred ), axis = - 1 )
253
+ fn = K .sum (K .transpose (y_true * neg_y_pred ), axis = - 1 )
271
254
272
- # Accumulated results
273
- total_t = self .true * 1 + batch_t
274
- total_p = self .pred * 1 + batch_p
255
+ current_tp = K .cast (self .tp + tp , self .epsilon .dtype )
256
+ current_fn = K .cast (self .fn + fn , self .epsilon .dtype )
275
257
276
- self . add_update ( K .update_add (self .true , batch_t ) )
277
- self . add_update ( K .update_add (self .pred , batch_p ) )
258
+ tp_update = K .update_add (self .tp , tp )
259
+ fn_update = K .update_add (self .fn , fn )
278
260
279
- tp = K . cast ( total_p , dtype = 'float64' )
280
- tt = K . cast ( total_t , dtype = 'float64' )
261
+ self . add_update ( tp_update , inputs = [ y_true , y_pred ] )
262
+ self . add_update ( fn_update , inputs = [ y_true , y_pred ] )
281
263
282
- return K .mean (truediv (tp , ( tt + self .epsilon ) ))
264
+ return K .mean (truediv (current_tp , current_tp + current_fn + self .epsilon ))
0 commit comments