|
1 | 1 | import contextlib
|
2 | 2 | from typing import Union, Dict, Any, List
|
3 | 3 |
|
| 4 | + |
4 | 5 | from ..record import Record
|
5 | 6 | from ..source.source import BaseSource
|
6 | 7 | from ..feature import Feature, Features
|
7 | 8 | from ..model import Model, ModelContext
|
8 | 9 | from ..util.internal import records_to_sources, list_records_to_dict
|
9 | 10 | from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
|
| 11 | +from ..tuner import Tuner, TunerContext |
10 | 12 |
|
11 | 13 |
|
12 | 14 | async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]):
|
@@ -293,3 +295,149 @@ async def predict(
|
293 | 295 | )
|
294 | 296 | if update:
|
295 | 297 | await sctx.update(record)
|
| 298 | + |
| 299 | +async def tune( |
| 300 | + model, |
| 301 | + tuner: Union[Tuner, TunerContext], |
| 302 | + accuracy_scorer: Union[AccuracyScorer, AccuracyContext], |
| 303 | + features: Union[Feature, Features], |
| 304 | + train_ds: Union[BaseSource, Record, Dict[str, Any], List], |
| 305 | + valid_ds: Union[BaseSource, Record, Dict[str, Any], List], |
| 306 | +) -> float: |
| 307 | + |
| 308 | + """ |
| 309 | + Tune the hyperparameters of a model with a given tuner. |
| 310 | +
|
| 311 | + |
| 312 | + Parameters |
| 313 | + ---------- |
| 314 | + model : Model |
| 315 | + Machine Learning model to use. See :doc:`/plugins/dffml_model` for |
| 316 | + models options. |
| 317 | + tuner: Tuner |
| 318 | + Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for |
| 319 | + tuner options. |
| 320 | + train_ds : list |
| 321 | + Input data for training. Could be a ``dict``, :py:class:`Record`, |
| 322 | + filename, one of the data :doc:`/plugins/dffml_source`, or a filename |
| 323 | + with the extension being one of the data sources. |
| 324 | + valid_ds : list |
| 325 | + Validation data for testing. Could be a ``dict``, :py:class:`Record`, |
| 326 | + filename, one of the data :doc:`/plugins/dffml_source`, or a filename |
| 327 | + with the extension being one of the data sources. |
| 328 | +
|
| 329 | +
|
| 330 | + Returns |
| 331 | + ------- |
| 332 | + float |
| 333 | + A decimal value representing the result of the accuracy scorer on the given |
| 334 | + test set. For instance, ClassificationAccuracy represents the percentage of correct |
| 335 | + classifications made by the model. |
| 336 | +
|
| 337 | + Examples |
| 338 | + -------- |
| 339 | +
|
| 340 | + >>> import asyncio |
| 341 | + >>> from dffml import * |
| 342 | + >>> |
| 343 | + >>> model = SLRModel( |
| 344 | + ... features=Features( |
| 345 | + ... Feature("Years", int, 1), |
| 346 | + ... ), |
| 347 | + ... predict=Feature("Salary", int, 1), |
| 348 | + ... location="tempdir", |
| 349 | + ... ) |
| 350 | + >>> |
| 351 | + >>> async def main(): |
| 352 | + ... score = await tune( |
| 353 | + ... model, |
| 354 | + ... ParameterGrid(objective="min"), |
| 355 | + ... MeanSquaredErrorAccuracy(), |
| 356 | + ... Features( |
| 357 | + ... Feature("Years", float, 1), |
| 358 | + ... ), |
| 359 | + ... [ |
| 360 | + ... {"Years": 0, "Salary": 10}, |
| 361 | + ... {"Years": 1, "Salary": 20}, |
| 362 | + ... {"Years": 2, "Salary": 30}, |
| 363 | + ... {"Years": 3, "Salary": 40} |
| 364 | + ... ], |
| 365 | + ... [ |
| 366 | + ... {"Years": 6, "Salary": 70}, |
| 367 | + ... {"Years": 7, "Salary": 80} |
| 368 | + ... ] |
| 369 | + ... |
| 370 | + ... ) |
| 371 | + ... print(f"Tuner score: {score}") |
| 372 | + ... |
| 373 | + >>> asyncio.run(main()) |
| 374 | + Tuner score: 0.0 |
| 375 | + """ |
| 376 | + |
| 377 | + if not isinstance(features, (Feature, Features)): |
| 378 | + raise TypeError( |
| 379 | + f"features was {type(features)}: {features!r}. Should have been Feature or Features" |
| 380 | + ) |
| 381 | + if isinstance(features, Feature): |
| 382 | + features = Features(features) |
| 383 | + if hasattr(model.config, "predict"): |
| 384 | + if isinstance(model.config.predict, Features): |
| 385 | + predict_feature = [ |
| 386 | + feature.name for feature in model.config.predict |
| 387 | + ] |
| 388 | + else: |
| 389 | + predict_feature = [model.config.predict.name] |
| 390 | + |
| 391 | + if hasattr(model.config, "features") and any( |
| 392 | + isinstance(td, list) for td in train_ds |
| 393 | + ): |
| 394 | + train_ds = list_records_to_dict( |
| 395 | + [feature.name for feature in model.config.features] |
| 396 | + + predict_feature, |
| 397 | + *train_ds, |
| 398 | + model=model, |
| 399 | + ) |
| 400 | + if hasattr(model.config, "features") and any( |
| 401 | + isinstance(td, list) for td in valid_ds |
| 402 | + ): |
| 403 | + valid_ds = list_records_to_dict( |
| 404 | + [feature.name for feature in model.config.features] |
| 405 | + + predict_feature, |
| 406 | + *valid_ds, |
| 407 | + model=model, |
| 408 | + ) |
| 409 | + |
| 410 | + async with contextlib.AsyncExitStack() as astack: |
| 411 | + # Open sources |
| 412 | + train = await astack.enter_async_context(records_to_sources(*train_ds)) |
| 413 | + test = await astack.enter_async_context(records_to_sources(*valid_ds)) |
| 414 | + # Allow for keep models open |
| 415 | + if isinstance(model, Model): |
| 416 | + model = await astack.enter_async_context(model) |
| 417 | + mctx = await astack.enter_async_context(model()) |
| 418 | + elif isinstance(model, ModelContext): |
| 419 | + mctx = model |
| 420 | + |
| 421 | + # Allow for keep models open |
| 422 | + if isinstance(accuracy_scorer, AccuracyScorer): |
| 423 | + accuracy_scorer = await astack.enter_async_context(accuracy_scorer) |
| 424 | + actx = await astack.enter_async_context(accuracy_scorer()) |
| 425 | + elif isinstance(accuracy_scorer, AccuracyContext): |
| 426 | + actx = accuracy_scorer |
| 427 | + else: |
| 428 | + # TODO Replace this with static type checking and maybe dynamic |
| 429 | + # through something like pydantic. See issue #36 |
| 430 | + raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer") |
| 431 | + |
| 432 | + if isinstance(tuner, Tuner): |
| 433 | + tuner = await astack.enter_async_context(tuner) |
| 434 | + tctx = await astack.enter_async_context(tuner()) |
| 435 | + elif isinstance(tuner, TunerContext): |
| 436 | + tctx = tuner |
| 437 | + else: |
| 438 | + raise TypeError(f"{tuner} is not an Tuner") |
| 439 | + |
| 440 | + return float( |
| 441 | + await tctx.optimize(mctx, model.config.predict, actx, train, test) |
| 442 | + ) |
| 443 | + |
0 commit comments