Skip to content

FIX:rule_00112 中bigint和整数误触发 #3051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions sqle/driver/mysql/rule/ai/rule_00112.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/opcode"
"github.com/pingcap/tidb/types"
parserdriver "github.com/pingcap/tidb/types/parser_driver"

"github.com/actiontech/sqle/sqle/driver/mysql/plocale"
Expand Down Expand Up @@ -96,6 +95,16 @@ func RuleSQLE00112(input *rulepkg.RuleHandlerInput) error {
return defaultTable
}

// 内部辅助函数:判断TP是否为整数类型
isIntegerType := func(tp byte) bool {
switch tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
return true
default:
return false
}
}

// 内部辅助函数:获取表达式的类型
getExprType := func(expr ast.ExprNode) (byte, error) {
switch node := expr.(type) {
Expand Down Expand Up @@ -131,32 +140,9 @@ func RuleSQLE00112(input *rulepkg.RuleHandlerInput) error {
}
return 0, fmt.Errorf("不支持的函数: %s", strings.Join(names, "."))
case *parserdriver.ValueExpr:
switch node.Datum.Kind() {
case types.KindInt64, types.KindUint64:
return mysql.TypeLong, nil
case types.KindFloat32, types.KindFloat64:
return mysql.TypeDouble, nil
case types.KindString:
return mysql.TypeVarchar, nil
case types.KindBytes:
return mysql.TypeBlob, nil
case types.KindBinaryLiteral:
return mysql.TypeBit, nil
case types.KindMysqlDecimal:
return mysql.TypeNewDecimal, nil
case types.KindMysqlDuration:
return mysql.TypeDuration, nil
case types.KindMysqlTime:
return mysql.TypeDatetime, nil
case types.KindMysqlEnum:
return mysql.TypeEnum, nil
case types.KindMysqlSet:
return mysql.TypeSet, nil
case types.KindMysqlJSON:
return mysql.TypeJSON, nil
default:
return 0, fmt.Errorf("不支持的常量类型: %d", node.Datum.Kind())
}
// 处理常量表达式
return node.Type.Tp, nil

default:
return 0, fmt.Errorf("不支持的表达式类型: %T", expr)
}
Expand Down Expand Up @@ -225,8 +211,12 @@ func RuleSQLE00112(input *rulepkg.RuleHandlerInput) error {

// 比较类型是否一致
if leftType != rightType {
// 报告违规
rulepkg.AddResult(input.Res, input.Rule, SQLE00112)
//如果不都是整数类型
if !(isIntegerType(leftType) && isIntegerType(rightType)) {
// 报告违规
rulepkg.AddResult(input.Res, input.Rule, SQLE00112)
}

}
return false
}, expr)
Expand Down Expand Up @@ -311,8 +301,12 @@ func RuleSQLE00112(input *rulepkg.RuleHandlerInput) error {

// 比较类型是否一致
if leftType != rightType {
// 报告违规
rulepkg.AddResult(input.Res, input.Rule, SQLE00112)
//如果不都是整数类型
if !(isIntegerType(leftType) && isIntegerType(rightType)) {
// 报告违规
rulepkg.AddResult(input.Res, input.Rule, SQLE00112)
}

}
return false
}, expr)
Expand Down Expand Up @@ -368,8 +362,12 @@ func RuleSQLE00112(input *rulepkg.RuleHandlerInput) error {

// 比较类型是否一致
if leftType != rightType {
// 报告违规
rulepkg.AddResult(input.Res, input.Rule, SQLE00112)
//如果不都是整数类型
if !(isIntegerType(leftType) && isIntegerType(rightType)) {
// 报告违规
rulepkg.AddResult(input.Res, input.Rule, SQLE00112)
}

}
return false
}, expr)
Expand Down
2 changes: 1 addition & 1 deletion sqle/driver/mysql/rule_00108_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestRuleSQLE00108(t *testing.T) {
// "WITH CTE AS (SELECT * FROM exist_db.exist_tb_1 WHERE id IN (SELECT id FROM exist_db.exist_tb_1 WHERE id IN (SELECT id FROM exist_db.exist_tb_1 WHERE id IN (SELECT id FROM exist_db.exist_tb_1 WHERE id IN (SELECT id FROM exist_db.exist_tb_1 WHERE id = 'value'))))) SELECT * FROM CTE",
// nil, nil, newTestResult())

runAIRuleCase(rule, t, "case 13_tes: SELECT语句where中包含6层嵌套子查询,使用示例中的表结构",
runAIRuleCase(rule, t, "case 13: SELECT语句from中包含2层嵌套子查询,使用示例中的表结构",
"SELECT AVG(subquery_middle.subquery_grade) AS subquery_middle_avg FROM (SELECT grade AS subquery_grade FROM st1 WHERE st1.cid IN (SELECT cid FROM st_class WHERE cname = 'class2')) subquery_middle;",
session.NewAIMockContext().WithSQL("CREATE TABLE st1 (id bigint, name VARCHAR(32), cid bigint, grade NUMERIC); CREATE TABLE st_class (cid bigint, cname VARCHAR(32));"),
nil, newTestResult())
Expand Down
6 changes: 6 additions & 0 deletions sqle/driver/mysql/rule_00112_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ func TestRuleSQLE00112(t *testing.T) {
session.NewAIMockContext().WithSQL("CREATE TABLE orders (c_id VARCHAR(100), order_date DATE);").
WithSQL("CREATE TABLE customers (c_id INT, name VARCHAR(100));"),
nil, newTestResult().addResult(ruleName))

runAIRuleCase(rule, t, "case 33: UPDATE语句中WHERE子句比较t1.id (bigint)与常量2838923,预期通过",
"UPDATE t1 SET name = 'jack' WHERE id = 2838923;",
session.NewAIMockContext().WithSQL("CREATE TABLE t1 (id BIGINT UNSIGNED not null, name VARCHAR(100));"),
nil, newTestResult())

}

// ==== Rule test code end ====