diff --git a/src/snowflake/connector/aio/_pandas_tools.py b/src/snowflake/connector/aio/_pandas_tools.py new file mode 100644 index 000000000..1c36e064d --- /dev/null +++ b/src/snowflake/connector/aio/_pandas_tools.py @@ -0,0 +1,561 @@ +from __future__ import annotations + +import os +import warnings +from logging import getLogger +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Any, Literal, Sequence + +from snowflake.connector import ProgrammingError +from snowflake.connector.options import pandas +from snowflake.connector.telemetry import TelemetryData, TelemetryField + +from .._utils import ( + TempObjectType, + get_temp_type_for_object, + random_name_for_temp_object, +) +from ..constants import _PARAM_USE_SCOPED_TEMP_FOR_PANDAS_TOOLS + +# Import utilities from sync version +from ..pandas_tools import ( + _iceberg_config_statement_helper, + build_location_helper, + chunk_helper, +) +from ._cursor import SnowflakeCursor + +if TYPE_CHECKING: # pragma: no cover + from ._connection import SnowflakeConnection + + try: + import sqlalchemy + except ImportError: + sqlalchemy = None + +logger = getLogger(__name__) + + +async def _do_create_temp_stage( + cursor: SnowflakeCursor, + stage_location: str, + compression: str, + auto_create_table: bool, + overwrite: bool, + use_scoped_temp_object: bool, +) -> None: + create_stage_sql = f"CREATE {get_temp_type_for_object(use_scoped_temp_object)} STAGE /* Python:snowflake.connector.aio._pandas_tools.write_pandas() */ identifier(?) FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}{' BINARY_AS_TEXT=FALSE' if auto_create_table or overwrite else ''})" + params = (stage_location,) + logger.debug(f"creating stage with '{create_stage_sql}'. params: %s", params) + await cursor.execute( + create_stage_sql, + _is_internal=True, + _force_qmark_paramstyle=True, + params=params, + num_statements=1, + ) + + +async def _create_temp_stage( + cursor: SnowflakeCursor, + database: str | None, + schema: str | None, + quote_identifiers: bool, + compression: str, + auto_create_table: bool, + overwrite: bool, + use_scoped_temp_object: bool = False, +) -> str: + stage_name = random_name_for_temp_object(TempObjectType.STAGE) + stage_location = build_location_helper( + database=database, + schema=schema, + name=stage_name, + quote_identifiers=quote_identifiers, + ) + try: + await _do_create_temp_stage( + cursor, + stage_location, + compression, + auto_create_table, + overwrite, + use_scoped_temp_object, + ) + except ProgrammingError as e: + # User may not have the privilege to create stage on the target schema, so fall back to use current schema as + # the old behavior. + logger.debug( + f"creating stage {stage_location} failed. Exception {str(e)}. Fall back to use current schema" + ) + stage_location = stage_name + await _do_create_temp_stage( + cursor, + stage_location, + compression, + auto_create_table, + overwrite, + use_scoped_temp_object, + ) + + return stage_location + + +async def _do_create_temp_file_format( + cursor: SnowflakeCursor, + file_format_location: str, + compression: str, + sql_use_logical_type: str, + use_scoped_temp_object: bool, +) -> None: + file_format_sql = ( + f"CREATE {get_temp_type_for_object(use_scoped_temp_object)} FILE FORMAT identifier(?) " + f"/* Python:snowflake.connector.aio._pandas_tools.write_pandas() */ " + f"TYPE=PARQUET COMPRESSION={compression}{sql_use_logical_type}" + ) + params = (file_format_location,) + logger.debug(f"creating file format with '{file_format_sql}'. params: %s", params) + await cursor.execute( + file_format_sql, + _is_internal=True, + _force_qmark_paramstyle=True, + params=params, + num_statements=1, + ) + + +async def _create_temp_file_format( + cursor: SnowflakeCursor, + database: str | None, + schema: str | None, + quote_identifiers: bool, + compression: str, + sql_use_logical_type: str, + use_scoped_temp_object: bool = False, +) -> str: + file_format_name = random_name_for_temp_object(TempObjectType.FILE_FORMAT) + file_format_location = build_location_helper( + database=database, + schema=schema, + name=file_format_name, + quote_identifiers=quote_identifiers, + ) + try: + await _do_create_temp_file_format( + cursor, + file_format_location, + compression, + sql_use_logical_type, + use_scoped_temp_object, + ) + except ProgrammingError as e: + # User may not have the privilege to create file format on the target schema, so fall back to use current schema + # as the old behavior. + logger.debug( + f"creating stage {file_format_location} failed. Exception {str(e)}. Fall back to use current schema" + ) + file_format_location = file_format_name + await _do_create_temp_file_format( + cursor, + file_format_location, + compression, + sql_use_logical_type, + use_scoped_temp_object, + ) + + return file_format_location + + +async def write_pandas( + conn: SnowflakeConnection, + df: pandas.DataFrame, + table_name: str, + database: str | None = None, + schema: str | None = None, + chunk_size: int | None = None, + compression: str = "gzip", + on_error: str = "abort_statement", + parallel: int = 4, + quote_identifiers: bool = True, + infer_schema: bool = False, + auto_create_table: bool = False, + overwrite: bool = False, + table_type: Literal["", "temp", "temporary", "transient"] = "", + use_logical_type: bool | None = None, + iceberg_config: dict[str, str] | None = None, + bulk_upload_chunks: bool = False, + use_vectorized_scanner: bool = False, + **kwargs: Any, +) -> tuple[ + bool, + int, + int, + Sequence[ + tuple[ + str, + str, + int, + int, + int, + int, + str | None, + int | None, + int | None, + str | None, + ] + ], +]: + """Allows users to most efficiently write back a pandas DataFrame to Snowflake. + + It works by dumping the DataFrame into Parquet files, uploading them and finally copying their data into the table. + + Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested + with all of the COPY INTO command's output for debugging purposes. + + Example usage: + import pandas + from snowflake.connector.aio import SnowflakeConnection + from snowflake.connector.aio.pandas_tools import write_pandas + + async with SnowflakeConnection(...) as conn: + df = pandas.DataFrame([('Mark', 10), ('Luke', 20)], columns=['name', 'balance']) + success, nchunks, nrows, _ = await write_pandas(conn, df, 'customers') + + Args: + conn: Connection to be used to communicate with Snowflake. + df: Dataframe we'd like to write back. + table_name: Table name where we want to insert into. + database: Database schema and table is in, if not provided the default one will be used (Default value = None). + schema: Schema table is in, if not provided the default one will be used (Default value = None). + chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once + (Default value = None). + compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a + better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip'). + on_error: Action to take when COPY INTO statements fail, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions + (Default value = 'abort_statement'). + use_vectorized_scanner: Boolean that specifies whether to use a vectorized scanner for loading Parquet files. See details at + `copy options `_. + parallel: Number of threads to be used when uploading chunks, default follows documentation at: + https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4). + quote_identifiers: By default, identifiers, specifically database, schema, table and column names + (from df.columns) will be quoted. If set to False, identifiers are passed on to Snowflake without quoting. + I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True) + infer_schema: Perform explicit schema inference on the data in the DataFrame and use the inferred data types + when selecting columns from the DataFrame. (Default value = False) + auto_create_table: When true, will automatically create a table with corresponding columns for each column in + the passed in DataFrame. The table will not be created if it already exists + overwrite: When true, and if auto_create_table is true, then it drops the table. Otherwise, it + truncates the table. In both cases it will replace the existing contents of the table with that of the passed in + Pandas DataFrame. + table_type: The table type of to-be-created table. The supported table types include ``temp``/``temporary`` + and ``transient``. Empty means permanent table as per SQL convention. + use_logical_type: Boolean that specifies whether to use Parquet logical types. With this file format option, + Snowflake can interpret Parquet logical types during data loading. To enable Parquet logical types, + set use_logical_type as True. Set to None to use Snowflakes default. For more information, see: + https://docs.snowflake.com/en/sql-reference/sql/create-file-format + iceberg_config: A dictionary that can contain the following iceberg configuration values: + * external_volume: specifies the identifier for the external volume where + the Iceberg table stores its metadata files and data in Parquet format + * catalog: specifies either Snowflake or a catalog integration to use for this table + * base_location: the base directory that snowflake can write iceberg metadata and files to + * catalog_sync: optionally sets the catalog integration configured for Polaris Catalog + * storage_serialization_policy: specifies the storage serialization policy for the table + bulk_upload_chunks: If set to True, the upload will use the wildcard upload method. + This is a faster method of uploading but instead of uploading and cleaning up each chunk separately it will upload all chunks at once and then clean up locally stored chunks. + + + + Returns: + Returns the COPY INTO command's results to verify ingestion in the form of a tuple of whether all chunks were + ingested correctly, # of chunks, # of ingested rows, and ingest's output. + """ + if database is not None and schema is None: + raise ProgrammingError( + "Schema has to be provided to write_pandas when a database is provided" + ) + # This dictionary maps the compression algorithm to Snowflake put copy into command type + # https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#type-parquet + compression_map = {"gzip": "auto", "snappy": "snappy"} + if compression not in compression_map.keys(): + raise ProgrammingError( + f"Invalid compression '{compression}', only acceptable values are: {compression_map.keys()}" + ) + + # TODO(SNOW-1505026): Get rid of this when the BCR to always create scoped temp for intermediate results is done. + _use_scoped_temp_object = ( + conn._session_parameters.get(_PARAM_USE_SCOPED_TEMP_FOR_PANDAS_TOOLS, False) + if conn._session_parameters + else False + ) + + if table_type and table_type.lower() not in ["temp", "temporary", "transient"]: + raise ValueError( + "Unsupported table type. Expected table types: temp/temporary, transient" + ) + + if table_type.lower() in ["temp", "temporary"]: + # Add scoped keyword when applicable. + table_type = get_temp_type_for_object(_use_scoped_temp_object).lower() + + if chunk_size is None: + chunk_size = len(df) + + if not ( + isinstance(df.index, pandas.RangeIndex) + and 1 == df.index.step + and 0 == df.index.start + ): + warnings.warn( + f"Pandas Dataframe has non-standard index of type {str(type(df.index))} which will not be written." + f" Consider changing the index to pd.RangeIndex(start=0,...,step=1) or " + f"call reset_index() to keep index as column(s)", + UserWarning, + stacklevel=2, + ) + + # use_logical_type should be True when dataframe contains datetimes with timezone. + # https://github.com/snowflakedb/snowflake-connector-python/issues/1687 + if not use_logical_type and any( + [pandas.api.types.is_datetime64tz_dtype(df[c]) for c in df.columns] + ): + warnings.warn( + "Dataframe contains a datetime with timezone column, but " + f"'{use_logical_type=}'. This can result in datetimes " + "being incorrectly written to Snowflake. Consider setting " + "'use_logical_type = True'", + UserWarning, + stacklevel=2, + ) + + if use_logical_type is None: + sql_use_logical_type = "" + elif use_logical_type: + sql_use_logical_type = " USE_LOGICAL_TYPE = TRUE" + else: + sql_use_logical_type = " USE_LOGICAL_TYPE = FALSE" + + cursor = conn.cursor() + stage_location = await _create_temp_stage( + cursor, + database, + schema, + quote_identifiers, + compression, + auto_create_table, + overwrite, + _use_scoped_temp_object, + ) + + with TemporaryDirectory() as tmp_folder: + for i, chunk in chunk_helper(df, chunk_size): + chunk_path = os.path.join(tmp_folder, f"file{i}.txt") + # Dump chunk into parquet file + chunk.to_parquet(chunk_path, compression=compression, **kwargs) + if not bulk_upload_chunks: + # Upload parquet file chunk right away + path = chunk_path.replace("\\", "\\\\").replace("'", "\\'") + await cursor._upload( + local_file_name=f"'file://{path}'", + stage_location="@" + stage_location, + options={"parallel": parallel, "source_compression": "auto_detect"}, + ) + + # Remove chunk file + os.remove(chunk_path) + + if bulk_upload_chunks: + # Upload tmp directory with parquet chunks + path = tmp_folder.replace("\\", "\\\\").replace("'", "\\'") + await cursor._upload( + local_file_name=f"'file://{path}/*'", + stage_location="@" + stage_location, + options={"parallel": parallel, "source_compression": "auto_detect"}, + ) + + # in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly + # see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html) + if quote_identifiers: + quote = '"' + # if the column name contains a double quote, we need to escape it by replacing with two double quotes + # https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers + snowflake_column_names = [str(c).replace('"', '""') for c in df.columns] + else: + quote = "" + snowflake_column_names = list(df.columns) + columns = quote + f"{quote},{quote}".join(snowflake_column_names) + quote + + async def drop_object(name: str, object_type: str) -> None: + drop_sql = f"DROP {object_type.upper()} IF EXISTS identifier(?) /* Python:snowflake.connector.aio._pandas_tools.write_pandas() */" + params = (name,) + logger.debug(f"dropping {object_type} with '{drop_sql}'. params: %s", params) + + await cursor.execute( + drop_sql, + _is_internal=True, + _force_qmark_paramstyle=True, + params=params, + num_statements=1, + ) + + if auto_create_table or overwrite or infer_schema: + file_format_location = await _create_temp_file_format( + cursor, + database, + schema, + quote_identifiers, + compression_map[compression], + sql_use_logical_type, + _use_scoped_temp_object, + ) + infer_schema_sql = "SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>?, file_format=>?))" + params = (f"@{stage_location}", file_format_location) + logger.debug(f"inferring schema with '{infer_schema_sql}'. params: %s", params) + column_type_mapping = dict( + await ( + await cursor.execute( + infer_schema_sql, + _is_internal=True, + _force_qmark_paramstyle=True, + params=params, + num_statements=1, + ) + ).fetchall() + ) + # Infer schema can return the columns out of order depending on the chunking we do when uploading + # so we have to iterate through the dataframe columns to make sure we create the table with its + # columns in order + create_table_columns = ", ".join( + [ + f"{quote}{snowflake_col}{quote} {column_type_mapping[col]}" + for snowflake_col, col in zip(snowflake_column_names, df.columns) + ] + ) + + target_table_location = build_location_helper( + database, + schema, + ( + random_name_for_temp_object(TempObjectType.TABLE) + if (overwrite and auto_create_table) + else table_name + ), + quote_identifiers, + ) + + if auto_create_table or overwrite: + iceberg = "ICEBERG " if iceberg_config else "" + iceberg_config_statement = _iceberg_config_statement_helper( + iceberg_config or {} + ) + + create_table_sql = ( + f"CREATE {table_type.upper()} {iceberg}TABLE IF NOT EXISTS identifier(?) " + f"({create_table_columns}) {iceberg_config_statement}" + f" /* Python:snowflake.connector.aio._pandas_tools.write_pandas() */ " + ) + params = (target_table_location,) + logger.debug( + f"auto creating table with '{create_table_sql}'. params: %s", params + ) + await cursor.execute( + create_table_sql, + _is_internal=True, + _force_qmark_paramstyle=True, + params=params, + num_statements=1, + ) + + # need explicit casting when the underlying table schema is inferred + parquet_columns = "$1:" + ",$1:".join( + f"{quote}{snowflake_col}{quote}::{column_type_mapping[col]}" + for snowflake_col, col in zip(snowflake_column_names, df.columns) + ) + else: + target_table_location = build_location_helper( + database=database, + schema=schema, + name=table_name, + quote_identifiers=quote_identifiers, + ) + parquet_columns = "$1:" + ",$1:".join( + f"{quote}{snowflake_col}{quote}" for snowflake_col in snowflake_column_names + ) + + try: + if overwrite and (not auto_create_table): + truncate_sql = "TRUNCATE TABLE identifier(?) /* Python:snowflake.connector.aio._pandas_tools.write_pandas() */" + params = (target_table_location,) + logger.debug(f"truncating table with '{truncate_sql}'. params: %s", params) + await cursor.execute( + truncate_sql, + _is_internal=True, + _force_qmark_paramstyle=True, + params=params, + num_statements=1, + ) + + copy_stage_location = "@" + stage_location.replace("'", "\\'") + copy_into_sql = ( + f"COPY INTO identifier(?) /* Python:snowflake.connector.aio._pandas_tools.write_pandas() */ " + f"({columns}) " + f"FROM (SELECT {parquet_columns} FROM '{copy_stage_location}') " + f"FILE_FORMAT=(" + f"TYPE=PARQUET " + f"USE_VECTORIZED_SCANNER={use_vectorized_scanner} " + f"COMPRESSION={compression_map[compression]}" + f"{' BINARY_AS_TEXT=FALSE' if auto_create_table or overwrite or infer_schema else ''}" + f"{sql_use_logical_type}" + f") " + f"PURGE=TRUE ON_ERROR=?" + ) + params = ( + target_table_location, + on_error, + ) + logger.debug(f"copying into with '{copy_into_sql}'. params: %s", params) + copy_results = await ( + await cursor.execute( + copy_into_sql, + _is_internal=True, + _force_qmark_paramstyle=True, + params=params, + num_statements=1, + ) + ).fetchall() + + if overwrite and auto_create_table: + original_table_location = build_location_helper( + database=database, + schema=schema, + name=table_name, + quote_identifiers=quote_identifiers, + ) + await drop_object(original_table_location, "table") + rename_table_sql = "ALTER TABLE identifier(?) RENAME TO identifier(?) /* Python:snowflake.connector.aio._pandas_tools.write_pandas() */" + params = (target_table_location, original_table_location) + logger.debug(f"rename table with '{rename_table_sql}'. params: %s", params) + await cursor.execute( + rename_table_sql, + _is_internal=True, + _force_qmark_paramstyle=True, + params=params, + num_statements=1, + ) + except ProgrammingError: + if overwrite and auto_create_table: + # drop table only if we created a new one with a random name + await drop_object(target_table_location, "table") + raise + finally: + await cursor._log_telemetry_job_data( + TelemetryField.PANDAS_WRITE, TelemetryData.TRUE + ) + await cursor.close() + + return ( + all(e[1] == "LOADED" for e in copy_results), + len(copy_results), + sum(int(e[3]) for e in copy_results), + copy_results, + ) diff --git a/test/integ/aio_it/test_arrow_result_async.py b/test/integ/aio_it/test_arrow_result_async.py index 7974d39f8..033f4650a 100644 --- a/test/integ/aio_it/test_arrow_result_async.py +++ b/test/integ/aio_it/test_arrow_result_async.py @@ -42,6 +42,13 @@ serialize, ) +try: + from snowflake.connector.aio._pandas_tools import write_pandas + from snowflake.connector.options import pandas +except ImportError: + pandas = None + write_pandas = None + @pytest.fixture(scope="module") def structured_type_support(module_conn_cnx): @@ -1152,6 +1159,43 @@ async def iterate_over_test_chunk( assert str(arrow_res[0]) == expected[i] +@pytest.mark.skipif(not pandas_available, reason="test requires pandas") +async def test_iceberg_write_pandas(conn_cnx, iceberg_support, structured_type_support): + if not structured_type_support: + pytest.skip("Test requires structured type support.") + if not iceberg_support: + pytest.skip("Test requires iceberg support.") + table_name = f"write_pandas_iceberg_test_table_{random_string(5)}" + + data = ( + 1, + "A", + # Server side infer schema can only create VARIANTS for pandas structured data + # [1, 2, 3], + # {"a": 1}, + # {"b": 1, "c": "d"}, + ) + + pdf = pandas.DataFrame([data], columns=["A", "B"]) + config = { + "CATALOG": "SNOWFLAKE", + "EXTERNAL_VOLUME": "python_connector_iceberg_exvol", + "BASE_LOCATION": "python_connector_merge_gate", + } + + async with conn_cnx() as conn: + try: + await write_pandas( + conn, pdf, table_name, auto_create_table=True, iceberg_config=config + ) + results = await ( + await conn.cursor().execute(f'select * from "{table_name}"') + ).fetchall() + assert results == [data] + finally: + await conn.cursor().execute(f"drop table IF EXISTS {table_name};") + + @pytest.mark.parametrize("debug_arrow_chunk", [True, False]) @pytest.mark.asyncio async def test_arrow_bad_data(conn_cnx, caplog, debug_arrow_chunk): diff --git a/test/integ/aio_it/test_direct_file_operation_utils_async.py b/test/integ/aio_it/test_direct_file_operation_utils_async.py index 350b50675..8bb031bf4 100644 --- a/test/integ/aio_it/test_direct_file_operation_utils_async.py +++ b/test/integ/aio_it/test_direct_file_operation_utils_async.py @@ -9,14 +9,8 @@ try: from snowflake.connector.options import pandas - from snowflake.connector.pandas_tools import ( - _iceberg_config_statement_helper, - write_pandas, - ) except ImportError: pandas = None - write_pandas = None - _iceberg_config_statement_helper = None if TYPE_CHECKING: from snowflake.connector.aio import SnowflakeConnection, SnowflakeCursor diff --git a/test/integ/aio_it/test_pandas_tools_async.py b/test/integ/aio_it/test_pandas_tools_async.py new file mode 100644 index 000000000..05246ebe5 --- /dev/null +++ b/test/integ/aio_it/test_pandas_tools_async.py @@ -0,0 +1,1266 @@ +#!/usr/bin/env python +from __future__ import annotations + +import math +import re +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Callable, Generator +from unittest import mock +from unittest.mock import MagicMock + +import numpy.random +import pytest + +from snowflake.connector import ProgrammingError +from snowflake.connector.aio import DictCursor, SnowflakeCursor + +try: + from snowflake.connector.util_text import random_string +except ImportError: + from ...randomize import random_string + +from ...lazy_var import LazyVar + +try: + from snowflake.connector.aio._pandas_tools import write_pandas + from snowflake.connector.options import pandas + from snowflake.connector.pandas_tools import _iceberg_config_statement_helper +except ImportError: + pandas = None + write_pandas = None + _iceberg_config_statement_helper = None + +if TYPE_CHECKING: + from snowflake.connector.aio import SnowflakeConnection + +sf_connector_version_data = [ + ("snowflake-connector-python", "1.2.23"), + ("snowflake-sqlalchemy", "1.1.1"), + ("snowflake-connector-go", "0.0.1"), + ("snowflake-go", "1.0.1"), + ("snowflake-odbc", "3.12.3"), +] + +sf_connector_version_df = LazyVar( + lambda: pandas.DataFrame( + sf_connector_version_data, columns=["name", "newest_version"] + ) +) + + +async def assert_result_equals( + cnx: SnowflakeConnection, + num_of_chunks: int, + sql: str, + expected_data: list[tuple[Any, ...]], +): + if num_of_chunks == 1: + # Note: since we used one chunk order is conserved + assert await (await cnx.cursor().execute(sql)).fetchall() == expected_data + else: + # Note: since we used more than one chunk order is NOT conserved + assert set(await (await cnx.cursor().execute(sql)).fetchall()) == set( + expected_data + ) + + +async def test_fix_snow_746341( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], +): + cat = '"cat"' + df = pandas.DataFrame([[1], [2]], columns=[f"col_'{cat}'"]) + table_name = random_string(5, "snow746341_") + async with conn_cnx() as conn: + await write_pandas( + conn, df, table_name, auto_create_table=True, table_type="temporary" + ) + assert await ( + await conn.cursor().execute(f'select * from "{table_name}"') + ).fetchall() == [ + (1,), + (2,), + ] + + +@pytest.mark.parametrize("quote_identifiers", [True, False]) +@pytest.mark.parametrize("auto_create_table", [True, False]) +@pytest.mark.parametrize("index", [False]) +async def test_write_pandas_with_overwrite( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + quote_identifiers: bool, + auto_create_table: bool, + index: bool, +): + """Tests whether overwriting table using a Pandas DataFrame works as expected.""" + random_table_name = random_string(5, "userspoints_") + df1_data = [("John", 10), ("Jane", 20)] + df1 = pandas.DataFrame(df1_data, columns=["name", "points"]) + df2_data = [("Dash", 50)] + df2 = pandas.DataFrame(df2_data, columns=["name", "points"]) + df3_data = [(2022, "Jan", 10000), (2022, "Feb", 10220)] + df3 = pandas.DataFrame(df3_data, columns=["year", "month", "revenue"]) + df4_data = [("Frank", 100)] + df4 = pandas.DataFrame(df4_data, columns=["name%", "points"]) + + if quote_identifiers: + table_name = '"' + random_table_name + '"' + col_id = '"id"' + col_name = '"name"' + col_points = '"points"' + else: + table_name = random_table_name + col_id = "id" + col_name = "name" + col_points = "points" + + create_sql = ( + f"CREATE OR REPLACE TABLE {table_name}" + f"({col_name} STRING, {col_points} INT, {col_id} INT AUTOINCREMENT)" + ) + + select_sql = f"SELECT * FROM {table_name}" + select_count_sql = f"SELECT count(*) FROM {table_name}" + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + async with conn_cnx() as cnx: # type: SnowflakeConnection + await cnx.execute_string(create_sql) + try: + # Write dataframe with 2 rows + await write_pandas( + cnx, + df1, + random_table_name, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=True, + index=index, + ) + # Write dataframe with 1 row + success, nchunks, nrows, _ = await write_pandas( + cnx, + df2, + random_table_name, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=True, + index=index, + ) + # Check write_pandas output + assert success + assert nchunks == 1 + result = await ( + await cnx.cursor(DictCursor).execute(select_count_sql) + ).fetchone() + # Check number of rows + assert result["COUNT(*)"] == 1 + + # Write dataframe with a different schema + if auto_create_table: + # Should drop table and SUCCEED because the new table will be created with new schema of df3 + success, nchunks, nrows, _ = await write_pandas( + cnx, + df3, + random_table_name, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=True, + index=index, + ) + # Check write_pandas output + assert success + assert nchunks == 1 + result = await cnx.execute_string(select_sql) + # Check column names + assert ( + "year" + if quote_identifiers + else "YEAR" in [col.name for col in result[0].description] + ) + else: + # Should fail because the table will be truncated and df3 schema doesn't match + # (since df3 should at least have a subset of the columns of the target table) + with pytest.raises(ProgrammingError, match="invalid identifier"): + await write_pandas( + cnx, + df3, + random_table_name, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=True, + index=index, + ) + + # Check that we have truncated the table but not dropped it in case or error. + result = await ( + await cnx.cursor(DictCursor).execute(select_count_sql) + ).fetchone() + assert result["COUNT(*)"] == 0 + + if not quote_identifiers: + original_result = await ( + await cnx.cursor(DictCursor).execute(select_count_sql) + ).fetchone() + # the column name contains special char which should fail + with pytest.raises(ProgrammingError, match="unexpected '%'"): + await write_pandas( + cnx, + df4, + random_table_name, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + overwrite=True, + index=index, + ) + # the original table shouldn't have any change + assert ( + original_result + == await ( + await cnx.cursor(DictCursor).execute(select_count_sql) + ).fetchone() + ) + + finally: + await cnx.execute_string(drop_sql) + + +@pytest.mark.parametrize("chunk_size", [5, 1]) +@pytest.mark.parametrize( + "compression", + [ + "gzip", + ], +) +@pytest.mark.parametrize("quote_identifiers", [True, False]) +@pytest.mark.parametrize("auto_create_table", [True, False]) +@pytest.mark.parametrize("create_temp_table", [True, False]) +@pytest.mark.parametrize("index", [False]) +async def test_write_pandas( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + db_parameters: dict[str, str], + compression: str, + chunk_size: int, + quote_identifiers: bool, + auto_create_table: bool, + create_temp_table: bool, + index: bool, +): + num_of_chunks = math.ceil(len(sf_connector_version_data) / chunk_size) + + async with conn_cnx( + user=db_parameters["user"], + account=db_parameters["account"], + ) as cnx: + table_name = "driver_versions" + + if quote_identifiers: + create_sql = 'CREATE OR REPLACE TABLE "{}" ("name" STRING, "newest_version" STRING)'.format( + table_name + ) + select_sql = f'SELECT * FROM "{table_name}"' + drop_sql = f'DROP TABLE IF EXISTS "{table_name}"' + else: + create_sql = "CREATE OR REPLACE TABLE {} (name STRING, newest_version STRING)".format( + table_name + ) + select_sql = f"SELECT * FROM {table_name}" + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + + if not auto_create_table: + await cnx.execute_string(create_sql) + try: + success, nchunks, nrows, _ = await write_pandas( + cnx, + sf_connector_version_df.get(), + table_name, + compression=compression, + chunk_size=chunk_size, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + create_temp_table=create_temp_table, + index=index, + ) + + await assert_result_equals( + cnx, num_of_chunks, select_sql, sf_connector_version_data + ) + + # Make sure all files were loaded and no error occurred + assert success + # Make sure overall as many rows were ingested as we tried to insert + assert nrows == len(sf_connector_version_data) + # Make sure we uploaded in as many chunk as we wanted to + assert nchunks == num_of_chunks + # Check to see if this is a temporary or regular table if we auto-created this table + if auto_create_table: + table_info = await ( + await cnx.cursor(DictCursor).execute( + f"show tables like '{table_name}'" + ) + ).fetchall() + assert table_info[0]["kind"] == ( + "TEMPORARY" if create_temp_table else "TABLE" + ) + finally: + await cnx.execute_string(drop_sql) + + +async def test_write_non_range_index_pandas( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + db_parameters: dict[str, str], +): + compression = "gzip" + chunk_size = 3 + quote_identifiers: bool = False + auto_create_table: bool = True + create_temp_table: bool = False + index: bool = False + + # use pandas dataframe with float index + n_rows = 17 + pandas_df = pandas.DataFrame( + pandas.DataFrame( + numpy.random.normal(size=(n_rows, 4)), + columns=["a", "b", "c", "d"], + index=numpy.random.normal(size=n_rows), + ) + ) + + # convert to list of tuples to compare to received output + pandas_df_data = [tuple(row) for row in list(pandas_df.values)] + + num_of_chunks = math.ceil(len(pandas_df_data) / chunk_size) + + async with conn_cnx() as cnx: + table_name = "driver_versions" + + if quote_identifiers: + create_sql = 'CREATE OR REPLACE TABLE "{}" ("name" STRING, "newest_version" STRING)'.format( + table_name + ) + select_sql = f'SELECT * FROM "{table_name}"' + drop_sql = f'DROP TABLE IF EXISTS "{table_name}"' + else: + create_sql = "CREATE OR REPLACE TABLE {} (name STRING, newest_version STRING)".format( + table_name + ) + select_sql = f"SELECT * FROM {table_name}" + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + + if not auto_create_table: + await cnx.execute_string(create_sql) + try: + success, nchunks, nrows, _ = await write_pandas( + cnx, + pandas_df, + table_name, + compression=compression, + chunk_size=chunk_size, + quote_identifiers=quote_identifiers, + auto_create_table=auto_create_table, + create_temp_table=create_temp_table, + index=index, + ) + + await assert_result_equals(cnx, num_of_chunks, select_sql, pandas_df_data) + + # Make sure all files were loaded and no error occurred + assert success + # Make sure overall as many rows were ingested as we tried to insert + assert nrows == len(pandas_df_data) + # Make sure we uploaded in as many chunk as we wanted to + assert nchunks == num_of_chunks + # Check to see if this is a temporary or regular table if we auto-created this table + if auto_create_table: + table_info = await ( + await cnx.cursor(DictCursor).execute( + f"show tables like '{table_name}'" + ) + ).fetchall() + assert table_info[0]["kind"] == ( + "TEMPORARY" if create_temp_table else "TABLE" + ) + finally: + await cnx.execute_string(drop_sql) + + +@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"]) +async def test_write_pandas_table_type( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + table_type: str, +): + async with conn_cnx() as cnx: + table_name = random_string(5, "write_pandas_table_type_") + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + try: + success, _, _, _ = await write_pandas( + cnx, + sf_connector_version_df.get(), + table_name, + table_type=table_type, + auto_create_table=True, + ) + table_info = await ( + await cnx.cursor(DictCursor).execute(f"show tables like '{table_name}'") + ).fetchall() + assert success + if not table_type: + expected_table_kind = "TABLE" + elif table_type == "temp": + expected_table_kind = "TEMPORARY" + else: + expected_table_kind = table_type.upper() + assert table_info[0]["kind"] == expected_table_kind + finally: + await cnx.execute_string(drop_sql) + + +async def test_write_pandas_create_temp_table_deprecation_warning( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], +): + async with conn_cnx() as cnx: + table_name = random_string(5, "driver_versions_") + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + try: + with pytest.deprecated_call(match="create_temp_table is deprecated"): + success, _, _, _ = await write_pandas( + cnx, + sf_connector_version_df.get(), + table_name, + create_temp_table=True, + auto_create_table=True, + ) + + assert success + table_info = await ( + await cnx.cursor(DictCursor).execute(f"show tables like '{table_name}'") + ).fetchall() + assert table_info[0]["kind"] == "TEMPORARY" + finally: + await cnx.execute_string(drop_sql) + + +@pytest.mark.parametrize("use_logical_type", [None, True, False]) +async def test_write_pandas_use_logical_type( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + use_logical_type: bool | None, +): + table_name = random_string(5, "USE_LOCAL_TYPE_").upper() + col_name = "DT" + create_sql = f"CREATE OR REPLACE TABLE {table_name} ({col_name} TIMESTAMP_TZ)" + select_sql = f"SELECT * FROM {table_name}" + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + timestamp = datetime( + year=2020, + month=1, + day=2, + hour=3, + minute=4, + second=5, + microsecond=6, + tzinfo=timezone(timedelta(hours=2)), + ) + df_write = pandas.DataFrame({col_name: [timestamp]}) + + async with conn_cnx() as cnx: # type: SnowflakeConnection + await (await cnx.cursor().execute(create_sql)).fetchall() + + write_pandas_kwargs = dict( + conn=cnx, + df=df_write, + use_logical_type=use_logical_type, + auto_create_table=False, + table_name=table_name, + ) + + try: + # When use_logical_type = True, datetimes with timestamps should be + # correctly written to Snowflake. + if use_logical_type: + await write_pandas(**write_pandas_kwargs) + df_read = await ( + await cnx.cursor().execute(select_sql) + ).fetch_pandas_all() + assert all(df_write == df_read) + # For other use_logical_type values, a UserWarning should be displayed. + else: + with pytest.warns(UserWarning, match="Dataframe contains a datetime.*"): + await write_pandas(**write_pandas_kwargs) + finally: + await cnx.execute_string(drop_sql) + + +async def test_invalid_table_type_write_pandas( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], +): + async with conn_cnx() as cnx: + with pytest.raises(ValueError, match="Unsupported table type"): + await write_pandas( + cnx, + sf_connector_version_df.get(), + "invalid_table_type", + table_type="invalid", + ) + + +async def test_empty_dataframe_write_pandas( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], +): + table_name = random_string(5, "empty_dataframe_") + df = pandas.DataFrame([], columns=["name", "balance"]) + async with conn_cnx() as cnx: + success, num_chunks, num_rows, _ = await write_pandas( + cnx, df, table_name, auto_create_table=True, table_type="temp" + ) + assert ( + success and num_chunks == 1 and num_rows == 0 + ), f"sucess: {success}, num_chunks: {num_chunks}, num_rows: {num_rows}" + + +@pytest.mark.parametrize( + "database,schema,quote_identifiers,expected_location", + [ + ("database", "schema", True, '"database"."schema"."table"'), + ("database", "schema", False, "database.schema.table"), + (None, "schema", True, '"schema"."table"'), + (None, "schema", False, "schema.table"), + (None, None, True, '"table"'), + (None, None, False, "table"), + ], +) +async def test_table_location_building( + conn_cnx, + database: str | None, + schema: str | None, + quote_identifiers: bool, + expected_location: str, +): + """This tests that write_pandas constructs table location correctly with database, schema, and table name.""" + + async with conn_cnx() as cnx: + + async def mocked_execute(*args, **kwargs): + if len(args) >= 1 and args[0].startswith("COPY INTO"): + assert kwargs["params"][0] == expected_location + cur = SnowflakeCursor(cnx) + cur._result = iter([]) + return cur + + with mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor.execute", + side_effect=mocked_execute, + ) as m_execute, mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor._upload", + side_effect=MagicMock(), + ) as _: + success, nchunks, nrows, _ = await write_pandas( + cnx, + sf_connector_version_df.get(), + "table", + database=database, + schema=schema, + quote_identifiers=quote_identifiers, + ) + assert m_execute.called and any( + map(lambda e: "COPY INTO" in str(e[0]), m_execute.call_args_list) + ) + + +@pytest.mark.parametrize( + "database,schema,quote_identifiers,expected_db_schema", + [ + ("database", "schema", True, '"database"."schema"'), + ("database", "schema", False, "database.schema"), + (None, "schema", True, '"schema"'), + (None, "schema", False, "schema"), + (None, None, True, ""), + (None, None, False, ""), + ("data'base", "schema", True, '"data\'base"."schema"'), + ("data'base", "schema", False, '"data\'base".schema'), + ], +) +async def test_stage_location_building( + conn_cnx, + database: str | None, + schema: str | None, + quote_identifiers: bool, + expected_db_schema: str, +): + """This tests that write_pandas constructs stage location correctly with database and schema.""" + + async with conn_cnx() as cnx: + + async def mocked_execute(*args, **kwargs): + if len(args) >= 1 and args[0].startswith("create temporary stage"): + db_schema = ".".join(args[0].split(" ")[-1].split(".")[:-1]) + assert db_schema == expected_db_schema + cur = SnowflakeCursor(cnx) + cur._result = iter([]) + return cur + + with mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor.execute", + side_effect=mocked_execute, + ) as m_execute, mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor._upload", + side_effect=MagicMock(), + ) as _: + success, nchunks, nrows, _ = await write_pandas( + cnx, + sf_connector_version_df.get(), + "table", + database=database, + schema=schema, + quote_identifiers=quote_identifiers, + ) + assert m_execute.called and any( + map( + lambda e: ("CREATE TEMP STAGE" in str(e[0])), + m_execute.call_args_list, + ) + ) + + +@pytest.mark.skip("scoped object isn't used yet.") +@pytest.mark.parametrize( + "database,schema,quote_identifiers,expected_db_schema", + [ + ("database", "schema", True, '"database"."schema"'), + ("database", "schema", False, "database.schema"), + (None, "schema", True, '"schema"'), + (None, "schema", False, "schema"), + (None, None, True, ""), + (None, None, False, ""), + ], +) +async def test_use_scoped_object( + conn_cnx, + database: str | None, + schema: str | None, + quote_identifiers: bool, + expected_db_schema: str, +): + """This tests that write_pandas constructs stage location correctly with database and schema.""" + + async with conn_cnx() as cnx: + + async def mocked_execute(*args, **kwargs): + if len(args) >= 1 and args[0].startswith("create temporary stage"): + db_schema = ".".join(args[0].split(" ")[-1].split(".")[:-1]) + assert db_schema == expected_db_schema + cur = SnowflakeCursor(cnx) + cur._result = iter([]) + return cur + + with mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor.execute", + side_effect=mocked_execute, + ) as m_execute, mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor._upload", + side_effect=MagicMock(), + ) as _: + await cnx._update_parameters( + {"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS": True} + ) + success, nchunks, nrows, _ = await write_pandas( + cnx, + sf_connector_version_df.get(), + "table", + database=database, + schema=schema, + quote_identifiers=quote_identifiers, + ) + assert m_execute.called and any( + map( + lambda e: ("CREATE SCOPED TEMPORARY STAGE" in str(e[0])), + m_execute.call_args_list, + ) + ) + + +@pytest.mark.parametrize( + "database,schema,quote_identifiers,expected_db_schema", + [ + ("database", "schema", True, '"database"."schema"'), + ("database", "schema", False, "database.schema"), + (None, "schema", True, '"schema"'), + (None, "schema", False, "schema"), + (None, None, True, ""), + (None, None, False, ""), + ], +) +async def test_file_format_location_building( + conn_cnx, + database: str | None, + schema: str | None, + quote_identifiers: bool, + expected_db_schema: str, +): + """This tests that write_pandas constructs file format location correctly with database and schema.""" + + async with conn_cnx() as cnx: + + async def mocked_execute(*args, **kwargs): + if len(args) >= 1 and args[0].startswith("CREATE FILE FORMAT"): + db_schema = ".".join(args[0].split(" ")[3].split(".")[:-1]) + assert db_schema == expected_db_schema + cur = SnowflakeCursor(cnx) + if args[0].startswith("SELECT"): + cur._rownumber = 0 + cur._result = iter( + [(col, "") for col in sf_connector_version_df.get().columns] + ) + else: + cur._result = iter([]) + return cur + + with mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor.execute", + side_effect=mocked_execute, + ) as m_execute, mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor._upload", + side_effect=MagicMock(), + ) as _: + success, nchunks, nrows, _ = await write_pandas( + cnx, + sf_connector_version_df.get(), + "table", + database=database, + schema=schema, + quote_identifiers=quote_identifiers, + auto_create_table=True, + ) + assert m_execute.called and any( + map( + lambda e: ("CREATE TEMP FILE FORMAT" in str(e[0])), + m_execute.call_args_list, + ) + ) + + +@pytest.mark.parametrize("quote_identifiers", [True, False]) +async def test_default_value_insertion( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + quote_identifiers: bool, +): + """Tests whether default values can be successfully inserted with the pandas writeback.""" + table_name = "users" + df_data = [("Mark", 10), ("Luke", 20)] + + # Create a DataFrame containing data about customers + df = pandas.DataFrame(df_data, columns=["name", "balance"]) + # Assume quote_identifiers is true in string and if not remove " from strings + create_sql = """CREATE OR REPLACE TABLE "{}" + ("name" STRING, "balance" INT, + "id" varchar(36) default uuid_string(), + "ts" timestamp_ltz default current_timestamp)""".format( + table_name + ) + select_sql = f'SELECT * FROM "{table_name}"' + drop_sql = f'DROP TABLE IF EXISTS "{table_name}"' + if not quote_identifiers: + create_sql = create_sql.replace('"', "") + select_sql = select_sql.replace('"', "") + drop_sql = drop_sql.replace('"', "") + async with conn_cnx() as cnx: # type: SnowflakeConnection + await cnx.execute_string(create_sql) + try: + success, nchunks, nrows, _ = await write_pandas( + cnx, df, table_name, quote_identifiers=quote_identifiers + ) + + # Check write_pandas output + assert success + assert nrows == len(df_data) + assert nchunks == 1 + # Check table's contents + result = await (await cnx.cursor(DictCursor).execute(select_sql)).fetchall() + for row in result: + assert ( + row["id" if quote_identifiers else "ID"] is not None + ) # ID (UUID String) + assert len(row["id" if quote_identifiers else "ID"]) == 36 + assert ( + row["ts" if quote_identifiers else "TS"] is not None + ) # TS (Current Timestamp) + assert isinstance(row["ts" if quote_identifiers else "TS"], datetime) + assert ( + row["name" if quote_identifiers else "NAME"], + row["balance" if quote_identifiers else "BALANCE"], + ) in df_data + finally: + await cnx.execute_string(drop_sql) + + +@pytest.mark.parametrize("quote_identifiers", [True, False]) +async def test_autoincrement_insertion( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + quote_identifiers: bool, +): + """Tests whether default values can be successfully inserted with the pandas writeback.""" + table_name = "users" + df_data = [("Mark", 10), ("Luke", 20)] + + # Create a DataFrame containing data about customers + df = pandas.DataFrame(df_data, columns=["name", "balance"]) + # Assume quote_identifiers is true in string and if not remove " from strings + create_sql = ( + 'CREATE OR REPLACE TABLE "{}"' + '("name" STRING, "balance" INT, "id" INT AUTOINCREMENT)' + ).format(table_name) + select_sql = f'SELECT * FROM "{table_name}"' + drop_sql = f'DROP TABLE IF EXISTS "{table_name}"' + if not quote_identifiers: + create_sql = create_sql.replace('"', "") + select_sql = select_sql.replace('"', "") + drop_sql = drop_sql.replace('"', "") + async with conn_cnx() as cnx: # type: SnowflakeConnection + await cnx.execute_string(create_sql) + try: + success, nchunks, nrows, _ = await write_pandas( + cnx, df, table_name, quote_identifiers=quote_identifiers + ) + + # Check write_pandas output + assert success + assert nrows == len(df_data) + assert nchunks == 1 + # Check table's contents + result = await (await cnx.cursor(DictCursor).execute(select_sql)).fetchall() + for row in result: + assert row["id" if quote_identifiers else "ID"] in (1, 2) + assert ( + row["name" if quote_identifiers else "NAME"], + row["balance" if quote_identifiers else "BALANCE"], + ) in df_data + finally: + await cnx.execute_string(drop_sql) + + +@pytest.mark.parametrize("auto_create_table", [True, False]) +@pytest.mark.parametrize( + "column_names", + [ + ["00 name", "bAl_ance"], + ['c""ol', '"col"'], + ["c''ol", "'col'"], + ["チリヌル", "熊猫"], + ], +) +async def test_special_name_quoting( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + auto_create_table: bool, + column_names: list[str], +): + """Tests whether special column names get quoted as expected.""" + table_name = "users" + df_data = [("Mark", 10), ("Luke", 20)] + + df = pandas.DataFrame(df_data, columns=column_names) + snowflake_column_names = [c.replace('"', '""') for c in column_names] + create_sql = ( + f'CREATE OR REPLACE TABLE "{table_name}"' + f'("{snowflake_column_names[0]}" STRING, "{snowflake_column_names[1]}" INT, "id" INT AUTOINCREMENT)' + ) + select_sql = f'SELECT * FROM "{table_name}"' + drop_sql = f'DROP TABLE IF EXISTS "{table_name}"' + async with conn_cnx() as cnx: # type: SnowflakeConnection + if not auto_create_table: + await cnx.execute_string(create_sql) + try: + success, nchunks, nrows, _ = await write_pandas( + cnx, + df, + table_name, + quote_identifiers=True, + auto_create_table=auto_create_table, + ) + + # Check write_pandas output + assert success + assert nrows == len(df_data) + assert nchunks == 1 + # Check table's contents + result = await (await cnx.cursor(DictCursor).execute(select_sql)).fetchall() + for row in result: + # The auto create table functionality does not auto-create an incrementing ID + if not auto_create_table: + assert row["id"] in (1, 2) + assert ( + row[column_names[0]], + row[column_names[1]], + ) in df_data + finally: + await cnx.execute_string(drop_sql) + + +async def test_auto_create_table_similar_column_names( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], +): + """Tests whether similar names do not cause issues when auto-creating a table as expected.""" + table_name = random_string(5, "numbas_") + df_data = [(10, 11), (20, 21)] + + df = pandas.DataFrame(df_data, columns=["number", "Number"]) + select_sql = f'SELECT * FROM "{table_name}"' + drop_sql = f'DROP TABLE IF EXISTS "{table_name}"' + async with conn_cnx() as cnx: + try: + success, nchunks, nrows, _ = await write_pandas( + cnx, df, table_name, quote_identifiers=True, auto_create_table=True + ) + + # Check write_pandas output + assert success + assert nrows == len(df_data) + assert nchunks == 1 + # Check table's contents + result = await (await cnx.cursor(DictCursor).execute(select_sql)).fetchall() + for row in result: + assert ( + row["number"], + row["Number"], + ) in df_data + finally: + await cnx.execute_string(drop_sql) + + +async def test_all_pandas_types( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], +): + table_name = random_string(5, "all_types_") + datetime_with_tz = datetime(1997, 6, 3, 14, 21, 32, 00, tzinfo=timezone.utc) + datetime_with_ntz = datetime(1997, 6, 3, 14, 21, 32, 00) + df_data = [ + [ + 1, + 1.1, + "1string1", + True, + datetime_with_tz, + datetime_with_ntz, + datetime_with_tz.date(), + datetime_with_tz.time(), + bytes("a", "utf-8"), + ], + [ + 2, + 2.2, + "2string2", + False, + datetime_with_tz, + datetime_with_ntz, + datetime_with_tz.date(), + datetime_with_tz.time(), + bytes("b", "utf-16"), + ], + ] + columns = [ + "int", + "float", + "string", + "bool", + "timestamp_tz", + "timestamp_ntz", + "date", + "time", + "binary", + ] + + df = pandas.DataFrame( + df_data, + columns=columns, + ) + + select_sql = f'SELECT * FROM "{table_name}"' + drop_sql = f'DROP TABLE IF EXISTS "{table_name}"' + async with conn_cnx() as cnx: + try: + success, nchunks, nrows, _ = await write_pandas( + cnx, + df, + table_name, + quote_identifiers=True, + auto_create_table=True, + use_logical_type=True, + ) + + # Check write_pandas output + assert success + assert nrows == len(df_data) + assert nchunks == 1 + # Check table's contents + cur = await cnx.cursor(DictCursor).execute(select_sql) + result = await cur.fetchall() + for row, data in zip(result, df_data): + for c in columns: + # TODO: check values of timestamp data after SNOW-667350 is fixed + if "timestamp" in c: + assert row[c] is not None + else: + assert row[c] in data + finally: + await cnx.execute_string(drop_sql) + + +@pytest.mark.parametrize("object_type", ["STAGE", "FILE FORMAT"]) +async def test_no_create_internal_object_privilege_in_target_schema( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + caplog, + object_type, +): + source_schema = random_string(5, "source_schema_") + target_schema = random_string(5, "target_schema_no_create_") + table = random_string(5, "table_") + select_sql = f"select * from {target_schema}.{table}" + + async with conn_cnx() as cnx: + try: + await cnx.execute_string(f"create or replace schema {source_schema}") + await cnx.execute_string(f"create or replace schema {target_schema}") + original_execute = SnowflakeCursor.execute + + async def mock_execute(*args, **kwargs): + if ( + f"CREATE TEMP {object_type}" in args[0] + and "target_schema_no_create_" in kwargs["params"][0] + ): + raise ProgrammingError("Cannot create temp object in target schema") + cursor = cnx.cursor() + await original_execute(cursor, *args, **kwargs) + return cursor + + with mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor.execute", + side_effect=mock_execute, + ): + with caplog.at_level("DEBUG"): + success, num_of_chunks, _, _ = await write_pandas( + cnx, + sf_connector_version_df.get(), + table, + database=cnx.database, + schema=target_schema, + auto_create_table=True, + quote_identifiers=False, + ) + + assert "Fall back to use current schema" in caplog.text + assert success + await assert_result_equals( + cnx, num_of_chunks, select_sql, sf_connector_version_data + ) + finally: + await cnx.execute_string(f"drop schema if exists {source_schema}") + await cnx.execute_string(f"drop schema if exists {target_schema}") + + +def test__iceberg_config_statement_helper(): + config = { + "EXTERNAL_VOLUME": "vol", + "CATALOG": "'SNOWFLAKE'", + "BASE_LOCATION": "/root", + "CATALOG_SYNC": "foo", + "STORAGE_SERIALIZATION_POLICY": "bar", + } + assert ( + _iceberg_config_statement_helper(config) + == "EXTERNAL_VOLUME='vol' CATALOG='SNOWFLAKE' BASE_LOCATION='/root' CATALOG_SYNC='foo' STORAGE_SERIALIZATION_POLICY='bar'" + ) + + config["STORAGE_SERIALIZATION_POLICY"] = None + assert ( + _iceberg_config_statement_helper(config) + == "EXTERNAL_VOLUME='vol' CATALOG='SNOWFLAKE' BASE_LOCATION='/root' CATALOG_SYNC='foo'" + ) + + config["foo"] = True + config["bar"] = True + with pytest.raises( + ProgrammingError, + match=re.escape("Invalid iceberg configurations option(s) provided BAR, FOO"), + ): + _iceberg_config_statement_helper(config) + + +async def test_write_pandas_with_on_error( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], +): + """Tests whether overwriting table using a Pandas DataFrame works as expected.""" + random_table_name = random_string(5, "userspoints_") + df_data = [("Dash", 50)] + df = pandas.DataFrame(df_data, columns=["name", "points"]) + + table_name = random_table_name + col_id = "id" + col_name = "name" + col_points = "points" + + create_sql = ( + f"CREATE OR REPLACE TABLE {table_name}" + f"({col_name} STRING, {col_points} INT, {col_id} INT AUTOINCREMENT)" + ) + + select_count_sql = f"SELECT count(*) FROM {table_name}" + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + async with conn_cnx() as cnx: # type: SnowflakeConnection + await cnx.execute_string(create_sql) + try: + # Write dataframe with 1 row + success, nchunks, nrows, _ = await write_pandas( + cnx, + df, + random_table_name, + quote_identifiers=False, + auto_create_table=False, + overwrite=True, + index=True, + on_error="continue", + ) + # Check write_pandas output + assert success + assert nchunks == 1 + assert nrows == 1 + result = await ( + await cnx.cursor(DictCursor).execute(select_count_sql) + ).fetchone() + # Check number of rows + assert result["COUNT(*)"] == 1 + finally: + await cnx.execute_string(drop_sql) + + +async def test_pandas_with_single_quote( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], +): + random_table_name = random_string(5, "test'table") + table_name = f'"{random_table_name}"' + create_sql = f"CREATE OR REPLACE TABLE {table_name}(A INT)" + df_data = [[1]] + df = pandas.DataFrame(df_data, columns=["a"]) + async with conn_cnx() as cnx: # type: SnowflakeConnection + try: + await cnx.execute_string(create_sql) + await write_pandas( + cnx, + df, + table_name, + quote_identifiers=False, + auto_create_table=False, + index=False, + ) + finally: + await cnx.execute_string(f"drop table if exists {table_name}") + + +@pytest.mark.parametrize("bulk_upload_chunks", [True, False]) +async def test_write_pandas_bulk_chunks_upload(conn_cnx, bulk_upload_chunks): + """Tests whether overwriting table using a Pandas DataFrame works as expected.""" + random_table_name = random_string(5, "userspoints_") + df_data = [("Dash", 50), ("Luke", 20), ("Mark", 10), ("John", 30)] + df = pandas.DataFrame(df_data, columns=["name", "points"]) + + table_name = random_table_name + col_id = "id" + col_name = "name" + col_points = "points" + + create_sql = ( + f"CREATE OR REPLACE TABLE {table_name}" + f"({col_name} STRING, {col_points} INT, {col_id} INT AUTOINCREMENT)" + ) + + select_count_sql = f"SELECT count(*) FROM {table_name}" + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + async with conn_cnx() as cnx: # type: SnowflakeConnection + await cnx.execute_string(create_sql) + try: + # Write dataframe with 1 row + success, nchunks, nrows, _ = await write_pandas( + cnx, + df, + random_table_name, + quote_identifiers=False, + auto_create_table=False, + overwrite=True, + index=True, + on_error="continue", + chunk_size=1, + bulk_upload_chunks=bulk_upload_chunks, + ) + # Check write_pandas output + assert success + assert nchunks == 4 + assert nrows == 4 + result = await ( + await cnx.cursor(DictCursor).execute(select_count_sql) + ).fetchone() + # Check number of rows + assert result["COUNT(*)"] == 4 + finally: + await cnx.execute_string(drop_sql) + + +@pytest.mark.parametrize( + "use_vectorized_scanner", + [ + True, + False, + ], +) +async def test_write_pandas_with_use_vectorized_scanner( + conn_cnx: Callable[..., Generator[SnowflakeConnection]], + use_vectorized_scanner, + caplog, +): + """Tests whether overwriting table using a Pandas DataFrame works as expected.""" + random_table_name = random_string(5, "userspoints_") + df_data = [("Dash", 50)] + df = pandas.DataFrame(df_data, columns=["name", "points"]) + + table_name = random_table_name + col_id = "id" + col_name = "name" + col_points = "points" + + create_sql = ( + f"CREATE OR REPLACE TABLE {table_name}" + f"({col_name} STRING, {col_points} INT, {col_id} INT AUTOINCREMENT)" + ) + + drop_sql = f"DROP TABLE IF EXISTS {table_name}" + async with conn_cnx() as cnx: # type: SnowflakeConnection + original_cur = (await cnx.cursor()).execute + + async def fake_execute(query, params=None, *args, **kwargs): + return await original_cur(query, params, *args, **kwargs) + + await cnx.execute_string(create_sql) + try: + with mock.patch( + "snowflake.connector.aio._cursor.SnowflakeCursor.execute", + side_effect=fake_execute, + ) as execute: + # Write dataframe with 1 row + success, nchunks, nrows, _ = await write_pandas( + cnx, + df, + random_table_name, + quote_identifiers=False, + auto_create_table=False, + overwrite=True, + index=True, + use_vectorized_scanner=use_vectorized_scanner, + ) + # Check write_pandas output + assert success + assert nchunks == 1 + assert nrows == 1 + + for call in execute.call_args_list: + if call.args[0].startswith("COPY"): + assert ( + f"USE_VECTORIZED_SCANNER={use_vectorized_scanner}" + in call.args[0] + ) + + finally: + await cnx.execute_string(drop_sql)