Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/ast/dml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,19 @@ use sqlparser_derive::{Visit, VisitMut};
use crate::display_utils::{indented_list, Indent, SpaceOrNewline};

use super::{
display_comma_separated, query::InputFormatClause, Assignment, Expr, FromTable, Ident,
InsertAliases, MysqlInsertPriority, ObjectName, OnInsert, OrderByExpr, Query, SelectItem,
Setting, SqliteOnConflict, TableObject, TableWithJoins, UpdateTableFromKind,
display_comma_separated, helpers::attached_token::AttachedToken, query::InputFormatClause,
Assignment, Expr, FromTable, Ident, InsertAliases, MysqlInsertPriority, ObjectName, OnInsert,
OrderByExpr, Query, SelectItem, Setting, SqliteOnConflict, TableObject, TableWithJoins,
UpdateTableFromKind,
};

/// INSERT statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Insert {
/// Token for the `INSERT` keyword (or its substitutes)
pub insert_token: AttachedToken,
/// Only for Sqlite
pub or: Option<SqliteOnConflict>,
/// Only for mysql
Expand Down Expand Up @@ -179,6 +182,8 @@ impl Display for Insert {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Delete {
/// Token for the `DELETE` keyword
pub delete_token: AttachedToken,
/// Multi tables delete are supported in mysql
pub tables: Vec<ObjectName>,
/// FROM
Expand Down Expand Up @@ -246,6 +251,8 @@ impl Display for Delete {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Update {
/// Token for the `UPDATE` keyword
pub update_token: AttachedToken,
/// TABLE
pub table: TableWithJoins,
/// Column assignments
Expand Down
9 changes: 2 additions & 7 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2787,10 +2787,11 @@ impl fmt::Display for Declare {
}

/// Sql options of a `CREATE TABLE` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CreateTableOptions {
#[default]
None,
/// Options specified using the `WITH` keyword.
/// e.g. `WITH (description = "123")`
Expand Down Expand Up @@ -2819,12 +2820,6 @@ pub enum CreateTableOptions {
TableProperties(Vec<SqlOption>),
}

impl Default for CreateTableOptions {
fn default() -> Self {
Self::None
}
}

impl fmt::Display for CreateTableOptions {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expand Down
51 changes: 38 additions & 13 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ impl Spanned for CopySource {
impl Spanned for Delete {
fn span(&self) -> Span {
let Delete {
delete_token,
tables,
from,
using,
Expand All @@ -847,26 +848,29 @@ impl Spanned for Delete {
} = self;

union_spans(
tables
.iter()
.map(|i| i.span())
.chain(core::iter::once(from.span()))
.chain(
using
.iter()
.map(|u| union_spans(u.iter().map(|i| i.span()))),
)
.chain(selection.iter().map(|i| i.span()))
.chain(returning.iter().flat_map(|i| i.iter().map(|k| k.span())))
.chain(order_by.iter().map(|i| i.span()))
.chain(limit.iter().map(|i| i.span())),
core::iter::once(delete_token.0.span).chain(
tables
.iter()
.map(|i| i.span())
.chain(core::iter::once(from.span()))
.chain(
using
.iter()
.map(|u| union_spans(u.iter().map(|i| i.span()))),
)
.chain(selection.iter().map(|i| i.span()))
.chain(returning.iter().flat_map(|i| i.iter().map(|k| k.span())))
.chain(order_by.iter().map(|i| i.span()))
.chain(limit.iter().map(|i| i.span())),
),
)
}
}

impl Spanned for Update {
fn span(&self) -> Span {
let Update {
update_token,
table,
assignments,
from,
Expand All @@ -878,6 +882,7 @@ impl Spanned for Update {

union_spans(
core::iter::once(table.span())
.chain(core::iter::once(update_token.0.span))
.chain(assignments.iter().map(|i| i.span()))
.chain(from.iter().map(|i| i.span()))
.chain(selection.iter().map(|i| i.span()))
Expand Down Expand Up @@ -1212,6 +1217,7 @@ impl Spanned for AlterIndexOperation {
impl Spanned for Insert {
fn span(&self) -> Span {
let Insert {
insert_token,
or: _, // enum, sqlite specific
ignore: _, // bool
into: _, // bool
Expand All @@ -1235,6 +1241,7 @@ impl Spanned for Insert {

union_spans(
core::iter::once(table.span())
.chain(core::iter::once(insert_token.0.span))
.chain(table_alias.as_ref().map(|i| i.span))
.chain(columns.iter().map(|i| i.span))
.chain(source.as_ref().map(|q| q.span()))
Expand Down Expand Up @@ -2535,4 +2542,22 @@ ALTER TABLE users
assert_eq!(stmt_span.start, (2, 13).into());
assert_eq!(stmt_span.end, (4, 11).into());
}

#[test]
fn test_update_statement_span() {
let sql = r#"-- foo
UPDATE foo
/* bar */
SET bar = 3
WHERE quux > 42 ;
"#;

let r = Parser::parse_sql(&crate::dialect::GenericDialect, sql).unwrap();
assert_eq!(1, r.len());

let stmt_span = r[0].span();

assert_eq!(stmt_span.start, (2, 7).into());
assert_eq!(stmt_span.end, (5, 17).into());
}
}
2 changes: 1 addition & 1 deletion src/dialect/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl Dialect for SQLiteDialect {
fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::REPLACE) {
parser.prev_token();
Some(parser.parse_insert())
Some(parser.parse_insert(parser.get_current_token().clone()))
} else {
None
}
Expand Down
51 changes: 33 additions & 18 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,11 @@ impl<'a> Parser<'a> {
Keyword::DISCARD => self.parse_discard(),
Keyword::DECLARE => self.parse_declare(),
Keyword::FETCH => self.parse_fetch_statement(),
Keyword::DELETE => self.parse_delete(),
Keyword::INSERT => self.parse_insert(),
Keyword::REPLACE => self.parse_replace(),
Keyword::DELETE => self.parse_delete(next_token),
Keyword::INSERT => self.parse_insert(next_token),
Keyword::REPLACE => self.parse_replace(next_token),
Keyword::UNCACHE => self.parse_uncache_table(),
Keyword::UPDATE => self.parse_update(),
Keyword::UPDATE => self.parse_update(next_token),
Keyword::ALTER => self.parse_alter(),
Keyword::CALL => self.parse_call(),
Keyword::COPY => self.parse_copy(),
Expand Down Expand Up @@ -11812,8 +11812,11 @@ impl<'a> Parser<'a> {
/// Parse a DELETE statement, returning a `Box`ed SetExpr
///
/// This is used to reduce the size of the stack frames in debug builds
fn parse_delete_setexpr_boxed(&mut self) -> Result<Box<SetExpr>, ParserError> {
Ok(Box::new(SetExpr::Delete(self.parse_delete()?)))
fn parse_delete_setexpr_boxed(
&mut self,
delete_token: TokenWithSpan,
) -> Result<Box<SetExpr>, ParserError> {
Ok(Box::new(SetExpr::Delete(self.parse_delete(delete_token)?)))
}

/// Parse a MERGE statement, returning a `Box`ed SetExpr
Expand All @@ -11823,7 +11826,7 @@ impl<'a> Parser<'a> {
Ok(Box::new(SetExpr::Merge(self.parse_merge()?)))
}

pub fn parse_delete(&mut self) -> Result<Statement, ParserError> {
pub fn parse_delete(&mut self, delete_token: TokenWithSpan) -> Result<Statement, ParserError> {
let (tables, with_from_keyword) = if !self.parse_keyword(Keyword::FROM) {
// `FROM` keyword is optional in BigQuery SQL.
// https://cloud.google.com/bigquery/docs/reference/standard-sql/dml-syntax#delete_statement
Expand Down Expand Up @@ -11866,6 +11869,7 @@ impl<'a> Parser<'a> {
};

Ok(Statement::Delete(Delete {
delete_token: delete_token.into(),
tables,
from: if with_from_keyword {
FromTable::WithFromKeyword(from)
Expand Down Expand Up @@ -11995,7 +11999,7 @@ impl<'a> Parser<'a> {
if self.parse_keyword(Keyword::INSERT) {
Ok(Query {
with,
body: self.parse_insert_setexpr_boxed()?,
body: self.parse_insert_setexpr_boxed(self.get_current_token().clone())?,
order_by: None,
limit_clause: None,
fetch: None,
Expand All @@ -12009,7 +12013,7 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::UPDATE) {
Ok(Query {
with,
body: self.parse_update_setexpr_boxed()?,
body: self.parse_update_setexpr_boxed(self.get_current_token().clone())?,
order_by: None,
limit_clause: None,
fetch: None,
Expand All @@ -12023,7 +12027,7 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::DELETE) {
Ok(Query {
with,
body: self.parse_delete_setexpr_boxed()?,
body: self.parse_delete_setexpr_boxed(self.get_current_token().clone())?,
limit_clause: None,
order_by: None,
fetch: None,
Expand Down Expand Up @@ -15450,15 +15454,18 @@ impl<'a> Parser<'a> {
}

/// Parse an REPLACE statement
pub fn parse_replace(&mut self) -> Result<Statement, ParserError> {
pub fn parse_replace(
&mut self,
replace_token: TokenWithSpan,
) -> Result<Statement, ParserError> {
if !dialect_of!(self is MySqlDialect | GenericDialect) {
return parser_err!(
"Unsupported statement REPLACE",
self.peek_token().span.start
);
}

let mut insert = self.parse_insert()?;
let mut insert = self.parse_insert(replace_token)?;
if let Statement::Insert(Insert { replace_into, .. }) = &mut insert {
*replace_into = true;
}
Expand All @@ -15469,12 +15476,15 @@ impl<'a> Parser<'a> {
/// Parse an INSERT statement, returning a `Box`ed SetExpr
///
/// This is used to reduce the size of the stack frames in debug builds
fn parse_insert_setexpr_boxed(&mut self) -> Result<Box<SetExpr>, ParserError> {
Ok(Box::new(SetExpr::Insert(self.parse_insert()?)))
fn parse_insert_setexpr_boxed(
&mut self,
insert_token: TokenWithSpan,
) -> Result<Box<SetExpr>, ParserError> {
Ok(Box::new(SetExpr::Insert(self.parse_insert(insert_token)?)))
}

/// Parse an INSERT statement
pub fn parse_insert(&mut self) -> Result<Statement, ParserError> {
pub fn parse_insert(&mut self, insert_token: TokenWithSpan) -> Result<Statement, ParserError> {
let or = self.parse_conflict_clause();
let priority = if !dialect_of!(self is MySqlDialect | GenericDialect) {
None
Expand Down Expand Up @@ -15643,6 +15653,7 @@ impl<'a> Parser<'a> {
};

Ok(Statement::Insert(Insert {
insert_token: insert_token.into(),
or,
table: table_object,
table_alias,
Expand Down Expand Up @@ -15734,11 +15745,14 @@ impl<'a> Parser<'a> {
/// Parse an UPDATE statement, returning a `Box`ed SetExpr
///
/// This is used to reduce the size of the stack frames in debug builds
fn parse_update_setexpr_boxed(&mut self) -> Result<Box<SetExpr>, ParserError> {
Ok(Box::new(SetExpr::Update(self.parse_update()?)))
fn parse_update_setexpr_boxed(
&mut self,
update_token: TokenWithSpan,
) -> Result<Box<SetExpr>, ParserError> {
Ok(Box::new(SetExpr::Update(self.parse_update(update_token)?)))
}

pub fn parse_update(&mut self) -> Result<Statement, ParserError> {
pub fn parse_update(&mut self, update_token: TokenWithSpan) -> Result<Statement, ParserError> {
let or = self.parse_conflict_clause();
let table = self.parse_table_and_joins()?;
let from_before_set = if self.parse_keyword(Keyword::FROM) {
Expand Down Expand Up @@ -15773,6 +15787,7 @@ impl<'a> Parser<'a> {
None
};
Ok(Update {
update_token: update_token.into(),
table,
assignments,
from,
Expand Down
16 changes: 14 additions & 2 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ use sqlparser::dialect::{
};
use sqlparser::keywords::{Keyword, ALL_KEYWORDS};
use sqlparser::parser::{Parser, ParserError, ParserOptions};
use sqlparser::tokenizer::Tokenizer;
use sqlparser::tokenizer::{Location, Span};
use sqlparser::tokenizer::{Location, Span, TokenWithSpan};
use sqlparser::tokenizer::{Token, Tokenizer};
use test_utils::{
all_dialects, all_dialects_where, all_dialects_with_options, alter_table_op, assert_eq_vec,
call, expr_from_projection, join, number, only, table, table_alias, table_from_name,
Expand Down Expand Up @@ -440,6 +440,10 @@ fn parse_update_set_from() {
assert_eq!(
stmt,
Statement::Update(Update {
update_token: AttachedToken(TokenWithSpan {
token: Token::make_keyword("UPDATE"),
span: Span::new((1, 1).into(), (1, 7).into()),
}),
table: TableWithJoins {
relation: table_from_name(ObjectName::from(vec![Ident::new("t1")])),
joins: vec![],
Expand Down Expand Up @@ -535,6 +539,7 @@ fn parse_update_with_table_alias() {
returning,
or: None,
limit: None,
update_token,
}) => {
assert_eq!(
TableWithJoins {
Expand Down Expand Up @@ -583,6 +588,13 @@ fn parse_update_with_table_alias() {
selection
);
assert_eq!(None, returning);
assert_eq!(
AttachedToken(TokenWithSpan {
token: Token::make_keyword("UPDATE"),
span: Span::new((1, 1).into(), (1, 7).into()),
}),
update_token
);
}
_ => unreachable!(),
}
Expand Down
10 changes: 9 additions & 1 deletion tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use sqlparser::ast::MysqlInsertPriority::{Delayed, HighPriority, LowPriority};
use sqlparser::ast::*;
use sqlparser::dialect::{GenericDialect, MySqlDialect};
use sqlparser::parser::{ParserError, ParserOptions};
use sqlparser::tokenizer::Span;
use sqlparser::tokenizer::Token;
use sqlparser::tokenizer::{Span, TokenWithSpan};
use test_utils::*;

#[macro_use]
Expand Down Expand Up @@ -2623,6 +2623,7 @@ fn parse_update_with_joins() {
returning,
or: None,
limit: None,
update_token,
}) => {
assert_eq!(
TableWithJoins {
Expand Down Expand Up @@ -2697,6 +2698,13 @@ fn parse_update_with_joins() {
selection
);
assert_eq!(None, returning);
assert_eq!(
AttachedToken(TokenWithSpan {
token: Token::make_keyword("UPDATE"),
span: Span::new((1, 1).into(), (1, 7).into()),
}),
update_token
);
}
_ => unreachable!(),
}
Expand Down
Loading