@@ -82,20 +82,21 @@ struct ForestModel<rapids::HostMemory> {
82
82
auto output_buffer = rapids::Buffer<float >{
83
83
output.data (), output.size (), output.mem_type (), output.device (),
84
84
output.stream ()};
85
- auto output_size = output. size ();
85
+ auto const num_classes = tl_model_-> num_classes ();
86
86
// New FIL expects buffer of size samples * num_classes for multi-class
87
87
// classifiers, but output buffer may be smaller, so we need a temporary
88
88
// buffer
89
- auto const num_classes = tl_model_->num_classes ();
90
89
if (!predict_proba && tl_model_->config ().output_class &&
91
90
num_classes > 1 ) {
92
- output_size = samples * num_classes;
93
- if (output_size != output.size ()) {
94
- // If expected output size is not the same as the size of `output`,
95
- // create a temporary buffer of the correct size
96
- output_buffer =
97
- rapids::Buffer<float >{output_size, rapids::HostMemory};
98
- }
91
+ output_buffer =
92
+ rapids::Buffer<float >{samples * num_classes, rapids::HostMemory};
93
+ } else if (
94
+ predict_proba && tl_model_->config ().output_class &&
95
+ num_classes == 1 ) {
96
+ // Also use a temp buffer when probabilities are requested for
97
+ // a binary classifier. This is so that we can output probabilities
98
+ // for both positive and negative classes.
99
+ output_buffer = rapids::Buffer<float >{samples * 2 , rapids::HostMemory};
99
100
}
100
101
101
102
// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
0 commit comments