Skip to content

Commit 88eb877

Browse files
committed
feat(pnnx): convert prelu[num_parameters=1] to leakyrelu, so that it can be fused with conv
1 parent 45cea8f commit 88eb877

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

tools/pnnx/src/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,10 @@ set(pnnx_pass_level5_SRCS
412412
set(pnnx_pass_ncnn_SRCS
413413
pass_ncnn/convert_attribute.cpp
414414
pass_ncnn/convert_custom_op.cpp
415-
pass_ncnn/convert_module_op.cpp
416415
pass_ncnn/convert_half_to_float.cpp
417416
pass_ncnn/convert_input.cpp
417+
pass_ncnn/convert_module_op.cpp
418+
pass_ncnn/convert_prelu.cpp
418419
pass_ncnn/convert_reshape_interp_expression.cpp
419420
pass_ncnn/convert_slice_expression.cpp
420421
pass_ncnn/convert_torch_cat.cpp
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright 2025 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "pass_ncnn.h"
5+
6+
namespace pnnx {
7+
8+
namespace ncnn {
9+
10+
class convert_prelu : public GraphRewriterPass
11+
{
12+
public:
13+
const char* match_pattern_graph() const
14+
{
15+
return R"PNNXIR(7767517
16+
3 2
17+
pnnx.Input input 0 1 input
18+
PReLU op_0 1 1 input out 0=1
19+
pnnx.Output output 1 0 out
20+
)PNNXIR";
21+
}
22+
23+
const char* type_str() const
24+
{
25+
return "ReLU";
26+
}
27+
28+
const char* name_str() const
29+
{
30+
return "relu";
31+
}
32+
33+
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
34+
{
35+
const Attribute& slope = captured_attrs.at("op_0.0");
36+
op->params["0"] = slope.get_float32_data()[0];
37+
}
38+
};
39+
40+
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(convert_prelu, 99)
41+
42+
} // namespace ncnn
43+
44+
} // namespace pnnx
45+

0 commit comments

Comments
 (0)