Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions steps/02_load_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# Last Updated: 1/9/2023
#------------------------------------------------------------------------------

import time
import time,os
from dotenv import load_dotenv
from snowflake.snowpark import Session
#import snowflake.snowpark.types as T
#import snowflake.snowpark.functions as F
Expand Down Expand Up @@ -69,7 +70,16 @@ def validate_raw_tables(session):

# For local debugging
if __name__ == "__main__":
# Create a local Snowpark session
with Session.builder.getOrCreate() as session:
connection_parameters = {
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
"user": os.getenv("SNOWFLAKE_USER"),
"password": os.getenv("SNOWFLAKE_PASSWORD"),
"role": os.getenv("SNOWFLAKE_ROLE"),
"warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"),
"database": os.getenv("SNOWFLAKE_DATABASE"),
}
with Session.builder.configs(connection_parameters).create() as session:
load_all_raw_tables(session)
# validate_raw_tables(session)
print("Done.")

2 changes: 1 addition & 1 deletion steps/03_load_weather.sql
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ GRANT IMPORTED PRIVILEGES ON DATABASE FROSTBYTE_WEATHERSOURCE TO ROLE HOL_ROLE;


-- Let's look at the data - same 3-part naming convention as any other table
SELECT * FROM FROSTBYTE_WEATHERSOURCE.ONPOINT_ID.POSTAL_CODES LIMIT 100;
SELECT * FROM FROSTBYTE_WEATHERSOURCE.ONPOINT_ID.POSTAL_CODES ;
13 changes: 11 additions & 2 deletions steps/04_create_pos_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from snowflake.snowpark import Session
#import snowflake.snowpark.types as T
import snowflake.snowpark.functions as F
from dotenv import load_dotenv
import os


def create_pos_view(session):
Expand Down Expand Up @@ -106,8 +108,15 @@ def test_pos_view(session):

# For local debugging
if __name__ == "__main__":
# Create a local Snowpark session
with Session.builder.getOrCreate() as session:
connection_parameters = {
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
"user": os.getenv("SNOWFLAKE_USER"),
"password": os.getenv("SNOWFLAKE_PASSWORD"),
"role": os.getenv("SNOWFLAKE_ROLE"),
"warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"),
"database": os.getenv("SNOWFLAKE_DATABASE"),
}
with Session.builder.configs(connection_parameters).create() as session:
create_pos_view(session)
create_pos_view_stream(session)
# test_pos_view(session)
44 changes: 32 additions & 12 deletions steps/06_orders_update_sp/orders_update_sp/procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,51 @@

# SNOWFLAKE ADVANTAGE: Python Stored Procedures

import os
import sys
import time
from snowflake.snowpark import Session
#import snowflake.snowpark.types as T
# import snowflake.snowpark.types as T
import snowflake.snowpark.functions as F


def table_exists(session, schema='', name=''):
exists = session.sql("SELECT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{}' AND TABLE_NAME = '{}') AS TABLE_EXISTS".format(schema, name)).collect()[0]['TABLE_EXISTS']
exists = session.sql(
"SELECT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES "
"WHERE TABLE_SCHEMA = '{}' AND TABLE_NAME = '{}') AS TABLE_EXISTS"
.format(schema, name)
).collect()[0]['TABLE_EXISTS']
return exists


def create_orders_table(session):
_ = session.sql("CREATE TABLE HARMONIZED.ORDERS LIKE HARMONIZED.POS_FLATTENED_V").collect()
_ = session.sql("ALTER TABLE HARMONIZED.ORDERS ADD COLUMN META_UPDATED_AT TIMESTAMP").collect()


def create_orders_stream(session):
_ = session.sql("CREATE STREAM HARMONIZED.ORDERS_STREAM ON TABLE HARMONIZED.ORDERS").collect()


def merge_order_updates(session):
_ = session.sql('ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XLARGE WAIT_FOR_COMPLETION = TRUE').collect()

source = session.table('HARMONIZED.POS_FLATTENED_V_STREAM')
target = session.table('HARMONIZED.ORDERS')

# TODO: Is the if clause supposed to be based on "META_UPDATED_AT"?
cols_to_update = {c: source[c] for c in source.schema.names if "METADATA" not in c}
metadata_col_to_update = {"META_UPDATED_AT": F.current_timestamp()}
updates = {**cols_to_update, **metadata_col_to_update}

# merge into DIM_CUSTOMER
target.merge(source, target['ORDER_DETAIL_ID'] == source['ORDER_DETAIL_ID'], \
[F.when_matched().update(updates), F.when_not_matched().insert(updates)])
target.merge(
source,
target['ORDER_DETAIL_ID'] == source['ORDER_DETAIL_ID'],
[F.when_matched().update(updates), F.when_not_matched().insert(updates)]
)

_ = session.sql('ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XSMALL').collect()


def main(session: Session) -> str:
# Create the ORDERS table and ORDERS_STREAM stream if they don't exist
if not table_exists(session, schema='HARMONIZED', name='ORDERS'):
Expand All @@ -49,18 +60,27 @@ def main(session: Session) -> str:

# Process data incrementally
merge_order_updates(session)
# session.table('HARMONIZED.ORDERS').limit(5).show()
# session.table('HARMONIZED.ORDERS').limit(5).show()

return f"Successfully processed ORDERS"
return "Successfully processed ORDERS"


# For local debugging
# Be aware you may need to type-convert arguments if you add input parameters
if __name__ == '__main__':
# Create a local Snowpark session
with Session.builder.getOrCreate() as session:
import sys
connection_parameters = {
"account": os.getenv("SNOWFLAKE_ACCOUNT"),
"user": os.getenv("SNOWFLAKE_USER"),
"password": os.getenv("SNOWFLAKE_PASSWORD"),
"role": os.getenv("SNOWFLAKE_ROLE"),
"warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"),
"database": os.getenv("SNOWFLAKE_DATABASE"),
}

# Create the session inside the main block
with Session.builder.configs(connection_parameters).create() as session:
if len(sys.argv) > 1:
print(main(session, *sys.argv[1:])) # type: ignore
print(main(session, *sys.argv[1:])) # Pass CLI arguments to main()
else:
print(main(session)) # type: ignore
print(main(session))