@@ -10,6 +10,7 @@ using StableRNGs
1010import Optimisers
1111import Zygote
1212import NNlib
13+ import CategoricalArrays
1314import CategoricalDistributions
1415import CategoricalDistributions: pdf, mode
1516import ComponentArrays
@@ -55,7 +56,7 @@ for the specified number of `epochs`.
5556- `perceptron`: component array with components `weights` and `bias`
5657- `optimiser`: optimiser from Optimiser.jl
5758- `X`: feature matrix, of size `(p, n)`
58- - `y_hot`: one-hot encoded target, of size `(nclasses , n)`
59+ - `y_hot`: one-hot encoded target, of size `(nlevels , n)`
5960- `epochs`: number of epochs
6061- `state`: optimiser state
6162
@@ -108,7 +109,7 @@ point predictions with `predict(model, Point(), Xnew)`.
108109
109110# Warm restart options
110111
111- update(model, newdata, :epochs=>n, other_replacements...; verbosity=1 )
112+ update(model, newdata, :epochs=>n, other_replacements...)
112113
113114If `Δepochs = n - perceptron.epochs` is non-negative, then return an updated model, with
114115the weights and bias of the previously learned perceptron used as the starting state in
@@ -117,7 +118,7 @@ instead of the previous training data. Any other hyperparaameter `replacements`
117118adopted. If `Δepochs` is negative or not specified, instead return `fit(learner,
118119newdata)`, where `learner=LearnAPI.clone(learner; epochs=n, replacements....)`.
119120
120- update_observations(model, newdata, replacements...; verbosity=1 )
121+ update_observations(model, newdata, replacements...)
121122
122123Return an updated model, with the weights and bias of the previously learned perceptron
123124used as the starting state in new gradient descent updates. Adopt any specified
@@ -132,38 +133,38 @@ PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.defaul
132133struct PerceptronClassifierObs
133134 X:: Matrix{Float32}
134135 y_hot:: BitMatrix # one-hot encoded target
135- classes # the (ordered) pool of `y`, as `CategoricalValue`s
136+ levels # the (ordered) pool of `y`, as `CategoricalValue`s
136137end
137138
138139# For pre-processing the training data:
139140function LearnAPI. obs (:: PerceptronClassifier , data:: Tuple )
140141 X, y = data
141- classes = CategoricalDistributions . classes (y)
142- y_hot = classes .== permutedims (y) # one-hot encoding
143- return PerceptronClassifierObs (X, y_hot, classes )
142+ levels = CategoricalArrays . levels (y)
143+ y_hot = levels .== permutedims (y) # one-hot encoding
144+ return PerceptronClassifierObs (X, y_hot, levels )
144145end
145146LearnAPI. obs (:: PerceptronClassifier , observations:: PerceptronClassifierObs ) =
146147 observations # involutivity
147148
148149# helper:
149- function decode (y_hot, classes )
150+ function decode (y_hot, levels )
150151 n = size (y_hot, 2 )
151- [only (classes [y_hot[:,i]]) for i in 1 : n]
152+ [only (levels [y_hot[:,i]]) for i in 1 : n]
152153end
153154
154155# implement `RadomAccess()` interface for output of `obs`:
155156Base. length (observations:: PerceptronClassifierObs ) = size (observations. y_hot, 2 )
156157Base. getindex (observations:: PerceptronClassifierObs , I) = PerceptronClassifierObs (
157158 observations. X[:, I],
158159 observations. y_hot[:, I],
159- observations. classes ,
160+ observations. levels ,
160161)
161162
162163# training data deconstructors:
163164LearnAPI. target (
164165 learner:: PerceptronClassifier ,
165166 observations:: PerceptronClassifierObs ,
166- ) = decode (observations. y_hot, observations. classes )
167+ ) = decode (observations. y_hot, observations. levels )
167168LearnAPI. target (learner:: PerceptronClassifier , data) =
168169 LearnAPI. target (learner, obs (learner, data))
169170LearnAPI. features (
@@ -184,7 +185,7 @@ struct PerceptronClassifierFitted
184185 learner:: PerceptronClassifier
185186 perceptron # component array storing weights and bias
186187 state # optimiser state
187- classes # target classes
188+ levels # target levels
188189 losses
189190end
190191
@@ -194,7 +195,7 @@ LearnAPI.learner(model::PerceptronClassifierFitted) = model.learner
194195function LearnAPI. fit (
195196 learner:: PerceptronClassifier ,
196197 observations:: PerceptronClassifierObs ;
197- verbosity= 1 ,
198+ verbosity= LearnAPI . default_verbosity () ,
198199 )
199200
200201 # unpack hyperparameters:
@@ -205,20 +206,20 @@ function LearnAPI.fit(
205206 # unpack data:
206207 X = observations. X
207208 y_hot = observations. y_hot
208- classes = observations. classes
209- nclasses = length (classes )
209+ levels = observations. levels
210+ nlevels = length (levels )
210211
211212 # initialize bias and weights:
212- weights = randn (rng, Float32, nclasses , p)
213- bias = zeros (Float32, nclasses )
213+ weights = randn (rng, Float32, nlevels , p)
214+ bias = zeros (Float32, nlevels )
214215 perceptron = (; weights, bias) |> ComponentArrays. ComponentArray
215216
216217 # initialize optimiser:
217218 state = Optimisers. setup (optimiser, perceptron)
218219
219220 perceptron, state, losses = corefit (perceptron, X, y_hot, epochs, state, verbosity)
220221
221- return PerceptronClassifierFitted (learner, perceptron, state, classes , losses)
222+ return PerceptronClassifierFitted (learner, perceptron, state, levels , losses)
222223end
223224
224225# `fit` for unprocessed data:
@@ -230,16 +231,16 @@ function LearnAPI.update_observations(
230231 model:: PerceptronClassifierFitted ,
231232 observations_new:: PerceptronClassifierObs ,
232233 replacements... ;
233- verbosity= 1 ,
234+ verbosity= LearnAPI . default_verbosity () ,
234235 )
235236
236237 # unpack data:
237238 X = observations_new. X
238239 y_hot = observations_new. y_hot
239- classes = observations_new. classes
240- nclasses = length (classes )
240+ levels = observations_new. levels
241+ nlevels = length (levels )
241242
242- classes == model. classes || error (" New training target has incompatible classes ." )
243+ levels == model. levels || error (" New training target has incompatible levels ." )
243244
244245 learner_old = LearnAPI. learner (model)
245246 learner = LearnAPI. clone (learner_old, replacements... )
@@ -252,7 +253,7 @@ function LearnAPI.update_observations(
252253 perceptron, state, losses_new = corefit (perceptron, X, y_hot, epochs, state, verbosity)
253254 losses = vcat (losses, losses_new)
254255
255- return PerceptronClassifierFitted (learner, perceptron, state, classes , losses)
256+ return PerceptronClassifierFitted (learner, perceptron, state, levels , losses)
256257end
257258LearnAPI. update_observations (model:: PerceptronClassifierFitted , data, args... ; kwargs... ) =
258259 update_observations (model, obs (LearnAPI. learner (model), data), args... ; kwargs... )
@@ -262,16 +263,16 @@ function LearnAPI.update(
262263 model:: PerceptronClassifierFitted ,
263264 observations:: PerceptronClassifierObs ,
264265 replacements... ;
265- verbosity= 1 ,
266+ verbosity= LearnAPI . default_verbosity () ,
266267 )
267268
268269 # unpack data:
269270 X = observations. X
270271 y_hot = observations. y_hot
271- classes = observations. classes
272- nclasses = length (classes )
272+ levels = observations. levels
273+ nlevels = length (levels )
273274
274- classes == model. classes || error (" New training target has incompatible classes ." )
275+ levels == model. levels || error (" New training target has incompatible levels ." )
275276
276277 learner_old = LearnAPI. learner (model)
277278 learner = LearnAPI. clone (learner_old, replacements... )
@@ -289,7 +290,7 @@ function LearnAPI.update(
289290 corefit (perceptron, X, y_hot, Δepochs, state, verbosity)
290291 losses = vcat (losses, losses_new)
291292
292- return PerceptronClassifierFitted (learner, perceptron, state, classes , losses)
293+ return PerceptronClassifierFitted (learner, perceptron, state, levels , losses)
293294end
294295LearnAPI. update (model:: PerceptronClassifierFitted , data, args... ; kwargs... ) =
295296 update (model, obs (LearnAPI. learner (model), data), args... ; kwargs... )
@@ -299,9 +300,9 @@ LearnAPI.update(model::PerceptronClassifierFitted, data, args...; kwargs...) =
299300
300301function LearnAPI. predict (model:: PerceptronClassifierFitted , :: Distribution , Xnew)
301302 perceptron = model. perceptron
302- classes = model. classes
303+ levels = model. levels
303304 probs = perceptron. weights* Xnew .+ perceptron. bias |> NNlib. softmax
304- return CategoricalDistributions. UnivariateFinite (classes , probs' )
305+ return CategoricalDistributions. UnivariateFinite (levels , probs' )
305306end
306307
307308LearnAPI. predict (model:: PerceptronClassifierFitted , :: Point , Xnew) =
0 commit comments