@@ -79,36 +79,66 @@ def test_feature_store_create(
7979 role_arn = role_arn ,
8080 enable_online_store = True ,
8181 )
82- assert sagemaker_session_mock .create_feature_group .called_with (
82+ sagemaker_session_mock .create_feature_group .assert_called_with (
8383 feature_group_name = "MyFeatureGroup" ,
8484 record_identifier_name = "feature1" ,
8585 event_time_feature_name = "feature2" ,
86+ feature_definitions = [fd .to_dict () for fd in feature_group_dummy_definitions ],
8687 role_arn = role_arn ,
88+ description = None ,
89+ tags = None ,
8790 online_store_config = {"EnableOnlineStore" : True },
91+ offline_store_config = {
92+ "DisableGlueTableCreation" : False ,
93+ "S3StorageConfig" : {"S3Uri" : s3_uri },
94+ },
95+ )
96+
97+
98+ def test_feature_store_create_online_only (
99+ sagemaker_session_mock , role_arn , feature_group_dummy_definitions
100+ ):
101+ feature_group = FeatureGroup (name = "MyFeatureGroup" , sagemaker_session = sagemaker_session_mock )
102+ feature_group .feature_definitions = feature_group_dummy_definitions
103+ feature_group .create (
104+ s3_uri = False ,
105+ record_identifier_name = "feature1" ,
106+ event_time_feature_name = "feature2" ,
107+ role_arn = role_arn ,
108+ enable_online_store = True ,
109+ )
110+ sagemaker_session_mock .create_feature_group .assert_called_with (
111+ feature_group_name = "MyFeatureGroup" ,
112+ record_identifier_name = "feature1" ,
113+ event_time_feature_name = "feature2" ,
88114 feature_definitions = [fd .to_dict () for fd in feature_group_dummy_definitions ],
115+ role_arn = role_arn ,
116+ description = None ,
117+ tags = None ,
118+ online_store_config = {"EnableOnlineStore" : True },
89119 )
90120
91121
92122def test_feature_store_delete (sagemaker_session_mock ):
93123 feature_group = FeatureGroup (name = "MyFeatureGroup" , sagemaker_session = sagemaker_session_mock )
94124 feature_group .delete ()
95- assert sagemaker_session_mock .delete_feature_group .called_with (
125+ sagemaker_session_mock .delete_feature_group .assert_called_with (
96126 feature_group_name = "MyFeatureGroup"
97127 )
98128
99129
100130def test_feature_store_describe (sagemaker_session_mock ):
101131 feature_group = FeatureGroup (name = "MyFeatureGroup" , sagemaker_session = sagemaker_session_mock )
102132 feature_group .describe ()
103- assert sagemaker_session_mock .describe_feature_group .called_with (
104- feature_group_name = "MyFeatureGroup"
133+ sagemaker_session_mock .describe_feature_group .assert_called_with (
134+ feature_group_name = "MyFeatureGroup" , next_token = None
105135 )
106136
107137
108138def test_put_record (sagemaker_session_mock ):
109139 feature_group = FeatureGroup (name = "MyFeatureGroup" , sagemaker_session = sagemaker_session_mock )
110140 feature_group .put_record (record = [])
111- assert sagemaker_session_mock .put_record .called_with (
141+ sagemaker_session_mock .put_record .assert_called_with (
112142 feature_group_name = "MyFeatureGroup" , record = []
113143 )
114144
@@ -268,7 +298,7 @@ def query(sagemaker_session_mock):
268298def test_athena_query_run (sagemaker_session_mock , query ):
269299 sagemaker_session_mock .start_query_execution .return_value = {"QueryExecutionId" : "query_id" }
270300 query .run (query_string = "query" , output_location = "s3://some-bucket/some-path" )
271- assert sagemaker_session_mock .start_query_execution .called_with (
301+ sagemaker_session_mock .start_query_execution .assert_called_with (
272302 catalog = "catalog" ,
273303 database = "database" ,
274304 query_string = "query" ,
@@ -283,13 +313,13 @@ def test_athena_query_run(sagemaker_session_mock, query):
283313def test_athena_query_wait (sagemaker_session_mock , query ):
284314 query ._current_query_execution_id = "query_id"
285315 query .wait ()
286- assert sagemaker_session_mock .wait_for_athena_query .called_with (query_execution_id = "query_id" )
316+ sagemaker_session_mock .wait_for_athena_query .assert_called_with (query_execution_id = "query_id" )
287317
288318
289319def test_athena_query_get_query_execution (sagemaker_session_mock , query ):
290320 query ._current_query_execution_id = "query_id"
291321 query .get_query_execution ()
292- assert sagemaker_session_mock .wait_for_athena_query . called_with (query_execution_id = "query_id" )
322+ sagemaker_session_mock .get_query_execution . assert_called_with (query_execution_id = "query_id" )
293323
294324
295325@patch ("tempfile.gettempdir" , Mock (return_value = "tmp" ))
@@ -302,13 +332,13 @@ def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query):
302332 query ._result_bucket = "bucket"
303333 query ._result_file_prefix = "prefix"
304334 query .as_dataframe ()
305- assert sagemaker_session_mock .download_athena_query_result .called_with (
335+ sagemaker_session_mock .download_athena_query_result .assert_called_with (
306336 bucket = "bucket" ,
307337 prefix = "prefix" ,
308338 query_execution_id = "query_id" ,
309339 filename = "tmp/query_id.csv" ,
310340 )
311- assert read_csv .called_with ("tmp/query_id.csv" , delimiter = "," )
341+ read_csv .assert_called_with ("tmp/query_id.csv" , delimiter = "," )
312342
313343
314344@patch ("tempfile.gettempdir" , Mock (return_value = "tmp" ))
0 commit comments