diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 01e5ec4f149a6..a5094757b892c 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -22,13 +22,13 @@ use datafusion_expr::planner::{ use sqlparser::ast::{ AccessExpr, BinaryOperator, CastFormat, CastKind, CeilFloorKind, DataType as SQLDataType, DateTimeField, DictionaryField, Expr as SQLExpr, - ExprWithAlias as SQLExprWithAlias, JsonPath, MapEntry, StructField, Subscript, - TrimWhereField, TypedString, Value, ValueWithSpan, + ExprWithAlias as SQLExprWithAlias, JsonPath, MapEntry, Spanned, StructField, + Subscript, TrimWhereField, TypedString, Value, ValueWithSpan, }; use datafusion_common::{ - DFSchema, Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err, - plan_err, + DFSchema, Diagnostic, Result, ScalarValue, Span, internal_datafusion_err, + internal_err, not_impl_err, plan_err, }; use datafusion_expr::expr::ScalarFunction; @@ -55,6 +55,99 @@ mod unary_op; mod value; impl SqlToRel<'_, S> { + pub(crate) fn warn_on_null_equality_predicate(&self, predicate: &SQLExpr) { + fn null_value_span(expr: &SQLExpr) -> Option { + match expr { + SQLExpr::Value(ValueWithSpan { + value: Value::Null, + span, + }) => Span::try_from_sqlparser_span(*span), + _ => None, + } + } + + fn null_equality_warning(expr: &SQLExpr) -> Option { + let SQLExpr::BinaryOp { left, op, right } = expr else { + return None; + }; + + let (message, help) = match op { + BinaryOperator::Eq => ( + "comparison with NULL using `=` always evaluates to NULL", + "use `IS NULL` to check for NULL values", + ), + BinaryOperator::NotEq => ( + "comparison with NULL using `<>` always evaluates to NULL", + "use `IS NOT NULL` to check for non-NULL values", + ), + _ => return None, + }; + + let null_span = null_value_span(left).or_else(|| null_value_span(right)); + null_span.map(|null_span| { + Diagnostic::new_warning( + message, + Span::try_from_sqlparser_span(expr.span()), + ) + .with_help(help, Some(null_span)) + }) + } + + fn collect_null_equality_warnings( + expr: &SQLExpr, + warnings: &mut Vec, + ) { + if let Some(warning) = null_equality_warning(expr) { + warnings.push(warning); + } + + match expr { + SQLExpr::BinaryOp { left, right, .. } + | SQLExpr::IsDistinctFrom(left, right) + | SQLExpr::IsNotDistinctFrom(left, right) => { + collect_null_equality_warnings(left, warnings); + collect_null_equality_warnings(right, warnings); + } + SQLExpr::Nested(expr) + | SQLExpr::UnaryOp { expr, .. } + | SQLExpr::IsFalse(expr) + | SQLExpr::IsNotFalse(expr) + | SQLExpr::IsTrue(expr) + | SQLExpr::IsNotTrue(expr) + | SQLExpr::IsUnknown(expr) + | SQLExpr::IsNotUnknown(expr) + | SQLExpr::OuterJoin(expr) + | SQLExpr::Prior(expr) => { + collect_null_equality_warnings(expr, warnings); + } + SQLExpr::Case { + operand, + conditions, + else_result, + .. + } => { + if let Some(operand) = operand { + collect_null_equality_warnings(operand, warnings); + } + for condition in conditions { + collect_null_equality_warnings(&condition.condition, warnings); + collect_null_equality_warnings(&condition.result, warnings); + } + if let Some(else_result) = else_result { + collect_null_equality_warnings(else_result, warnings); + } + } + _ => {} + } + } + + let mut warnings = vec![]; + collect_null_equality_warnings(predicate, &mut warnings); + for warning in warnings { + self.add_warning(warning); + } + } + pub(crate) fn sql_expr_to_logical_expr_with_alias( &self, sql: SQLExprWithAlias, diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 01215ae3434cf..20a80e4f8ae9b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -18,7 +18,7 @@ //! [`SqlToRel`]: SQL Query Planner (produces [`LogicalPlan`] from SQL AST) use std::collections::HashMap; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::vec; use crate::utils::make_decimal_type; @@ -455,6 +455,7 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) ident_normalizer: IdentNormalizer, + warnings: Mutex>, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -477,9 +478,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { context_provider, options, ident_normalizer: IdentNormalizer::new(ident_normalize), + warnings: Mutex::new(vec![]), } } + pub(crate) fn add_warning(&self, warning: Diagnostic) { + self.warnings + .lock() + .expect("warning diagnostic lock poisoned") + .push(warning); + } + + /// Drain and return non-fatal warnings collected during SQL planning. + pub fn take_warnings(&self) -> Vec { + std::mem::take( + &mut self + .warnings + .lock() + .expect("warning diagnostic lock poisoned"), + ) + } + pub fn build_schema(&self, columns: Vec) -> Result { let mut fields = Vec::with_capacity(columns.len()); diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 3343890c6dc1d..475d9a5b38099 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -122,6 +122,7 @@ impl SqlToRel<'_, S> { JoinConstraint::On(sql_expr) => { let join_schema = left.schema().join(right.schema())?; // parse ON expression + self.warn_on_null_equality_predicate(&sql_expr); let expr = self.sql_to_expr(sql_expr, &join_schema, planner_context)?; LogicalPlanBuilder::from(left) .join_on(right, join_type, Some(expr))? diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index b0099b8a1dcc3..ba7353c424f4e 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -198,6 +198,7 @@ impl SqlToRel<'_, S> { let having_expr_opt = select .having .map::, _>(|having_expr| { + self.warn_on_null_equality_predicate(&having_expr); let having_expr = self.sql_expr_to_logical_expr( having_expr, &combined_schema, @@ -865,6 +866,7 @@ impl SqlToRel<'_, S> { Some(predicate_expr) => { let fallback_schemas = plan.fallback_normalize_schemas(); + self.warn_on_null_equality_predicate(&predicate_expr); let filter_expr = self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index 7a729739469d3..52a95c081b6af 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -19,6 +19,7 @@ use datafusion_functions::string; use insta::assert_snapshot; use std::{collections::HashMap, sync::Arc}; +use datafusion_common::diagnostic::DiagnosticKind; use datafusion_common::{Diagnostic, Location, Result, Span}; use datafusion_sql::{ parser::{DFParser, DFParserBuilder}, @@ -51,6 +52,25 @@ fn do_query(sql: &'static str) -> Diagnostic { } } +fn do_query_warnings(sql: &'static str) -> Vec { + let statement = DFParserBuilder::new(sql) + .build() + .expect("unable to create parser") + .parse_statement() + .expect("unable to parse query"); + let options = ParserOptions { + collect_spans: true, + ..ParserOptions::default() + }; + let state = MockSessionState::default(); + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new_with_options(&context, options); + sql_to_rel + .statement_to_plan(statement) + .expect("expected planning to succeed"); + sql_to_rel.take_warnings() +} + /// Given a query that contains tag delimited spans, returns a mapping from the /// span name to the [`Span`]. Tags are comments of the form `/*tag*/`. In case /// you want the same location to open two spans, or close open and open @@ -390,3 +410,136 @@ fn test_syntax_error() -> Result<()> { }, } } + +#[test] +fn test_eq_null_warning_in_where() -> Result<()> { + let query = "SELECT * FROM person WHERE /*cmp*/first_name = /*null*/NULL/*null+cmp*/"; + let spans = get_spans(query); + let warnings = do_query_warnings(query); + assert_eq!(warnings.len(), 1); + + let warning = &warnings[0]; + assert_eq!(warning.kind, DiagnosticKind::Warning); + assert_snapshot!( + warning.message, + @"comparison with NULL using `=` always evaluates to NULL" + ); + assert_eq!(warning.span, Some(spans["cmp"])); + assert_snapshot!( + warning.helps[0].message, + @"use `IS NULL` to check for NULL values" + ); + assert_eq!(warning.helps[0].span, Some(spans["null"])); + Ok(()) +} + +#[test] +fn test_null_eq_warning_in_where() -> Result<()> { + let query = "SELECT * FROM person WHERE /*cmp+null*/NULL/*null*/ = first_name/*cmp*/"; + let spans = get_spans(query); + let warnings = do_query_warnings(query); + assert_eq!(warnings.len(), 1); + assert_eq!(warnings[0].kind, DiagnosticKind::Warning); + assert_snapshot!( + warnings[0].message, + @"comparison with NULL using `=` always evaluates to NULL" + ); + assert_eq!(warnings[0].span, Some(spans["cmp"])); + assert_eq!(warnings[0].helps[0].span, Some(spans["null"])); + Ok(()) +} + +#[test] +fn test_not_eq_null_warning_in_where() -> Result<()> { + let query = + "SELECT * FROM person WHERE /*cmp*/first_name <> /*null*/NULL/*null+cmp*/"; + let spans = get_spans(query); + let warnings = do_query_warnings(query); + assert_eq!(warnings.len(), 1); + assert_eq!(warnings[0].kind, DiagnosticKind::Warning); + assert_snapshot!( + warnings[0].message, + @"comparison with NULL using `<>` always evaluates to NULL" + ); + assert_eq!(warnings[0].span, Some(spans["cmp"])); + assert_snapshot!( + warnings[0].helps[0].message, + @"use `IS NOT NULL` to check for non-NULL values" + ); + assert_eq!(warnings[0].helps[0].span, Some(spans["null"])); + Ok(()) +} + +#[test] +fn test_eq_null_warning_in_join_on() -> Result<()> { + let query = + "SELECT * FROM person a JOIN person b ON /*cmp*/a.id = /*null*/NULL/*null+cmp*/"; + let spans = get_spans(query); + let warnings = do_query_warnings(query); + assert_eq!(warnings.len(), 1); + assert_eq!(warnings[0].kind, DiagnosticKind::Warning); + assert_snapshot!( + warnings[0].message, + @"comparison with NULL using `=` always evaluates to NULL" + ); + assert_eq!(warnings[0].span, Some(spans["cmp"])); + Ok(()) +} + +#[test] +fn test_eq_null_warning_in_having() -> Result<()> { + let query = "SELECT first_name FROM person GROUP BY first_name HAVING /*cmp*/1 = /*null*/NULL/*null+cmp*/"; + let spans = get_spans(query); + let warnings = do_query_warnings(query); + assert_eq!(warnings.len(), 1); + assert_eq!(warnings[0].kind, DiagnosticKind::Warning); + assert_snapshot!( + warnings[0].message, + @"comparison with NULL using `=` always evaluates to NULL" + ); + assert_eq!(warnings[0].span, Some(spans["cmp"])); + Ok(()) +} + +#[test] +fn test_eq_null_warning_nested_in_case_predicate() -> Result<()> { + let query = "SELECT * FROM person WHERE CASE WHEN /*cmp*/first_name = /*null*/NULL/*null+cmp*/ THEN true ELSE false END"; + let spans = get_spans(query); + let warnings = do_query_warnings(query); + assert_eq!(warnings.len(), 1); + assert_eq!(warnings[0].kind, DiagnosticKind::Warning); + assert_eq!(warnings[0].span, Some(spans["cmp"])); + Ok(()) +} + +#[test] +fn test_is_null_has_no_warning() -> Result<()> { + let warnings = do_query_warnings("SELECT * FROM person WHERE first_name IS NULL"); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + Ok(()) +} + +#[test] +fn test_eq_null_projection_has_no_warning() -> Result<()> { + let warnings = do_query_warnings("SELECT first_name = NULL FROM person"); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + Ok(()) +} + +#[test] +fn test_multiple_null_comparison_warnings() -> Result<()> { + let warnings = do_query_warnings( + "SELECT * FROM person WHERE first_name = NULL OR last_name <> NULL", + ); + assert_eq!(warnings.len(), 2); + assert!(warnings.iter().all(|w| w.kind == DiagnosticKind::Warning)); + assert_snapshot!( + warnings[0].message, + @"comparison with NULL using `=` always evaluates to NULL" + ); + assert_snapshot!( + warnings[1].message, + @"comparison with NULL using `<>` always evaluates to NULL" + ); + Ok(()) +}