|
1 | 1 | import logging |
2 | | -import pytest |
3 | 2 |
|
| 3 | +import pytest |
4 | 4 | from pyspark.sql import functions as F |
5 | 5 | from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType, ArrayType, MapType, \ |
6 | | - BinaryType, LongType |
| 6 | + BinaryType, LongType |
7 | 7 |
|
8 | 8 | import dbldatagen as dg |
9 | 9 |
|
@@ -245,6 +245,7 @@ def test_basic_arrays_with_existing_schema6(self, arraySchema, setupLogging): |
245 | 245 | .withColumnSpec("arrayVal", expr="array(id+1)") |
246 | 246 | ) |
247 | 247 | df = gen1.build() |
| 248 | + assert df is not None |
248 | 249 | df.show() |
249 | 250 |
|
250 | 251 | def test_use_of_struct_in_schema1(self, setupLogging): |
@@ -290,3 +291,108 @@ def test_varying_arrays(self, setupLogging): |
290 | 291 |
|
291 | 292 | df = df_spec.build() |
292 | 293 | df.show() |
| 294 | + |
| 295 | + def test_array_values(self): |
| 296 | + df_spec = dg.DataGenerator(spark, name="test-data", rows=2) |
| 297 | + df_spec = df_spec.withColumn( |
| 298 | + "test", |
| 299 | + ArrayType(StringType()), |
| 300 | + values=[ |
| 301 | + F.array(F.lit("A")), |
| 302 | + F.array(F.lit("C")), |
| 303 | + F.array(F.lit("T")), |
| 304 | + F.array(F.lit("G")), |
| 305 | + ], |
| 306 | + ) |
| 307 | + test_df = df_spec.build() |
| 308 | + |
| 309 | + rows = test_df.collect() |
| 310 | + |
| 311 | + for r in rows: |
| 312 | + assert r['test'] is not None |
| 313 | + |
| 314 | + def test_single_element_array(self): |
| 315 | + df_spec = dg.DataGenerator(spark, name="test-data", rows=2) |
| 316 | + df_spec = df_spec.withColumn( |
| 317 | + "test1", |
| 318 | + ArrayType(StringType()), |
| 319 | + values=[ |
| 320 | + F.array(F.lit("A")), |
| 321 | + F.array(F.lit("C")), |
| 322 | + F.array(F.lit("T")), |
| 323 | + F.array(F.lit("G")), |
| 324 | + ], |
| 325 | + ) |
| 326 | + df_spec = df_spec.withColumn( |
| 327 | + "test2", "string", structType="array", numFeatures=1, values=["one", "two", "three"] |
| 328 | + ) |
| 329 | + df_spec = df_spec.withColumn( |
| 330 | + "test3", "string", structType="array", numFeatures=(1, 1), values=["one", "two", "three"] |
| 331 | + ) |
| 332 | + df_spec = df_spec.withColumn( |
| 333 | + "test4", "string", structType="array", values=["one", "two", "three"] |
| 334 | + ) |
| 335 | + |
| 336 | + test_df = df_spec.build() |
| 337 | + |
| 338 | + for field in test_df.schema: |
| 339 | + assert isinstance(field.dataType, ArrayType) |
| 340 | + |
| 341 | + def test_map_values(self): |
| 342 | + df_spec = dg.DataGenerator(spark, name="test-data", rows=50, random=True) |
| 343 | + df_spec = df_spec.withColumn( |
| 344 | + "v1", |
| 345 | + "array<string>", |
| 346 | + values=[ |
| 347 | + F.array(F.lit("A")), |
| 348 | + F.array(F.lit("C")), |
| 349 | + F.array(F.lit("T")), |
| 350 | + F.array(F.lit("G")), |
| 351 | + ], |
| 352 | + ) |
| 353 | + df_spec = df_spec.withColumn( |
| 354 | + "v2", |
| 355 | + "array<string>", |
| 356 | + values=[ |
| 357 | + F.array(F.lit("one")), |
| 358 | + F.array(F.lit("two")), |
| 359 | + F.array(F.lit("three")), |
| 360 | + F.array(F.lit("four")), |
| 361 | + ], |
| 362 | + ) |
| 363 | + df_spec = df_spec.withColumn( |
| 364 | + "v3", |
| 365 | + "array<string>", |
| 366 | + values=[ |
| 367 | + F.array(F.lit("alpha")), |
| 368 | + F.array(F.lit("beta")), |
| 369 | + F.array(F.lit("delta")), |
| 370 | + F.array(F.lit("gamma")), |
| 371 | + ], |
| 372 | + ) |
| 373 | + df_spec = df_spec.withColumn( |
| 374 | + "v4", |
| 375 | + "string", |
| 376 | + values=["this", "is", "a", "test"], |
| 377 | + numFeatures=1, |
| 378 | + structType="array" |
| 379 | + ) |
| 380 | + |
| 381 | + df_spec = df_spec.withColumn( |
| 382 | + "test", |
| 383 | + "map<string,string>", |
| 384 | + values=[F.map_from_arrays(F.col("v1"), F.col("v2")), |
| 385 | + F.map_from_arrays(F.col("v1"), F.col("v3")), |
| 386 | + F.map_from_arrays(F.col("v2"), F.col("v3")), |
| 387 | + F.map_from_arrays(F.col("v1"), F.col("v4")), |
| 388 | + F.map_from_arrays(F.col("v2"), F.col("v4")), |
| 389 | + F.map_from_arrays(F.col("v3"), F.col("v4")) |
| 390 | + ], |
| 391 | + baseColumns=["v1", "v2", "v3", "v4"] |
| 392 | + ) |
| 393 | + test_df = df_spec.build() |
| 394 | + |
| 395 | + rows = test_df.collect() |
| 396 | + |
| 397 | + for r in rows: |
| 398 | + assert r['test'] is not None |
0 commit comments