File tree 2 files changed +9
-3
lines changed 2 files changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -100,16 +100,15 @@ struct ForestModel<rapids::HostMemory> {
100
100
// for both positive and negative classes.
101
101
output_buffer = rapids::Buffer<float >{samples * 2 , rapids::HostMemory};
102
102
}
103
-
104
- // TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
105
- // input buffer
106
103
if (new_fil_model_->row_postprocessing () == filex::row_op::max_index) {
107
104
// Work around the limitation of max_index, to allow for probability
108
105
// output
109
106
// TODO(hcho3): Review new FIL to expose predict_proba(), to always
110
107
// output probabilities
111
108
new_fil_model_->set_row_postprocessing (filex::row_op::softmax);
112
109
}
110
+ // TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
111
+ // input buffer
113
112
new_fil_model_->predict (
114
113
raft_proto::handle_t {}, output_buffer.data (),
115
114
const_cast <float *>(input.data ()), samples,
Original file line number Diff line number Diff line change @@ -127,6 +127,13 @@ struct ForestModel<rapids::DeviceMemory> {
127
127
output_buffer =
128
128
rapids::Buffer<float >{samples * 2 , rapids::DeviceMemory};
129
129
}
130
+ if (new_fil_model_->row_postprocessing () == filex::row_op::max_index) {
131
+ // Work around the limitation of max_index, to allow for probability
132
+ // output
133
+ // TODO(hcho3): Review new FIL to expose predict_proba(), to always
134
+ // output probabilities
135
+ new_fil_model_->set_row_postprocessing (filex::row_op::softmax);
136
+ }
130
137
// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
131
138
// input buffer
132
139
new_fil_model_->predict (
You can’t perform that action at this time.
0 commit comments