Skip to content
This repository was archived by the owner on Mar 25, 2025. It is now read-only.

Commit 60249f1

Browse files
authored
Remove internal ml argument and move GPU code to Acc visitor (#919)
1 parent aa5046e commit 60249f1

File tree

5 files changed

+83
-56
lines changed

5 files changed

+83
-56
lines changed

src/codegen/codegen_acc_visitor.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,26 @@ void CodegenAccVisitor::print_newtonspace_transfer_to_device() const {
265265
}
266266

267267

268-
void CodegenAccVisitor::print_instance_variable_transfer_to_device(
269-
std::vector<std::string> const& ptr_members) const {
268+
void CodegenAccVisitor::print_instance_struct_transfer_routine_declarations() {
270269
if (info.artificial_cell) {
271270
return;
272271
}
272+
printer->fmt_line(
273+
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst);",
274+
instance_struct());
275+
printer->fmt_line("static inline void delete_instance_from_device({}* inst);",
276+
instance_struct());
277+
}
278+
279+
280+
void CodegenAccVisitor::print_instance_struct_transfer_routines(
281+
std::vector<std::string> const& ptr_members) {
282+
if (info.artificial_cell) {
283+
return;
284+
}
285+
printer->fmt_start_block(
286+
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst)",
287+
instance_struct());
273288
printer->start_block("if (!nt->compute_gpu)");
274289
printer->add_line("return;");
275290
printer->end_block(1);
@@ -285,18 +300,33 @@ void CodegenAccVisitor::print_instance_variable_transfer_to_device(
285300
printer->add_line("auto* d_ml = cnrn_target_deviceptr(ml);");
286301
printer->add_line("void* d_inst_void = d_inst;");
287302
printer->add_line("cnrn_target_memcpy_to_device(&(d_ml->instance), &d_inst_void);");
303+
printer->end_block(2); // copy_instance_to_device
304+
305+
printer->fmt_start_block("static inline void delete_instance_from_device({}* inst)",
306+
instance_struct());
307+
printer->start_block("if (cnrn_target_is_present(inst))");
308+
printer->add_line("cnrn_target_delete(inst);");
309+
printer->end_block(1);
310+
printer->end_block(2); // delete_instance_from_device
288311
}
289312

290313

291-
void CodegenAccVisitor::print_instance_variable_deletion_from_device() const {
314+
void CodegenAccVisitor::print_instance_struct_copy_to_device() {
292315
if (info.artificial_cell) {
293316
return;
294317
}
295-
printer->start_block("if (cnrn_target_is_present(&inst))");
296-
printer->add_line("cnrn_target_delete(&inst);");
297-
printer->end_block(1);
318+
printer->add_line("copy_instance_to_device(nt, ml, inst);");
319+
}
320+
321+
322+
void CodegenAccVisitor::print_instance_struct_delete_from_device() {
323+
if (info.artificial_cell) {
324+
return;
325+
}
326+
printer->add_line("delete_instance_from_device(inst);");
298327
}
299328

329+
300330
void CodegenAccVisitor::print_deriv_advance_flag_transfer_to_device() const {
301331
printer->add_line("nrn_pragma_acc(update device (deriv_advance_flag) if(nt->compute_gpu))");
302332
printer->add_line("nrn_pragma_omp(target update to(deriv_advance_flag) if(nt->compute_gpu))");

src/codegen/codegen_acc_visitor.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,17 @@ class CodegenAccVisitor: public CodegenCVisitor {
9292
/// transfer newtonspace structure to device
9393
void print_newtonspace_transfer_to_device() const override;
9494

95-
/// copy the instance struct to the device
96-
void print_instance_variable_transfer_to_device(
97-
std::vector<std::string> const& ptr_members) const override;
95+
/// declare helper functions for copying the instance struct to the device
96+
void print_instance_struct_transfer_routine_declarations() override;
9897

99-
/// delete the instance struct from the device
100-
void print_instance_variable_deletion_from_device() const override;
98+
/// define helper functions for copying the instance struct to the device
99+
void print_instance_struct_transfer_routines(std::vector<std::string> const&) override;
100+
101+
/// call helper function for copying the instance struct to the device
102+
void print_instance_struct_copy_to_device() override;
103+
104+
/// call helper function that deletes the instance struct from the device
105+
void print_instance_struct_delete_from_device() override;
101106

102107
// update derivimplicit advance flag on the gpu device
103108
void print_deriv_advance_flag_transfer_to_device() const override;

src/codegen/codegen_c_visitor.cpp

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,17 +1041,6 @@ void CodegenCVisitor::print_channel_iteration_tiling_block_end() {
10411041
}
10421042

10431043

1044-
void CodegenCVisitor::print_instance_variable_transfer_to_device(
1045-
std::vector<std::string> const& ptr_members) const {
1046-
// backend specific, do nothing
1047-
}
1048-
1049-
1050-
void CodegenCVisitor::print_instance_variable_deletion_from_device() const {
1051-
// backend specific, do nothing
1052-
}
1053-
1054-
10551044
void CodegenCVisitor::print_deriv_advance_flag_transfer_to_device() const {
10561045
// backend specific, do nothing
10571046
}
@@ -1892,9 +1881,9 @@ void CodegenCVisitor::print_eigen_linear_solver(const std::string& float_type, i
18921881

18931882
std::string CodegenCVisitor::internal_method_arguments() {
18941883
if (ion_variable_struct_required()) {
1895-
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, ml, v";
1884+
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v";
18961885
}
1897-
return "id, pnodecount, inst, data, indexes, thread, nt, ml, v";
1886+
return "id, pnodecount, inst, data, indexes, thread, nt, v";
18981887
}
18991888

19001889

@@ -1926,7 +1915,6 @@ CodegenCVisitor::ParamVector CodegenCVisitor::internal_method_parameters() {
19261915
params.emplace_back("const ", "Datum*", "", "indexes");
19271916
params.emplace_back(param_type_qualifier(), "ThreadDatum*", "", "thread");
19281917
params.emplace_back(param_type_qualifier(), "NrnThread*", param_ptr_qualifier(), "nt");
1929-
params.emplace_back(param_type_qualifier(), "Memb_list*", param_ptr_qualifier(), "ml");
19301918
params.emplace_back("", "double", "", "v");
19311919
return params;
19321920
}
@@ -1961,9 +1949,9 @@ std::string CodegenCVisitor::nrn_thread_arguments() {
19611949
*/
19621950
std::string CodegenCVisitor::nrn_thread_internal_arguments() {
19631951
if (ion_variable_struct_required()) {
1964-
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, ml, v";
1952+
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v";
19651953
}
1966-
return "id, pnodecount, inst, data, indexes, thread, nt, ml, v";
1954+
return "id, pnodecount, inst, data, indexes, thread, nt, v";
19671955
}
19681956

19691957

@@ -3200,18 +3188,15 @@ void CodegenCVisitor::print_instance_variable_setup() {
32003188
printer->fmt_line("assert(ml->global_variables_size == sizeof({}));", global_struct());
32013189
};
32023190

3203-
printer->fmt_line(
3204-
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst);",
3205-
instance_struct());
3206-
printer->fmt_line("static inline void delete_instance_from_device({}& inst);",
3207-
instance_struct());
3208-
printer->add_newline();
3191+
// Must come before print_instance_struct_copy_to_device and
3192+
// print_instance_struct_delete_from_device
3193+
print_instance_struct_transfer_routine_declarations();
32093194

32103195
printer->add_line("// Deallocate the instance structure");
32113196
printer->fmt_start_block("static void {}(NrnThread* nt, Memb_list* ml, int type)",
32123197
method_name(naming::NRN_PRIVATE_DESTRUCTOR_METHOD));
32133198
cast_inst_and_assert_validity();
3214-
printer->add_line("delete_instance_from_device(*inst);");
3199+
print_instance_struct_delete_from_device();
32153200
printer->add_line("delete inst;");
32163201
printer->add_line("ml->instance = nullptr;");
32173202
printer->add_line("ml->global_variables = nullptr;");
@@ -3269,20 +3254,10 @@ void CodegenCVisitor::print_instance_variable_setup() {
32693254
printer->fmt_line("inst->{} = {};", name, variable);
32703255
ptr_members.push_back(std::move(name));
32713256
}
3272-
printer->add_line("copy_instance_to_device(nt, ml, inst);");
3257+
print_instance_struct_copy_to_device();
32733258
printer->end_block(2); // setup_instance
32743259

3275-
printer->add_line("// Set up the device-side copy of the instance structure");
3276-
printer->fmt_start_block(
3277-
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst)",
3278-
instance_struct());
3279-
print_instance_variable_transfer_to_device(ptr_members);
3280-
printer->end_block(2); // copy_instance_to_device
3281-
3282-
printer->fmt_start_block("static inline void delete_instance_from_device({}& inst)",
3283-
instance_struct());
3284-
print_instance_variable_deletion_from_device();
3285-
printer->end_block(2); // delete_instance_from_device
3260+
print_instance_struct_transfer_routines(ptr_members);
32863261
}
32873262

32883263

src/codegen/codegen_c_visitor.hpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,19 +1080,39 @@ class CodegenCVisitor: public visitor::ConstAstVisitor {
10801080

10811081

10821082
/**
1083-
* Print the code to copy instance struct members to the device,
1084-
* substituting host pointers for device ones.
1083+
* Print declarations of the functions used by \ref
1084+
* print_instance_struct_copy_to_device and \ref
1085+
* print_instance_struct_delete_from_device.
1086+
*/
1087+
virtual void print_instance_struct_transfer_routine_declarations() {}
1088+
1089+
/**
1090+
* Print the definitions of the functions used by \ref
1091+
* print_instance_struct_copy_to_device and \ref
1092+
* print_instance_struct_delete_from_device. Declarations of these functions
1093+
* are printed by \ref print_instance_struct_transfer_routine_declarations.
1094+
*
1095+
* This updates the (pointer) member variables in the device copy of the
1096+
* instance struct to contain device pointers, which is why you must pass a
1097+
* list of names of those member variables.
10851098
*
1086-
* \param ptr_members Members to update.
1099+
* \param ptr_members List of instance struct member names.
10871100
*/
1088-
virtual void print_instance_variable_transfer_to_device(
1089-
std::vector<std::string> const& ptr_members) const;
1101+
virtual void print_instance_struct_transfer_routines(
1102+
std::vector<std::string> const& /* ptr_members */) {}
10901103

10911104

10921105
/**
1093-
* Print the code to delete the instance structure from the device.
1106+
* Transfer the instance struct to the device. This calls a function
1107+
* declared by \ref print_instance_struct_transfer_routine_declarations.
1108+
*/
1109+
virtual void print_instance_struct_copy_to_device() {}
1110+
1111+
/**
1112+
* Delete the instance struct from the device. This calls a function
1113+
* declared by \ref print_instance_struct_transfer_routine_declarations.
10941114
*/
1095-
virtual void print_instance_variable_deletion_from_device() const;
1115+
virtual void print_instance_struct_delete_from_device() {}
10961116

10971117

10981118
/**

test/unit/codegen/codegen_c_visitor.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ SCENARIO("Check instance variable definition order", "[codegen][var_order]") {
102102
inst->ion_cao = nt->_data;
103103
inst->ion_ica = nt->_data;
104104
inst->ion_dicadv = nt->_data;
105-
copy_instance_to_device(nt, ml, inst);
106105
}
107106
)";
108107
auto const expected = reindent_text(generated_code);
@@ -151,7 +150,6 @@ SCENARIO("Check instance variable definition order", "[codegen][var_order]") {
151150
inst->v_unused = ml->data+4*pnodecount;
152151
inst->ion_cai = nt->_data;
153152
inst->ion_cao = nt->_data;
154-
copy_instance_to_device(nt, ml, inst);
155153
}
156154
)";
157155

@@ -231,7 +229,6 @@ SCENARIO("Check instance variable definition order", "[codegen][var_order]") {
231229
inst->ion_ilca = nt->_data;
232230
inst->ion_elca = nt->_data;
233231
inst->style_lca = ml->pdata;
234-
copy_instance_to_device(nt, ml, inst);
235232
}
236233
)";
237234

0 commit comments

Comments
 (0)