6
6
7
7
class layer (Layer ):
8
8
9
- def __init__ (self , label = None , name = None , ** kwargs ):
9
+ def __init__ (self , label = None , name = None , sparse = False , ** kwargs ):
10
10
super (layer , self ).__init__ (name = name , ** kwargs )
11
11
self .stateful = True
12
12
self .epsilon = K .constant (K .epsilon (), dtype = "float64" )
13
13
self .__name__ = name
14
+ self .sparse = sparse
14
15
15
16
# If layer metric is explicitly created to evaluate specified class,
16
17
# then use a binary transformation of the output arrays, otherwise
@@ -30,7 +31,10 @@ def _binary(self, y_true, y_pred, dtype, label=0):
30
31
column = slice (label , label + 1 )
31
32
32
33
# Slice a column of the output array.
33
- y_true = y_true [...,column ]
34
+ if self .sparse :
35
+ y_true = K .cast (K .equal (y_true , label ), dtype = dtype )
36
+ else :
37
+ y_true = y_true [...,column ]
34
38
y_pred = y_pred [...,column ]
35
39
return self ._categorical (y_true , y_pred , dtype )
36
40
@@ -65,8 +69,8 @@ class true_positive(layer):
65
69
positive class.
66
70
"""
67
71
68
- def __init__ (self , name = "true_positive" , ** kwargs ):
69
- super (true_positive , self ).__init__ (name = name , ** kwargs )
72
+ def __init__ (self , name = "true_positive" , label = None , sparse = False , ** kwargs ):
73
+ super (true_positive , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
70
74
self .tp = K .variable (0 , dtype = "int32" )
71
75
72
76
def reset_states (self ):
@@ -92,8 +96,8 @@ class true_negative(layer):
92
96
negative class.
93
97
"""
94
98
95
- def __init__ (self , name = "true_negative" , ** kwargs ):
96
- super (true_negative , self ).__init__ (name = name , ** kwargs )
99
+ def __init__ (self , name = "true_negative" , label = None , sparse = False , ** kwargs ):
100
+ super (true_negative , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
97
101
self .tn = K .variable (0 , dtype = "int32" )
98
102
99
103
def reset_states (self ):
@@ -122,8 +126,8 @@ class false_negative(layer):
122
126
negative class.
123
127
"""
124
128
125
- def __init__ (self , name = "false_negative" , ** kwargs ):
126
- super (false_negative , self ).__init__ (name = name , ** kwargs )
129
+ def __init__ (self , name = "false_negative" , label = None , sparse = False , ** kwargs ):
130
+ super (false_negative , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
127
131
self .fn = K .variable (0 , dtype = "int32" )
128
132
129
133
def reset_states (self ):
@@ -150,8 +154,8 @@ class false_positive(layer):
150
154
positive class.
151
155
"""
152
156
153
- def __init__ (self , name = "false_positive" , ** kwargs ):
154
- super (false_positive , self ).__init__ (name = name , ** kwargs )
157
+ def __init__ (self , name = "false_positive" , label = None , sparse = False , ** kwargs ):
158
+ super (false_positive , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
155
159
self .fp = K .variable (0 , dtype = "int32" )
156
160
157
161
def reset_states (self ):
@@ -177,11 +181,11 @@ class recall(layer):
177
181
Recall measures proportion of actual positives that was identified correctly.
178
182
"""
179
183
180
- def __init__ (self , name = "recall" , ** kwargs ):
181
- super (recall , self ).__init__ (name = name , ** kwargs )
184
+ def __init__ (self , name = "recall" , label = None , sparse = False , ** kwargs ):
185
+ super (recall , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
182
186
183
- self .tp = true_positive ()
184
- self .fn = false_negative ()
187
+ self .tp = true_positive (label = label , sparse = sparse )
188
+ self .fn = false_negative (label = label , sparse = sparse )
185
189
186
190
def reset_states (self ):
187
191
"""Reset the state of the metrics."""
@@ -208,11 +212,11 @@ class precision(layer):
208
212
actually correct.
209
213
"""
210
214
211
- def __init__ (self , name = "precision" , ** kwargs ):
212
- super (precision , self ).__init__ (name = name , ** kwargs )
215
+ def __init__ (self , name = "precision" , label = None , sparse = False , ** kwargs ):
216
+ super (precision , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
213
217
214
- self .tp = true_positive ()
215
- self .fp = false_positive ()
218
+ self .tp = true_positive (label = label , sparse = sparse )
219
+ self .fp = false_positive (label = label , sparse = sparse )
216
220
217
221
def reset_states (self ):
218
222
"""Reset the state of the metrics."""
@@ -238,11 +242,11 @@ class f1_score(layer):
238
242
The F1 score is the harmonic mean of precision and recall.
239
243
"""
240
244
241
- def __init__ (self , name = "f1_score" , label = None , ** kwargs ):
242
- super (f1_score , self ).__init__ (name = name , label = label , ** kwargs )
245
+ def __init__ (self , name = "f1_score" , label = None , sparse = False , ** kwargs ):
246
+ super (f1_score , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
243
247
244
- self .precision = precision (label = label )
245
- self .recall = recall (label = label )
248
+ self .precision = precision (label = label , sparse = sparse )
249
+ self .recall = recall (label = label , sparse = sparse )
246
250
247
251
def reset_states (self ):
248
252
"""Reset the state of the metrics."""
0 commit comments