Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 97 additions & 4 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ use datafusion_expr::planner::{
use sqlparser::ast::{
AccessExpr, BinaryOperator, CastFormat, CastKind, CeilFloorKind,
DataType as SQLDataType, DateTimeField, DictionaryField, Expr as SQLExpr,
ExprWithAlias as SQLExprWithAlias, JsonPath, MapEntry, StructField, Subscript,
TrimWhereField, TypedString, Value, ValueWithSpan,
ExprWithAlias as SQLExprWithAlias, JsonPath, MapEntry, Spanned, StructField,
Subscript, TrimWhereField, TypedString, Value, ValueWithSpan,
};

use datafusion_common::{
DFSchema, Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err,
plan_err,
DFSchema, Diagnostic, Result, ScalarValue, Span, internal_datafusion_err,
internal_err, not_impl_err, plan_err,
};

use datafusion_expr::expr::ScalarFunction;
Expand All @@ -55,6 +55,99 @@ mod unary_op;
mod value;

impl<S: ContextProvider> SqlToRel<'_, S> {
pub(crate) fn warn_on_null_equality_predicate(&self, predicate: &SQLExpr) {
fn null_value_span(expr: &SQLExpr) -> Option<Span> {
match expr {
SQLExpr::Value(ValueWithSpan {
value: Value::Null,
span,
}) => Span::try_from_sqlparser_span(*span),
_ => None,
}
}

fn null_equality_warning(expr: &SQLExpr) -> Option<Diagnostic> {
let SQLExpr::BinaryOp { left, op, right } = expr else {
return None;
};

let (message, help) = match op {
BinaryOperator::Eq => (
"comparison with NULL using `=` always evaluates to NULL",
"use `IS NULL` to check for NULL values",
),
BinaryOperator::NotEq => (
"comparison with NULL using `<>` always evaluates to NULL",
"use `IS NOT NULL` to check for non-NULL values",
),
_ => return None,
};

let null_span = null_value_span(left).or_else(|| null_value_span(right));
null_span.map(|null_span| {
Diagnostic::new_warning(
message,
Span::try_from_sqlparser_span(expr.span()),
)
.with_help(help, Some(null_span))
})
}

fn collect_null_equality_warnings(
expr: &SQLExpr,
warnings: &mut Vec<Diagnostic>,
) {
if let Some(warning) = null_equality_warning(expr) {
warnings.push(warning);
}

match expr {
SQLExpr::BinaryOp { left, right, .. }
| SQLExpr::IsDistinctFrom(left, right)
| SQLExpr::IsNotDistinctFrom(left, right) => {
collect_null_equality_warnings(left, warnings);
collect_null_equality_warnings(right, warnings);
}
SQLExpr::Nested(expr)
| SQLExpr::UnaryOp { expr, .. }
| SQLExpr::IsFalse(expr)
| SQLExpr::IsNotFalse(expr)
| SQLExpr::IsTrue(expr)
| SQLExpr::IsNotTrue(expr)
| SQLExpr::IsUnknown(expr)
| SQLExpr::IsNotUnknown(expr)
| SQLExpr::OuterJoin(expr)
| SQLExpr::Prior(expr) => {
collect_null_equality_warnings(expr, warnings);
}
SQLExpr::Case {
operand,
conditions,
else_result,
..
} => {
if let Some(operand) = operand {
collect_null_equality_warnings(operand, warnings);
}
for condition in conditions {
collect_null_equality_warnings(&condition.condition, warnings);
collect_null_equality_warnings(&condition.result, warnings);
}
if let Some(else_result) = else_result {
collect_null_equality_warnings(else_result, warnings);
}
}
_ => {}
}
}

let mut warnings = vec![];
collect_null_equality_warnings(predicate, &mut warnings);
for warning in warnings {
self.add_warning(warning);
}
}

pub(crate) fn sql_expr_to_logical_expr_with_alias(
&self,
sql: SQLExprWithAlias,
Expand Down
21 changes: 20 additions & 1 deletion datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! [`SqlToRel`]: SQL Query Planner (produces [`LogicalPlan`] from SQL AST)
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::vec;

use crate::utils::make_decimal_type;
Expand Down Expand Up @@ -455,6 +455,7 @@ pub struct SqlToRel<'a, S: ContextProvider> {
pub(crate) context_provider: &'a S,
pub(crate) options: ParserOptions,
pub(crate) ident_normalizer: IdentNormalizer,
warnings: Mutex<Vec<Diagnostic>>,
}

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expand All @@ -477,9 +478,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
context_provider,
options,
ident_normalizer: IdentNormalizer::new(ident_normalize),
warnings: Mutex::new(vec![]),
}
}

pub(crate) fn add_warning(&self, warning: Diagnostic) {
self.warnings
.lock()
.expect("warning diagnostic lock poisoned")
.push(warning);
}

/// Drain and return non-fatal warnings collected during SQL planning.
pub fn take_warnings(&self) -> Vec<Diagnostic> {
std::mem::take(
&mut self
.warnings
.lock()
.expect("warning diagnostic lock poisoned"),
)
}

pub fn build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema> {
let mut fields = Vec::with_capacity(columns.len());

Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/src/relation/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
JoinConstraint::On(sql_expr) => {
let join_schema = left.schema().join(right.schema())?;
// parse ON expression
self.warn_on_null_equality_predicate(&sql_expr);
let expr = self.sql_to_expr(sql_expr, &join_schema, planner_context)?;
LogicalPlanBuilder::from(left)
.join_on(right, join_type, Some(expr))?
Expand Down
2 changes: 2 additions & 0 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
let having_expr_opt = select
.having
.map::<Result<Expr>, _>(|having_expr| {
self.warn_on_null_equality_predicate(&having_expr);
let having_expr = self.sql_expr_to_logical_expr(
having_expr,
&combined_schema,
Expand Down Expand Up @@ -865,6 +866,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
Some(predicate_expr) => {
let fallback_schemas = plan.fallback_normalize_schemas();

self.warn_on_null_equality_predicate(&predicate_expr);
let filter_expr =
self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?;

Expand Down
153 changes: 153 additions & 0 deletions datafusion/sql/tests/cases/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use datafusion_functions::string;
use insta::assert_snapshot;
use std::{collections::HashMap, sync::Arc};

use datafusion_common::diagnostic::DiagnosticKind;
use datafusion_common::{Diagnostic, Location, Result, Span};
use datafusion_sql::{
parser::{DFParser, DFParserBuilder},
Expand Down Expand Up @@ -51,6 +52,25 @@ fn do_query(sql: &'static str) -> Diagnostic {
}
}

fn do_query_warnings(sql: &'static str) -> Vec<Diagnostic> {
let statement = DFParserBuilder::new(sql)
.build()
.expect("unable to create parser")
.parse_statement()
.expect("unable to parse query");
let options = ParserOptions {
collect_spans: true,
..ParserOptions::default()
};
let state = MockSessionState::default();
let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new_with_options(&context, options);
sql_to_rel
.statement_to_plan(statement)
.expect("expected planning to succeed");
sql_to_rel.take_warnings()
}

/// Given a query that contains tag delimited spans, returns a mapping from the
/// span name to the [`Span`]. Tags are comments of the form `/*tag*/`. In case
/// you want the same location to open two spans, or close open and open
Expand Down Expand Up @@ -390,3 +410,136 @@ fn test_syntax_error() -> Result<()> {
},
}
}

#[test]
fn test_eq_null_warning_in_where() -> Result<()> {
let query = "SELECT * FROM person WHERE /*cmp*/first_name = /*null*/NULL/*null+cmp*/";
let spans = get_spans(query);
let warnings = do_query_warnings(query);
assert_eq!(warnings.len(), 1);

let warning = &warnings[0];
assert_eq!(warning.kind, DiagnosticKind::Warning);
assert_snapshot!(
warning.message,
@"comparison with NULL using `=` always evaluates to NULL"
);
assert_eq!(warning.span, Some(spans["cmp"]));
assert_snapshot!(
warning.helps[0].message,
@"use `IS NULL` to check for NULL values"
);
assert_eq!(warning.helps[0].span, Some(spans["null"]));
Ok(())
}

#[test]
fn test_null_eq_warning_in_where() -> Result<()> {
let query = "SELECT * FROM person WHERE /*cmp+null*/NULL/*null*/ = first_name/*cmp*/";
let spans = get_spans(query);
let warnings = do_query_warnings(query);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].kind, DiagnosticKind::Warning);
assert_snapshot!(
warnings[0].message,
@"comparison with NULL using `=` always evaluates to NULL"
);
assert_eq!(warnings[0].span, Some(spans["cmp"]));
assert_eq!(warnings[0].helps[0].span, Some(spans["null"]));
Ok(())
}

#[test]
fn test_not_eq_null_warning_in_where() -> Result<()> {
let query =
"SELECT * FROM person WHERE /*cmp*/first_name <> /*null*/NULL/*null+cmp*/";
let spans = get_spans(query);
let warnings = do_query_warnings(query);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].kind, DiagnosticKind::Warning);
assert_snapshot!(
warnings[0].message,
@"comparison with NULL using `<>` always evaluates to NULL"
);
assert_eq!(warnings[0].span, Some(spans["cmp"]));
assert_snapshot!(
warnings[0].helps[0].message,
@"use `IS NOT NULL` to check for non-NULL values"
);
assert_eq!(warnings[0].helps[0].span, Some(spans["null"]));
Ok(())
}

#[test]
fn test_eq_null_warning_in_join_on() -> Result<()> {
let query =
"SELECT * FROM person a JOIN person b ON /*cmp*/a.id = /*null*/NULL/*null+cmp*/";
let spans = get_spans(query);
let warnings = do_query_warnings(query);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].kind, DiagnosticKind::Warning);
assert_snapshot!(
warnings[0].message,
@"comparison with NULL using `=` always evaluates to NULL"
);
assert_eq!(warnings[0].span, Some(spans["cmp"]));
Ok(())
}

#[test]
fn test_eq_null_warning_in_having() -> Result<()> {
let query = "SELECT first_name FROM person GROUP BY first_name HAVING /*cmp*/1 = /*null*/NULL/*null+cmp*/";
let spans = get_spans(query);
let warnings = do_query_warnings(query);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].kind, DiagnosticKind::Warning);
assert_snapshot!(
warnings[0].message,
@"comparison with NULL using `=` always evaluates to NULL"
);
assert_eq!(warnings[0].span, Some(spans["cmp"]));
Ok(())
}

#[test]
fn test_eq_null_warning_nested_in_case_predicate() -> Result<()> {
let query = "SELECT * FROM person WHERE CASE WHEN /*cmp*/first_name = /*null*/NULL/*null+cmp*/ THEN true ELSE false END";
let spans = get_spans(query);
let warnings = do_query_warnings(query);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].kind, DiagnosticKind::Warning);
assert_eq!(warnings[0].span, Some(spans["cmp"]));
Ok(())
}

#[test]
fn test_is_null_has_no_warning() -> Result<()> {
let warnings = do_query_warnings("SELECT * FROM person WHERE first_name IS NULL");
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
Ok(())
}

#[test]
fn test_eq_null_projection_has_no_warning() -> Result<()> {
let warnings = do_query_warnings("SELECT first_name = NULL FROM person");
assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
Ok(())
}

#[test]
fn test_multiple_null_comparison_warnings() -> Result<()> {
let warnings = do_query_warnings(
"SELECT * FROM person WHERE first_name = NULL OR last_name <> NULL",
);
assert_eq!(warnings.len(), 2);
assert!(warnings.iter().all(|w| w.kind == DiagnosticKind::Warning));
assert_snapshot!(
warnings[0].message,
@"comparison with NULL using `=` always evaluates to NULL"
);
assert_snapshot!(
warnings[1].message,
@"comparison with NULL using `<>` always evaluates to NULL"
);
Ok(())
}