Skip to content

Commit 33bcbfa

Browse files
committed
feat(pnnx): convert prelu[num_parameters=1] to leakyrelu, so that it can be fused with conv
1 parent 45cea8f commit 33bcbfa

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tools/pnnx/src/pass_level5/fuse_static_prelu.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,44 @@ pnnx.Output output 1 0 out
3535
}
3636
};
3737

38+
class convert_prelu_to_leakyrelu : public GraphRewriterPass
39+
{
40+
public:
41+
const char* match_pattern_graph() const
42+
{
43+
return R"PNNXIR(7767517
44+
3 2
45+
pnnx.Input input 0 1 input
46+
nn.PReLU op_0 1 1 input out num_parameters=1
47+
pnnx.Output output 1 0 out
48+
)PNNXIR";
49+
}
50+
51+
const char* type_str() const
52+
{
53+
return "nn.LeakyReLU";
54+
}
55+
56+
const char* name_str() const
57+
{
58+
return "leakyrelu";
59+
}
60+
61+
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
62+
{
63+
const Attribute& weight = captured_attrs.at("op_0.weight");
64+
op->params["negative_slope"] = weight.get_float32_data()[0];
65+
}
66+
};
67+
3868
void fuse_static_prelu(Graph& graph)
3969
{
4070
fuse_static_Fprelu_pass a;
71+
convert_prelu_to_leakyrelu b;
4172
int opindex = 0;
4273

4374
pnnx_graph_rewrite(graph, &a, opindex);
75+
pnnx_graph_rewrite(graph, &b, opindex);
4476
}
4577

4678
} // namespace pnnx

0 commit comments

Comments
 (0)