Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlglot.generators.spark import SparkGenerator
from sqlglot.parsers.spark import SparkParser
from sqlglot.tokens import TokenType
from sqlglot.trie import new_trie
from sqlglot.typing.spark import EXPRESSION_METADATA


Expand All @@ -16,6 +17,13 @@ class Spark(Spark2):
ARRAY_FUNCS_PROPAGATES_NULLS = True
EXPRESSION_METADATA = EXPRESSION_METADATA.copy()

LENIENT_INVERSE_TIME_MAPPING = {v: k for k, v in Spark2.TIME_MAPPING.items()} | {
# Parse zero-padded months and days, as per strptime() behavior.
"%m": "M",
"%d": "d",
}
LENIENT_INVERSE_TIME_TRIE = new_trie(LENIENT_INVERSE_TIME_MAPPING)

class Tokenizer(Spark2.Tokenizer):
STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False

Expand Down
39 changes: 28 additions & 11 deletions sqlglot/generators/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
)


def _groupconcat_sql(self: SparkGenerator, expression: exp.GroupConcat) -> str:
if self.dialect.version < (4,):
expr = exp.ArrayToString(
this=exp.ArrayAgg(this=expression.this),
expression=expression.args.get("separator") or exp.Literal.string(""),
)
return self.sql(expr)

return groupconcat_sql(self, expression)


def _normalize_partition(e: exp.Expr) -> exp.Expr:
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
if isinstance(e, str):
Expand All @@ -30,6 +41,21 @@ def _normalize_partition(e: exp.Expr) -> exp.Expr:
return e


def _str_to_datetime_sql(self: SparkGenerator, expression: exp.StrToDate | exp.StrToTime) -> str:
from sqlglot.dialects.spark import Spark

assert isinstance(self.dialect, Spark)
return self.func(
f"TO_{'DATE' if isinstance(expression, exp.StrToDate) else 'TIMESTAMP'}",
expression.this,
self.format_time(
expression,
self.dialect.LENIENT_INVERSE_TIME_MAPPING,
self.dialect.LENIENT_INVERSE_TIME_TRIE,
),
)


def _dateadd_sql(self: SparkGenerator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
if not expression.unit or (
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
Expand All @@ -54,17 +80,6 @@ def _dateadd_sql(self: SparkGenerator, expression: exp.TsOrDsAdd | exp.Timestamp
return this


def _groupconcat_sql(self: SparkGenerator, expression: exp.GroupConcat) -> str:
if self.dialect.version < (4,):
expr = exp.ArrayToString(
this=exp.ArrayAgg(this=expression.this),
expression=expression.args.get("separator") or exp.Literal.string(""),
)
return self.sql(expr)

return groupconcat_sql(self, expression)


class SparkGenerator(Spark2Generator):
SUPPORTS_TO_NUMBER = True
PAD_FILL_PATTERN_IS_REQUIRED = False
Expand Down Expand Up @@ -129,6 +144,8 @@ class SparkGenerator(Spark2Generator):
exp.SafeMultiply: rename_func("TRY_MULTIPLY"),
exp.SafeSubtract: rename_func("TRY_SUBTRACT"),
exp.StartsWith: rename_func("STARTSWITH"),
exp.StrToDate: _str_to_datetime_sql,
exp.StrToTime: _str_to_datetime_sql,
exp.TimeAdd: date_delta_to_binary_interval_op(cast=False),
exp.TimeSub: date_delta_to_binary_interval_op(cast=False),
exp.TsOrDsAdd: _dateadd_sql,
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/generators/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _map_sql(self: Spark2Generator, expression: exp.Map) -> str:
return self.func("MAP_FROM_ARRAYS", keys, values)


def _str_to_date(self: Spark2Generator, expression: exp.StrToDate) -> str:
def _str_to_date_sql(self: Spark2Generator, expression: exp.StrToDate) -> str:
time_format = self.format_time(expression)
if time_format == HIVE_DATE_FORMAT:
return self.func("TO_DATE", expression.this)
Expand Down Expand Up @@ -184,7 +184,7 @@ class Spark2Generator(HiveGenerator):
exp.SHA2Digest: lambda self, e: self.func(
"SHA2", e.this, e.args.get("length") or exp.Literal.number(256)
),
exp.StrToDate: _str_to_date,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
exp.UnixToTime: _unix_to_time_sql,
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def test_time(self):
"presto": "DATE_PARSE(x, '%Y-%m-%dT%T')",
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH24:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
"spark": "TO_TIMESTAMP(x, 'yyyy-M-dTHH:mm:ss')",
},
)
self.validate_all(
Expand All @@ -776,7 +776,7 @@ def test_time(self):
"postgres": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
"redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-M-d')",
},
)
self.validate_all(
Expand Down Expand Up @@ -1219,7 +1219,7 @@ def test_time(self):
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)",
"spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')",
"spark": "TO_DATE(x, 'yyyy-M-dTHH:mm:ss')",
"doris": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
},
)
Expand All @@ -1231,7 +1231,7 @@ def test_time(self):
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
"hive": "CAST(x AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)",
"spark": "TO_DATE(x, 'yyyy-M-d')",
"doris": "STR_TO_DATE(x, '%Y-%m-%d')",
},
)
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_time(self):
"duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_PARSE(x, '%Y-%m-%d %T')",
"hive": "CAST(x AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')",
"spark": "TO_TIMESTAMP(x, 'yyyy-M-d HH:mm:ss')",
},
)
self.validate_all(
Expand All @@ -315,7 +315,7 @@ def test_time(self):
"duckdb": "STRPTIME(x, '%Y-%m-%d')",
"presto": "DATE_PARSE(x, '%Y-%m-%d')",
"hive": "CAST(x AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')",
"spark": "TO_TIMESTAMP(x, 'yyyy-M-d')",
},
)
self.validate_all(
Expand All @@ -330,7 +330,7 @@ def test_time(self):
"duckdb": "STRPTIME(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
"presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
"hive": "CAST(SUBSTRING(x, 1, 10) AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-MM-dd')",
"spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-M-d')",
},
)
self.validate_all(
Expand All @@ -339,7 +339,7 @@ def test_time(self):
"duckdb": "STRPTIME(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
"presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
"hive": "CAST(SUBSTRING(x, 1, 10) AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-MM-dd')",
"spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-M-d')",
},
)
self.validate_all(
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,7 @@ def test_snowflake(self):
"bigquery": "SELECT PARSE_TIMESTAMP('%d-%m-%Y %I:%M:%S', col) FROM t",
"duckdb": "SELECT STRPTIME(col, '%d-%m-%Y %I:%M:%S') FROM t",
"snowflake": "SELECT TO_TIMESTAMP(col, 'DD-mm-yyyy hh12:mi:ss') FROM t",
"spark": "SELECT TO_TIMESTAMP(col, 'dd-MM-yyyy hh:mm:ss') FROM t",
"spark": "SELECT TO_TIMESTAMP(col, 'd-M-yyyy hh:mm:ss') FROM t",
},
)
self.validate_all(
Expand Down Expand Up @@ -1904,7 +1904,7 @@ def test_snowflake(self):
write={
"bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %T', '04/05/2013 01:02:03')",
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')",
"spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')",
"spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'M/d/yyyy HH:mm:ss')",
},
)
self.validate_all(
Expand Down
10 changes: 5 additions & 5 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,14 +659,14 @@ def test_spark(self):
},
)
self.validate_all(
"SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')",
"SELECT TO_TIMESTAMP('2016-1-1', 'yyyy-M-d')",
read={
"duckdb": "SELECT STRPTIME('2016-12-31', '%Y-%m-%d')",
"duckdb": "SELECT STRPTIME('2016-1-1', '%Y-%m-%d')",
},
write={
"": "SELECT STR_TO_TIME('2016-12-31', '%Y-%m-%d')",
"duckdb": "SELECT STRPTIME('2016-12-31', '%Y-%m-%d')",
"spark": "SELECT TO_TIMESTAMP('2016-12-31', 'yyyy-MM-dd')",
"": "SELECT STR_TO_TIME('2016-1-1', '%Y-%-m-%-d')",
"duckdb": "SELECT STRPTIME('2016-1-1', '%Y-%-m-%-d')",
"spark": "SELECT TO_TIMESTAMP('2016-1-1', 'yyyy-M-d')",
},
)
self.validate_all(
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def test_cast(self):
write={
"teradata": "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')",
"bigquery": "PARSE_DATE('%Y-%d', '1992-01')",
"databricks": "TO_DATE('1992-01', 'yyyy-dd')",
"databricks": "TO_DATE('1992-01', 'yyyy-d')",
"mysql": "STR_TO_DATE('1992-01', '%Y-%d')",
"spark": "TO_DATE('1992-01', 'yyyy-dd')",
"spark": "TO_DATE('1992-01', 'yyyy-d')",
"": "STR_TO_DATE('1992-01', '%Y-%d')",
},
)
Expand Down
6 changes: 3 additions & 3 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,21 +1749,21 @@ def test_convert(self):
self.validate_all(
"CONVERT(DATE, x, 121)",
write={
"spark": "TO_DATE(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
"spark": "TO_DATE(x, 'yyyy-M-d HH:mm:ss.SSSSSS')",
"tsql": "CONVERT(DATE, x, 121)",
},
)
self.validate_all(
"CONVERT(DATETIME, x, 121)",
write={
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-M-d HH:mm:ss.SSSSSS')",
"tsql": "CONVERT(DATETIME, x, 121)",
},
)
self.validate_all(
"CONVERT(DATETIME2, x, 121)",
write={
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-M-d HH:mm:ss.SSSSSS')",
"tsql": "CONVERT(DATETIME2, x, 121)",
},
)
Expand Down
Loading