Skip to content

Commit 5a363b3

Browse files
committed
Allocate a temp buffer for binary predict_proba
1 parent ce6f2d5 commit 5a363b3

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

src/cpu_forest_model.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,21 @@ struct ForestModel<rapids::HostMemory> {
8282
auto output_buffer = rapids::Buffer<float>{
8383
output.data(), output.size(), output.mem_type(), output.device(),
8484
output.stream()};
85-
auto output_size = output.size();
85+
auto const num_classes = tl_model_->num_classes();
8686
// New FIL expects buffer of size samples * num_classes for multi-class
8787
// classifiers, but output buffer may be smaller, so we need a temporary
8888
// buffer
89-
auto const num_classes = tl_model_->num_classes();
9089
if (!predict_proba && tl_model_->config().output_class &&
9190
num_classes > 1) {
92-
output_size = samples * num_classes;
93-
if (output_size != output.size()) {
94-
// If expected output size is not the same as the size of `output`,
95-
// create a temporary buffer of the correct size
96-
output_buffer =
97-
rapids::Buffer<float>{output_size, rapids::HostMemory};
98-
}
91+
output_buffer =
92+
rapids::Buffer<float>{samples * num_classes, rapids::HostMemory};
93+
} else if (
94+
predict_proba && tl_model_->config().output_class &&
95+
num_classes == 1) {
96+
// Also use a temp buffer when probabilities are requested for
97+
// a binary classifier. This is so that we can output probabilities
98+
// for both positive and negative classes.
99+
output_buffer = rapids::Buffer<float>{samples * 2, rapids::HostMemory};
99100
}
100101

101102
// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for

src/gpu_forest_model.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,22 @@ struct ForestModel<rapids::DeviceMemory> {
108108
auto output_buffer = rapids::Buffer<float>{
109109
output.data(), output.size(), output.mem_type(), output.device(),
110110
output.stream()};
111-
auto output_size = output.size();
111+
auto const num_classes = tl_model_->num_classes();
112112
// New FIL expects buffer of size samples * num_classes for multi-class
113113
// classifiers, but output buffer may be smaller, so we need a temporary
114114
// buffer
115-
auto const num_classes = tl_model_->num_classes();
116115
if (!predict_proba && tl_model_->config().output_class &&
117116
num_classes > 1) {
118-
output_size = samples * num_classes;
119-
if (output_size != output.size()) {
120-
// If expected output size is not the same as the size of `output`,
121-
// create a temporary buffer of the correct size
122-
output_buffer =
123-
rapids::Buffer<float>{output_size, rapids::DeviceMemory};
124-
}
117+
output_buffer =
118+
rapids::Buffer<float>{samples * num_classes, rapids::DeviceMemory};
119+
} else if (
120+
predict_proba && tl_model_->config().output_class &&
121+
num_classes == 1) {
122+
// Also use a temp buffer when probabilities are requested for
123+
// a binary classifier. This is so that we can output probabilities
124+
// for both positive and negative classes.
125+
output_buffer =
126+
rapids::Buffer<float>{samples * 2, rapids::DeviceMemory};
125127
}
126128
// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
127129
// input buffer

0 commit comments

Comments
 (0)