diff --git a/src-json/warning.json b/src-json/warning.json index 0db09c76a01..d9783ef392d 100644 --- a/src-json/warning.json +++ b/src-json/warning.json @@ -125,6 +125,11 @@ "parent": "WTyper", "enabled": false }, + { + "name": "WRedundantNullCheck", + "doc": "Value can't be null, so comparison with null is excessive", + "parent": "WTyper" + }, { "name": "WHxb", "doc": "Hxb (either --hxb output or haxe compiler cache) related warnings" diff --git a/src/context/common.ml b/src/context/common.ml index 3eef8c80c8b..d661230865d 100644 --- a/src/context/common.ml +++ b/src/context/common.ml @@ -75,7 +75,7 @@ class compiler_callbacks = object(self) method add_after_generation (f : unit -> unit) : unit = after_generation := f :: !after_generation - method add_null_safety_report (f : (string*pos) list -> unit) : unit = + method add_null_safety_report (f : (WarningList.warning option*string*pos) list -> unit) : unit = null_safety_report <- f :: null_safety_report method run handle_error r = diff --git a/src/macro/macroApi.ml b/src/macro/macroApi.ml index 784eb25b204..606fb8be7b3 100644 --- a/src/macro/macroApi.ml +++ b/src/macro/macroApi.ml @@ -2375,9 +2375,10 @@ let macro_api ccom get_api = ); "on_null_safety_report", vfun1 (fun f -> let f = prepare_callback f 1 in - (ccom()).callbacks#add_null_safety_report (fun (errors:(string*pos) list) -> - let encode_item (msg,pos) = - encode_obj [("msg", encode_string msg); ("pos", encode_pos pos)] + (ccom()).callbacks#add_null_safety_report (fun (errors:(WarningList.warning option*string*pos) list) -> + let encode_item (wtype,msg,pos) = + let wtype = match wtype with | Some _ -> "warning" | None -> "error" in + encode_obj [("type", encode_string wtype); ("msg", encode_string msg); ("pos", encode_pos pos)] in ignore(f [encode_array (List.map encode_item errors)]) ); diff --git a/src/typing/nullSafety.ml b/src/typing/nullSafety.ml index 8071b9a2990..a476242acca 100644 --- a/src/typing/nullSafety.ml +++ b/src/typing/nullSafety.ml @@ -5,16 +5,23 @@ open Type type safety_message = { sm_msg : string; sm_pos : pos; + sm_type : WarningList.warning option } type safety_report = { mutable sr_errors : safety_message list; + mutable sr_warnings: safety_message list; } let add_error report msg pos = - let error = { sm_msg = ("Null safety: " ^ msg); sm_pos = pos; } in + let error = { sm_type = None; sm_msg = ("Null safety: " ^ msg); sm_pos = pos; } in if not (List.mem error report.sr_errors) then - report.sr_errors <- error :: report.sr_errors; + report.sr_errors <- error :: report.sr_errors;; + +let add_warning report wtype msg pos = + let warning = { sm_type = Some wtype; sm_msg = ("Null safety: " ^ msg); sm_pos = pos; } in + if not (List.mem warning report.sr_warnings) then + report.sr_warnings <- warning :: report.sr_warnings; type scope_type = | STNormal @@ -447,7 +454,7 @@ let rec contains_safe_meta metadata = let safety_enabled meta = (contains_safe_meta meta) && not (contains_unsafe_meta meta) -let safety_mode (metadata:Ast.metadata) = +let get_safety_mode (metadata:Ast.metadata) = let rec traverse mode meta = match mode, meta with | Some SMOff, _ @@ -1053,7 +1060,6 @@ class expr_checker mode immediate_execution report = val mutable in_closure = false (* if this flag is `true` then spotted errors and warnings will not be reported *) val mutable is_pretending = false - (* val mutable cnt = 0 *) (** Get safety mode for this expression checker *) @@ -1072,6 +1078,33 @@ class expr_checker mode immediate_execution report = in add_error report msg (get_first_valid_pos positions) end + (** + Register a warning + *) + method warning wtype msg (positions:Globals.pos list) = + if not is_pretending then begin + let rec get_first_valid_pos positions = + match positions with + | [] -> null_pos + | p :: rest -> + if p <> null_pos then p + else get_first_valid_pos rest + in + add_warning report wtype msg (get_first_valid_pos positions) + end + + method private check_binop_redundant_null_checks e = + match e.eexpr with + | TBinop ((OpEq | OpNotEq), { eexpr = TConst TNull }, expr) + | TBinop ((OpEq | OpNotEq), expr, { eexpr = TConst TNull }) + | TBinop(OpAssignOp OpNullCoal, expr, _) + | TBinop (OpNullCoal, expr, _) -> + if not (is_nullable_type ~dynamic_is_nullable:true expr.etype) then + self#warning + WRedundantNullCheck + ("The operand type is not nullable, so null-check should be redundant.") + [expr.epos; e.epos]; + | _ -> () (** Check if `e` is nullable even if the type is reported not-nullable. Haxe type system lies sometimes. @@ -1180,7 +1213,9 @@ class expr_checker mode immediate_execution report = | TConst _ -> () | TLocal _ -> () | TArray (arr, idx) -> self#check_array_access arr idx e.epos - | TBinop (op, left_expr, right_expr) -> self#check_binop op left_expr right_expr e.epos + | TBinop (op, left_expr, right_expr) -> + self#check_binop_redundant_null_checks e; + self#check_binop op left_expr right_expr e.epos | TField (target, access) -> self#check_field target access e.epos | TTypeExpr _ -> () | TParenthesis e -> self#check_expr e @@ -1539,7 +1574,7 @@ class class_checker cls immediate_execution report (main_expr : texpr option) = object (self) val is_safe_class = (safety_enabled cls_meta) val mutable checker = new expr_checker SMLoose immediate_execution report - val mutable mode = None + val mutable mode : safety_mode option = None (** Entry point for checking a class *) @@ -1549,7 +1584,7 @@ class class_checker cls immediate_execution report (main_expr : texpr option) = self#check_var_fields; let check_field is_static f = if not (has_class_field_flag f CfPostProcessed) then begin validate_safety_meta report f.cf_meta; - match (safety_mode (cls_meta @ f.cf_meta)) with + match (get_safety_mode (cls_meta @ f.cf_meta)) with | SMOff -> () | mode -> (match f.cf_expr with @@ -1560,7 +1595,7 @@ class class_checker cls immediate_execution report (main_expr : texpr option) = self#check_accessors is_static f end in if is_safe_class then - Option.may ((self#get_checker (safety_mode cls_meta))#check_root_expr) (TClass.get_cl_init cls); + Option.may ((self#get_checker (get_safety_mode cls_meta))#check_root_expr) (TClass.get_cl_init cls); Option.may (check_field false) cls.cl_constructor; List.iter (check_field false) cls.cl_ordered_fields; List.iter (check_field true) cls.cl_ordered_statics; @@ -1601,7 +1636,7 @@ class class_checker cls immediate_execution report (main_expr : texpr option) = match mode with | Some mode -> mode | None -> - let m = safety_mode cls_meta in + let m = get_safety_mode cls_meta in mode <- Some m; m (** @@ -1784,7 +1819,10 @@ class class_checker cls immediate_execution report (main_expr : texpr option) = *) let run (com:Common.context) (types:module_type list) = let report = Timer.time com.timer_ctx ["null safety"] (fun () -> - let report = { sr_errors = [] } in + let report = { + sr_errors = []; + sr_warnings = []; + } in let immediate_execution = new immediate_execution in let traverse module_type = match module_type with @@ -1798,11 +1836,21 @@ let run (com:Common.context) (types:module_type list) = ) () in match com.callbacks#get_null_safety_report with | [] -> - List.iter (fun err -> Common.display_error com err.sm_msg err.sm_pos) (List.rev report.sr_errors) + List.iter (fun warn -> + com.warning (Option.get warn.sm_type) [] warn.sm_msg warn.sm_pos + ) (List.rev report.sr_warnings); + + List.iter (fun err -> + Common.display_error com err.sm_msg err.sm_pos + ) (List.rev report.sr_errors) | callbacks -> + let warnings = + List.map (fun warn -> (warn.sm_type, warn.sm_msg, warn.sm_pos)) report.sr_warnings + in let errors = - List.map (fun err -> (err.sm_msg, err.sm_pos)) report.sr_errors + List.map (fun err -> (err.sm_type, err.sm_msg, err.sm_pos)) report.sr_errors in - List.iter (fun fn -> fn errors) callbacks + let all = warnings @ errors in + List.iter (fun fn -> fn all) callbacks ;; diff --git a/tests/nullsafety/src/Validator.hx b/tests/nullsafety/src/Validator.hx index eb7d11696d0..654dd8defe4 100644 --- a/tests/nullsafety/src/Validator.hx +++ b/tests/nullsafety/src/Validator.hx @@ -2,19 +2,21 @@ import haxe.macro.Context; import haxe.macro.Expr; -typedef SafetyMessage = {msg:String, pos:Position} -typedef ExpectedMessage = {symbol:String, pos:Position} +typedef SafetyMessage = {type:String, msg:String, pos:Position} +typedef ExpectedMessage = {type:String, symbol:String, pos:Position} #end class Validator { #if macro static var expectedErrors:Array = []; + static var expectedWarnings:Array = []; static dynamic function onNullSafetyReport(callback:(errors:Array)->Void):Void { } static public function register() { expectedErrors = []; + expectedWarnings = []; onNullSafetyReport = @:privateAccess Context.load("on_null_safety_report", 1); onNullSafetyReport(validate); } @@ -25,7 +27,13 @@ class Validator { if(meta.name == ':shouldFail') { var fieldPosInfos = Context.getPosInfos(field.pos); fieldPosInfos.min = Context.getPosInfos(meta.pos).max + 1; - expectedErrors.push({symbol: field.name, pos:Context.makePosition(fieldPosInfos)}); + expectedErrors.push({type: "error", symbol: field.name, pos:Context.makePosition(fieldPosInfos)}); + break; + } + if(meta.name == ':shouldWarn') { + var fieldPosInfos = Context.getPosInfos(field.pos); + fieldPosInfos.min = Context.getPosInfos(meta.pos).max + 1; + expectedWarnings.push({type: "warning", symbol: field.name, pos:Context.makePosition(fieldPosInfos)}); break; } } @@ -34,7 +42,7 @@ class Validator { } static function validate(errors:Array) { - var errors = check(expectedErrors.copy(), errors.copy()); + var errors = check(expectedErrors.concat(expectedWarnings), errors.copy()); if(errors.ok) { Sys.println('${errors.passed} expected errors spotted'); Sys.println('Compile-time tests passed.'); @@ -50,6 +58,7 @@ class Validator { var actualEvent = actual[i]; var wasExpected = false; for(expectedEvent in expected) { + if (expectedEvent.type != actualEvent.type) continue; if(posContains(expectedEvent.pos, actualEvent.pos)) { expected.remove(expectedEvent); wasExpected = true; @@ -85,7 +94,12 @@ class Validator { #end macro static public function shouldFail(expr:Expr):Expr { - expectedErrors.push({symbol:Context.getLocalMethod(), pos:expr.pos}); + expectedErrors.push({type: "error", symbol:Context.getLocalMethod(), pos:expr.pos}); + return expr; + } + + macro static public function shouldWarn(expr:Expr):Expr { + expectedWarnings.push({type: "warning", symbol:Context.getLocalMethod(), pos:expr.pos}); return expr; } -} \ No newline at end of file +} diff --git a/tests/nullsafety/src/cases/TestLoose.hx b/tests/nullsafety/src/cases/TestLoose.hx index cab9654ae9e..38931644781 100644 --- a/tests/nullsafety/src/cases/TestLoose.hx +++ b/tests/nullsafety/src/cases/TestLoose.hx @@ -1,6 +1,7 @@ package cases; import Validator.shouldFail; +import Validator.shouldWarn; typedef NotNullAnon = { a:String @@ -133,7 +134,7 @@ class TestLoose { } static function nullCoal_returnNull_shouldPass(token:{children:Array}):Null { - final children = token.children ?? return null; + final children = shouldWarn(token.children ?? return null); var i = children.length; return null; } diff --git a/tests/nullsafety/src/cases/TestNonNullable.hx b/tests/nullsafety/src/cases/TestNonNullable.hx new file mode 100644 index 00000000000..b7c838986b4 --- /dev/null +++ b/tests/nullsafety/src/cases/TestNonNullable.hx @@ -0,0 +1,60 @@ +package cases; + +import Validator.shouldWarn; +import Validator.shouldFail; + +typedef Data = { + var foo:String; +} + +class TestNonNullable { + static function main() { + final foo = 0; + if (shouldWarn(foo) == null) {} + + final dyn:Dynamic = null; + if (dyn == null) {} + + final dyn:Any = 1; + if (shouldWarn(dyn) == null) {} + + var data:Data = haxe.Json.parse("{}"); + data.foo.length; + + switch shouldWarn(data.foo) { + case null if (shouldWarn(data.foo) == null): + final v = shouldWarn(data.foo) == null; + } + + final v = shouldWarn(data.foo) == null; + shouldWarn(data.foo) != null && true; + true && shouldWarn(data.foo) != null; + data.foo != null || true; + true || shouldWarn(data.foo) != null; + + throw shouldWarn(data.foo) == null; + + function foo():Bool { + return shouldWarn(data.foo) == null; + } + + while (shouldWarn(data.foo) == null) {} + + shouldWarn(data.foo) ??= ""; + final foo = shouldWarn(data.foo ?? ""); + if (null == shouldWarn(data.foo)) { + trace(1); + } + if (shouldWarn(data.foo) == null) { + data.foo = "default"; + } + } +} + +@:build(Validator.checkFields()) +class BasicErrors { + @:shouldFail static var foo2:Int; + public function new() { + shouldFail(var foo:Int = null); + } +} diff --git a/tests/nullsafety/test.hxml b/tests/nullsafety/test.hxml index 9780163dec4..591082bdaa5 100644 --- a/tests/nullsafety/test.hxml +++ b/tests/nullsafety/test.hxml @@ -5,8 +5,10 @@ cases.TestStrictThreaded cases.TestLoose cases.TestSafeFieldInUnsafeClass cases.TestAbstract +cases.TestNonNullable --macro nullSafety('cases.TestLoose', Loose) --macro nullSafety('cases.TestStrict', Strict) --macro nullSafety('cases.TestStrictThreaded', StrictThreaded) ---macro Validator.register() \ No newline at end of file +--macro nullSafety('cases.TestNonNullable', Loose) +--macro Validator.register()