Skip to content

Commit 57ce8c8

Browse files
be-marcsumnypfistfl
authored
feat: add mbo defaults (#192)
* push latest so_config attic * new so_config * fix: always add dropped columns to xdt during surrogate predict * new so_config * use current bbotk main branch * LearnerRegrRangerCustom * yahpo experiments * fix init_design_size for random_interleave, acqopt logging and warmstart * tweaks in acq opt and warmstart * some more infill optimization tweaks for so config * .. * new ac setup * .. * .. * .. * .. * .. * .. * .. * .. * .. * play around with log scaling * chore: AcqFunctionCB cleanup * merge acqf_ttei from acqf_ttei branch for now * .. * prepare new so_config * .. * new cd version * some so config try-outs * mies baseline * new eval scheme for so_config * .. * .. * .. * .. * .. * .. * .. * .. * some changes to char to fct handling, ... * .. * ... * .. * .. * add learner * fix impl errors, add docs * update * add to dict? * .. * reiterate lfbo * rerun docs * .. * manually require setting the lfbo direction * .. * .. * .. * .. * .. * .. * fix: fix_xdt_missing for logical NA * .. * .. * fix: remove reqired from parameters with default * fix: bbotk update * fix: optimizerchain * update * refactor: local search with new paradox * fix: cols_y * fix: remove browser() * chore: browser * fix: local search * fix: minimize * refactor: remove chain and local search * feat: add epsilon to log ei * .. * feat: rework LearnerRegrRangerMbo * .. * fix: make log ei more robost * se >= 1e-8 for LearnerRegrRangerMbo and law of total variance * .. * perf: speed up surrogate predictions * draft SurrogateGP * typo * .. * .. * .. * .. * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * add libcmaesr * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * fix direction * feat: add mbo defaults * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... * ... --------- Co-authored-by: Lennart Schneider <[email protected]> Co-authored-by: pfistfl <[email protected]>
1 parent 3cd01af commit 57ce8c8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1718
-247
lines changed

.lintr

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
linters: linters_with_defaults(
2-
# lintr defaults: https://github.com/jimhester/lintr#available-linters
2+
# lintr defaults: https://lintr.r-lib.org/reference/default_linters.html
33
# the following setup changes/removes certain linters
44
assignment_linter = NULL, # do not force using <- for assignments
5-
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
6-
cyclocomp_linter = NULL, # do not check function complexity
5+
object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
76
commented_code_linter = NULL, # allow code in comments
8-
line_length_linter = line_length_linter(120),
9-
indentation_linter(indent = 2L, hanging_indent_style = "never")
10-
)
7+
line_length_linter(200L),
8+
object_length_linter(40L),
9+
undesirable_function_linter(fun = c(
10+
# base messaging
11+
cat = "use catf()",
12+
stop = "use stopf()",
13+
warning = "use warningf()",
14+
message = "use messagef()",
15+
# perf
16+
ifelse = "use fifelse()",
17+
rank = "use frank()"
18+
))
19+
)

DESCRIPTION

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ License: LGPL-3
3939
URL: https://mlr3mbo.mlr-org.com, https://github.com/mlr-org/mlr3mbo
4040
BugReports: https://github.com/mlr-org/mlr3mbo/issues
4141
Depends:
42+
mlr3 (>= 1.2.0),
4243
mlr3tuning (>= 1.4.0),
4344
R (>= 3.1.0)
4445
Imports:
45-
bbotk (>= 1.6.0),
46+
bbotk (>= 1.8.0),
4647
checkmate (>= 2.0.0),
4748
data.table,
4849
lgr (>= 0.3.4),
49-
mlr3 (>= 0.23.0),
5050
mlr3misc (>= 0.15.1),
5151
paradox (>= 1.0.1),
5252
spacefillr,
@@ -56,8 +56,8 @@ Suggests:
5656
emoa,
5757
fastGHQuad,
5858
lhs,
59+
mlr3learners (>= 0.12.0),
5960
mirai,
60-
mlr3learners (>= 0.7.0),
6161
mlr3pipelines (>= 0.5.2),
6262
nloptr,
6363
ranger,
@@ -66,14 +66,14 @@ Suggests:
6666
redux,
6767
rush,
6868
stringi,
69-
testthat (>= 3.0.0)
69+
testthat (>= 3.0.0),
7070
ByteCompile: no
7171
Encoding: UTF-8
7272
Config/testthat/edition: 3
7373
Config/testthat/parallel: false
7474
NeedsCompilation: yes
7575
Roxygen: list(markdown = TRUE, r6 = TRUE)
76-
RoxygenNote: 7.3.2
76+
RoxygenNote: 7.3.3
7777
Collate:
7878
'mlr_acqfunctions.R'
7979
'AcqFunction.R'
@@ -92,6 +92,11 @@ Collate:
9292
'AcqFunctionStochasticCB.R'
9393
'AcqFunctionStochasticEI.R'
9494
'AcqOptimizer.R'
95+
'mlr_acqoptimizers.R'
96+
'AcqOptimizerDirect.R'
97+
'AcqOptimizerLbfgsb.R'
98+
'AcqOptimizerLocalSearch.R'
99+
'AcqOptimzerRandomSearch.R'
95100
'mlr_input_trafos.R'
96101
'InputTrafo.R'
97102
'InputTrafoUnitcube.R'

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(as.data.table,DictionaryAcqFunction)
4+
S3method(as.data.table,DictionaryAcqOptimizer)
45
S3method(as.data.table,DictionaryInputTrafo)
56
S3method(as.data.table,DictionaryLoopFunction)
67
S3method(as.data.table,DictionaryOutputTrafo)
@@ -22,6 +23,10 @@ export(AcqFunctionSmsEgo)
2223
export(AcqFunctionStochasticCB)
2324
export(AcqFunctionStochasticEI)
2425
export(AcqOptimizer)
26+
export(AcqOptimizerDirect)
27+
export(AcqOptimizerLbfgsb)
28+
export(AcqOptimizerLocalSearch)
29+
export(AcqOptimizerRandomSearch)
2530
export(InputTrafo)
2631
export(InputTrafoUnitcube)
2732
export(OptimizerADBO)
@@ -56,6 +61,7 @@ export(default_rf)
5661
export(default_surrogate)
5762
export(it)
5863
export(mlr_acqfunctions)
64+
export(mlr_acqoptimizers)
5965
export(mlr_input_trafos)
6066
export(mlr_loop_functions)
6167
export(mlr_output_trafos)

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# mlr3mbo (development version)
22

3+
* feat: Added `mlr_acqoptimizers` dictionary with predefined acquisition function optimizers.
4+
* perf: Added `AcqOptimizerDirect`, `AcqOptimizerLbfgsb`, `AcqOptimizerLocalSearch`, and `AcqOptimizerRandomSearch`.
5+
* feat: `default_*` helpers return new empirical based default values.
6+
37
# mlr3mbo 0.3.3
48

59
* compatibility: bbotk 1.7.0

R/AcqOptimizerDirect.R

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#' @title Direct Acquisition Function Optimizer
2+
#'
3+
#' @include AcqOptimizer.R mlr_acqoptimizers.R
4+
#'
5+
#' @description
6+
#' Direct acquisition function optimizer.
7+
#' Calls `nloptr()` from \CRANpkg{nloptr}.
8+
#' In its default setting, the algorithm restarts `5 * D` times and runs at most for `100 * D^2` function evaluations, where `D` is the dimension of the search space.
9+
#' Each run stops when the relative tolerance of the parameters is less than `10^-4`.
10+
#' The first iteration starts with the best point in the archive and the next iterations start from a random point.
11+
#'
12+
#' @section Parameters:
13+
#' \describe{
14+
#' \item{`restart_strategy`}{`character(1)`\cr
15+
#' Restart strategy.
16+
#' Can be `"none"` or `"random"`.
17+
#' Default is `"none"`.
18+
#' }
19+
#' \item{`max_restarts`}{`integer(1)`\cr
20+
#' Maximum number of restarts.
21+
#' Default is `5 * D` (Default).}
22+
#' }
23+
#'
24+
#' @note
25+
#' If the restart strategy is `"none"`, the optimizer starts with the best point in the archive.
26+
#' The optimization stops when one of the stopping criteria is met.
27+
#'
28+
#' If `restart_strategy` is `"random"`, the optimizer runs at least for `maxeval` iterations.
29+
#' The first iteration starts with the best point in the archive and stops when one of the stopping criteria is met.
30+
#' The next iterations start from a random point.
31+
#'
32+
#' @section Termination Parameters:
33+
#' The following termination parameters can be used.
34+
#'
35+
#' \describe{
36+
#' \item{`stopval`}{`numeric(1)`\cr
37+
#' Stop value.
38+
#' Deactivate with `-Inf` (Default).}
39+
#' \item{`maxeval`}{`integer(1)`\cr
40+
#' Maximum number of evaluations.
41+
#' Default is `100 * D^2`, where `D` is the dimension of the search space.
42+
#' Deactivate with `-1L`.}
43+
#' \item{`xtol_rel`}{`numeric(1)`\cr
44+
#' Relative tolerance of the parameters.
45+
#' Default is `10^-4`.
46+
#' Deactivate with `-1`.}
47+
#' \item{`xtol_abs`}{`numeric(1)`\cr
48+
#' Absolute tolerance of the parameters.
49+
#' Deactivate with `-1` (Default).}
50+
#' \item{`ftol_rel`}{`numeric(1)`\cr
51+
#' Relative tolerance of the objective function.
52+
#' Deactivate with `-1`. (Default).}
53+
#' \item{`ftol_abs`}{`numeric(1)`\cr
54+
#' Absolute tolerance of the objective function.
55+
#' Deactivate with `-1` (Default).}
56+
#' }
57+
#'
58+
#' @export
59+
#' @examples
60+
#' if (requireNamespace("nloptr")) {
61+
#' acqo("direct")
62+
#' }
63+
AcqOptimizerDirect = R6Class("AcqOptimizerDirect",
64+
inherit = AcqOptimizer,
65+
public = list(
66+
67+
#' @field state (`list()`)\cr
68+
#' List of [nloptr::nloptr()] results.
69+
state = NULL,
70+
71+
#' @description
72+
#' Creates a new instance of this [R6][R6::R6Class] class.
73+
#'
74+
#' @param acq_function (`NULL` | [AcqFunction]).
75+
initialize = function(acq_function = NULL) {
76+
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
77+
param_set = ps(
78+
maxeval = p_int(),
79+
stopval = p_dbl(default = -Inf, lower = -Inf, upper = Inf),
80+
xtol_rel = p_dbl(default = 1e-04, lower = 0, upper = Inf, special_vals = list(-1L)),
81+
xtol_abs = p_dbl(default = 0, lower = 0, upper = Inf, special_vals = list(-1L)),
82+
ftol_rel = p_dbl(default = 0, lower = 0, upper = Inf, special_vals = list(-1L)),
83+
ftol_abs = p_dbl(default = 0, lower = 0, upper = Inf, special_vals = list(-1L)),
84+
minf_max = p_dbl(default = -Inf),
85+
restart_strategy = p_fct(levels = c("none", "random"), init = "random"),
86+
max_restarts = p_int(lower = 0L),
87+
catch_errors = p_lgl(init = TRUE)
88+
)
89+
private$.param_set = param_set
90+
},
91+
92+
#' @description
93+
#' Optimize the acquisition function.
94+
#'
95+
#' @return [data.table::data.table()] with 1 row per candidate.
96+
optimize = function() {
97+
pv = self$param_set$values
98+
restart_strategy = pv$restart_strategy
99+
max_restarts = pv$max_restarts
100+
maxeval = pv$maxeval
101+
pv$max_restarts = NULL
102+
pv$restart_strategy = NULL
103+
pv$maxeval = NULL
104+
105+
if (restart_strategy == "none") {
106+
max_restarts = 0L
107+
} else if (restart_strategy == "random" && is.null(max_restarts)) {
108+
max_restarts = 5 * self$acq_function$domain$length
109+
}
110+
111+
if (is.null(maxeval)) {
112+
maxeval = 100 * self$acq_function$domain$length^2
113+
}
114+
115+
wrapper = function(x, fun, constants, direction) {
116+
xdt = as.data.table(as.list(set_names(x, self$acq_function$domain$ids())))
117+
res = mlr3misc::invoke(fun, xdt = xdt, .args = constants)[[1]]
118+
res * direction
119+
}
120+
121+
fun = get_private(self$acq_function)$.fun
122+
constants = self$acq_function$constants$values
123+
direction = self$acq_function$codomain$direction
124+
125+
y = Inf
126+
n_evals = 0L
127+
n_restarts = 0L
128+
while (n_evals < maxeval || maxeval < 0 && n_restarts <= max_restarts) {
129+
n_restarts = n_restarts + 1L
130+
131+
x0 = if (n_restarts == 1L) {
132+
as.numeric(self$acq_function$archive$best()[, self$acq_function$domain$ids(), with = FALSE])
133+
} else {
134+
# random restart
135+
as.numeric(generate_design_random(self$acq_function$domain, n = 1)$data)
136+
}
137+
138+
optimize = function() {
139+
invoke(nloptr::nloptr,
140+
eval_f = wrapper,
141+
lb = self$acq_function$domain$lower,
142+
ub = self$acq_function$domain$upper,
143+
opts = insert_named(pv, list(algorithm = "NLOPT_GN_DIRECT_L", maxeval = maxeval - n_evals)),
144+
eval_grad_f = NULL,
145+
x0 = x0,
146+
fun = fun,
147+
constants = constants,
148+
direction = direction)
149+
}
150+
151+
if (pv$catch_errors) {
152+
tryCatch({
153+
res = optimize()
154+
}, error = function(error_condition) {
155+
lg$warn(error_condition$message)
156+
stop(set_class(list(message = error_condition$message, call = NULL), classes = c("acq_optimizer_error", "mbo_error", "error", "condition")))
157+
})
158+
} else {
159+
res = optimize()
160+
}
161+
162+
if (res$objective < y) {
163+
y = res$objective
164+
x = res$solution
165+
}
166+
167+
n_evals = n_evals + res$iterations
168+
169+
self$state = c(self$state, set_names(list(list(model = res, start = x0)), paste0("iteration_", n_restarts)))
170+
171+
if (restart_strategy == "none") break
172+
}
173+
as.data.table(as.list(set_names(c(x, y * direction), c(self$acq_function$domain$ids(), self$acq_function$codomain$ids()))))
174+
}
175+
),
176+
177+
active = list(
178+
#' @template field_print_id
179+
print_id = function(rhs) {
180+
assert_ro_binding(rhs)
181+
"(OptimizerDirect)"
182+
}
183+
)
184+
)
185+
186+
mlr_acqoptimizers$add("direct", AcqOptimizerDirect)

0 commit comments

Comments
 (0)