Skip to content

Commit 8452dea

Browse files
committed
Fix GPU inference
1 parent 9605843 commit 8452dea

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/cpu_forest_model.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,15 @@ struct ForestModel<rapids::HostMemory> {
100100
// for both positive and negative classes.
101101
output_buffer = rapids::Buffer<float>{samples * 2, rapids::HostMemory};
102102
}
103-
104-
// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
105-
// input buffer
106103
if (new_fil_model_->row_postprocessing() == filex::row_op::max_index) {
107104
// Work around the limitation of max_index, to allow for probability
108105
// output
109106
// TODO(hcho3): Review new FIL to expose predict_proba(), to always
110107
// output probabilities
111108
new_fil_model_->set_row_postprocessing(filex::row_op::softmax);
112109
}
110+
// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
111+
// input buffer
113112
new_fil_model_->predict(
114113
raft_proto::handle_t{}, output_buffer.data(),
115114
const_cast<float*>(input.data()), samples,

src/gpu_forest_model.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ struct ForestModel<rapids::DeviceMemory> {
127127
output_buffer =
128128
rapids::Buffer<float>{samples * 2, rapids::DeviceMemory};
129129
}
130+
if (new_fil_model_->row_postprocessing() == filex::row_op::max_index) {
131+
// Work around the limitation of max_index, to allow for probability
132+
// output
133+
// TODO(hcho3): Review new FIL to expose predict_proba(), to always
134+
// output probabilities
135+
new_fil_model_->set_row_postprocessing(filex::row_op::softmax);
136+
}
130137
// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
131138
// input buffer
132139
new_fil_model_->predict(

0 commit comments

Comments
 (0)