1
- from functools import partial
2
1
from keras import backend as K
3
2
from keras .layers import Layer
4
3
from operator import truediv
5
4
6
5
7
6
class layer (Layer ):
8
7
9
- def __init__ (self , label = None , name = None , sparse = False , ** kwargs ):
8
+ def __init__ (self , name = None , label = 0 , cast_strategy = None , ** kwargs ):
10
9
super (layer , self ).__init__ (name = name , ** kwargs )
10
+
11
11
self .stateful = True
12
+ self .label = label
13
+ self .cast_strategy = cast_strategy
12
14
self .epsilon = K .constant (K .epsilon (), dtype = "float64" )
13
- self .__name__ = name
14
- self .sparse = sparse
15
-
16
- # If layer metric is explicitly created to evaluate specified class,
17
- # then use a binary transformation of the output arrays, otherwise
18
- # calculate an "overall" metric.
19
- if label :
20
- self .cast_strategy = partial (self ._binary , label = label )
21
- else :
22
- self .cast_strategy = self ._categorical
23
15
24
16
def cast (self , y_true , y_pred , dtype = "int32" ):
25
17
"""Convert the specified true and predicted output to the specified
26
18
destination type (int32 by default).
27
19
"""
28
- return self .cast_strategy (y_true , y_pred , dtype = dtype )
29
-
30
- def _binary (self , y_true , y_pred , dtype , label = 0 ):
31
- column = slice (label , label + 1 )
32
-
33
- # Slice a column of the output array.
34
- if self .sparse :
35
- y_true = K .cast (K .equal (y_true , label ), dtype = dtype )
36
- else :
37
- y_true = y_true [...,column ]
38
- y_pred = y_pred [...,column ]
39
- return self ._categorical (y_true , y_pred , dtype )
40
-
41
- def _categorical (self , y_true , y_pred , dtype ):
42
- # In case when user did not specify the label, and the shape
43
- # of the output vector has exactly two elements, we can choose
44
- # the label automatically.
45
- #
46
- # When the shape had dimension 3 and more and the label is
47
- # not specified, we should throw an error as long as calculated
48
- # metric is incorrect.
49
- labels = y_pred .shape [- 1 ]
50
- if labels == 2 :
51
- return self ._binary (y_true , y_pred , dtype , label = 1 )
52
- elif labels > 2 :
53
- raise ValueError ("With 2 and more output classes a "
54
- "metric label must be specified" )
55
-
56
- y_true = K .cast (K .round (y_true ), dtype )
57
- y_pred = K .cast (K .round (y_pred ), dtype )
58
- return y_true , y_pred
20
+ return self .cast_strategy (
21
+ y_true , y_pred , dtype = dtype , label = self .label )
59
22
60
23
def __getattribute__ (self , name ):
61
24
if name == "get_config" :
62
25
raise AttributeError
63
26
return object .__getattribute__ (self , name )
64
27
28
+
65
29
class true_positive (layer ):
66
30
"""Create a metric for model's true positives amount calculation.
67
31
68
32
A true positive is an outcome where the model correctly predicts the
69
33
positive class.
70
34
"""
71
35
72
- def __init__ (self , name = "true_positive" , label = None , sparse = False , ** kwargs ):
73
- super (true_positive , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
36
+ def __init__ (self , name = "true_positive" , ** kwargs ):
37
+ super (true_positive , self ).__init__ (name = name , ** kwargs )
74
38
self .tp = K .variable (0 , dtype = "int32" )
75
39
76
40
def reset_states (self ):
@@ -96,8 +60,8 @@ class true_negative(layer):
96
60
negative class.
97
61
"""
98
62
99
- def __init__ (self , name = "true_negative" , label = None , sparse = False , ** kwargs ):
100
- super (true_negative , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
63
+ def __init__ (self , name = "true_negative" , ** kwargs ):
64
+ super (true_negative , self ).__init__ (name = name , ** kwargs )
101
65
self .tn = K .variable (0 , dtype = "int32" )
102
66
103
67
def reset_states (self ):
@@ -126,8 +90,8 @@ class false_negative(layer):
126
90
negative class.
127
91
"""
128
92
129
- def __init__ (self , name = "false_negative" , label = None , sparse = False , ** kwargs ):
130
- super (false_negative , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
93
+ def __init__ (self , name = "false_negative" , ** kwargs ):
94
+ super (false_negative , self ).__init__ (name = name , ** kwargs )
131
95
self .fn = K .variable (0 , dtype = "int32" )
132
96
133
97
def reset_states (self ):
@@ -154,8 +118,8 @@ class false_positive(layer):
154
118
positive class.
155
119
"""
156
120
157
- def __init__ (self , name = "false_positive" , label = None , sparse = False , ** kwargs ):
158
- super (false_positive , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
121
+ def __init__ (self , name = "false_positive" , ** kwargs ):
122
+ super (false_positive , self ).__init__ (name = name , ** kwargs )
159
123
self .fp = K .variable (0 , dtype = "int32" )
160
124
161
125
def reset_states (self ):
@@ -178,14 +142,15 @@ def __call__(self, y_true, y_pred):
178
142
class recall (layer ):
179
143
"""Create a metric for model's recall calculation.
180
144
181
- Recall measures proportion of actual positives that was identified correctly.
145
+ Recall measures proportion of actual positives that was identified
146
+ correctly.
182
147
"""
183
148
184
- def __init__ (self , name = "recall" , label = None , sparse = False , ** kwargs ):
185
- super (recall , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
149
+ def __init__ (self , name = "recall" , ** kwargs ):
150
+ super (recall , self ).__init__ (name = name , ** kwargs )
186
151
187
- self .tp = true_positive (label = label , sparse = sparse )
188
- self .fn = false_negative (label = label , sparse = sparse )
152
+ self .tp = true_positive (** kwargs )
153
+ self .fn = false_negative (** kwargs )
189
154
190
155
def reset_states (self ):
191
156
"""Reset the state of the metrics."""
@@ -212,11 +177,11 @@ class precision(layer):
212
177
actually correct.
213
178
"""
214
179
215
- def __init__ (self , name = "precision" , label = None , sparse = False , ** kwargs ):
216
- super (precision , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
180
+ def __init__ (self , name = "precision" , ** kwargs ):
181
+ super (precision , self ).__init__ (name = name , ** kwargs )
217
182
218
- self .tp = true_positive (label = label , sparse = sparse )
219
- self .fp = false_positive (label = label , sparse = sparse )
183
+ self .tp = true_positive (** kwargs )
184
+ self .fp = false_positive (** kwargs )
220
185
221
186
def reset_states (self ):
222
187
"""Reset the state of the metrics."""
@@ -242,11 +207,11 @@ class f1_score(layer):
242
207
The F1 score is the harmonic mean of precision and recall.
243
208
"""
244
209
245
- def __init__ (self , name = "f1_score" , label = None , sparse = False , ** kwargs ):
246
- super (f1_score , self ).__init__ (name = name , label = label , sparse = sparse , ** kwargs )
210
+ def __init__ (self , name = "f1_score" , ** kwargs ):
211
+ super (f1_score , self ).__init__ (name = name , ** kwargs )
247
212
248
- self .precision = precision (label = label , sparse = sparse )
249
- self .recall = recall (label = label , sparse = sparse )
213
+ self .precision = precision (** kwargs )
214
+ self .recall = recall (** kwargs )
250
215
251
216
def reset_states (self ):
252
217
"""Reset the state of the metrics."""
0 commit comments