diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 279ad438f1..7cf3a66e8a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,13 +27,13 @@ repos: - id: check-yaml - id: check-ast - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.9 + rev: v0.14.3 hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.17.1 + rev: v1.18.2 hooks: - id: mypy args: diff --git a/pyiceberg/avro/codecs/__init__.py b/pyiceberg/avro/codecs/__init__.py index ce592ccc5a..d5d3a7c4e5 100644 --- a/pyiceberg/avro/codecs/__init__.py +++ b/pyiceberg/avro/codecs/__init__.py @@ -26,7 +26,7 @@ from __future__ import annotations -from typing import Dict, Literal, Optional, Type +from typing import Dict, Literal, Type from typing_extensions import TypeAlias @@ -40,7 +40,7 @@ AVRO_CODEC_KEY = "avro.codec" -KNOWN_CODECS: Dict[AvroCompressionCodec, Optional[Type[Codec]]] = { +KNOWN_CODECS: Dict[AvroCompressionCodec, Type[Codec] | None] = { "null": None, "bzip2": BZip2Codec, "snappy": SnappyCodec, diff --git a/pyiceberg/avro/decoder.py b/pyiceberg/avro/decoder.py index 708392aad4..75b3209027 100644 --- a/pyiceberg/avro/decoder.py +++ b/pyiceberg/avro/decoder.py @@ -21,7 +21,6 @@ Dict, List, Tuple, - Union, cast, ) @@ -137,7 +136,7 @@ class StreamingBinaryDecoder(BinaryDecoder): __slots__ = "_input_stream" _input_stream: InputStream - def __init__(self, input_stream: Union[bytes, InputStream]) -> None: + def __init__(self, input_stream: bytes | InputStream) -> None: """Reader is a Python object on which we can call read, seek, and tell.""" if isinstance(input_stream, bytes): # In the case of bytes, we wrap it into a BytesIO to make it a stream diff --git a/pyiceberg/avro/file.py b/pyiceberg/avro/file.py index 82b042a412..3b91d70d85 100644 --- a/pyiceberg/avro/file.py +++ b/pyiceberg/avro/file.py @@ -30,7 +30,6 @@ Dict, Generic, List, - Optional, Type, TypeVar, ) @@ -85,7 +84,7 @@ def meta(self) -> Dict[str, str]: def sync(self) -> bytes: return self._data[2] - def compression_codec(self) -> Optional[Type[Codec]]: + def compression_codec(self) -> Type[Codec] | None: """Get the file's compression codec algorithm from the file's metadata. In the case of a null codec, we return a None indicating that we @@ -146,7 +145,7 @@ class AvroFile(Generic[D]): "block", ) input_file: InputFile - read_schema: Optional[Schema] + read_schema: Schema | None read_types: Dict[int, Callable[..., StructProtocol]] read_enums: Dict[int, Callable[..., Enum]] header: AvroFileHeader @@ -154,12 +153,12 @@ class AvroFile(Generic[D]): reader: Reader decoder: BinaryDecoder - block: Optional[Block[D]] + block: Block[D] | None def __init__( self, input_file: InputFile, - read_schema: Optional[Schema] = None, + read_schema: Schema | None = None, read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT, read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT, ) -> None: @@ -186,9 +185,7 @@ def __enter__(self) -> AvroFile[D]: return self - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ) -> None: + def __exit__(self, exctype: Type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None) -> None: """Perform cleanup when exiting the scope of a 'with' statement.""" def __iter__(self) -> AvroFile[D]: @@ -242,7 +239,7 @@ def __init__( output_file: OutputFile, file_schema: Schema, schema_name: str, - record_schema: Optional[Schema] = None, + record_schema: Schema | None = None, metadata: Dict[str, str] = EMPTY_DICT, ) -> None: self.output_file = output_file @@ -270,9 +267,7 @@ def __enter__(self) -> AvroOutputFile[D]: return self - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ) -> None: + def __exit__(self, exctype: Type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None) -> None: """Perform cleanup when exiting the scope of a 'with' statement.""" self.output_stream.close() @@ -289,7 +284,7 @@ def _write_header(self) -> None: header = AvroFileHeader(MAGIC, meta, self.sync_bytes) construct_writer(META_SCHEMA).write(self.encoder, header) - def compression_codec(self) -> Optional[Type[Codec]]: + def compression_codec(self) -> Type[Codec] | None: """Get the file's compression codec algorithm from the file's metadata. In the case of a null codec, we return a None indicating that we diff --git a/pyiceberg/avro/reader.py b/pyiceberg/avro/reader.py index bccc772022..97c41be473 100644 --- a/pyiceberg/avro/reader.py +++ b/pyiceberg/avro/reader.py @@ -35,7 +35,6 @@ Callable, List, Mapping, - Optional, Tuple, ) from uuid import UUID @@ -292,7 +291,7 @@ def __repr__(self) -> str: class OptionReader(Reader): option: Reader = dataclassfield() - def read(self, decoder: BinaryDecoder) -> Optional[Any]: + def read(self, decoder: BinaryDecoder) -> Any | None: # For the Iceberg spec it is required to set the default value to null # From https://iceberg.apache.org/spec/#avro # Optional fields must always set the Avro field default value to null. @@ -320,14 +319,14 @@ class StructReader(Reader): "_hash", "_max_pos", ) - field_readers: Tuple[Tuple[Optional[int], Reader], ...] + field_readers: Tuple[Tuple[int | None, Reader], ...] create_struct: Callable[..., StructProtocol] struct: StructType - field_reader_functions = Tuple[Tuple[Optional[str], int, Optional[Callable[[BinaryDecoder], Any]]], ...] + field_reader_functions = Tuple[Tuple[str | None, int, Callable[[BinaryDecoder], Any] | None], ...] def __init__( self, - field_readers: Tuple[Tuple[Optional[int], Reader], ...], + field_readers: Tuple[Tuple[int | None, Reader], ...], create_struct: Callable[..., StructProtocol], struct: StructType, ) -> None: @@ -339,7 +338,7 @@ def __init__( if not isinstance(self.create_struct(), StructProtocol): raise ValueError(f"Incompatible with StructProtocol: {self.create_struct}") - reading_callbacks: List[Tuple[Optional[int], Callable[[BinaryDecoder], Any]]] = [] + reading_callbacks: List[Tuple[int | None, Callable[[BinaryDecoder], Any]]] = [] max_pos = -1 for pos, field in field_readers: if pos is not None: diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index c4ec393513..c11d2878aa 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -20,9 +20,7 @@ Callable, Dict, List, - Optional, Tuple, - Union, ) from pyiceberg.avro.decoder import BinaryDecoder @@ -116,7 +114,7 @@ def construct_reader( - file_schema: Union[Schema, IcebergType], read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT + file_schema: Schema | IcebergType, read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT ) -> Reader: """Construct a reader from a file schema. @@ -130,7 +128,7 @@ def construct_reader( return resolve_reader(file_schema, file_schema, read_types) -def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer: +def construct_writer(file_schema: Schema | IcebergType) -> Writer: """Construct a writer from a file schema. Args: @@ -216,8 +214,8 @@ def visit_unknown(self, unknown_type: UnknownType) -> Writer: def resolve_writer( - record_schema: Union[Schema, IcebergType], - file_schema: Union[Schema, IcebergType], + record_schema: Schema | IcebergType, + file_schema: Schema | IcebergType, ) -> Writer: """Resolve the file and read schema to produce a reader. @@ -234,8 +232,8 @@ def resolve_writer( def resolve_reader( - file_schema: Union[Schema, IcebergType], - read_schema: Union[Schema, IcebergType], + file_schema: Schema | IcebergType, + read_schema: Schema | IcebergType, read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT, read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT, ) -> Reader: @@ -273,15 +271,15 @@ def skip(self, decoder: BinaryDecoder) -> None: class WriteSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Writer]): - def schema(self, file_schema: Schema, record_schema: Optional[IcebergType], result: Writer) -> Writer: + def schema(self, file_schema: Schema, record_schema: IcebergType | None, result: Writer) -> Writer: return result - def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], file_writers: List[Writer]) -> Writer: + def struct(self, file_schema: StructType, record_struct: IcebergType | None, file_writers: List[Writer]) -> Writer: if not isinstance(record_struct, StructType): raise ResolveError(f"File/write schema are not aligned for struct, got {record_struct}") record_struct_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(record_struct.fields)} - results: List[Tuple[Optional[int], Writer]] = [] + results: List[Tuple[int | None, Writer]] = [] for writer, file_field in zip(file_writers, file_schema.fields): if file_field.field_id in record_struct_positions: @@ -298,18 +296,16 @@ def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], return StructWriter(field_writers=tuple(results)) - def field(self, file_field: NestedField, record_type: Optional[IcebergType], field_writer: Writer) -> Writer: + def field(self, file_field: NestedField, record_type: IcebergType | None, field_writer: Writer) -> Writer: return field_writer if file_field.required else OptionWriter(field_writer) - def list(self, file_list_type: ListType, file_list: Optional[IcebergType], element_writer: Writer) -> Writer: + def list(self, file_list_type: ListType, file_list: IcebergType | None, element_writer: Writer) -> Writer: return ListWriter(element_writer if file_list_type.element_required else OptionWriter(element_writer)) - def map( - self, file_map_type: MapType, file_primitive: Optional[IcebergType], key_writer: Writer, value_writer: Writer - ) -> Writer: + def map(self, file_map_type: MapType, file_primitive: IcebergType | None, key_writer: Writer, value_writer: Writer) -> Writer: return MapWriter(key_writer, value_writer if file_map_type.value_required else OptionWriter(value_writer)) - def primitive(self, file_primitive: PrimitiveType, record_primitive: Optional[IcebergType]) -> Writer: + def primitive(self, file_primitive: PrimitiveType, record_primitive: IcebergType | None) -> Writer: if record_primitive is not None: # ensure that the type can be projected to the expected if file_primitive != record_primitive: @@ -317,55 +313,55 @@ def primitive(self, file_primitive: PrimitiveType, record_primitive: Optional[Ic return super().primitive(file_primitive, file_primitive) - def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Writer: + def visit_boolean(self, boolean_type: BooleanType, partner: IcebergType | None) -> Writer: return BooleanWriter() - def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Writer: + def visit_integer(self, integer_type: IntegerType, partner: IcebergType | None) -> Writer: return IntegerWriter() - def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Writer: + def visit_long(self, long_type: LongType, partner: IcebergType | None) -> Writer: return IntegerWriter() - def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Writer: + def visit_float(self, float_type: FloatType, partner: IcebergType | None) -> Writer: return FloatWriter() - def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Writer: + def visit_double(self, double_type: DoubleType, partner: IcebergType | None) -> Writer: return DoubleWriter() - def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Writer: + def visit_decimal(self, decimal_type: DecimalType, partner: IcebergType | None) -> Writer: return DecimalWriter(decimal_type.precision, decimal_type.scale) - def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Writer: + def visit_date(self, date_type: DateType, partner: IcebergType | None) -> Writer: return DateWriter() - def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Writer: + def visit_time(self, time_type: TimeType, partner: IcebergType | None) -> Writer: return TimeWriter() - def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Writer: + def visit_timestamp(self, timestamp_type: TimestampType, partner: IcebergType | None) -> Writer: return TimestampWriter() - def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[IcebergType]) -> Writer: + def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: IcebergType | None) -> Writer: return TimestampNanoWriter() - def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Writer: + def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: IcebergType | None) -> Writer: return TimestamptzWriter() - def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[IcebergType]) -> Writer: + def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: IcebergType | None) -> Writer: return TimestamptzNanoWriter() - def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Writer: + def visit_string(self, string_type: StringType, partner: IcebergType | None) -> Writer: return StringWriter() - def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Writer: + def visit_uuid(self, uuid_type: UUIDType, partner: IcebergType | None) -> Writer: return UUIDWriter() - def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Writer: + def visit_fixed(self, fixed_type: FixedType, partner: IcebergType | None) -> Writer: return FixedWriter(len(fixed_type)) - def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Writer: + def visit_binary(self, binary_type: BinaryType, partner: IcebergType | None) -> Writer: return BinaryWriter() - def visit_unknown(self, unknown_type: UnknownType, partner: Optional[IcebergType]) -> Writer: + def visit_unknown(self, unknown_type: UnknownType, partner: IcebergType | None) -> Writer: return UnknownWriter() @@ -384,16 +380,16 @@ def __init__( self.read_enums = read_enums self.context = [] - def schema(self, schema: Schema, expected_schema: Optional[IcebergType], result: Reader) -> Reader: + def schema(self, schema: Schema, expected_schema: IcebergType | None, result: Reader) -> Reader: return result - def before_field(self, field: NestedField, field_partner: Optional[NestedField]) -> None: + def before_field(self, field: NestedField, field_partner: NestedField | None) -> None: self.context.append(field.field_id) - def after_field(self, field: NestedField, field_partner: Optional[NestedField]) -> None: + def after_field(self, field: NestedField, field_partner: NestedField | None) -> None: self.context.pop() - def struct(self, struct: StructType, expected_struct: Optional[IcebergType], field_readers: List[Reader]) -> Reader: + def struct(self, struct: StructType, expected_struct: IcebergType | None, field_readers: List[Reader]) -> Reader: read_struct_id = self.context[STRUCT_ROOT] if len(self.context) > 0 else STRUCT_ROOT struct_callable = self.read_types.get(read_struct_id, Record) @@ -406,7 +402,7 @@ def struct(self, struct: StructType, expected_struct: Optional[IcebergType], fie expected_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(expected_struct.fields)} # first, add readers for the file fields that must be in order - results: List[Tuple[Optional[int], Reader]] = [ + results: List[Tuple[int | None, Reader]] = [ ( expected_positions.get(field.field_id), # Check if we need to convert it to an Enum @@ -430,22 +426,22 @@ def struct(self, struct: StructType, expected_struct: Optional[IcebergType], fie return StructReader(tuple(results), struct_callable, expected_struct) - def field(self, field: NestedField, expected_field: Optional[IcebergType], field_reader: Reader) -> Reader: + def field(self, field: NestedField, expected_field: IcebergType | None, field_reader: Reader) -> Reader: return field_reader if field.required else OptionReader(field_reader) - def list(self, list_type: ListType, expected_list: Optional[IcebergType], element_reader: Reader) -> Reader: + def list(self, list_type: ListType, expected_list: IcebergType | None, element_reader: Reader) -> Reader: if expected_list and not isinstance(expected_list, ListType): raise ResolveError(f"File/read schema are not aligned for list, got {expected_list}") return ListReader(element_reader if list_type.element_required else OptionReader(element_reader)) - def map(self, map_type: MapType, expected_map: Optional[IcebergType], key_reader: Reader, value_reader: Reader) -> Reader: + def map(self, map_type: MapType, expected_map: IcebergType | None, key_reader: Reader, value_reader: Reader) -> Reader: if expected_map and not isinstance(expected_map, MapType): raise ResolveError(f"File/read schema are not aligned for map, got {expected_map}") return MapReader(key_reader, value_reader if map_type.value_required else OptionReader(value_reader)) - def primitive(self, primitive: PrimitiveType, expected_primitive: Optional[IcebergType]) -> Reader: + def primitive(self, primitive: PrimitiveType, expected_primitive: IcebergType | None) -> Reader: if expected_primitive is not None: if not isinstance(expected_primitive, PrimitiveType): raise ResolveError(f"File/read schema are not aligned for {primitive}, got {expected_primitive}") @@ -456,66 +452,66 @@ def primitive(self, primitive: PrimitiveType, expected_primitive: Optional[Icebe return super().primitive(primitive, expected_primitive) - def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Reader: + def visit_boolean(self, boolean_type: BooleanType, partner: IcebergType | None) -> Reader: return BooleanReader() - def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Reader: + def visit_integer(self, integer_type: IntegerType, partner: IcebergType | None) -> Reader: return IntegerReader() - def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Reader: + def visit_long(self, long_type: LongType, partner: IcebergType | None) -> Reader: return IntegerReader() - def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Reader: + def visit_float(self, float_type: FloatType, partner: IcebergType | None) -> Reader: return FloatReader() - def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Reader: + def visit_double(self, double_type: DoubleType, partner: IcebergType | None) -> Reader: return DoubleReader() - def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Reader: + def visit_decimal(self, decimal_type: DecimalType, partner: IcebergType | None) -> Reader: return DecimalReader(decimal_type.precision, decimal_type.scale) - def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Reader: + def visit_date(self, date_type: DateType, partner: IcebergType | None) -> Reader: return DateReader() - def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Reader: + def visit_time(self, time_type: TimeType, partner: IcebergType | None) -> Reader: return TimeReader() - def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Reader: + def visit_timestamp(self, timestamp_type: TimestampType, partner: IcebergType | None) -> Reader: return TimestampReader() - def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[IcebergType]) -> Reader: + def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: IcebergType | None) -> Reader: return TimestampNanoReader() - def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Reader: + def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: IcebergType | None) -> Reader: return TimestamptzReader() - def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[IcebergType]) -> Reader: + def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: IcebergType | None) -> Reader: return TimestamptzNanoReader() - def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Reader: + def visit_string(self, string_type: StringType, partner: IcebergType | None) -> Reader: return StringReader() - def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Reader: + def visit_uuid(self, uuid_type: UUIDType, partner: IcebergType | None) -> Reader: return UUIDReader() - def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Reader: + def visit_fixed(self, fixed_type: FixedType, partner: IcebergType | None) -> Reader: return FixedReader(len(fixed_type)) - def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Reader: + def visit_binary(self, binary_type: BinaryType, partner: IcebergType | None) -> Reader: return BinaryReader() - def visit_unknown(self, unknown_type: UnknownType, partner: Optional[IcebergType]) -> Reader: + def visit_unknown(self, unknown_type: UnknownType, partner: IcebergType | None) -> Reader: return UnknownReader() class SchemaPartnerAccessor(PartnerAccessor[IcebergType]): - def schema_partner(self, partner: Optional[IcebergType]) -> Optional[IcebergType]: + def schema_partner(self, partner: IcebergType | None) -> IcebergType | None: if isinstance(partner, Schema): return partner.as_struct() raise ResolveError(f"File/read schema are not aligned for schema, got {partner}") - def field_partner(self, partner: Optional[IcebergType], field_id: int, field_name: str) -> Optional[IcebergType]: + def field_partner(self, partner: IcebergType | None, field_id: int, field_name: str) -> IcebergType | None: if isinstance(partner, StructType): field = partner.field(field_id) else: @@ -523,19 +519,19 @@ def field_partner(self, partner: Optional[IcebergType], field_id: int, field_nam return field.field_type if field else None - def list_element_partner(self, partner_list: Optional[IcebergType]) -> Optional[IcebergType]: + def list_element_partner(self, partner_list: IcebergType | None) -> IcebergType | None: if isinstance(partner_list, ListType): return partner_list.element_type raise ResolveError(f"File/read schema are not aligned for list, got {partner_list}") - def map_key_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]: + def map_key_partner(self, partner_map: IcebergType | None) -> IcebergType | None: if isinstance(partner_map, MapType): return partner_map.key_type raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}") - def map_value_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]: + def map_value_partner(self, partner_map: IcebergType | None) -> IcebergType | None: if isinstance(partner_map, MapType): return partner_map.value_type diff --git a/pyiceberg/avro/writer.py b/pyiceberg/avro/writer.py index 6fa485f21a..ba66d3003c 100644 --- a/pyiceberg/avro/writer.py +++ b/pyiceberg/avro/writer.py @@ -30,9 +30,7 @@ Any, Dict, List, - Optional, Tuple, - Union, ) from uuid import UUID @@ -122,7 +120,7 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None: @dataclass(frozen=True) class UUIDWriter(Writer): - def write(self, encoder: BinaryEncoder, val: Union[UUID, bytes]) -> None: + def write(self, encoder: BinaryEncoder, val: UUID | bytes) -> None: if isinstance(val, UUID): encoder.write(val.bytes) else: @@ -188,7 +186,7 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None: @dataclass(frozen=True) class StructWriter(Writer): - field_writers: Tuple[Tuple[Optional[int], Writer], ...] = dataclassfield() + field_writers: Tuple[Tuple[int | None, Writer], ...] = dataclassfield() def write(self, encoder: BinaryEncoder, val: Record) -> None: for pos, writer in self.field_writers: diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index a434193573..5b39062948 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -30,11 +30,9 @@ Callable, Dict, List, - Optional, Set, Tuple, Type, - Union, cast, ) @@ -195,7 +193,7 @@ def load_bigquery(name: str, conf: Properties) -> Catalog: } -def infer_catalog_type(name: str, catalog_properties: RecursiveDict) -> Optional[CatalogType]: +def infer_catalog_type(name: str, catalog_properties: RecursiveDict) -> CatalogType | None: """Try to infer the type based on the dict. Args: @@ -225,7 +223,7 @@ def infer_catalog_type(name: str, catalog_properties: RecursiveDict) -> Optional ) -def load_catalog(name: Optional[str] = None, **properties: Optional[str]) -> Catalog: +def load_catalog(name: str | None = None, **properties: str | None) -> Catalog: """Load the catalog based on the properties. Will look up the properties from the config, based on the name. @@ -247,7 +245,7 @@ def load_catalog(name: Optional[str] = None, **properties: Optional[str]) -> Cat env = _ENV_CONFIG.get_catalog_config(name) conf: RecursiveDict = merge_config(env or {}, cast(RecursiveDict, properties)) - catalog_type: Optional[CatalogType] + catalog_type: CatalogType | None provided_catalog_type = conf.get(TYPE) if catalog_impl := properties.get(PY_CATALOG_IMPL): @@ -317,7 +315,7 @@ def delete_data_files(io: FileIO, manifests_to_delete: List[ManifestFile]) -> No deleted_files[path] = True -def _import_catalog(name: str, catalog_impl: str, properties: Properties) -> Optional[Catalog]: +def _import_catalog(name: str, catalog_impl: str, properties: Properties) -> Catalog | None: try: path_parts = catalog_impl.split(".") if len(path_parts) < 2: @@ -362,9 +360,9 @@ def __init__(self, name: str, **properties: str): @abstractmethod def create_table( self, - identifier: Union[str, Identifier], - schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + identifier: str | Identifier, + schema: Schema | "pa.Schema", + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -389,9 +387,9 @@ def create_table( @abstractmethod def create_table_transaction( self, - identifier: Union[str, Identifier], - schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + identifier: str | Identifier, + schema: Schema | "pa.Schema", + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -412,9 +410,9 @@ def create_table_transaction( def create_table_if_not_exists( self, - identifier: Union[str, Identifier], - schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + identifier: str | Identifier, + schema: Schema | "pa.Schema", + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -439,7 +437,7 @@ def create_table_if_not_exists( return self.load_table(identifier) @abstractmethod - def load_table(self, identifier: Union[str, Identifier]) -> Table: + def load_table(self, identifier: str | Identifier) -> Table: """Load the table's metadata and returns the table instance. You can also use this method to check for table existence using 'try catalog.table() except NoSuchTableError'. @@ -456,7 +454,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: """ @abstractmethod - def table_exists(self, identifier: Union[str, Identifier]) -> bool: + def table_exists(self, identifier: str | Identifier) -> bool: """Check if a table exists. Args: @@ -467,7 +465,7 @@ def table_exists(self, identifier: Union[str, Identifier]) -> bool: """ @abstractmethod - def view_exists(self, identifier: Union[str, Identifier]) -> bool: + def view_exists(self, identifier: str | Identifier) -> bool: """Check if a view exists. Args: @@ -478,7 +476,7 @@ def view_exists(self, identifier: Union[str, Identifier]) -> bool: """ @abstractmethod - def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + def register_table(self, identifier: str | Identifier, metadata_location: str) -> Table: """Register a new table using existing metadata. Args: @@ -493,7 +491,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: """ @abstractmethod - def drop_table(self, identifier: Union[str, Identifier]) -> None: + def drop_table(self, identifier: str | Identifier) -> None: """Drop a table. Args: @@ -504,7 +502,7 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None: """ @abstractmethod - def purge_table(self, identifier: Union[str, Identifier]) -> None: + def purge_table(self, identifier: str | Identifier) -> None: """Drop a table and purge all data and metadata files. Note: This method only logs warning rather than raise exception when encountering file deletion failure. @@ -517,7 +515,7 @@ def purge_table(self, identifier: Union[str, Identifier]) -> None: """ @abstractmethod - def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: """Rename a fully classified table name. Args: @@ -552,7 +550,7 @@ def commit_table( """ @abstractmethod - def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: """Create a namespace in the catalog. Args: @@ -563,7 +561,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper NamespaceAlreadyExistsError: If a namespace with the given name already exists. """ - def create_namespace_if_not_exists(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace_if_not_exists(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: """Create a namespace if it does not exist. Args: @@ -576,7 +574,7 @@ def create_namespace_if_not_exists(self, namespace: Union[str, Identifier], prop pass @abstractmethod - def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + def drop_namespace(self, namespace: str | Identifier) -> None: """Drop a namespace. Args: @@ -588,7 +586,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: """ @abstractmethod - def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_tables(self, namespace: str | Identifier) -> List[Identifier]: """List tables under the given namespace in the catalog. Args: @@ -602,7 +600,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: """ @abstractmethod - def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + def list_namespaces(self, namespace: str | Identifier = ()) -> List[Identifier]: """List namespaces from the given namespace. If not given, list top-level namespaces from the catalog. Args: @@ -616,7 +614,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi """ @abstractmethod - def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_views(self, namespace: str | Identifier) -> List[Identifier]: """List views under the given namespace in the catalog. Args: @@ -630,7 +628,7 @@ def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: """ @abstractmethod - def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + def load_namespace_properties(self, namespace: str | Identifier) -> Properties: """Get properties for a namespace. Args: @@ -645,7 +643,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper @abstractmethod def update_namespace_properties( - self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + self, namespace: str | Identifier, removals: Set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: """Remove provided property keys and updates properties for a namespace. @@ -660,7 +658,7 @@ def update_namespace_properties( """ @abstractmethod - def drop_view(self, identifier: Union[str, Identifier]) -> None: + def drop_view(self, identifier: str | Identifier) -> None: """Drop a view. Args: @@ -671,7 +669,7 @@ def drop_view(self, identifier: Union[str, Identifier]) -> None: """ @staticmethod - def identifier_to_tuple(identifier: Union[str, Identifier]) -> Identifier: + def identifier_to_tuple(identifier: str | Identifier) -> Identifier: """Parse an identifier to a tuple. If the identifier is a string, it is split into a tuple on '.'. If it is a tuple, it is used as-is. @@ -685,7 +683,7 @@ def identifier_to_tuple(identifier: Union[str, Identifier]) -> Identifier: return identifier if isinstance(identifier, tuple) else tuple(str.split(identifier, ".")) @staticmethod - def table_name_from(identifier: Union[str, Identifier]) -> str: + def table_name_from(identifier: str | Identifier) -> str: """Extract table name from a table identifier. Args: @@ -697,7 +695,7 @@ def table_name_from(identifier: Union[str, Identifier]) -> str: return Catalog.identifier_to_tuple(identifier)[-1] @staticmethod - def namespace_from(identifier: Union[str, Identifier]) -> Identifier: + def namespace_from(identifier: str | Identifier) -> Identifier: """Extract table namespace from a table identifier. Args: @@ -709,9 +707,7 @@ def namespace_from(identifier: Union[str, Identifier]) -> Identifier: return Catalog.identifier_to_tuple(identifier)[:-1] @staticmethod - def namespace_to_string( - identifier: Union[str, Identifier], err: Union[Type[ValueError], Type[NoSuchNamespaceError]] = ValueError - ) -> str: + def namespace_to_string(identifier: str | Identifier, err: Type[ValueError] | Type[NoSuchNamespaceError] = ValueError) -> str: """Transform a namespace identifier into a string. Args: @@ -733,7 +729,7 @@ def namespace_to_string( @staticmethod def identifier_to_database( - identifier: Union[str, Identifier], err: Union[Type[ValueError], Type[NoSuchNamespaceError]] = ValueError + identifier: str | Identifier, err: Type[ValueError] | Type[NoSuchNamespaceError] = ValueError ) -> str: tuple_identifier = Catalog.identifier_to_tuple(identifier) if len(tuple_identifier) != 1: @@ -743,8 +739,8 @@ def identifier_to_database( @staticmethod def identifier_to_database_and_table( - identifier: Union[str, Identifier], - err: Union[Type[ValueError], Type[NoSuchTableError], Type[NoSuchNamespaceError]] = ValueError, + identifier: str | Identifier, + err: Type[ValueError] | Type[NoSuchTableError] | Type[NoSuchNamespaceError] = ValueError, ) -> Tuple[str, str]: tuple_identifier = Catalog.identifier_to_tuple(identifier) if len(tuple_identifier) != 2: @@ -752,12 +748,12 @@ def identifier_to_database_and_table( return tuple_identifier[0], tuple_identifier[1] - def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[str] = None) -> FileIO: + def _load_file_io(self, properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO: return load_file_io({**self.properties, **properties}, location) @staticmethod def _convert_schema_if_needed( - schema: Union[Schema, "pa.Schema"], format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION + schema: Schema | "pa.Schema", format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION ) -> Schema: if isinstance(schema, Schema): return schema @@ -811,7 +807,7 @@ def __enter__(self) -> "Catalog": """ return self - def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None: + def __exit__(self, exc_type: type | None, exc_val: BaseException | None, exc_tb: Any | None) -> None: """Exit the context manager and close the catalog. Args: @@ -832,9 +828,9 @@ def __init__(self, name: str, **properties: str): def create_table_transaction( self, - identifier: Union[str, Identifier], - schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + identifier: str | Identifier, + schema: Schema | "pa.Schema", + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -843,14 +839,14 @@ def create_table_transaction( self._create_staged_table(identifier, schema, location, partition_spec, sort_order, properties) ) - def table_exists(self, identifier: Union[str, Identifier]) -> bool: + def table_exists(self, identifier: str | Identifier) -> bool: try: self.load_table(identifier) return True except NoSuchTableError: return False - def purge_table(self, identifier: Union[str, Identifier]) -> None: + def purge_table(self, identifier: str | Identifier) -> None: table = self.load_table(identifier) self.drop_table(identifier) io = load_file_io(self.properties, table.metadata_location) @@ -872,9 +868,9 @@ def purge_table(self, identifier: Union[str, Identifier]) -> None: def _create_staged_table( self, - identifier: Union[str, Identifier], - schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + identifier: str | Identifier, + schema: Schema | "pa.Schema", + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -916,7 +912,7 @@ def _create_staged_table( def _update_and_stage_table( self, - current_table: Optional[Table], + current_table: Table | None, table_identifier: Identifier, requirements: Tuple[TableRequirement, ...], updates: Tuple[TableUpdate, ...], @@ -944,7 +940,7 @@ def _update_and_stage_table( ) def _get_updated_props_and_update_summary( - self, current_properties: Properties, removals: Optional[Set[str]], updates: Properties + self, current_properties: Properties, removals: Set[str] | None, updates: Properties ) -> Tuple[PropertiesUpdateSummary, Properties]: self._check_for_overlap(updates=updates, removals=removals) updated_properties = dict(current_properties) @@ -969,7 +965,7 @@ def _get_updated_props_and_update_summary( return properties_update_summary, updated_properties - def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str: + def _resolve_table_location(self, location: str | None, database_name: str, table_name: str) -> str: if not location: return self._get_default_warehouse_location(database_name, table_name) return location.rstrip("/") @@ -1032,7 +1028,7 @@ def _parse_metadata_version(metadata_location: str) -> int: return -1 @staticmethod - def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> None: + def _check_for_overlap(removals: Set[str] | None, updates: Properties) -> None: if updates and removals: overlap = set(removals) & set(updates.keys()) if overlap: diff --git a/pyiceberg/catalog/bigquery_metastore.py b/pyiceberg/catalog/bigquery_metastore.py index 2336874b52..4b1b922b41 100644 --- a/pyiceberg/catalog/bigquery_metastore.py +++ b/pyiceberg/catalog/bigquery_metastore.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import json -from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Set, Tuple, Union from google.api_core.exceptions import NotFound from google.cloud.bigquery import Client, Dataset, DatasetReference, TableReference @@ -62,10 +62,10 @@ class BigQueryMetastoreCatalog(MetastoreCatalog): def __init__(self, name: str, **properties: str): super().__init__(name, **properties) - project_id: Optional[str] = self.properties.get(GCP_PROJECT_ID) - location: Optional[str] = self.properties.get(GCP_LOCATION) - credentials_file: Optional[str] = self.properties.get(GCP_CREDENTIALS_FILE) - credentials_info_str: Optional[str] = self.properties.get(GCP_CREDENTIALS_INFO) + project_id: str | None = self.properties.get(GCP_PROJECT_ID) + location: str | None = self.properties.get(GCP_LOCATION) + credentials_file: str | None = self.properties.get(GCP_CREDENTIALS_FILE) + credentials_info_str: str | None = self.properties.get(GCP_CREDENTIALS_INFO) if not project_id: raise ValueError(f"Missing property: {GCP_PROJECT_ID}") @@ -100,9 +100,9 @@ def __init__(self, name: str, **properties: str): def create_table( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -154,7 +154,7 @@ def create_table( return self.load_table(identifier=identifier) - def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: """Create a namespace in the catalog. Args: @@ -177,7 +177,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper except Conflict as e: raise NamespaceAlreadyExistsError("Namespace {database_name} already exists") from e - def load_table(self, identifier: Union[str, Identifier]) -> Table: + def load_table(self, identifier: str | Identifier) -> Table: """ Load the table's metadata and returns the table instance. @@ -206,7 +206,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: except NotFound as e: raise NoSuchTableError(f"Table does not exist: {dataset_name}.{table_name}") from e - def drop_table(self, identifier: Union[str, Identifier]) -> None: + def drop_table(self, identifier: str | Identifier) -> None: """Drop a table. Args: @@ -231,10 +231,10 @@ def commit_table( ) -> CommitTableResponse: raise NotImplementedError - def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: raise NotImplementedError - def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + def drop_namespace(self, namespace: str | Identifier) -> None: database_name = self.identifier_to_database(namespace) try: @@ -244,7 +244,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: except NotFound as e: raise NoSuchNamespaceError(f"Namespace {namespace} does not exist.") from e - def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_tables(self, namespace: str | Identifier) -> List[Identifier]: database_name = self.identifier_to_database(namespace) iceberg_tables: List[Identifier] = [] try: @@ -258,7 +258,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: raise NoSuchNamespaceError(f"Namespace (dataset) '{database_name}' not found.") from None return iceberg_tables - def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + def list_namespaces(self, namespace: str | Identifier = ()) -> List[Identifier]: # Since this catalog only supports one-level namespaces, it always returns an empty list unless # passed an empty namespace to list all namespaces within the catalog. if namespace: @@ -268,7 +268,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi datasets_iterator = self.client.list_datasets() return [(dataset.dataset_id,) for dataset in datasets_iterator] - def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + def register_table(self, identifier: str | Identifier, metadata_location: str) -> Table: """Register a new table using existing metadata. Args: @@ -299,16 +299,16 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: return self.load_table(identifier=identifier) - def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_views(self, namespace: str | Identifier) -> List[Identifier]: raise NotImplementedError - def drop_view(self, identifier: Union[str, Identifier]) -> None: + def drop_view(self, identifier: str | Identifier) -> None: raise NotImplementedError - def view_exists(self, identifier: Union[str, Identifier]) -> bool: + def view_exists(self, identifier: str | Identifier) -> bool: raise NotImplementedError - def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + def load_namespace_properties(self, namespace: str | Identifier) -> Properties: dataset_name = self.identifier_to_database(namespace) try: @@ -321,7 +321,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper return {} def update_namespace_properties( - self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + self, namespace: str | Identifier, removals: Set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: raise NotImplementedError @@ -364,7 +364,7 @@ def _create_external_catalog_dataset_options( parameters=metadataParameters, ) - def _convert_bigquery_table_to_iceberg_table(self, identifier: Union[str, Identifier], table: BQTable) -> Table: + def _convert_bigquery_table_to_iceberg_table(self, identifier: str | Identifier, table: BQTable) -> Table: dataset_name, table_name = self.identifier_to_database_and_table(identifier, NoSuchTableError) metadata_location = "" if table.external_catalog_table_options and table.external_catalog_table_options.parameters: @@ -405,7 +405,7 @@ def _create_table_parameters(self, metadata_file_location: str, table_metadata: return parameters - def _default_storage_location(self, location: Optional[str], dataset_ref: DatasetReference) -> Union[str, None]: + def _default_storage_location(self, location: str | None, dataset_ref: DatasetReference) -> str | None: if location: return location dataset = self.client.get_dataset(dataset_ref) diff --git a/pyiceberg/catalog/dynamodb.py b/pyiceberg/catalog/dynamodb.py index 420fa5b523..59ce9f1b13 100644 --- a/pyiceberg/catalog/dynamodb.py +++ b/pyiceberg/catalog/dynamodb.py @@ -155,9 +155,9 @@ def _dynamodb_table_exists(self) -> bool: def create_table( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -212,7 +212,7 @@ def create_table( return self.load_table(identifier=identifier) - def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + def register_table(self, identifier: str | Identifier, metadata_location: str) -> Table: """Register a new table using existing metadata. Args: @@ -246,7 +246,7 @@ def commit_table( """ raise NotImplementedError - def load_table(self, identifier: Union[str, Identifier]) -> Table: + def load_table(self, identifier: str | Identifier) -> Table: """ Load the table's metadata and returns the table instance. @@ -266,7 +266,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: dynamo_table_item = self._get_iceberg_table_item(database_name=database_name, table_name=table_name) return self._convert_dynamo_table_item_to_iceberg_table(dynamo_table_item=dynamo_table_item) - def drop_table(self, identifier: Union[str, Identifier]) -> None: + def drop_table(self, identifier: str | Identifier) -> None: """Drop a table. Args: @@ -286,7 +286,7 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None: except ConditionalCheckFailedException as e: raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}") from e - def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: """Rename a fully classified table name. This method can only rename Iceberg tables in AWS Glue. @@ -352,7 +352,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U return self.load_table(to_identifier) - def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: """Create a namespace in the catalog. Args: @@ -373,7 +373,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper except ConditionalCheckFailedException as e: raise NamespaceAlreadyExistsError(f"Database {database_name} already exists") from e - def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + def drop_namespace(self, namespace: str | Identifier) -> None: """Drop a namespace. A Glue namespace can only be dropped if it is empty. @@ -400,7 +400,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: except ConditionalCheckFailedException as e: raise NoSuchNamespaceError(f"Database does not exist: {database_name}") from e - def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_tables(self, namespace: str | Identifier) -> List[Identifier]: """List Iceberg tables under the given namespace in the catalog. Args: @@ -444,7 +444,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: return table_identifiers - def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + def list_namespaces(self, namespace: str | Identifier = ()) -> List[Identifier]: """List top-level namespaces from the catalog. We do not support hierarchical namespace. @@ -486,7 +486,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi return database_identifiers - def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + def load_namespace_properties(self, namespace: str | Identifier) -> Properties: """ Get properties for a namespace. @@ -505,7 +505,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper return _get_namespace_properties(namespace_dict=namespace_dict) def update_namespace_properties( - self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + self, namespace: str | Identifier, removals: Set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: """ Remove or update provided property keys for a namespace. @@ -541,13 +541,13 @@ def update_namespace_properties( return properties_update_summary - def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_views(self, namespace: str | Identifier) -> List[Identifier]: raise NotImplementedError - def drop_view(self, identifier: Union[str, Identifier]) -> None: + def drop_view(self, identifier: str | Identifier) -> None: raise NotImplementedError - def view_exists(self, identifier: Union[str, Identifier]) -> bool: + def view_exists(self, identifier: str | Identifier) -> bool: raise NotImplementedError def _get_iceberg_table_item(self, database_name: str, table_name: str) -> Dict[str, Any]: diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index 98c656079b..f19cb6dec0 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -142,8 +142,8 @@ def _construct_parameters( metadata_location: str, glue_table: Optional["TableTypeDef"] = None, - prev_metadata_location: Optional[str] = None, - metadata_properties: Optional[Properties] = None, + prev_metadata_location: str | None = None, + metadata_properties: Properties | None = None, ) -> Properties: new_parameters = glue_table.get("Parameters", {}) if glue_table else {} new_parameters.update({TABLE_TYPE: ICEBERG.upper(), METADATA_LOCATION: metadata_location}) @@ -239,7 +239,7 @@ def _construct_table_input( properties: Properties, metadata: TableMetadata, glue_table: Optional["TableTypeDef"] = None, - prev_metadata_location: Optional[str] = None, + prev_metadata_location: str | None = None, ) -> "TableInputTypeDef": table_input: "TableInputTypeDef" = { "Name": table_name, @@ -422,9 +422,9 @@ def _get_glue_table(self, database_name: str, table_name: str) -> "TableTypeDef" def create_table( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -464,7 +464,7 @@ def create_table( return self.load_table(identifier=identifier) - def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + def register_table(self, identifier: str | Identifier, metadata_location: str) -> Table: """Register a new table using existing metadata. Args: @@ -506,9 +506,9 @@ def commit_table( table_identifier = table.name() database_name, table_name = self.identifier_to_database_and_table(table_identifier, NoSuchTableError) - current_glue_table: Optional["TableTypeDef"] - glue_table_version_id: Optional[str] - current_table: Optional[Table] + current_glue_table: "TableTypeDef" | None + glue_table_version_id: str | None + current_table: Table | None try: current_glue_table = self._get_glue_table(database_name=database_name, table_name=table_name) glue_table_version_id = current_glue_table.get("VersionId") @@ -565,7 +565,7 @@ def commit_table( metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location ) - def load_table(self, identifier: Union[str, Identifier]) -> Table: + def load_table(self, identifier: str | Identifier) -> Table: """Load the table's metadata and returns the table instance. You can also use this method to check for table existence using 'try catalog.table() except TableNotFoundError'. @@ -584,7 +584,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: return self._convert_glue_to_iceberg(self._get_glue_table(database_name=database_name, table_name=table_name)) - def drop_table(self, identifier: Union[str, Identifier]) -> None: + def drop_table(self, identifier: str | Identifier) -> None: """Drop a table. Args: @@ -599,7 +599,7 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None: except self.glue.exceptions.EntityNotFoundException as e: raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}") from e - def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: """Rename a fully classified table name. This method can only rename Iceberg tables in AWS Glue. @@ -659,7 +659,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U return self.load_table(to_identifier) - def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: """Create a namespace in the catalog. Args: @@ -676,7 +676,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper except self.glue.exceptions.AlreadyExistsException as e: raise NamespaceAlreadyExistsError(f"Database {database_name} already exists") from e - def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + def drop_namespace(self, namespace: str | Identifier) -> None: """Drop a namespace. A Glue namespace can only be dropped if it is empty. @@ -705,7 +705,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: ) self.glue.delete_database(Name=database_name) - def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_tables(self, namespace: str | Identifier) -> List[Identifier]: """List Iceberg tables under the given namespace in the catalog. Args: @@ -719,7 +719,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: """ database_name = self.identifier_to_database(namespace, NoSuchNamespaceError) table_list: List["TableTypeDef"] = [] - next_token: Optional[str] = None + next_token: str | None = None try: while True: table_list_response = ( @@ -736,7 +736,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: raise NoSuchNamespaceError(f"Database does not exist: {database_name}") from e return [(database_name, table["Name"]) for table in table_list if self.__is_iceberg_table(table)] - def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + def list_namespaces(self, namespace: str | Identifier = ()) -> List[Identifier]: """List namespaces from the given namespace. If not given, list top-level namespaces from the catalog. Returns: @@ -747,7 +747,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi return [] database_list: List["DatabaseTypeDef"] = [] - next_token: Optional[str] = None + next_token: str | None = None while True: databases_response = self.glue.get_databases() if not next_token else self.glue.get_databases(NextToken=next_token) @@ -758,7 +758,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi return [self.identifier_to_tuple(database["Name"]) for database in database_list] - def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + def load_namespace_properties(self, namespace: str | Identifier) -> Properties: """Get properties for a namespace. Args: @@ -789,7 +789,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper return properties def update_namespace_properties( - self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + self, namespace: str | Identifier, removals: Set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: """Remove provided property keys and updates properties for a namespace. @@ -812,13 +812,13 @@ def update_namespace_properties( return properties_update_summary - def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_views(self, namespace: str | Identifier) -> List[Identifier]: raise NotImplementedError - def drop_view(self, identifier: Union[str, Identifier]) -> None: + def drop_view(self, identifier: str | Identifier) -> None: raise NotImplementedError - def view_exists(self, identifier: Union[str, Identifier]) -> bool: + def view_exists(self, identifier: str | Identifier) -> bool: raise NotImplementedError @staticmethod diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index 93ece35cbb..a6f7131b06 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -24,7 +24,6 @@ Any, Dict, List, - Optional, Set, Tuple, Type, @@ -149,14 +148,14 @@ class _HiveClient: """Helper class to nicely open and close the transport.""" _transport: TTransport - _ugi: Optional[List[str]] + _ugi: List[str] | None def __init__( self, uri: str, - ugi: Optional[str] = None, - kerberos_auth: Optional[bool] = HIVE_KERBEROS_AUTH_DEFAULT, - kerberos_service_name: Optional[str] = HIVE_KERBEROS_SERVICE_NAME, + ugi: str | None = None, + kerberos_auth: bool | None = HIVE_KERBEROS_AUTH_DEFAULT, + kerberos_service_name: str | None = HIVE_KERBEROS_SERVICE_NAME, ): self._uri = uri self._kerberos_auth = kerberos_auth @@ -195,17 +194,13 @@ def __enter__(self) -> Client: self._transport.open() return self._client() # recreate the client - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ) -> None: + def __exit__(self, exctype: Type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None) -> None: """Close transport if it was opened.""" if self._transport.isOpen(): self._transport.close() -def _construct_hive_storage_descriptor( - schema: Schema, location: Optional[str], hive2_compatible: bool = False -) -> StorageDescriptor: +def _construct_hive_storage_descriptor(schema: Schema, location: str | None, hive2_compatible: bool = False) -> StorageDescriptor: ser_de_info = SerDeInfo(serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") return StorageDescriptor( [ @@ -227,7 +222,7 @@ def _construct_hive_storage_descriptor( def _construct_parameters( - metadata_location: str, previous_metadata_location: Optional[str] = None, metadata_properties: Optional[Properties] = None + metadata_location: str, previous_metadata_location: str | None = None, metadata_properties: Properties | None = None ) -> Dict[str, Any]: properties = {PROP_EXTERNAL: "TRUE", PROP_TABLE_TYPE: "ICEBERG", PROP_METADATA_LOCATION: metadata_location} if previous_metadata_location: @@ -400,9 +395,9 @@ def _get_hive_table(self, open_client: Client, database_name: str, table_name: s def create_table( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -444,7 +439,7 @@ def create_table( return self._convert_hive_into_iceberg(hive_table) - def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + def register_table(self, identifier: str | Identifier, metadata_location: str) -> Table: """Register a new table using existing metadata. Args: @@ -474,10 +469,10 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: return self._convert_hive_into_iceberg(hive_table) - def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_views(self, namespace: str | Identifier) -> List[Identifier]: raise NotImplementedError - def view_exists(self, identifier: Union[str, Identifier]) -> bool: + def view_exists(self, identifier: str | Identifier) -> bool: raise NotImplementedError def _create_lock_request(self, database_name: str, table_name: str) -> LockRequest: @@ -540,8 +535,8 @@ def commit_table( else: raise CommitFailedException(f"Failed to acquire lock for {table_identifier}, state: {lock.state}") - hive_table: Optional[HiveTable] - current_table: Optional[Table] + hive_table: HiveTable | None + current_table: Table | None try: hive_table = self._get_hive_table(open_client, database_name, table_name) current_table = self._convert_hive_into_iceberg(hive_table) @@ -599,7 +594,7 @@ def commit_table( metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location ) - def load_table(self, identifier: Union[str, Identifier]) -> Table: + def load_table(self, identifier: str | Identifier) -> Table: """Load the table's metadata and return the table instance. You can also use this method to check for table existence using 'try catalog.table() except TableNotFoundError'. @@ -621,7 +616,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: return self._convert_hive_into_iceberg(hive_table) - def drop_table(self, identifier: Union[str, Identifier]) -> None: + def drop_table(self, identifier: str | Identifier) -> None: """Drop a table. Args: @@ -638,11 +633,11 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None: # When the namespace doesn't exist, it throws the same error raise NoSuchTableError(f"Table does not exists: {table_name}") from e - def purge_table(self, identifier: Union[str, Identifier]) -> None: + def purge_table(self, identifier: str | Identifier) -> None: # This requires to traverse the reachability set, and drop all the data files. raise NotImplementedError("Not yet implemented") - def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: """Rename a fully classified table name. Args: @@ -681,7 +676,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U raise NoSuchNamespaceError(f"Database does not exists: {to_database_name}") from e return self.load_table(to_identifier) - def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: """Create a namespace in the catalog. Args: @@ -701,7 +696,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper except AlreadyExistsException as e: raise NamespaceAlreadyExistsError(f"Database {database_name} already exists") from e - def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + def drop_namespace(self, namespace: str | Identifier) -> None: """Drop a namespace. Args: @@ -720,7 +715,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: except MetaException as e: raise NoSuchNamespaceError(f"Database does not exists: {database_name}") from e - def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_tables(self, namespace: str | Identifier) -> List[Identifier]: """List Iceberg tables under the given namespace in the catalog. When the database doesn't exist, it will just return an empty list. @@ -744,7 +739,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: if table.parameters.get(TABLE_TYPE, "").lower() == ICEBERG ] - def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + def list_namespaces(self, namespace: str | Identifier = ()) -> List[Identifier]: """List namespaces from the given namespace. If not given, list top-level namespaces from the catalog. Returns: @@ -757,7 +752,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi with self._client as open_client: return list(map(self.identifier_to_tuple, open_client.get_all_databases())) - def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + def load_namespace_properties(self, namespace: str | Identifier) -> Properties: """Get properties for a namespace. Args: @@ -782,7 +777,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper raise NoSuchNamespaceError(f"Database does not exists: {database_name}") from e def update_namespace_properties( - self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + self, namespace: str | Identifier, removals: Set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: """Remove provided property keys and update properties for a namespace. @@ -823,7 +818,7 @@ def update_namespace_properties( return PropertiesUpdateSummary(removed=list(removed or []), updated=list(updated or []), missing=list(expected_to_change)) - def drop_view(self, identifier: Union[str, Identifier]) -> None: + def drop_view(self, identifier: str | Identifier) -> None: raise NotImplementedError def _get_default_warehouse_location(self, database_name: str, table_name: str) -> str: diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py index eb3132a9ac..08b71d90af 100644 --- a/pyiceberg/catalog/noop.py +++ b/pyiceberg/catalog/noop.py @@ -17,7 +17,6 @@ from typing import ( TYPE_CHECKING, List, - Optional, Set, Tuple, Union, @@ -45,9 +44,9 @@ class NoopCatalog(Catalog): def create_table( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -56,22 +55,22 @@ def create_table( def create_table_transaction( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, ) -> CreateTableTransaction: raise NotImplementedError - def load_table(self, identifier: Union[str, Identifier]) -> Table: + def load_table(self, identifier: str | Identifier) -> Table: raise NotImplementedError - def table_exists(self, identifier: Union[str, Identifier]) -> bool: + def table_exists(self, identifier: str | Identifier) -> bool: raise NotImplementedError - def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + def register_table(self, identifier: str | Identifier, metadata_location: str) -> Table: """Register a new table using existing metadata. Args: @@ -86,13 +85,13 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: """ raise NotImplementedError - def drop_table(self, identifier: Union[str, Identifier]) -> None: + def drop_table(self, identifier: str | Identifier) -> None: raise NotImplementedError - def purge_table(self, identifier: Union[str, Identifier]) -> None: + def purge_table(self, identifier: str | Identifier) -> None: raise NotImplementedError - def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: raise NotImplementedError def commit_table( @@ -100,31 +99,31 @@ def commit_table( ) -> CommitTableResponse: raise NotImplementedError - def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: raise NotImplementedError - def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + def drop_namespace(self, namespace: str | Identifier) -> None: raise NotImplementedError - def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_tables(self, namespace: str | Identifier) -> List[Identifier]: raise NotImplementedError - def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + def list_namespaces(self, namespace: str | Identifier = ()) -> List[Identifier]: raise NotImplementedError - def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + def load_namespace_properties(self, namespace: str | Identifier) -> Properties: raise NotImplementedError def update_namespace_properties( - self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + self, namespace: str | Identifier, removals: Set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: raise NotImplementedError - def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_views(self, namespace: str | Identifier) -> List[Identifier]: raise NotImplementedError - def view_exists(self, identifier: Union[str, Identifier]) -> bool: + def view_exists(self, identifier: str | Identifier) -> bool: raise NotImplementedError - def drop_view(self, identifier: Union[str, Identifier]) -> None: + def drop_view(self, identifier: str | Identifier) -> None: raise NotImplementedError diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 2915c7e347..e9571aa491 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -20,7 +20,6 @@ Any, Dict, List, - Optional, Set, Tuple, Union, @@ -153,17 +152,17 @@ def _retry_hook(retry_state: RetryCallState) -> None: class TableResponse(IcebergBaseModel): - metadata_location: Optional[str] = Field(alias="metadata-location", default=None) + metadata_location: str | None = Field(alias="metadata-location", default=None) metadata: TableMetadata config: Properties = Field(default_factory=dict) class CreateTableRequest(IcebergBaseModel): name: str = Field() - location: Optional[str] = Field() + location: str | None = Field() table_schema: Schema = Field(alias="schema") - partition_spec: Optional[PartitionSpec] = Field(alias="partition-spec") - write_order: Optional[SortOrder] = Field(alias="write-order") + partition_spec: PartitionSpec | None = Field(alias="partition-spec") + write_order: SortOrder | None = Field(alias="write-order") stage_create: bool = Field(alias="stage-create", default=False) properties: Dict[str, str] = Field(default_factory=dict) @@ -179,8 +178,8 @@ class RegisterTableRequest(IcebergBaseModel): class ConfigResponse(IcebergBaseModel): - defaults: Optional[Properties] = Field(default_factory=dict) - overrides: Optional[Properties] = Field(default_factory=dict) + defaults: Properties | None = Field(default_factory=dict) + overrides: Properties | None = Field(default_factory=dict) class ListNamespaceResponse(IcebergBaseModel): @@ -294,7 +293,7 @@ def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager: return AuthManagerFactory.create("legacyoauth2", auth_config) - def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier: + def _check_valid_namespace_identifier(self, identifier: str | Identifier) -> Identifier: """Check if the identifier has at least one element.""" identifier_tuple = Catalog.identifier_to_tuple(identifier) if len(identifier_tuple) < 1: @@ -377,14 +376,14 @@ def _fetch_config(self) -> None: # Update URI based on overrides self.uri = config[URI] - def _identifier_to_validated_tuple(self, identifier: Union[str, Identifier]) -> Identifier: + def _identifier_to_validated_tuple(self, identifier: str | Identifier) -> Identifier: identifier_tuple = self.identifier_to_tuple(identifier) if len(identifier_tuple) <= 1: raise NoSuchIdentifierError(f"Missing namespace or invalid identifier: {'.'.join(identifier_tuple)}") return identifier_tuple def _split_identifier_for_path( - self, identifier: Union[str, Identifier, TableIdentifier], kind: IdentifierKind = IdentifierKind.TABLE + self, identifier: str | Identifier | TableIdentifier, kind: IdentifierKind = IdentifierKind.TABLE ) -> Properties: if isinstance(identifier, TableIdentifier): return {"namespace": NAMESPACE_SEPARATOR.join(identifier.namespace.root), kind.value: identifier.name} @@ -392,7 +391,7 @@ def _split_identifier_for_path( return {"namespace": NAMESPACE_SEPARATOR.join(identifier_tuple[:-1]), kind.value: identifier_tuple[-1]} - def _split_identifier_for_json(self, identifier: Union[str, Identifier]) -> Dict[str, Union[Identifier, str]]: + def _split_identifier_for_json(self, identifier: str | Identifier) -> Dict[str, Identifier | str]: identifier_tuple = self._identifier_to_validated_tuple(identifier) return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]} @@ -488,9 +487,9 @@ def _config_headers(self, session: Session) -> None: def _create_table( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -530,9 +529,9 @@ def _create_table( @retry(**_RETRY_ARGS) def create_table( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -551,9 +550,9 @@ def create_table( @retry(**_RETRY_ARGS) def create_table_transaction( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -571,7 +570,7 @@ def create_table_transaction( return CreateTableTransaction(staged_table) @retry(**_RETRY_ARGS) - def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + def register_table(self, identifier: str | Identifier, metadata_location: str) -> Table: """Register a new table using existing metadata. Args: @@ -603,7 +602,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: return self._response_to_table(self.identifier_to_tuple(identifier), table_response) @retry(**_RETRY_ARGS) - def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_tables(self, namespace: str | Identifier) -> List[Identifier]: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.get(self.url(Endpoints.list_tables, namespace=namespace_concat)) @@ -614,7 +613,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: return [(*table.namespace, table.name) for table in ListTablesResponse.model_validate_json(response.text).identifiers] @retry(**_RETRY_ARGS) - def load_table(self, identifier: Union[str, Identifier]) -> Table: + def load_table(self, identifier: str | Identifier) -> Table: params = {} if mode := self.properties.get(SNAPSHOT_LOADING_MODE): if mode in {"all", "refs"}: @@ -634,7 +633,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: return self._response_to_table(self.identifier_to_tuple(identifier), table_response) @retry(**_RETRY_ARGS) - def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None: + def drop_table(self, identifier: str | Identifier, purge_requested: bool = False) -> None: response = self._session.delete( self.url(Endpoints.drop_table, prefixed=True, **self._split_identifier_for_path(identifier)), params={"purgeRequested": purge_requested}, @@ -645,11 +644,11 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = _handle_non_200_response(exc, {404: NoSuchTableError}) @retry(**_RETRY_ARGS) - def purge_table(self, identifier: Union[str, Identifier]) -> None: + def purge_table(self, identifier: str | Identifier) -> None: self.drop_table(identifier=identifier, purge_requested=True) @retry(**_RETRY_ARGS) - def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: payload = { "source": self._split_identifier_for_json(from_identifier), "destination": self._split_identifier_for_json(to_identifier), @@ -684,7 +683,7 @@ def _remove_catalog_name_from_table_request_identifier(self, table_request: Comm return table_request @retry(**_RETRY_ARGS) - def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_views(self, namespace: str | Identifier) -> List[Identifier]: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.get(self.url(Endpoints.list_views, namespace=namespace_concat)) @@ -741,7 +740,7 @@ def commit_table( return CommitTableResponse.model_validate_json(response.text) @retry(**_RETRY_ARGS) - def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: namespace_tuple = self._check_valid_namespace_identifier(namespace) payload = {"namespace": namespace_tuple, "properties": properties} response = self._session.post(self.url(Endpoints.create_namespace), json=payload) @@ -751,7 +750,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper _handle_non_200_response(exc, {409: NamespaceAlreadyExistsError}) @retry(**_RETRY_ARGS) - def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + def drop_namespace(self, namespace: str | Identifier) -> None: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.delete(self.url(Endpoints.drop_namespace, namespace=namespace)) @@ -761,7 +760,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: _handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError}) @retry(**_RETRY_ARGS) - def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + def list_namespaces(self, namespace: str | Identifier = ()) -> List[Identifier]: namespace_tuple = self.identifier_to_tuple(namespace) response = self._session.get( self.url( @@ -778,7 +777,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi return ListNamespaceResponse.model_validate_json(response.text).namespaces @retry(**_RETRY_ARGS) - def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + def load_namespace_properties(self, namespace: str | Identifier) -> Properties: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace)) @@ -791,7 +790,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper @retry(**_RETRY_ARGS) def update_namespace_properties( - self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + self, namespace: str | Identifier, removals: Set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) @@ -809,7 +808,7 @@ def update_namespace_properties( ) @retry(**_RETRY_ARGS) - def namespace_exists(self, namespace: Union[str, Identifier]) -> bool: + def namespace_exists(self, namespace: str | Identifier) -> bool: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) response = self._session.head(self.url(Endpoints.namespace_exists, namespace=namespace)) @@ -827,7 +826,7 @@ def namespace_exists(self, namespace: Union[str, Identifier]) -> bool: return False @retry(**_RETRY_ARGS) - def table_exists(self, identifier: Union[str, Identifier]) -> bool: + def table_exists(self, identifier: str | Identifier) -> bool: """Check if a table exists. Args: @@ -853,7 +852,7 @@ def table_exists(self, identifier: Union[str, Identifier]) -> bool: return False @retry(**_RETRY_ARGS) - def view_exists(self, identifier: Union[str, Identifier]) -> bool: + def view_exists(self, identifier: str | Identifier) -> bool: """Check if a view exists. Args: @@ -878,7 +877,7 @@ def view_exists(self, identifier: Union[str, Identifier]) -> bool: return False @retry(**_RETRY_ARGS) - def drop_view(self, identifier: Union[str]) -> None: + def drop_view(self, identifier: str) -> None: response = self._session.delete( self.url(Endpoints.drop_view, prefixed=True, **self._split_identifier_for_path(identifier, IdentifierKind.VIEW)), ) diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py index 527e82060e..48e49f08da 100644 --- a/pyiceberg/catalog/rest/auth.py +++ b/pyiceberg/catalog/rest/auth.py @@ -22,7 +22,7 @@ import time from abc import ABC, abstractmethod from functools import cached_property -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Type import requests from requests import HTTPError, PreparedRequest, Session @@ -43,14 +43,14 @@ class AuthManager(ABC): """ @abstractmethod - def auth_header(self) -> Optional[str]: + def auth_header(self) -> str | None: """Return the Authorization header value, or None if not applicable.""" class NoopAuthManager(AuthManager): """Auth Manager implementation with no auth.""" - def auth_header(self) -> Optional[str]: + def auth_header(self) -> str | None: return None @@ -73,18 +73,18 @@ class LegacyOAuth2AuthManager(AuthManager): """ _session: Session - _auth_url: Optional[str] - _token: Optional[str] - _credential: Optional[str] - _optional_oauth_params: Optional[Dict[str, str]] + _auth_url: str | None + _token: str | None + _credential: str | None + _optional_oauth_params: Dict[str, str] | None def __init__( self, session: Session, - auth_url: Optional[str] = None, - credential: Optional[str] = None, - initial_token: Optional[str] = None, - optional_oauth_params: Optional[Dict[str, str]] = None, + auth_url: str | None = None, + credential: str | None = None, + initial_token: str | None = None, + optional_oauth_params: Dict[str, str] | None = None, ): self._session = session self._auth_url = auth_url @@ -131,11 +131,11 @@ class OAuth2TokenProvider: client_id: str client_secret: str token_url: str - scope: Optional[str] + scope: str | None refresh_margin: int - expires_in: Optional[int] + expires_in: int | None - _token: Optional[str] + _token: str | None _expires_at: int _lock: threading.Lock @@ -144,9 +144,9 @@ def __init__( client_id: str, client_secret: str, token_url: str, - scope: Optional[str] = None, + scope: str | None = None, refresh_margin: int = 60, - expires_in: Optional[int] = None, + expires_in: int | None = None, ): self.client_id = client_id self.client_secret = client_secret @@ -200,9 +200,9 @@ def __init__( client_id: str, client_secret: str, token_url: str, - scope: Optional[str] = None, + scope: str | None = None, refresh_margin: int = 60, - expires_in: Optional[int] = None, + expires_in: int | None = None, ): self.token_provider = OAuth2TokenProvider( client_id, @@ -220,7 +220,7 @@ def auth_header(self) -> str: class GoogleAuthManager(AuthManager): """An auth manager that is responsible for handling Google credentials.""" - def __init__(self, credentials_path: Optional[str] = None, scopes: Optional[List[str]] = None): + def __init__(self, credentials_path: str | None = None, scopes: List[str] | None = None): """ Initialize GoogleAuthManager. diff --git a/pyiceberg/catalog/rest/response.py b/pyiceberg/catalog/rest/response.py index 8f23af8c35..d28a7c3f71 100644 --- a/pyiceberg/catalog/rest/response.py +++ b/pyiceberg/catalog/rest/response.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from json import JSONDecodeError -from typing import Dict, Literal, Optional, Type +from typing import Dict, Literal, Type from pydantic import Field, ValidationError from requests import HTTPError @@ -36,10 +36,10 @@ class TokenResponse(IcebergBaseModel): access_token: str = Field() token_type: str = Field() - expires_in: Optional[int] = Field(default=None) - issued_token_type: Optional[str] = Field(default=None) - refresh_token: Optional[str] = Field(default=None) - scope: Optional[str] = Field(default=None) + expires_in: int | None = Field(default=None) + issued_token_type: str | None = Field(default=None) + refresh_token: str | None = Field(default=None) + scope: str | None = Field(default=None) class ErrorResponseMessage(IcebergBaseModel): @@ -56,8 +56,8 @@ class OAuthErrorResponse(IcebergBaseModel): error: Literal[ "invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope" ] - error_description: Optional[str] = None - error_uri: Optional[str] = None + error_description: str | None = None + error_uri: str | None = None def _handle_non_200_response(exc: HTTPError, error_handler: Dict[int, Type[Exception]]) -> None: diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index dfa573bc13..cefb22b95b 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -18,7 +18,6 @@ from typing import ( TYPE_CHECKING, List, - Optional, Set, Tuple, Union, @@ -91,8 +90,8 @@ class IcebergTables(SqlCatalogBaseTable): catalog_name: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) table_namespace: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) table_name: Mapped[str] = mapped_column(String(255), nullable=False, primary_key=True) - metadata_location: Mapped[Optional[str]] = mapped_column(String(1000), nullable=True) - previous_metadata_location: Mapped[Optional[str]] = mapped_column(String(1000), nullable=True) + metadata_location: Mapped[str | None] = mapped_column(String(1000), nullable=True) + previous_metadata_location: Mapped[str | None] = mapped_column(String(1000), nullable=True) class IcebergNamespaceProperties(SqlCatalogBaseTable): @@ -174,9 +173,9 @@ def _convert_orm_to_iceberg(self, orm_table: IcebergTables) -> Table: def create_table( self, - identifier: Union[str, Identifier], + identifier: str | Identifier, schema: Union[Schema, "pa.Schema"], - location: Optional[str] = None, + location: str | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, @@ -237,7 +236,7 @@ def create_table( return self.load_table(identifier=identifier) - def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: + def register_table(self, identifier: str | Identifier, metadata_location: str) -> Table: """Register a new table using existing metadata. Args: @@ -274,7 +273,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: return self.load_table(identifier=identifier) - def load_table(self, identifier: Union[str, Identifier]) -> Table: + def load_table(self, identifier: str | Identifier) -> Table: """Load the table's metadata and return the table instance. You can also use this method to check for table existence using 'try catalog.table() except NoSuchTableError'. @@ -303,7 +302,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: return self._convert_orm_to_iceberg(result) raise NoSuchTableError(f"Table does not exist: {namespace}.{table_name}") - def drop_table(self, identifier: Union[str, Identifier]) -> None: + def drop_table(self, identifier: str | Identifier) -> None: """Drop a table. Args: @@ -343,7 +342,7 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None: raise NoSuchTableError(f"Table does not exist: {namespace}.{table_name}") from e session.commit() - def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: + def rename_table(self, from_identifier: str | Identifier, to_identifier: str | Identifier) -> Table: """Rename a fully classified table name. Args: @@ -424,7 +423,7 @@ def commit_table( namespace = Catalog.namespace_to_string(namespace_tuple) table_name = Catalog.table_name_from(table_identifier) - current_table: Optional[Table] + current_table: Table | None try: current_table = self.load_table(table_identifier) except NoSuchTableError: @@ -498,7 +497,7 @@ def commit_table( metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location ) - def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: + def _namespace_exists(self, identifier: str | Identifier) -> bool: namespace_tuple = Catalog.identifier_to_tuple(identifier) namespace = Catalog.namespace_to_string(namespace_tuple, NoSuchNamespaceError) namespace_starts_with = namespace.replace("!", "!!").replace("_", "!_").replace("%", "!%") + ".%" @@ -530,7 +529,7 @@ def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: return True return False - def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: + def create_namespace(self, namespace: str | Identifier, properties: Properties = EMPTY_DICT) -> None: """Create a namespace in the catalog. Args: @@ -558,7 +557,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper ) session.commit() - def drop_namespace(self, namespace: Union[str, Identifier]) -> None: + def drop_namespace(self, namespace: str | Identifier) -> None: """Drop a namespace. Args: @@ -584,7 +583,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: ) session.commit() - def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_tables(self, namespace: str | Identifier) -> List[Identifier]: """List tables under the given namespace in the catalog. Args: @@ -605,7 +604,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: result = session.scalars(stmt) return [(Catalog.identifier_to_tuple(table.table_namespace) + (table.table_name,)) for table in result] - def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: + def list_namespaces(self, namespace: str | Identifier = ()) -> List[Identifier]: """List namespaces from the given namespace. If not given, list top-level namespaces from the catalog. Args: @@ -646,7 +645,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi return namespaces - def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: + def load_namespace_properties(self, namespace: str | Identifier) -> Properties: """Get properties for a namespace. Args: @@ -670,7 +669,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper return {props.property_key: props.property_value for props in result} def update_namespace_properties( - self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT + self, namespace: str | Identifier, removals: Set[str] | None = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: """Remove provided property keys and update properties for a namespace. @@ -725,13 +724,13 @@ def update_namespace_properties( session.commit() return properties_update_summary - def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: + def list_views(self, namespace: str | Identifier) -> List[Identifier]: raise NotImplementedError - def view_exists(self, identifier: Union[str, Identifier]) -> bool: + def view_exists(self, identifier: str | Identifier) -> bool: raise NotImplementedError - def drop_view(self, identifier: Union[str, Identifier]) -> None: + def drop_view(self, identifier: str | Identifier) -> None: raise NotImplementedError def close(self) -> None: diff --git a/pyiceberg/cli/console.py b/pyiceberg/cli/console.py index d918f87918..f3adf830b2 100644 --- a/pyiceberg/cli/console.py +++ b/pyiceberg/cli/console.py @@ -21,7 +21,6 @@ Callable, Dict, Literal, - Optional, Tuple, ) @@ -64,12 +63,12 @@ def wrapper(*args: Any, **kwargs: Any): # type: ignore @click.pass_context def run( ctx: Context, - catalog: Optional[str], + catalog: str | None, verbose: bool, output: str, - ugi: Optional[str], - uri: Optional[str], - credential: Optional[str], + ugi: str | None, + uri: str | None, + credential: str | None, ) -> None: properties = {} if ugi: @@ -107,7 +106,7 @@ def _catalog_and_output(ctx: Context) -> Tuple[Catalog, Output]: @click.pass_context @click.argument("parent", required=False) @catch_exception() -def list(ctx: Context, parent: Optional[str]) -> None: # pylint: disable=redefined-builtin +def list(ctx: Context, parent: str | None) -> None: # pylint: disable=redefined-builtin """List tables or namespaces.""" catalog, output = _catalog_and_output(ctx) diff --git a/pyiceberg/cli/output.py b/pyiceberg/cli/output.py index 0eb85841bf..b546877fac 100644 --- a/pyiceberg/cli/output.py +++ b/pyiceberg/cli/output.py @@ -20,7 +20,6 @@ Any, Dict, List, - Optional, Tuple, ) from uuid import UUID @@ -65,7 +64,7 @@ def schema(self, schema: Schema) -> None: ... def spec(self, spec: PartitionSpec) -> None: ... @abstractmethod - def uuid(self, uuid: Optional[UUID]) -> None: ... + def uuid(self, uuid: UUID | None) -> None: ... @abstractmethod def version(self, version: str) -> None: ... @@ -169,7 +168,7 @@ def schema(self, schema: Schema) -> None: def spec(self, spec: PartitionSpec) -> None: Console().print(str(spec)) - def uuid(self, uuid: Optional[UUID]) -> None: + def uuid(self, uuid: UUID | None) -> None: Console().print(str(uuid) if uuid else "missing") def version(self, version: str) -> None: @@ -235,7 +234,7 @@ def files(self, table: Table, history: bool) -> None: def spec(self, spec: PartitionSpec) -> None: print(spec.model_dump_json()) - def uuid(self, uuid: Optional[UUID]) -> None: + def uuid(self, uuid: UUID | None) -> None: self._out({"uuid": str(uuid) if uuid else "missing"}) def version(self, version: str) -> None: diff --git a/pyiceberg/conversions.py b/pyiceberg/conversions.py index 7bf7b462e2..b4eaea1b8e 100644 --- a/pyiceberg/conversions.py +++ b/pyiceberg/conversions.py @@ -38,8 +38,6 @@ from typing import ( Any, Callable, - Optional, - Union, ) from pyiceberg.typedef import UTF8, L @@ -98,7 +96,7 @@ def handle_none(func: Callable) -> Callable: # type: ignore func (Callable): A function registered to the singledispatch function `partition_to_py`. """ - def wrapper(primitive_type: PrimitiveType, value_str: Optional[str]) -> Any: + def wrapper(primitive_type: PrimitiveType, value_str: str | None) -> Any: if value_str is None: return None elif value_str == "__HIVE_DEFAULT_PARTITION__": @@ -109,7 +107,7 @@ def wrapper(primitive_type: PrimitiveType, value_str: Optional[str]) -> Any: @singledispatch -def partition_to_py(primitive_type: PrimitiveType, value_str: str) -> Union[int, float, str, uuid.UUID, bytes, Decimal]: +def partition_to_py(primitive_type: PrimitiveType, value_str: str) -> int | float | str | uuid.UUID | bytes | Decimal: """Convert a partition string to a python built-in. Args: @@ -121,7 +119,7 @@ def partition_to_py(primitive_type: PrimitiveType, value_str: str) -> Union[int, @partition_to_py.register(BooleanType) @handle_none -def _(primitive_type: BooleanType, value_str: str) -> Union[int, float, str, uuid.UUID]: +def _(primitive_type: BooleanType, value_str: str) -> int | float | str | uuid.UUID: return strtobool(value_str) @@ -186,7 +184,7 @@ def _(type_: UnknownType, _: str) -> None: @singledispatch def to_bytes( - primitive_type: PrimitiveType, _: Union[bool, bytes, Decimal, date, datetime, float, int, str, time, uuid.UUID] + primitive_type: PrimitiveType, _: bool | bytes | Decimal | date | datetime | float | int | str | time | uuid.UUID ) -> bytes: """Convert a built-in python value to bytes. @@ -218,7 +216,7 @@ def _(_: PrimitiveType, value: int) -> bytes: @to_bytes.register(TimestampType) @to_bytes.register(TimestamptzType) -def _(_: PrimitiveType, value: Union[datetime, int]) -> bytes: +def _(_: PrimitiveType, value: datetime | int) -> bytes: if isinstance(value, datetime): value = datetime_to_micros(value) return _LONG_STRUCT.pack(value) @@ -226,21 +224,21 @@ def _(_: PrimitiveType, value: Union[datetime, int]) -> bytes: @to_bytes.register(TimestampNanoType) @to_bytes.register(TimestamptzNanoType) -def _(_: PrimitiveType, value: Union[datetime, int]) -> bytes: +def _(_: PrimitiveType, value: datetime | int) -> bytes: if isinstance(value, datetime): value = datetime_to_nanos(value) return _LONG_STRUCT.pack(value) @to_bytes.register(DateType) -def _(_: DateType, value: Union[date, int]) -> bytes: +def _(_: DateType, value: date | int) -> bytes: if isinstance(value, date): value = date_to_days(value) return _INT_STRUCT.pack(value) @to_bytes.register(TimeType) -def _(_: TimeType, value: Union[time, int]) -> bytes: +def _(_: TimeType, value: time | int) -> bytes: if isinstance(value, time): value = time_to_micros(value) return _LONG_STRUCT.pack(value) @@ -267,7 +265,7 @@ def _(_: StringType, value: str) -> bytes: @to_bytes.register(UUIDType) -def _(_: UUIDType, value: Union[uuid.UUID, bytes]) -> bytes: +def _(_: UUIDType, value: uuid.UUID | bytes) -> bytes: if isinstance(value, bytes): return value return value.bytes @@ -392,13 +390,13 @@ def _(_: BooleanType, val: bool) -> bool: @to_json.register(IntegerType) @to_json.register(LongType) -def _(_: Union[IntegerType, LongType], val: int) -> int: +def _(_: IntegerType | LongType, val: int) -> int: """Python int automatically converts to a JSON int.""" return val @to_json.register(DateType) -def _(_: DateType, val: Union[date, int]) -> str: +def _(_: DateType, val: date | int) -> str: """JSON date is string encoded.""" if isinstance(val, date): val = date_to_days(val) @@ -406,7 +404,7 @@ def _(_: DateType, val: Union[date, int]) -> str: @to_json.register(TimeType) -def _(_: TimeType, val: Union[int, time]) -> str: +def _(_: TimeType, val: int | time) -> str: """Python time or microseconds since epoch serializes into an ISO8601 time.""" if isinstance(val, time): val = time_to_micros(val) @@ -414,7 +412,7 @@ def _(_: TimeType, val: Union[int, time]) -> str: @to_json.register(TimestampType) -def _(_: PrimitiveType, val: Union[int, datetime]) -> str: +def _(_: PrimitiveType, val: int | datetime) -> str: """Python datetime (without timezone) or microseconds since epoch serializes into an ISO8601 timestamp.""" if isinstance(val, datetime): val = datetime_to_micros(val) @@ -423,7 +421,7 @@ def _(_: PrimitiveType, val: Union[int, datetime]) -> str: @to_json.register(TimestamptzType) -def _(_: TimestamptzType, val: Union[int, datetime]) -> str: +def _(_: TimestamptzType, val: int | datetime) -> str: """Python datetime (with timezone) or microseconds since epoch serializes into an ISO8601 timestamp.""" if isinstance(val, datetime): val = datetime_to_micros(val) @@ -432,7 +430,7 @@ def _(_: TimestamptzType, val: Union[int, datetime]) -> str: @to_json.register(FloatType) @to_json.register(DoubleType) -def _(_: Union[FloatType, DoubleType], val: float) -> float: +def _(_: FloatType | DoubleType, val: float) -> float: """Float serializes into JSON float.""" return val @@ -497,13 +495,13 @@ def _(_: BooleanType, val: bool) -> bool: @from_json.register(IntegerType) @from_json.register(LongType) -def _(_: Union[IntegerType, LongType], val: int) -> int: +def _(_: IntegerType | LongType, val: int) -> int: """JSON int automatically converts to a Python int.""" return val @from_json.register(DateType) -def _(_: DateType, val: Union[str, int, date]) -> date: +def _(_: DateType, val: str | int | date) -> date: """JSON date is string encoded.""" if isinstance(val, str): val = date_str_to_days(val) @@ -514,7 +512,7 @@ def _(_: DateType, val: Union[str, int, date]) -> date: @from_json.register(TimeType) -def _(_: TimeType, val: Union[str, int, time]) -> time: +def _(_: TimeType, val: str | int | time) -> time: """JSON ISO8601 string into Python time.""" if isinstance(val, str): val = time_str_to_micros(val) @@ -525,7 +523,7 @@ def _(_: TimeType, val: Union[str, int, time]) -> time: @from_json.register(TimestampType) -def _(_: PrimitiveType, val: Union[str, int, datetime]) -> datetime: +def _(_: PrimitiveType, val: str | int | datetime) -> datetime: """JSON ISO8601 string into Python datetime.""" if isinstance(val, str): val = timestamp_to_micros(val) @@ -536,7 +534,7 @@ def _(_: PrimitiveType, val: Union[str, int, datetime]) -> datetime: @from_json.register(TimestamptzType) -def _(_: TimestamptzType, val: Union[str, int, datetime]) -> datetime: +def _(_: TimestamptzType, val: str | int | datetime) -> datetime: """JSON ISO8601 string into Python datetime.""" if isinstance(val, str): val = timestamptz_to_micros(val) @@ -548,7 +546,7 @@ def _(_: TimestamptzType, val: Union[str, int, datetime]) -> datetime: @from_json.register(FloatType) @from_json.register(DoubleType) -def _(_: Union[FloatType, DoubleType], val: float) -> float: +def _(_: FloatType | DoubleType, val: float) -> float: """JSON float deserializes into a Python float.""" return val @@ -560,7 +558,7 @@ def _(_: StringType, val: str) -> str: @from_json.register(FixedType) -def _(t: FixedType, val: Union[str, bytes]) -> bytes: +def _(t: FixedType, val: str | bytes) -> bytes: """JSON hexadecimal encoded string into bytes.""" if isinstance(val, str): val = codecs.decode(val.encode(UTF8), "hex") @@ -572,7 +570,7 @@ def _(t: FixedType, val: Union[str, bytes]) -> bytes: @from_json.register(BinaryType) -def _(_: BinaryType, val: Union[bytes, str]) -> bytes: +def _(_: BinaryType, val: bytes | str) -> bytes: """JSON hexadecimal encoded string into bytes.""" if isinstance(val, str): return codecs.decode(val.encode(UTF8), "hex") @@ -587,7 +585,7 @@ def _(_: DecimalType, val: str) -> Decimal: @from_json.register(UUIDType) -def _(_: UUIDType, val: Union[str, bytes, uuid.UUID]) -> uuid.UUID: +def _(_: UUIDType, val: str | bytes | uuid.UUID) -> uuid.UUID: """Convert JSON string into Python UUID.""" if isinstance(val, str): return uuid.UUID(val) diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 0491a1f3c8..330d22b1a4 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -29,7 +29,6 @@ Tuple, Type, TypeVar, - Union, ) from typing import Literal as TypingLiteral @@ -52,15 +51,15 @@ ConfigDict = dict -def _to_unbound_term(term: Union[str, UnboundTerm[Any]]) -> UnboundTerm[Any]: +def _to_unbound_term(term: str | UnboundTerm[Any]) -> UnboundTerm[Any]: return Reference(term) if isinstance(term, str) else term -def _to_literal_set(values: Union[Iterable[L], Iterable[Literal[L]]]) -> Set[Literal[L]]: +def _to_literal_set(values: Iterable[L] | Iterable[Literal[L]]) -> Set[Literal[L]]: return {_to_literal(v) for v in values} -def _to_literal(value: Union[L, Literal[L]]) -> Literal[L]: +def _to_literal(value: L | Literal[L]) -> Literal[L]: if isinstance(value, Literal): return value else: @@ -448,7 +447,7 @@ def as_unbound(self) -> Type[UnboundPredicate[Any]]: ... class UnboundPredicate(Generic[L], Unbound[BooleanExpression], BooleanExpression, ABC): term: UnboundTerm[Any] - def __init__(self, term: Union[str, UnboundTerm[Any]]): + def __init__(self, term: str | UnboundTerm[Any]): self.term = _to_unbound_term(term) def __eq__(self, other: Any) -> bool: @@ -468,7 +467,7 @@ class UnaryPredicate(IcebergBaseModel, UnboundPredicate[Any], ABC): model_config = {"arbitrary_types_allowed": True} - def __init__(self, term: Union[str, UnboundTerm[Any]]): + def __init__(self, term: str | UnboundTerm[Any]): unbound = _to_unbound_term(term) super().__init__(term=unbound) @@ -620,7 +619,7 @@ class SetPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): type: TypingLiteral["in", "not-in"] = Field(default="in") literals: Set[Literal[L]] = Field(alias="items") - def __init__(self, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]]): + def __init__(self, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]]): super().__init__(term=_to_unbound_term(term), items=_to_literal_set(literals)) # type: ignore def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate[L]: @@ -736,7 +735,7 @@ class In(SetPredicate[L]): type: TypingLiteral["in"] = Field(default="in", alias="type") def __new__( # type: ignore # pylint: disable=W0221 - cls, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]] + cls, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]] ) -> BooleanExpression: literals_set: Set[Literal[L]] = _to_literal_set(literals) count = len(literals_set) @@ -760,7 +759,7 @@ class NotIn(SetPredicate[L], ABC): type: TypingLiteral["not-in"] = Field(default="not-in", alias="type") def __new__( # type: ignore # pylint: disable=W0221 - cls, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]] + cls, term: str | UnboundTerm[Any], literals: Iterable[L] | Iterable[Literal[L]] ) -> BooleanExpression: literals_set: Set[Literal[L]] = _to_literal_set(literals) count = len(literals_set) @@ -786,7 +785,7 @@ class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC): value: Literal[L] = Field() model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True) - def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): + def __init__(self, term: str | UnboundTerm[Any], literal: L | Literal[L]): super().__init__(term=_to_unbound_term(term), value=_to_literal(literal)) # type: ignore[call-arg] @property diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index a6268c0d48..ee8d1e930a 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -27,7 +27,6 @@ SupportsFloat, Tuple, TypeVar, - Union, ) from pyiceberg.conversions import from_bytes @@ -1014,7 +1013,7 @@ class ExpressionToPlainFormat(BoundBooleanExpressionVisitor[List[Tuple[str, str, def __init__(self, cast_int_to_date: bool = False) -> None: self.cast_int_to_date = cast_int_to_date - def _cast_if_necessary(self, iceberg_type: IcebergType, literal: Union[L, Set[L]]) -> Union[L, Set[L]]: + def _cast_if_necessary(self, iceberg_type: IcebergType, literal: L | Set[L]) -> L | Set[L]: if self.cast_int_to_date: iceberg_type_class = type(iceberg_type) conversions = {TimestampType: micros_to_timestamp, TimestamptzType: micros_to_timestamptz} diff --git a/pyiceberg/io/__init__.py b/pyiceberg/io/__init__.py index 8836dec79d..1915afcd0b 100644 --- a/pyiceberg/io/__init__.py +++ b/pyiceberg/io/__init__.py @@ -34,10 +34,8 @@ from typing import ( Dict, List, - Optional, Protocol, Type, - Union, runtime_checkable, ) from urllib.parse import urlparse @@ -128,9 +126,7 @@ def __enter__(self) -> InputStream: """Provide setup when opening an InputStream using a 'with' statement.""" @abstractmethod - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ) -> None: + def __exit__(self, exctype: Type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None) -> None: """Perform cleanup when exiting the scope of a 'with' statement.""" @@ -153,9 +149,7 @@ def __enter__(self) -> OutputStream: """Provide setup when opening an OutputStream using a 'with' statement.""" @abstractmethod - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ) -> None: + def __exit__(self, exctype: Type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None) -> None: """Perform cleanup when exiting the scope of a 'with' statement.""" @@ -283,7 +277,7 @@ def new_output(self, location: str) -> OutputFile: """ @abstractmethod - def delete(self, location: Union[str, InputFile, OutputFile]) -> None: + def delete(self, location: str | InputFile | OutputFile) -> None: """Delete the file at the given path. Args: @@ -321,7 +315,7 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None: } -def _import_file_io(io_impl: str, properties: Properties) -> Optional[FileIO]: +def _import_file_io(io_impl: str, properties: Properties) -> FileIO | None: try: path_parts = io_impl.split(".") if len(path_parts) < 2: @@ -338,7 +332,7 @@ def _import_file_io(io_impl: str, properties: Properties) -> Optional[FileIO]: PY_IO_IMPL = "py-io-impl" -def _infer_file_io_from_scheme(path: str, properties: Properties) -> Optional[FileIO]: +def _infer_file_io_from_scheme(path: str, properties: Properties) -> FileIO | None: parsed_url = urlparse(path) if parsed_url.scheme: if file_ios := SCHEMA_TO_FILE_IO.get(parsed_url.scheme): @@ -350,7 +344,7 @@ def _infer_file_io_from_scheme(path: str, properties: Properties) -> Optional[Fi return None -def load_file_io(properties: Properties = EMPTY_DICT, location: Optional[str] = None) -> FileIO: +def load_file_io(properties: Properties = EMPTY_DICT, location: str | None = None) -> FileIO: # First look for the py-io-impl property to directly load the class if io_impl := properties.get(PY_IO_IMPL): if file_io := _import_file_io(io_impl, properties): diff --git a/pyiceberg/io/fsspec.py b/pyiceberg/io/fsspec.py index d5e89c92f7..8f2fcc4312 100644 --- a/pyiceberg/io/fsspec.py +++ b/pyiceberg/io/fsspec.py @@ -30,7 +30,6 @@ Callable, Dict, Type, - Union, ) from urllib.parse import urlparse @@ -426,7 +425,7 @@ def new_output(self, location: str) -> FsspecOutputFile: fs = self.get_fs(uri.scheme) return FsspecOutputFile(location=location, fs=fs) - def delete(self, location: Union[str, InputFile, OutputFile]) -> None: + def delete(self, location: str | InputFile | OutputFile) -> None: """Delete the file at the given location. Args: diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index d1484a834a..5b4c041ff5 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -49,11 +49,9 @@ Iterable, Iterator, List, - Optional, Set, Tuple, TypeVar, - Union, cast, ) from urllib.parse import urlparse @@ -214,7 +212,7 @@ @lru_cache -def _cached_resolve_s3_region(bucket: str) -> Optional[str]: +def _cached_resolve_s3_region(bucket: str) -> str | None: from pyarrow.fs import resolve_s3_region try: @@ -224,7 +222,7 @@ def _cached_resolve_s3_region(bucket: str) -> Optional[str]: return None -def _import_retry_strategy(impl: str) -> Optional[S3RetryStrategy]: +def _import_retry_strategy(impl: str) -> S3RetryStrategy | None: try: path_parts = impl.split(".") if len(path_parts) < 2: @@ -387,10 +385,10 @@ def to_input_file(self) -> PyArrowFile: class PyArrowFileIO(FileIO): - fs_by_scheme: Callable[[str, Optional[str]], FileSystem] + fs_by_scheme: Callable[[str, str | None], FileSystem] def __init__(self, properties: Properties = EMPTY_DICT): - self.fs_by_scheme: Callable[[str, Optional[str]], FileSystem] = lru_cache(self._initialize_fs) + self.fs_by_scheme: Callable[[str, str | None], FileSystem] = lru_cache(self._initialize_fs) super().__init__(properties=properties) @staticmethod @@ -410,7 +408,7 @@ def parse_location(location: str, properties: Properties = EMPTY_DICT) -> Tuple[ else: return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}" - def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSystem: + def _initialize_fs(self, scheme: str, netloc: str | None = None) -> FileSystem: """Initialize FileSystem for different scheme.""" if scheme in {"oss"}: return self._initialize_oss_fs() @@ -465,7 +463,7 @@ def _initialize_oss_fs(self) -> FileSystem: return S3FileSystem(**client_kwargs) - def _initialize_s3_fs(self, netloc: Optional[str]) -> FileSystem: + def _initialize_s3_fs(self, netloc: str | None) -> FileSystem: from pyarrow.fs import S3FileSystem provided_region = get_first_property_value(self.properties, S3_REGION, AWS_REGION) @@ -575,7 +573,7 @@ def _initialize_azure_fs(self) -> FileSystem: return AzureFileSystem(**client_kwargs) - def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem: + def _initialize_hdfs_fs(self, scheme: str, netloc: str | None) -> FileSystem: from pyarrow.fs import HadoopFileSystem hdfs_kwargs: Dict[str, Any] = {} @@ -647,7 +645,7 @@ def new_output(self, location: str) -> PyArrowFile: buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)), ) - def delete(self, location: Union[str, InputFile, OutputFile]) -> None: + def delete(self, location: str | InputFile | OutputFile) -> None: """Delete the file at the given location. Args: @@ -689,7 +687,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: def schema_to_pyarrow( - schema: Union[Schema, IcebergType], + schema: Schema | IcebergType, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True, file_format: FileFormat = FileFormat.PARQUET, @@ -701,7 +699,7 @@ class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]): _metadata: Dict[bytes, bytes] def __init__( - self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True, file_format: Optional[FileFormat] = None + self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True, file_format: FileFormat | None = None ) -> None: self._metadata = metadata self._include_field_ids = include_field_ids @@ -1080,7 +1078,7 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start def pyarrow_to_schema( schema: pa.Schema, - name_mapping: Optional[NameMapping] = None, + name_mapping: NameMapping | None = None, downcast_ns_timestamp_to_us: bool = False, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, ) -> Schema: @@ -1120,7 +1118,7 @@ def _pyarrow_schema_ensure_small_types(schema: pa.Schema) -> pa.Schema: @singledispatch -def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T: +def visit_pyarrow(obj: pa.DataType | pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> T: """Apply a pyarrow schema visitor to any point within a schema. The function traverses the schema in post-order fashion. @@ -1150,7 +1148,7 @@ def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> T: @visit_pyarrow.register(pa.ListType) @visit_pyarrow.register(pa.FixedSizeListType) @visit_pyarrow.register(pa.LargeListType) -def _(obj: Union[pa.ListType, pa.LargeListType, pa.FixedSizeListType], visitor: PyArrowSchemaVisitor[T]) -> T: +def _(obj: pa.ListType | pa.LargeListType | pa.FixedSizeListType, visitor: PyArrowSchemaVisitor[T]) -> T: visitor.before_list_element(obj.value_field) result = visit_pyarrow(obj.value_type, visitor) visitor.after_list_element(obj.value_field) @@ -1251,7 +1249,7 @@ def primitive(self, primitive: pa.DataType) -> T: """Visit a primitive type.""" -def _get_field_id(field: pa.Field) -> Optional[int]: +def _get_field_id(field: pa.Field) -> int | None: """Return the Iceberg field ID from Parquet or ORC metadata if available.""" if field.metadata: # Try Parquet field ID first @@ -1291,7 +1289,7 @@ def primitive(self, primitive: pa.DataType) -> bool: return True -class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): +class _ConvertToIceberg(PyArrowSchemaVisitor[IcebergType | Schema]): """Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided.""" _field_names: List[str] @@ -1428,7 +1426,7 @@ def after_map_value(self, element: pa.Field) -> None: self._field_names.pop() -class _ConvertToLargeTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]): +class _ConvertToLargeTypes(PyArrowSchemaVisitor[IcebergType | pa.Schema]): def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema: return pa.schema(struct_result) @@ -1452,7 +1450,7 @@ def primitive(self, primitive: pa.DataType) -> pa.DataType: return primitive -class _ConvertToSmallTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]): +class _ConvertToSmallTypes(PyArrowSchemaVisitor[IcebergType | pa.Schema]): def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema: return pa.schema(struct_result) @@ -1495,7 +1493,7 @@ def _get_column_projection_values( file: DataFile, projected_schema: Schema, table_schema: Schema, - partition_spec: Optional[PartitionSpec], + partition_spec: PartitionSpec | None, file_project_field_ids: Set[int], ) -> Dict[int, Any]: """Apply Column Projection rules to File Schema.""" @@ -1523,12 +1521,12 @@ def _task_to_record_batches( projected_schema: Schema, table_schema: Schema, projected_field_ids: Set[int], - positional_deletes: Optional[List[ChunkedArray]], + positional_deletes: List[ChunkedArray] | None, case_sensitive: bool, - name_mapping: Optional[NameMapping] = None, - partition_spec: Optional[PartitionSpec] = None, + name_mapping: NameMapping | None = None, + partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, - downcast_ns_timestamp_to_us: Optional[bool] = None, + downcast_ns_timestamp_to_us: bool | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with io.new_input(task.file.file_path).open() as fin: @@ -1629,8 +1627,8 @@ class ArrowScan: _projected_schema: Schema _bound_row_filter: BooleanExpression _case_sensitive: bool - _limit: Optional[int] - _downcast_ns_timestamp_to_us: Optional[bool] + _limit: int | None + _downcast_ns_timestamp_to_us: bool | None """Scan the Iceberg Table and create an Arrow construct. Attributes: @@ -1649,7 +1647,7 @@ def __init__( projected_schema: Schema, row_filter: BooleanExpression, case_sensitive: bool = True, - limit: Optional[int] = None, + limit: int | None = None, ) -> None: self._table_metadata = table_metadata self._io = io @@ -1807,11 +1805,11 @@ def _to_requested_schema( return pa.RecordBatch.from_struct_array(struct_array) -class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): +class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, pa.Array | None]): _file_schema: Schema _include_field_ids: bool _downcast_ns_timestamp_to_us: bool - _use_large_types: Optional[bool] + _use_large_types: bool | None _projected_missing_fields: Dict[int, Any] def __init__( @@ -1819,7 +1817,7 @@ def __init__( file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False, - use_large_types: Optional[bool] = None, + use_large_types: bool | None = None, projected_missing_fields: Dict[int, Any] = EMPTY_DICT, ) -> None: self._file_schema = file_schema @@ -1892,12 +1890,10 @@ def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Fi metadata=metadata, ) - def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]: + def schema(self, schema: Schema, schema_partner: pa.Array | None, struct_result: pa.Array | None) -> pa.Array | None: return struct_result - def struct( - self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]] - ) -> Optional[pa.Array]: + def struct(self, struct: StructType, struct_array: pa.Array | None, field_results: List[pa.Array | None]) -> pa.Array | None: if struct_array is None: return None field_arrays: List[pa.Array] = [] @@ -1926,10 +1922,10 @@ def struct( mask=struct_array.is_null() if isinstance(struct_array, pa.StructArray) else None, ) - def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional[pa.Array]) -> Optional[pa.Array]: + def field(self, field: NestedField, _: pa.Array | None, field_array: pa.Array | None) -> pa.Array | None: return field_array - def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]: + def list(self, list_type: ListType, list_array: pa.Array | None, value_array: pa.Array | None) -> pa.Array | None: if isinstance(list_array, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) and value_array is not None: list_initializer = pa.large_list if isinstance(list_array, pa.LargeListArray) else pa.list_ if isinstance(value_array, pa.StructArray): @@ -1943,8 +1939,8 @@ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: return None def map( - self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array] - ) -> Optional[pa.Array]: + self, map_type: MapType, map_array: pa.Array | None, key_result: pa.Array | None, value_result: pa.Array | None + ) -> pa.Array | None: if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None: key_result = self._cast_if_needed(map_type.key_field, key_result) value_result = self._cast_if_needed(map_type.value_field, value_result) @@ -1960,7 +1956,7 @@ def map( else: return None - def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa.Array]: + def primitive(self, _: PrimitiveType, array: pa.Array | None) -> pa.Array | None: return array @@ -1970,10 +1966,10 @@ class ArrowAccessor(PartnerAccessor[pa.Array]): def __init__(self, file_schema: Schema): self.file_schema = file_schema - def schema_partner(self, partner: Optional[pa.Array]) -> Optional[pa.Array]: + def schema_partner(self, partner: pa.Array | None) -> pa.Array | None: return partner - def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]: + def field_partner(self, partner_struct: pa.Array | None, field_id: int, _: str) -> pa.Array | None: if partner_struct is not None: # use the field name from the file schema try: @@ -1992,13 +1988,13 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st return None - def list_element_partner(self, partner_list: Optional[pa.Array]) -> Optional[pa.Array]: + def list_element_partner(self, partner_list: pa.Array | None) -> pa.Array | None: return partner_list.values if isinstance(partner_list, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) else None - def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]: + def map_key_partner(self, partner_map: pa.Array | None) -> pa.Array | None: return partner_map.keys if isinstance(partner_map, pa.MapArray) else None - def map_value_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]: + def map_value_partner(self, partner_map: pa.Array | None) -> pa.Array | None: return partner_map.items if isinstance(partner_map, pa.MapArray) else None @@ -2080,9 +2076,9 @@ def visit_unknown(self, unknown_type: UnknownType) -> str: class StatsAggregator: current_min: Any current_max: Any - trunc_length: Optional[int] + trunc_length: int | None - def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc_length: Optional[int] = None) -> None: + def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc_length: int | None = None) -> None: self.current_min = None self.current_max = None self.trunc_length = trunc_length @@ -2110,19 +2106,19 @@ def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc def serialize(self, value: Any) -> bytes: return to_bytes(self.primitive_type, value) - def update_min(self, val: Optional[Any]) -> None: + def update_min(self, val: Any | None) -> None: if self.current_min is None: self.current_min = val elif val is not None: self.current_min = min(val, self.current_min) - def update_max(self, val: Optional[Any]) -> None: + def update_max(self, val: Any | None) -> None: if self.current_max is None: self.current_max = val elif val is not None: self.current_max = max(val, self.current_max) - def min_as_bytes(self) -> Optional[bytes]: + def min_as_bytes(self) -> bytes | None: if self.current_min is None: return None @@ -2132,7 +2128,7 @@ def min_as_bytes(self) -> Optional[bytes]: else TruncateTransform(width=self.trunc_length).transform(self.primitive_type)(self.current_min) ) - def max_as_bytes(self) -> Optional[bytes]: + def max_as_bytes(self) -> bytes | None: if self.current_max is None: return None @@ -2166,7 +2162,7 @@ class MetricModeTypes(Enum): @dataclass(frozen=True) class MetricsMode(Singleton): type: MetricModeTypes - length: Optional[int] = None + length: int | None = None def match_metrics_mode(mode: str) -> MetricsMode: @@ -2766,8 +2762,8 @@ def _dataframe_to_data_files( table_metadata: TableMetadata, df: pa.Table, io: FileIO, - write_uuid: Optional[uuid.UUID] = None, - counter: Optional[itertools.count[int]] = None, + write_uuid: uuid.UUID | None = None, + counter: itertools.count[int] | None = None, ) -> Iterable[DataFile]: """Convert a PyArrow table into a DataFile. diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index eafb2b7c03..40163a18b3 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -28,10 +28,8 @@ Iterator, List, Literal, - Optional, Tuple, Type, - Union, ) from cachetools import LRUCache, cached @@ -102,7 +100,7 @@ class FileFormat(str, Enum): PUFFIN = "PUFFIN" @classmethod - def _missing_(cls, value: object) -> Union[None, str]: + def _missing_(cls, value: object) -> None | str: for member in cls: if member.value == str(value).upper(): return member @@ -501,19 +499,19 @@ def upper_bounds(self) -> Dict[int, bytes]: return self._data[11] @property - def key_metadata(self) -> Optional[bytes]: + def key_metadata(self) -> bytes | None: return self._data[12] @property - def split_offsets(self) -> Optional[List[int]]: + def split_offsets(self) -> List[int] | None: return self._data[13] @property - def equality_ids(self) -> Optional[List[int]]: + def equality_ids(self) -> List[int] | None: return self._data[14] @property - def sort_order_id(self) -> Optional[int]: + def sort_order_id(self) -> int | None: return self._data[15] # Spec ID should not be stored in the file @@ -594,7 +592,7 @@ def status(self, value: ManifestEntryStatus) -> None: self._data[0] = value @property - def snapshot_id(self) -> Optional[int]: + def snapshot_id(self) -> int | None: return self._data[1] @snapshot_id.setter @@ -602,7 +600,7 @@ def snapshot_id(self, value: int) -> None: self._data[0] = value @property - def sequence_number(self) -> Optional[int]: + def sequence_number(self) -> int | None: return self._data[2] @sequence_number.setter @@ -610,7 +608,7 @@ def sequence_number(self, value: int) -> None: self._data[2] = value @property - def file_sequence_number(self) -> Optional[int]: + def file_sequence_number(self) -> int | None: return self._data[3] @file_sequence_number.setter @@ -644,15 +642,15 @@ def contains_null(self) -> bool: return self._data[0] @property - def contains_nan(self) -> Optional[bool]: + def contains_nan(self) -> bool | None: return self._data[1] @property - def lower_bound(self) -> Optional[bytes]: + def lower_bound(self) -> bytes | None: return self._data[2] @property - def upper_bound(self) -> Optional[bytes]: + def upper_bound(self) -> bytes | None: return self._data[3] @@ -660,8 +658,8 @@ class PartitionFieldStats: _type: PrimitiveType _contains_null: bool _contains_nan: bool - _min: Optional[Any] - _max: Optional[Any] + _min: Any | None + _max: Any | None def __init__(self, iceberg_type: PrimitiveType) -> None: self._type = iceberg_type @@ -802,39 +800,39 @@ def min_sequence_number(self, value: int) -> None: self._data[5] = value @property - def added_snapshot_id(self) -> Optional[int]: + def added_snapshot_id(self) -> int | None: return self._data[6] @property - def added_files_count(self) -> Optional[int]: + def added_files_count(self) -> int | None: return self._data[7] @property - def existing_files_count(self) -> Optional[int]: + def existing_files_count(self) -> int | None: return self._data[8] @property - def deleted_files_count(self) -> Optional[int]: + def deleted_files_count(self) -> int | None: return self._data[9] @property - def added_rows_count(self) -> Optional[int]: + def added_rows_count(self) -> int | None: return self._data[10] @property - def existing_rows_count(self) -> Optional[int]: + def existing_rows_count(self) -> int | None: return self._data[11] @property - def deleted_rows_count(self) -> Optional[int]: + def deleted_rows_count(self) -> int | None: return self._data[12] @property - def partitions(self) -> Optional[List[PartitionFieldSummary]]: + def partitions(self) -> List[PartitionFieldSummary] | None: return self._data[13] @property - def key_metadata(self) -> Optional[bytes]: + def key_metadata(self) -> bytes | None: return self._data[14] def has_added_files(self) -> bool: @@ -954,7 +952,7 @@ class ManifestWriter(ABC): _existing_rows: int _deleted_files: int _deleted_rows: int - _min_sequence_number: Optional[int] + _min_sequence_number: int | None _partitions: List[Record] _compression: AvroCompressionCodec @@ -990,9 +988,9 @@ def __enter__(self) -> ManifestWriter: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: """Close the writer.""" if (self._added_files + self._existing_files + self._deleted_files) == 0: @@ -1224,9 +1222,9 @@ def __enter__(self) -> ManifestListWriter: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: """Close the writer.""" self._writer.__exit__(exc_type, exc_value, traceback) @@ -1245,7 +1243,7 @@ def __init__( self, output_file: OutputFile, snapshot_id: int, - parent_snapshot_id: Optional[int], + parent_snapshot_id: int | None, compression: AvroCompressionCodec, ): super().__init__( @@ -1273,7 +1271,7 @@ def __init__( self, output_file: OutputFile, snapshot_id: int, - parent_snapshot_id: Optional[int], + parent_snapshot_id: int | None, sequence_number: int, compression: AvroCompressionCodec, ): @@ -1318,8 +1316,8 @@ def write_manifest_list( format_version: TableVersion, output_file: OutputFile, snapshot_id: int, - parent_snapshot_id: Optional[int], - sequence_number: Optional[int], + parent_snapshot_id: int | None, + sequence_number: int | None, avro_compression: AvroCompressionCodec, ) -> ManifestListWriter: if format_version == 1: diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 408126d3b3..046782c0dc 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -21,7 +21,7 @@ from dataclasses import dataclass from datetime import date, datetime, time from functools import cached_property, singledispatch -from typing import Annotated, Any, Dict, Generic, List, Optional, Set, Tuple, TypeVar, Union +from typing import Annotated, Any, Dict, Generic, List, Set, Tuple, TypeVar from urllib.parse import quote_plus from pydantic import ( @@ -86,10 +86,10 @@ class PartitionField(IcebergBaseModel): def __init__( self, - source_id: Optional[int] = None, - field_id: Optional[int] = None, - transform: Optional[Transform[Any, Any]] = None, - name: Optional[str] = None, + source_id: int | None = None, + field_id: int | None = None, + transform: Transform[Any, Any] | None = None, + name: str | None = None, **data: Any, ): if source_id is not None: @@ -466,7 +466,7 @@ def _to_partition_representation(type: IcebergType, value: Any) -> Any: @_to_partition_representation.register(TimestampType) @_to_partition_representation.register(TimestamptzType) -def _(type: IcebergType, value: Optional[Union[int, datetime]]) -> Optional[int]: +def _(type: IcebergType, value: int | datetime | None) -> int | None: if value is None: return None elif isinstance(value, int): @@ -478,7 +478,7 @@ def _(type: IcebergType, value: Optional[Union[int, datetime]]) -> Optional[int] @_to_partition_representation.register(DateType) -def _(type: IcebergType, value: Optional[Union[int, date]]) -> Optional[int]: +def _(type: IcebergType, value: int | date | None) -> int | None: if value is None: return None elif isinstance(value, int): @@ -490,12 +490,12 @@ def _(type: IcebergType, value: Optional[Union[int, date]]) -> Optional[int]: @_to_partition_representation.register(TimeType) -def _(type: IcebergType, value: Optional[time]) -> Optional[int]: +def _(type: IcebergType, value: time | None) -> int | None: return time_to_micros(value) if value is not None else None @_to_partition_representation.register(UUIDType) -def _(type: IcebergType, value: Optional[Union[uuid.UUID, int, bytes]]) -> Optional[Union[bytes, int]]: +def _(type: IcebergType, value: uuid.UUID | int | bytes | None) -> bytes | int | None: if value is None: return None elif isinstance(value, bytes): @@ -509,5 +509,5 @@ def _(type: IcebergType, value: Optional[Union[uuid.UUID, int, bytes]]) -> Optio @_to_partition_representation.register(PrimitiveType) -def _(type: IcebergType, value: Optional[Any]) -> Optional[Any]: +def _(type: IcebergType, value: Any | None) -> Any | None: return value diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index d9c2d7ddfc..c8c73eded8 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -29,11 +29,9 @@ Generic, List, Literal, - Optional, Set, Tuple, TypeVar, - Union, ) from pydantic import Field, PrivateAttr, model_validator @@ -194,7 +192,7 @@ def as_arrow(self) -> "pa.Schema": return schema_to_pyarrow(self) - def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> NestedField: + def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField: """Find a field using a field name or field ID. Args: @@ -222,7 +220,7 @@ def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) - return self._lazy_id_to_field[field_id] - def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> IcebergType: + def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> IcebergType: """Find a field type using a field name or field ID. Args: @@ -247,7 +245,7 @@ def name_mapping(self) -> NameMapping: return create_mapping_from_schema(self) - def find_column_name(self, column_id: int) -> Optional[str]: + def find_column_name(self, column_id: int) -> str | None: """Find a column name given a column ID. Args: @@ -466,63 +464,63 @@ def primitive(self, primitive: PrimitiveType) -> T: class SchemaWithPartnerVisitor(Generic[P, T], ABC): - def before_field(self, field: NestedField, field_partner: Optional[P]) -> None: + def before_field(self, field: NestedField, field_partner: P | None) -> None: """Override this method to perform an action immediately before visiting a field.""" - def after_field(self, field: NestedField, field_partner: Optional[P]) -> None: + def after_field(self, field: NestedField, field_partner: P | None) -> None: """Override this method to perform an action immediately after visiting a field.""" - def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: + def before_list_element(self, element: NestedField, element_partner: P | None) -> None: """Override this method to perform an action immediately before visiting an element within a ListType.""" self.before_field(element, element_partner) - def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: + def after_list_element(self, element: NestedField, element_partner: P | None) -> None: """Override this method to perform an action immediately after visiting an element within a ListType.""" self.after_field(element, element_partner) - def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: + def before_map_key(self, key: NestedField, key_partner: P | None) -> None: """Override this method to perform an action immediately before visiting a key within a MapType.""" self.before_field(key, key_partner) - def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: + def after_map_key(self, key: NestedField, key_partner: P | None) -> None: """Override this method to perform an action immediately after visiting a key within a MapType.""" self.after_field(key, key_partner) - def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: + def before_map_value(self, value: NestedField, value_partner: P | None) -> None: """Override this method to perform an action immediately before visiting a value within a MapType.""" self.before_field(value, value_partner) - def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: + def after_map_value(self, value: NestedField, value_partner: P | None) -> None: """Override this method to perform an action immediately after visiting a value within a MapType.""" self.after_field(value, value_partner) @abstractmethod - def schema(self, schema: Schema, schema_partner: Optional[P], struct_result: T) -> T: + def schema(self, schema: Schema, schema_partner: P | None, struct_result: T) -> T: """Visit a schema with a partner.""" @abstractmethod - def struct(self, struct: StructType, struct_partner: Optional[P], field_results: List[T]) -> T: + def struct(self, struct: StructType, struct_partner: P | None, field_results: List[T]) -> T: """Visit a struct type with a partner.""" @abstractmethod - def field(self, field: NestedField, field_partner: Optional[P], field_result: T) -> T: + def field(self, field: NestedField, field_partner: P | None, field_result: T) -> T: """Visit a nested field with a partner.""" @abstractmethod - def list(self, list_type: ListType, list_partner: Optional[P], element_result: T) -> T: + def list(self, list_type: ListType, list_partner: P | None, element_result: T) -> T: """Visit a list type with a partner.""" @abstractmethod - def map(self, map_type: MapType, map_partner: Optional[P], key_result: T, value_result: T) -> T: + def map(self, map_type: MapType, map_partner: P | None, key_result: T, value_result: T) -> T: """Visit a map type with a partner.""" @abstractmethod - def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> T: + def primitive(self, primitive: PrimitiveType, primitive_partner: P | None) -> T: """Visit a primitive type with a partner.""" class PrimitiveWithPartnerVisitor(SchemaWithPartnerVisitor[P, T]): - def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> T: + def primitive(self, primitive: PrimitiveType, primitive_partner: P | None) -> T: """Visit a PrimitiveType.""" if isinstance(primitive, BooleanType): return self.visit_boolean(primitive, primitive_partner) @@ -562,99 +560,99 @@ def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> raise ValueError(f"Type not recognized: {primitive}") @abstractmethod - def visit_boolean(self, boolean_type: BooleanType, partner: Optional[P]) -> T: + def visit_boolean(self, boolean_type: BooleanType, partner: P | None) -> T: """Visit a BooleanType.""" @abstractmethod - def visit_integer(self, integer_type: IntegerType, partner: Optional[P]) -> T: + def visit_integer(self, integer_type: IntegerType, partner: P | None) -> T: """Visit a IntegerType.""" @abstractmethod - def visit_long(self, long_type: LongType, partner: Optional[P]) -> T: + def visit_long(self, long_type: LongType, partner: P | None) -> T: """Visit a LongType.""" @abstractmethod - def visit_float(self, float_type: FloatType, partner: Optional[P]) -> T: + def visit_float(self, float_type: FloatType, partner: P | None) -> T: """Visit a FloatType.""" @abstractmethod - def visit_double(self, double_type: DoubleType, partner: Optional[P]) -> T: + def visit_double(self, double_type: DoubleType, partner: P | None) -> T: """Visit a DoubleType.""" @abstractmethod - def visit_decimal(self, decimal_type: DecimalType, partner: Optional[P]) -> T: + def visit_decimal(self, decimal_type: DecimalType, partner: P | None) -> T: """Visit a DecimalType.""" @abstractmethod - def visit_date(self, date_type: DateType, partner: Optional[P]) -> T: + def visit_date(self, date_type: DateType, partner: P | None) -> T: """Visit a DecimalType.""" @abstractmethod - def visit_time(self, time_type: TimeType, partner: Optional[P]) -> T: + def visit_time(self, time_type: TimeType, partner: P | None) -> T: """Visit a DecimalType.""" @abstractmethod - def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[P]) -> T: + def visit_timestamp(self, timestamp_type: TimestampType, partner: P | None) -> T: """Visit a TimestampType.""" @abstractmethod - def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[P]) -> T: + def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: P | None) -> T: """Visit a TimestampNanoType.""" @abstractmethod - def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[P]) -> T: + def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: P | None) -> T: """Visit a TimestamptzType.""" @abstractmethod - def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[P]) -> T: + def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: P | None) -> T: """Visit a TimestamptzNanoType.""" @abstractmethod - def visit_string(self, string_type: StringType, partner: Optional[P]) -> T: + def visit_string(self, string_type: StringType, partner: P | None) -> T: """Visit a StringType.""" @abstractmethod - def visit_uuid(self, uuid_type: UUIDType, partner: Optional[P]) -> T: + def visit_uuid(self, uuid_type: UUIDType, partner: P | None) -> T: """Visit a UUIDType.""" @abstractmethod - def visit_fixed(self, fixed_type: FixedType, partner: Optional[P]) -> T: + def visit_fixed(self, fixed_type: FixedType, partner: P | None) -> T: """Visit a FixedType.""" @abstractmethod - def visit_binary(self, binary_type: BinaryType, partner: Optional[P]) -> T: + def visit_binary(self, binary_type: BinaryType, partner: P | None) -> T: """Visit a BinaryType.""" @abstractmethod - def visit_unknown(self, unknown_type: UnknownType, partner: Optional[P]) -> T: + def visit_unknown(self, unknown_type: UnknownType, partner: P | None) -> T: """Visit a UnknownType.""" class PartnerAccessor(Generic[P], ABC): @abstractmethod - def schema_partner(self, partner: Optional[P]) -> Optional[P]: + def schema_partner(self, partner: P | None) -> P | None: """Return the equivalent of the schema as a struct.""" @abstractmethod - def field_partner(self, partner_struct: Optional[P], field_id: int, field_name: str) -> Optional[P]: + def field_partner(self, partner_struct: P | None, field_id: int, field_name: str) -> P | None: """Return the equivalent struct field by name or id in the partner struct.""" @abstractmethod - def list_element_partner(self, partner_list: Optional[P]) -> Optional[P]: + def list_element_partner(self, partner_list: P | None) -> P | None: """Return the equivalent list element in the partner list.""" @abstractmethod - def map_key_partner(self, partner_map: Optional[P]) -> Optional[P]: + def map_key_partner(self, partner_map: P | None) -> P | None: """Return the equivalent map key in the partner map.""" @abstractmethod - def map_value_partner(self, partner_map: Optional[P]) -> Optional[P]: + def map_value_partner(self, partner_map: P | None) -> P | None: """Return the equivalent map value in the partner map.""" @singledispatch def visit_with_partner( - schema_or_type: Union[Schema, IcebergType], partner: P, visitor: SchemaWithPartnerVisitor[T, P], accessor: PartnerAccessor[P] + schema_or_type: Schema | IcebergType, partner: P, visitor: SchemaWithPartnerVisitor[T, P], accessor: PartnerAccessor[P] ) -> T: raise ValueError(f"Unsupported type: {schema_or_type}") @@ -829,7 +827,7 @@ class Accessor: """An accessor for a specific position in a container that implements the StructProtocol.""" position: int - inner: Optional[Accessor] = None + inner: Accessor | None = None def __str__(self) -> str: """Return the string representation of the Accessor class.""" @@ -859,7 +857,7 @@ def get(self, container: StructProtocol) -> Any: @singledispatch -def visit(obj: Union[Schema, IcebergType], visitor: SchemaVisitor[T]) -> T: +def visit(obj: Schema | IcebergType, visitor: SchemaVisitor[T]) -> T: """Apply a schema visitor to any point within a schema. The function traverses the schema in post-order fashion. @@ -925,7 +923,7 @@ def _(obj: PrimitiveType, visitor: SchemaVisitor[T]) -> T: @singledispatch -def pre_order_visit(obj: Union[Schema, IcebergType], visitor: PreOrderSchemaVisitor[T]) -> T: +def pre_order_visit(obj: Schema | IcebergType, visitor: PreOrderSchemaVisitor[T]) -> T: """Apply a schema visitor to any point within a schema. The function traverses the schema in pre-order fashion. This is a slimmed down version @@ -1015,7 +1013,7 @@ def primitive(self, primitive: PrimitiveType) -> Dict[int, NestedField]: return self._index -def index_by_id(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, NestedField]: +def index_by_id(schema_or_type: Schema | IcebergType) -> Dict[int, NestedField]: """Generate an index of field IDs to NestedField instances. Args: @@ -1066,7 +1064,7 @@ def primitive(self, primitive: PrimitiveType) -> Dict[int, int]: return self.id_to_parent -def _index_parents(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, int]: +def _index_parents(schema_or_type: Schema | IcebergType) -> Dict[int, int]: """Generate an index of field IDs to their parent field IDs. Args: @@ -1182,7 +1180,7 @@ def by_id(self) -> Dict[int, str]: return id_to_full_name -def index_by_name(schema_or_type: Union[Schema, IcebergType]) -> Dict[str, int]: +def index_by_name(schema_or_type: Schema | IcebergType) -> Dict[str, int]: """Generate an index of field names to field IDs. Args: @@ -1199,7 +1197,7 @@ def index_by_name(schema_or_type: Union[Schema, IcebergType]) -> Dict[str, int]: return EMPTY_DICT -def index_name_by_id(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, str]: +def index_name_by_id(schema_or_type: Schema | IcebergType) -> Dict[int, str]: """Generate an index of field IDs full field names. Args: @@ -1278,7 +1276,7 @@ def primitive(self, primitive: PrimitiveType) -> Dict[Position, Accessor]: return {} -def build_position_accessors(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, Accessor]: +def build_position_accessors(schema_or_type: Schema | IcebergType) -> Dict[int, Accessor]: """Generate an index of field IDs to schema position accessors. Args: @@ -1290,7 +1288,7 @@ def build_position_accessors(schema_or_type: Union[Schema, IcebergType]) -> Dict return visit(schema_or_type, _BuildPositionAccessors()) -def assign_fresh_schema_ids(schema_or_type: Union[Schema, IcebergType], next_id: Optional[Callable[[], int]] = None) -> Schema: +def assign_fresh_schema_ids(schema_or_type: Schema | IcebergType, next_id: Callable[[], int] | None = None) -> Schema: """Traverses the schema, and sets new IDs.""" return pre_order_visit(schema_or_type, _SetFreshIDs(next_id_func=next_id)) @@ -1300,7 +1298,7 @@ class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]): old_id_to_new_id: Dict[int, int] - def __init__(self, next_id_func: Optional[Callable[[], int]] = None) -> None: + def __init__(self, next_id_func: Callable[[], int] | None = None) -> None: self.old_id_to_new_id = {} counter = itertools.count(1) self.next_id_func = next_id_func if next_id_func is not None else lambda: next(counter) @@ -1434,11 +1432,11 @@ def sanitize_column_names(schema: Schema) -> Schema: ) -class _SanitizeColumnsVisitor(SchemaVisitor[Optional[IcebergType]]): - def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]: +class _SanitizeColumnsVisitor(SchemaVisitor[IcebergType | None]): + def schema(self, schema: Schema, struct_result: IcebergType | None) -> IcebergType | None: return struct_result - def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]: + def field(self, field: NestedField, field_result: IcebergType | None) -> IcebergType | None: return NestedField( field_id=field.field_id, name=make_compatible_name(field.name), @@ -1447,15 +1445,13 @@ def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Opti required=field.required, ) - def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]: + def struct(self, struct: StructType, field_results: List[IcebergType | None]) -> IcebergType | None: return StructType(*[field for field in field_results if field is not None]) - def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: + def list(self, list_type: ListType, element_result: IcebergType | None) -> IcebergType | None: return ListType(element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required) - def map( - self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType] - ) -> Optional[IcebergType]: + def map(self, map_type: MapType, key_result: IcebergType | None, value_result: IcebergType | None) -> IcebergType | None: return MapType( key_id=map_type.key_id, value_id=map_type.value_id, @@ -1464,7 +1460,7 @@ def map( value_required=map_type.value_required, ) - def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: + def primitive(self, primitive: PrimitiveType) -> IcebergType | None: return primitive @@ -1487,7 +1483,7 @@ def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = ) -class _PruneColumnsVisitor(SchemaVisitor[Optional[IcebergType]]): +class _PruneColumnsVisitor(SchemaVisitor[IcebergType | None]): selected: Set[int] select_full_types: bool @@ -1495,10 +1491,10 @@ def __init__(self, selected: Set[int], select_full_types: bool): self.selected = selected self.select_full_types = select_full_types - def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]: + def schema(self, schema: Schema, struct_result: IcebergType | None) -> IcebergType | None: return struct_result - def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]: + def struct(self, struct: StructType, field_results: List[IcebergType | None]) -> IcebergType | None: fields = struct.fields selected_fields = [] same_type = True @@ -1528,7 +1524,7 @@ def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) return StructType(*selected_fields) return None - def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]: + def field(self, field: NestedField, field_result: IcebergType | None) -> IcebergType | None: if field.field_id in self.selected: if self.select_full_types: return field.field_type @@ -1547,7 +1543,7 @@ def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Opti else: return None - def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: + def list(self, list_type: ListType, element_result: IcebergType | None) -> IcebergType | None: if list_type.element_id in self.selected: if self.select_full_types: return list_type @@ -1565,9 +1561,7 @@ def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Op else: return None - def map( - self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType] - ) -> Optional[IcebergType]: + def map(self, map_type: MapType, key_result: IcebergType | None, value_result: IcebergType | None) -> IcebergType | None: if map_type.value_id in self.selected: if self.select_full_types: return map_type @@ -1585,11 +1579,11 @@ def map( return map_type return None - def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: + def primitive(self, primitive: PrimitiveType) -> IcebergType | None: return None @staticmethod - def _project_selected_struct(projected_field: Optional[IcebergType]) -> StructType: + def _project_selected_struct(projected_field: IcebergType | None) -> StructType: if projected_field and not isinstance(projected_field, StructType): raise ValueError("Expected a struct") diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8b7f4d165a..5fd83a81c3 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -33,12 +33,10 @@ Iterable, Iterator, List, - Optional, Set, Tuple, Type, TypeVar, - Union, ) from pydantic import Field @@ -279,9 +277,7 @@ def __enter__(self) -> Transaction: """Start a transaction to update the table.""" return self - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ) -> None: + def __exit__(self, exctype: Type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None) -> None: """Close and commit the transaction if no exceptions have been raised.""" if exctype is None and excinst is None and exctb is None: self.commit_transaction() @@ -305,7 +301,7 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ return self - def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, case_sensitive: bool = True) -> DataScan: + def _scan(self, row_filter: str | BooleanExpression = ALWAYS_TRUE, case_sensitive: bool = True) -> DataScan: """Minimal data scan of the table with the current state of the transaction.""" return DataScan( table_metadata=self.table_metadata, io=self._table.io, row_filter=row_filter, case_sensitive=case_sensitive @@ -353,9 +349,9 @@ def _set_ref_snapshot( snapshot_id: int, ref_name: str, type: str, - max_ref_age_ms: Optional[int] = None, - max_snapshot_age_ms: Optional[int] = None, - min_snapshots_to_keep: Optional[int] = None, + max_ref_age_ms: int | None = None, + max_snapshot_age_ms: int | None = None, + min_snapshots_to_keep: int | None = None, ) -> UpdatesAndRequirements: """Update a ref to a snapshot. @@ -408,7 +404,7 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE return expr def _append_snapshot_producer( - self, snapshot_properties: Dict[str, str], branch: Optional[str] = MAIN_BRANCH + self, snapshot_properties: Dict[str, str], branch: str | None = MAIN_BRANCH ) -> _FastAppendFiles: """Determine the append type based on table properties. @@ -457,7 +453,7 @@ def update_sort_order(self, case_sensitive: bool = True) -> UpdateSortOrder: ) def update_snapshot( - self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH + self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH ) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. @@ -475,7 +471,7 @@ def update_statistics(self) -> UpdateStatistics: """ return UpdateStatistics(transaction=self) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to a table transaction. @@ -514,7 +510,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, append_files.append_data_file(data_file) def dynamic_partition_overwrite( - self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH ) -> None: """ Shorthand for overwriting existing partitions with a PyArrow table. @@ -578,10 +574,10 @@ def dynamic_partition_overwrite( def overwrite( self, df: pa.Table, - overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, + overwrite_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> None: """ Shorthand for adding a table overwrite with a PyArrow table to the transaction. @@ -638,10 +634,10 @@ def overwrite( def delete( self, - delete_filter: Union[str, BooleanExpression], + delete_filter: str | BooleanExpression, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> None: """ Shorthand for deleting record from a table. @@ -740,11 +736,11 @@ def delete( def upsert( self, df: pa.Table, - join_cols: Optional[List[str]] = None, + join_cols: List[str] | None = None, when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -886,7 +882,7 @@ def add_files( file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> None: """ Shorthand API for adding files as data files to the table transaction. @@ -994,7 +990,7 @@ def _initial_changes(self, table_metadata: TableMetadata) -> None: self._updates += (AddPartitionSpecUpdate(spec=spec),) self._updates += (SetDefaultSpecUpdate(spec_id=-1),) - sort_order: Optional[SortOrder] = table_metadata.sort_order_by_id(table_metadata.default_sort_order_id) + sort_order: SortOrder | None = table_metadata.sort_order_by_id(table_metadata.default_sort_order_id) if sort_order is None or sort_order.is_unsorted: self._updates += (AddSortOrderUpdate(sort_order=UNSORTED_SORT_ORDER),) else: @@ -1134,12 +1130,12 @@ def name(self) -> Identifier: def scan( self, - row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + row_filter: str | BooleanExpression = ALWAYS_TRUE, selected_fields: Tuple[str, ...] = ("*",), case_sensitive: bool = True, - snapshot_id: Optional[int] = None, + snapshot_id: int | None = None, options: Properties = EMPTY_DICT, - limit: Optional[int] = None, + limit: int | None = None, ) -> DataScan: """Fetch a DataScan based on the table's current metadata. @@ -1234,7 +1230,7 @@ def location_provider(self) -> LocationProvider: def last_sequence_number(self) -> int: return self.metadata.last_sequence_number - def current_snapshot(self) -> Optional[Snapshot]: + def current_snapshot(self) -> Snapshot | None: """Get the current snapshot for this table, or None if there is no current snapshot.""" if self.metadata.current_snapshot_id is not None: return self.snapshot_by_id(self.metadata.current_snapshot_id) @@ -1243,17 +1239,17 @@ def current_snapshot(self) -> Optional[Snapshot]: def snapshots(self) -> List[Snapshot]: return self.metadata.snapshots - def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: + def snapshot_by_id(self, snapshot_id: int) -> Snapshot | None: """Get the snapshot of this table with the given id, or None if there is no matching snapshot.""" return self.metadata.snapshot_by_id(snapshot_id) - def snapshot_by_name(self, name: str) -> Optional[Snapshot]: + def snapshot_by_name(self, name: str) -> Snapshot | None: """Return the snapshot referenced by the given name or null if no such reference exists.""" if ref := self.metadata.refs.get(name): return self.snapshot_by_id(ref.snapshot_id) return None - def snapshot_as_of_timestamp(self, timestamp_ms: int, inclusive: bool = True) -> Optional[Snapshot]: + def snapshot_as_of_timestamp(self, timestamp_ms: int, inclusive: bool = True) -> Snapshot | None: """Get the snapshot that was current as of or right before the given timestamp, or None if there is no matching snapshot. Args: @@ -1326,18 +1322,18 @@ def update_sort_order(self, case_sensitive: bool = True) -> UpdateSortOrder: """ return UpdateSortOrder(transaction=Transaction(self, autocommit=True), case_sensitive=case_sensitive) - def name_mapping(self) -> Optional[NameMapping]: + def name_mapping(self) -> NameMapping | None: """Return the table's field-id NameMapping.""" return self.metadata.name_mapping() def upsert( self, df: pa.Table, - join_cols: Optional[List[str]] = None, + join_cols: List[str] | None = None, when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -1384,7 +1380,7 @@ def upsert( branch=branch, ) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to the table. @@ -1397,7 +1393,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch) def dynamic_partition_overwrite( - self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH ) -> None: """Shorthand for dynamic overwriting the table with a PyArrow table. @@ -1413,10 +1409,10 @@ def dynamic_partition_overwrite( def overwrite( self, df: pa.Table, - overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, + overwrite_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> None: """ Shorthand for overwriting the table with a PyArrow table. @@ -1446,10 +1442,10 @@ def overwrite( def delete( self, - delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, + delete_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> None: """ Shorthand for deleting rows from the table. @@ -1470,7 +1466,7 @@ def add_files( file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> None: """ Shorthand API for adding files as data files to the table. @@ -1657,12 +1653,12 @@ def refresh(self) -> Table: def scan( self, - row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + row_filter: str | BooleanExpression = ALWAYS_TRUE, selected_fields: Tuple[str, ...] = ("*",), case_sensitive: bool = True, - snapshot_id: Optional[int] = None, + snapshot_id: int | None = None, options: Properties = EMPTY_DICT, - limit: Optional[int] = None, + limit: int | None = None, ) -> DataScan: raise ValueError("Cannot scan a staged table") @@ -1670,7 +1666,7 @@ def to_daft(self) -> daft.DataFrame: raise ValueError("Cannot convert a staged table to a Daft DataFrame") -def _parse_row_filter(expr: Union[str, BooleanExpression]) -> BooleanExpression: +def _parse_row_filter(expr: str | BooleanExpression) -> BooleanExpression: """Accept an expression in the form of a BooleanExpression or a string. In the case of a string, it will be converted into a unbound BooleanExpression. @@ -1692,20 +1688,20 @@ class TableScan(ABC): row_filter: BooleanExpression selected_fields: Tuple[str, ...] case_sensitive: bool - snapshot_id: Optional[int] + snapshot_id: int | None options: Properties - limit: Optional[int] + limit: int | None def __init__( self, table_metadata: TableMetadata, io: FileIO, - row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + row_filter: str | BooleanExpression = ALWAYS_TRUE, selected_fields: Tuple[str, ...] = ("*",), case_sensitive: bool = True, - snapshot_id: Optional[int] = None, + snapshot_id: int | None = None, options: Properties = EMPTY_DICT, - limit: Optional[int] = None, + limit: int | None = None, ): self.table_metadata = table_metadata self.io = io @@ -1716,7 +1712,7 @@ def __init__( self.options = options self.limit = limit - def snapshot(self) -> Optional[Snapshot]: + def snapshot(self) -> Snapshot | None: if self.snapshot_id: return self.table_metadata.snapshot_by_id(self.snapshot_id) return self.table_metadata.current_snapshot() @@ -1777,7 +1773,7 @@ def select(self: S, *field_names: str) -> S: return self.update(selected_fields=field_names) return self.update(selected_fields=tuple(set(self.selected_fields).intersection(set(field_names)))) - def filter(self: S, expr: Union[str, BooleanExpression]) -> S: + def filter(self: S, expr: str | BooleanExpression) -> S: return self.update(row_filter=And(self.row_filter, _parse_row_filter(expr))) def with_case_sensitive(self: S, case_sensitive: bool = True) -> S: @@ -1804,9 +1800,9 @@ class FileScanTask(ScanTask): def __init__( self, data_file: DataFile, - delete_files: Optional[Set[DataFile]] = None, - start: Optional[int] = None, - length: Optional[int] = None, + delete_files: Set[DataFile] | None = None, + start: int | None = None, + length: int | None = None, residual: BooleanExpression = ALWAYS_TRUE, ) -> None: self.file = data_file @@ -2070,7 +2066,7 @@ def to_pandas(self, **kwargs: Any) -> pd.DataFrame: """ return self.to_arrow().to_pandas(**kwargs) - def to_duckdb(self, table_name: str, connection: Optional[DuckDBPyConnection] = None) -> DuckDBPyConnection: + def to_duckdb(self, table_name: str, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: """Shorthand for loading the Iceberg Table in DuckDB. Returns: @@ -2143,8 +2139,8 @@ class WriteTask: task_id: int schema: Schema record_batches: List[pa.RecordBatch] - sort_order_id: Optional[int] = None - partition_key: Optional[PartitionKey] = None + sort_order_id: int | None = None + partition_key: PartitionKey | None = None def generate_data_file_filename(self, extension: str) -> str: # Mimics the behavior in the Java API: diff --git a/pyiceberg/table/inspect.py b/pyiceberg/table/inspect.py index a8791bf1e0..c4591a40e9 100644 --- a/pyiceberg/table/inspect.py +++ b/pyiceberg/table/inspect.py @@ -18,7 +18,7 @@ import itertools from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Set, Tuple from pyiceberg.conversions import from_bytes from pyiceberg.expressions import AlwaysTrue, BooleanExpression @@ -48,7 +48,7 @@ def __init__(self, tbl: Table) -> None: except ModuleNotFoundError as e: raise ModuleNotFoundError("For metadata operations PyArrow needs to be installed") from e - def _get_snapshot(self, snapshot_id: Optional[int] = None) -> Snapshot: + def _get_snapshot(self, snapshot_id: int | None = None) -> Snapshot: if snapshot_id is not None: if snapshot := self.tbl.metadata.snapshot_by_id(snapshot_id): return snapshot @@ -98,7 +98,7 @@ def snapshots(self) -> "pa.Table": schema=snapshots_schema, ) - def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table": + def entries(self, snapshot_id: int | None = None) -> "pa.Table": import pyarrow as pa from pyiceberg.io.pyarrow import schema_to_pyarrow @@ -261,8 +261,8 @@ def refs(self) -> "pa.Table": def partitions( self, - snapshot_id: Optional[int] = None, - row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + snapshot_id: int | None = None, + row_filter: str | BooleanExpression = ALWAYS_TRUE, case_sensitive: bool = True, ) -> "pa.Table": import pyarrow as pa @@ -330,7 +330,7 @@ def _update_partitions_map_from_manifest_entry( partitions_map: Dict[Tuple[str, Any], Any], file: DataFile, partition_record_dict: Dict[str, Any], - snapshot: Optional[Snapshot], + snapshot: Snapshot | None, ) -> None: partition_record_key = _convert_to_hashable_type(partition_record_dict) if partition_record_key not in partitions_map: @@ -405,7 +405,7 @@ def _get_all_manifests_schema(self) -> "pa.Schema": all_manifests_schema = all_manifests_schema.append(pa.field("reference_snapshot_id", pa.int64(), nullable=False)) return all_manifests_schema - def _generate_manifests_table(self, snapshot: Optional[Snapshot], is_all_manifests_table: bool = False) -> "pa.Table": + def _generate_manifests_table(self, snapshot: Snapshot | None, is_all_manifests_table: bool = False) -> "pa.Table": import pyarrow as pa def _partition_summaries_to_rows( @@ -545,7 +545,7 @@ def history(self) -> "pa.Table": return pa.Table.from_pylist(history, schema=history_schema) def _get_files_from_manifest( - self, manifest_list: ManifestFile, data_file_filter: Optional[Set[DataFileContent]] = None + self, manifest_list: ManifestFile, data_file_filter: Set[DataFileContent] | None = None ) -> "pa.Table": import pyarrow as pa @@ -663,7 +663,7 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: ) return files_schema - def _files(self, snapshot_id: Optional[int] = None, data_file_filter: Optional[Set[DataFileContent]] = None) -> "pa.Table": + def _files(self, snapshot_id: int | None = None, data_file_filter: Set[DataFileContent] | None = None) -> "pa.Table": import pyarrow as pa if not snapshot_id and not self.tbl.metadata.current_snapshot(): @@ -680,13 +680,13 @@ def _files(self, snapshot_id: Optional[int] = None, data_file_filter: Optional[S ) return pa.concat_tables(results) - def files(self, snapshot_id: Optional[int] = None) -> "pa.Table": + def files(self, snapshot_id: int | None = None) -> "pa.Table": return self._files(snapshot_id) - def data_files(self, snapshot_id: Optional[int] = None) -> "pa.Table": + def data_files(self, snapshot_id: int | None = None) -> "pa.Table": return self._files(snapshot_id, {DataFileContent.DATA}) - def delete_files(self, snapshot_id: Optional[int] = None) -> "pa.Table": + def delete_files(self, snapshot_id: int | None = None) -> "pa.Table": return self._files(snapshot_id, {DataFileContent.POSITION_DELETES, DataFileContent.EQUALITY_DELETES}) def all_manifests(self) -> "pa.Table": @@ -702,7 +702,7 @@ def all_manifests(self) -> "pa.Table": ) return pa.concat_tables(manifests_by_snapshots) - def _all_files(self, data_file_filter: Optional[Set[DataFileContent]] = None) -> "pa.Table": + def _all_files(self, data_file_filter: Set[DataFileContent] | None = None) -> "pa.Table": import pyarrow as pa snapshots = self.tbl.snapshots() diff --git a/pyiceberg/table/locations.py b/pyiceberg/table/locations.py index 2d604abb6c..25da7e7f6c 100644 --- a/pyiceberg/table/locations.py +++ b/pyiceberg/table/locations.py @@ -18,7 +18,6 @@ import logging import uuid from abc import ABC, abstractmethod -from typing import Optional import mmh3 @@ -60,7 +59,7 @@ def __init__(self, table_location: str, table_properties: Properties): self.metadata_path = f"{self.table_location.rstrip('/')}/metadata" @abstractmethod - def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str: + def new_data_location(self, data_file_name: str, partition_key: PartitionKey | None = None) -> str: """Return a fully-qualified data file location for the given filename. Args: @@ -105,7 +104,7 @@ class SimpleLocationProvider(LocationProvider): def __init__(self, table_location: str, table_properties: Properties): super().__init__(table_location, table_properties) - def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str: + def new_data_location(self, data_file_name: str, partition_key: PartitionKey | None = None) -> str: return ( f"{self.data_path}/{partition_key.to_path()}/{data_file_name}" if partition_key @@ -130,7 +129,7 @@ def __init__(self, table_location: str, table_properties: Properties): TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT, ) - def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str: + def new_data_location(self, data_file_name: str, partition_key: PartitionKey | None = None) -> str: if self._include_partition_paths and partition_key: return self.new_data_location(f"{partition_key.to_path()}/{data_file_name}") @@ -166,7 +165,7 @@ def _dirs_from_hash(file_hash: str) -> str: def _import_location_provider( location_provider_impl: str, table_location: str, table_properties: Properties -) -> Optional[LocationProvider]: +) -> LocationProvider | None: try: path_parts = location_provider_impl.split(".") if len(path_parts) < 2: diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 4fa2235fb9..3582a9be8c 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -19,7 +19,7 @@ import datetime import uuid from copy import copy -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Annotated, Any, Dict, List, Literal from pydantic import Field, field_serializer, field_validator, model_validator from pydantic import ValidationError as PydanticValidationError @@ -163,7 +163,7 @@ class TableMetadataCommonFields(IcebergBaseModel): default_spec_id: int = Field(alias="default-spec-id", default=INITIAL_SPEC_ID) """ID of the “current” spec that writers should use by default.""" - last_partition_id: Optional[int] = Field(alias="last-partition-id", default=None) + last_partition_id: int | None = Field(alias="last-partition-id", default=None) """An integer; the highest assigned partition field ID across all partition specs for the table. This is used to ensure partition fields are always assigned an unused ID when evolving specs.""" @@ -174,7 +174,7 @@ class TableMetadataCommonFields(IcebergBaseModel): to be used for arbitrary metadata. For example, commit.retry.num-retries is used to control the number of commit retries.""" - current_snapshot_id: Optional[int] = Field(alias="current-snapshot-id", default=None) + current_snapshot_id: int | None = Field(alias="current-snapshot-id", default=None) """ID of the current table snapshot.""" snapshots: List[Snapshot] = Field(default_factory=list) @@ -235,11 +235,11 @@ class TableMetadataCommonFields(IcebergBaseModel): def transform_properties_dict_value_to_str(cls, properties: Properties) -> Dict[str, str]: return transform_dict_value_to_str(properties) - def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: + def snapshot_by_id(self, snapshot_id: int) -> Snapshot | None: """Get the snapshot by snapshot_id.""" return next((snapshot for snapshot in self.snapshots if snapshot.snapshot_id == snapshot_id), None) - def schema_by_id(self, schema_id: int) -> Optional[Schema]: + def schema_by_id(self, schema_id: int) -> Schema | None: """Get the schema by schema_id.""" return next((schema for schema in self.schemas if schema.schema_id == schema_id), None) @@ -247,7 +247,7 @@ def schema(self) -> Schema: """Return the schema for this table.""" return next(schema for schema in self.schemas if schema.schema_id == self.current_schema_id) - def name_mapping(self) -> Optional[NameMapping]: + def name_mapping(self) -> NameMapping | None: """Return the table's field-id NameMapping.""" if name_mapping_json := self.properties.get("schema.name-mapping.default"): return parse_mapping_from_json(name_mapping_json) @@ -295,7 +295,7 @@ def new_snapshot_id(self) -> int: return snapshot_id - def snapshot_by_name(self, name: Optional[str]) -> Optional[Snapshot]: + def snapshot_by_name(self, name: str | None) -> Snapshot | None: """Return the snapshot referenced by the given name or null if no such reference exists.""" if name is None: name = MAIN_BRANCH @@ -303,7 +303,7 @@ def snapshot_by_name(self, name: Optional[str]) -> Optional[Snapshot]: return self.snapshot_by_id(ref.snapshot_id) return None - def current_snapshot(self) -> Optional[Snapshot]: + def current_snapshot(self) -> Snapshot | None: """Get the current snapshot for this table, or None if there is no current snapshot.""" if self.current_snapshot_id is not None: return self.snapshot_by_id(self.current_snapshot_id) @@ -312,12 +312,12 @@ def current_snapshot(self) -> Optional[Snapshot]: def next_sequence_number(self) -> int: return self.last_sequence_number + 1 if self.format_version > 1 else INITIAL_SEQUENCE_NUMBER - def sort_order_by_id(self, sort_order_id: int) -> Optional[SortOrder]: + def sort_order_by_id(self, sort_order_id: int) -> SortOrder | None: """Get the sort order by sort_order_id.""" return next((sort_order for sort_order in self.sort_orders if sort_order.order_id == sort_order_id), None) @field_serializer("current_snapshot_id") - def serialize_current_snapshot_id(self, current_snapshot_id: Optional[int]) -> Optional[int]: + def serialize_current_snapshot_id(self, current_snapshot_id: int | None) -> int | None: if current_snapshot_id is None and Config().get_bool("legacy-current-snapshot-id"): return -1 return current_snapshot_id @@ -559,16 +559,14 @@ def construct_refs(self) -> TableMetadata: """The table’s highest assigned sequence number, a monotonically increasing long that tracks the order of snapshots in a table.""" - next_row_id: Optional[int] = Field(alias="next-row-id", default=None) + next_row_id: int | None = Field(alias="next-row-id", default=None) """A long higher than all assigned row IDs; the next snapshot's `first-row-id`.""" - def model_dump_json( - self, exclude_none: bool = True, exclude: Optional[Any] = None, by_alias: bool = True, **kwargs: Any - ) -> str: + def model_dump_json(self, exclude_none: bool = True, exclude: Any | None = None, by_alias: bool = True, **kwargs: Any) -> str: raise NotImplementedError("Writing V3 is not yet supported, see: https://github.com/apache/iceberg-python/issues/1551") -TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2, TableMetadataV3], Field(discriminator="format_version")] +TableMetadata = Annotated[TableMetadataV1 | TableMetadataV2 | TableMetadataV3, Field(discriminator="format_version")] def new_table_metadata( @@ -577,7 +575,7 @@ def new_table_metadata( sort_order: SortOrder, location: str, properties: Properties = EMPTY_DICT, - table_uuid: Optional[uuid.UUID] = None, + table_uuid: uuid.UUID | None = None, ) -> TableMetadata: from pyiceberg.table import TableProperties diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py index e27763fc6a..8ba9fd2554 100644 --- a/pyiceberg/table/name_mapping.py +++ b/pyiceberg/table/name_mapping.py @@ -26,7 +26,7 @@ from abc import ABC, abstractmethod from collections import ChainMap from functools import cached_property, singledispatch -from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, TypeVar from pydantic import Field, conlist, field_validator, model_serializer @@ -36,7 +36,7 @@ class MappedField(IcebergBaseModel): - field_id: Optional[int] = Field(alias="field-id", default=None) + field_id: int | None = Field(alias="field-id", default=None) names: List[str] = conlist(str) fields: List[MappedField] = Field(default_factory=list) @@ -128,7 +128,7 @@ def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dic @singledispatch -def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[S, T]) -> S: +def visit_name_mapping(obj: NameMapping | List[MappedField] | MappedField, visitor: NameMappingVisitor[S, T]) -> S: """Traverse the name mapping in post-order traversal.""" raise NotImplementedError(f"Cannot visit non-type: {obj}") @@ -183,7 +183,7 @@ def __init__(self, updates: Dict[int, NestedField], adds: Dict[int, List[NestedF self._adds = adds @staticmethod - def _remove_reassigned_names(field: MappedField, assignments: Dict[str, int]) -> Optional[MappedField]: + def _remove_reassigned_names(field: MappedField, assignments: Dict[str, int]) -> MappedField | None: removed_names = set() for name in field.names: if (assigned_id := assignments.get(name)) and assigned_id != field.field_id: @@ -249,12 +249,12 @@ def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: class NameMappingAccessor(PartnerAccessor[MappedField]): - def schema_partner(self, partner: Optional[MappedField]) -> Optional[MappedField]: + def schema_partner(self, partner: MappedField | None) -> MappedField | None: return partner def field_partner( - self, partner_struct: Optional[Union[List[MappedField], MappedField]], _: int, field_name: str - ) -> Optional[MappedField]: + self, partner_struct: List[MappedField] | MappedField | None, _: int, field_name: str + ) -> MappedField | None: if partner_struct is not None: if isinstance(partner_struct, MappedField): partner_struct = partner_struct.fields @@ -265,21 +265,21 @@ def field_partner( return None - def list_element_partner(self, partner_list: Optional[MappedField]) -> Optional[MappedField]: + def list_element_partner(self, partner_list: MappedField | None) -> MappedField | None: if partner_list is not None: for field in partner_list.fields: if "element" in field.names: return field return None - def map_key_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]: + def map_key_partner(self, partner_map: MappedField | None) -> MappedField | None: if partner_map is not None: for field in partner_map.fields: if "key" in field.names: return field return None - def map_value_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]: + def map_value_partner(self, partner_map: MappedField | None) -> MappedField | None: if partner_map is not None: for field in partner_map.fields: if "value" in field.names: @@ -294,37 +294,37 @@ def __init__(self) -> None: # For keeping track where we are in case when a field cannot be found self.current_path = [] - def before_field(self, field: NestedField, field_partner: Optional[P]) -> None: + def before_field(self, field: NestedField, field_partner: P | None) -> None: self.current_path.append(field.name) - def after_field(self, field: NestedField, field_partner: Optional[P]) -> None: + def after_field(self, field: NestedField, field_partner: P | None) -> None: self.current_path.pop() - def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: + def before_list_element(self, element: NestedField, element_partner: P | None) -> None: self.current_path.append("element") - def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: + def after_list_element(self, element: NestedField, element_partner: P | None) -> None: self.current_path.pop() - def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: + def before_map_key(self, key: NestedField, key_partner: P | None) -> None: self.current_path.append("key") - def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: + def after_map_key(self, key: NestedField, key_partner: P | None) -> None: self.current_path.pop() - def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: + def before_map_value(self, value: NestedField, value_partner: P | None) -> None: self.current_path.append("value") - def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: + def after_map_value(self, value: NestedField, value_partner: P | None) -> None: self.current_path.pop() - def schema(self, schema: Schema, schema_partner: Optional[MappedField], struct_result: StructType) -> IcebergType: + def schema(self, schema: Schema, schema_partner: MappedField | None, struct_result: StructType) -> IcebergType: return Schema(*struct_result.fields, schema_id=schema.schema_id) - def struct(self, struct: StructType, struct_partner: Optional[MappedField], field_results: List[NestedField]) -> IcebergType: + def struct(self, struct: StructType, struct_partner: MappedField | None, field_results: List[NestedField]) -> IcebergType: return StructType(*field_results) - def field(self, field: NestedField, field_partner: Optional[MappedField], field_result: IcebergType) -> IcebergType: + def field(self, field: NestedField, field_partner: MappedField | None, field_result: IcebergType) -> IcebergType: if field_partner is None or field_partner.field_id is None: raise ValueError(f"Field or field ID missing from NameMapping: {'.'.join(self.current_path)}") @@ -338,7 +338,7 @@ def field(self, field: NestedField, field_partner: Optional[MappedField], field_ initial_write=field.write_default, ) - def list(self, list_type: ListType, list_partner: Optional[MappedField], element_result: IcebergType) -> IcebergType: + def list(self, list_type: ListType, list_partner: MappedField | None, element_result: IcebergType) -> IcebergType: if list_partner is None: raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") @@ -346,7 +346,7 @@ def list(self, list_type: ListType, list_partner: Optional[MappedField], element return ListType(element_id=element_id, element=element_result, element_required=list_type.element_required) def map( - self, map_type: MapType, map_partner: Optional[MappedField], key_result: IcebergType, value_result: IcebergType + self, map_type: MapType, map_partner: MappedField | None, key_result: IcebergType, value_result: IcebergType ) -> IcebergType: if map_partner is None: raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") @@ -361,7 +361,7 @@ def map( value_required=map_type.value_required, ) - def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[MappedField]) -> PrimitiveType: + def primitive(self, primitive: PrimitiveType, primitive_partner: MappedField | None) -> PrimitiveType: if primitive_partner is None: raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") diff --git a/pyiceberg/table/puffin.py b/pyiceberg/table/puffin.py index a90ef7ee0d..326fe3e37a 100644 --- a/pyiceberg/table/puffin.py +++ b/pyiceberg/table/puffin.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import math -from typing import TYPE_CHECKING, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Dict, List, Literal from pydantic import Field from pyroaring import BitMap, FrozenBitMap @@ -69,7 +69,7 @@ class PuffinBlobMetadata(IcebergBaseModel): sequence_number: int = Field(alias="sequence-number") offset: int = Field() length: int = Field() - compression_codec: Optional[str] = Field(alias="compression-codec", default=None) + compression_codec: str | None = Field(alias="compression-codec", default=None) properties: Dict[str, str] = Field(default_factory=dict) diff --git a/pyiceberg/table/refs.py b/pyiceberg/table/refs.py index 2c9f7ae39e..0cbc11689a 100644 --- a/pyiceberg/table/refs.py +++ b/pyiceberg/table/refs.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from enum import Enum -from typing import Annotated, Optional +from typing import Annotated from pydantic import Field, model_validator @@ -41,9 +41,9 @@ def __str__(self) -> str: class SnapshotRef(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") snapshot_ref_type: SnapshotRefType = Field(alias="type") - min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None, gt=0)] - max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None, gt=0)] - max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None, gt=0)] + min_snapshots_to_keep: Annotated[int | None, Field(alias="min-snapshots-to-keep", default=None, gt=0)] + max_snapshot_age_ms: Annotated[int | None, Field(alias="max-snapshot-age-ms", default=None, gt=0)] + max_ref_age_ms: Annotated[int | None, Field(alias="max-ref-age-ms", default=None, gt=0)] @model_validator(mode="after") def check_min_snapshots_to_keep(self) -> "SnapshotRef": diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 13ce52b7eb..14b5fa833c 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -20,7 +20,7 @@ import warnings from collections import defaultdict from enum import Enum -from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Iterable, List, Mapping, Optional +from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Iterable, List, Mapping from pydantic import Field, PrivateAttr, model_serializer @@ -185,14 +185,14 @@ class Summary(IcebergBaseModel, Mapping[str, str]): operation: Operation = Field() _additional_properties: Dict[str, str] = PrivateAttr() - def __init__(self, operation: Optional[Operation] = None, **data: Any) -> None: + def __init__(self, operation: Operation | None = None, **data: Any) -> None: if operation is None: warnings.warn("Encountered invalid snapshot summary: operation is missing, defaulting to overwrite") operation = Operation.OVERWRITE super().__init__(operation=operation, **data) self._additional_properties = data - def __getitem__(self, __key: str) -> Optional[Any]: # type: ignore + def __getitem__(self, __key: str) -> Any | None: # type: ignore """Return a key as it is a map.""" if __key.lower() == "operation": return self.operation @@ -238,16 +238,16 @@ def __eq__(self, other: Any) -> bool: class Snapshot(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") - parent_snapshot_id: Optional[int] = Field(alias="parent-snapshot-id", default=None) - sequence_number: Optional[int] = Field(alias="sequence-number", default=INITIAL_SEQUENCE_NUMBER) + parent_snapshot_id: int | None = Field(alias="parent-snapshot-id", default=None) + sequence_number: int | None = Field(alias="sequence-number", default=INITIAL_SEQUENCE_NUMBER) timestamp_ms: int = Field(alias="timestamp-ms", default_factory=lambda: int(time.time() * 1000)) manifest_list: str = Field(alias="manifest-list", description="Location of the snapshot's manifest list file") - summary: Optional[Summary] = Field(default=None) - schema_id: Optional[int] = Field(alias="schema-id", default=None) - first_row_id: Optional[int] = Field( + summary: Summary | None = Field(default=None) + schema_id: int | None = Field(alias="schema-id", default=None) + first_row_id: int | None = Field( alias="first-row-id", default=None, description="assigned to the first row in the first data file in the first manifest" ) - added_rows: Optional[int] = Field( + added_rows: int | None = Field( alias="added-rows", default=None, description="The upper bound of the number of rows with assigned row IDs" ) @@ -376,7 +376,7 @@ def get_prop(prop: str) -> int: def update_snapshot_summaries( - summary: Summary, previous_summary: Optional[Mapping[str, str]] = None, truncate_full_table: bool = False + summary: Summary, previous_summary: Mapping[str, str] | None = None, truncate_full_table: bool = False ) -> Summary: if summary.operation not in {Operation.APPEND, Operation.OVERWRITE, Operation.DELETE}: raise ValueError(f"Operation not implemented: {summary.operation}") @@ -452,7 +452,7 @@ def set_when_positive(properties: Dict[str, str], num: int, property_name: str) properties[property_name] = str(num) -def ancestors_of(current_snapshot: Optional[Snapshot], table_metadata: TableMetadata) -> Iterable[Snapshot]: +def ancestors_of(current_snapshot: Snapshot | None, table_metadata: TableMetadata) -> Iterable[Snapshot]: """Get the ancestors of and including the given snapshot.""" snapshot = current_snapshot while snapshot is not None: @@ -462,9 +462,7 @@ def ancestors_of(current_snapshot: Optional[Snapshot], table_metadata: TableMeta snapshot = table_metadata.snapshot_by_id(snapshot.parent_snapshot_id) -def ancestors_between( - from_snapshot: Optional[Snapshot], to_snapshot: Snapshot, table_metadata: TableMetadata -) -> Iterable[Snapshot]: +def ancestors_between(from_snapshot: Snapshot | None, to_snapshot: Snapshot, table_metadata: TableMetadata) -> Iterable[Snapshot]: """Get the ancestors of and including the given snapshot between the to and from snapshots.""" if from_snapshot is not None: for snapshot in ancestors_of(to_snapshot, table_metadata): diff --git a/pyiceberg/table/sorting.py b/pyiceberg/table/sorting.py index 244c8ba867..8bd9a08176 100644 --- a/pyiceberg/table/sorting.py +++ b/pyiceberg/table/sorting.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=keyword-arg-before-vararg from enum import Enum -from typing import Annotated, Any, Callable, Dict, List, Optional, Union +from typing import Annotated, Any, Callable, Dict, List from pydantic import ( BeforeValidator, @@ -71,10 +71,10 @@ class SortField(IcebergBaseModel): def __init__( self, - source_id: Optional[int] = None, - transform: Optional[Union[Transform[Any, Any], Callable[[IcebergType], Transform[Any, Any]]]] = None, - direction: Optional[SortDirection] = None, - null_order: Optional[NullOrder] = None, + source_id: int | None = None, + transform: Transform[Any, Any] | Callable[[IcebergType], Transform[Any, Any]] | None = None, + direction: SortDirection | None = None, + null_order: NullOrder | None = None, **data: Any, ): if source_id is not None: diff --git a/pyiceberg/table/statistics.py b/pyiceberg/table/statistics.py index 484391efb1..25654d0c27 100644 --- a/pyiceberg/table/statistics.py +++ b/pyiceberg/table/statistics.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal from pydantic import Field @@ -26,7 +26,7 @@ class BlobMetadata(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") sequence_number: int = Field(alias="sequence-number") fields: List[int] - properties: Optional[Dict[str, str]] = None + properties: Dict[str, str] | None = None class StatisticsCommonFields(IcebergBaseModel): @@ -39,7 +39,7 @@ class StatisticsCommonFields(IcebergBaseModel): class StatisticsFile(StatisticsCommonFields): file_footer_size_in_bytes: int = Field(alias="file-footer-size-in-bytes") - key_metadata: Optional[str] = Field(alias="key-metadata", default=None) + key_metadata: str | None = Field(alias="key-metadata", default=None) blob_metadata: List[BlobMetadata] = Field(alias="blob-metadata") @@ -48,7 +48,7 @@ class PartitionStatisticsFile(StatisticsCommonFields): def filter_statistics_by_snapshot_id( - statistics: List[Union[StatisticsFile, PartitionStatisticsFile]], + statistics: List[StatisticsFile | PartitionStatisticsFile], reject_snapshot_id: int, -) -> List[Union[StatisticsFile, PartitionStatisticsFile]]: +) -> List[StatisticsFile | PartitionStatisticsFile]: return [stat for stat in statistics if stat.snapshot_id != reject_snapshot_id] diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index bcbe429688..9d6ab38eae 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -21,7 +21,7 @@ from abc import ABC, abstractmethod from datetime import datetime from functools import singledispatch -from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Tuple, TypeVar, cast from pydantic import Field, field_validator, model_serializer, model_validator @@ -136,9 +136,9 @@ class SetSnapshotRefUpdate(IcebergBaseModel): ref_name: str = Field(alias="ref-name") type: Literal[SnapshotRefType.TAG, SnapshotRefType.BRANCH] snapshot_id: int = Field(alias="snapshot-id") - max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)] - max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)] - min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)] + max_ref_age_ms: Annotated[int | None, Field(alias="max-ref-age-ms", default=None)] + max_snapshot_age_ms: Annotated[int | None, Field(alias="max-snapshot-age-ms", default=None)] + min_snapshots_to_keep: Annotated[int | None, Field(alias="min-snapshots-to-keep", default=None)] class RemoveSnapshotsUpdate(IcebergBaseModel): @@ -173,7 +173,7 @@ class RemovePropertiesUpdate(IcebergBaseModel): class SetStatisticsUpdate(IcebergBaseModel): action: Literal["set-statistics"] = Field(default="set-statistics") statistics: StatisticsFile - snapshot_id: Optional[int] = Field( + snapshot_id: int | None = Field( None, alias="snapshot-id", description="snapshot-id is **DEPRECATED for REMOVAL** since it contains redundant information. Use `statistics.snapshot-id` field instead.", @@ -214,29 +214,27 @@ class RemovePartitionStatisticsUpdate(IcebergBaseModel): TableUpdate = Annotated[ - Union[ - AssignUUIDUpdate, - UpgradeFormatVersionUpdate, - AddSchemaUpdate, - SetCurrentSchemaUpdate, - AddPartitionSpecUpdate, - SetDefaultSpecUpdate, - AddSortOrderUpdate, - SetDefaultSortOrderUpdate, - AddSnapshotUpdate, - SetSnapshotRefUpdate, - RemoveSnapshotsUpdate, - RemoveSnapshotRefUpdate, - SetLocationUpdate, - SetPropertiesUpdate, - RemovePropertiesUpdate, - SetStatisticsUpdate, - RemoveStatisticsUpdate, - RemovePartitionSpecsUpdate, - RemoveSchemasUpdate, - SetPartitionStatisticsUpdate, - RemovePartitionStatisticsUpdate, - ], + AssignUUIDUpdate + | UpgradeFormatVersionUpdate + | AddSchemaUpdate + | SetCurrentSchemaUpdate + | AddPartitionSpecUpdate + | SetDefaultSpecUpdate + | AddSortOrderUpdate + | SetDefaultSortOrderUpdate + | AddSnapshotUpdate + | SetSnapshotRefUpdate + | RemoveSnapshotsUpdate + | RemoveSnapshotRefUpdate + | SetLocationUpdate + | SetPropertiesUpdate + | RemovePropertiesUpdate + | SetStatisticsUpdate + | RemoveStatisticsUpdate + | RemovePartitionSpecsUpdate + | RemoveSchemasUpdate + | SetPartitionStatisticsUpdate + | RemovePartitionStatisticsUpdate, Field(discriminator="action"), ] @@ -676,7 +674,7 @@ def update_table_metadata( base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...], enforce_validation: bool = False, - metadata_location: Optional[str] = None, + metadata_location: str | None = None, ) -> TableMetadata: """Update the table metadata with the given updates in one transaction. @@ -744,7 +742,7 @@ class ValidatableTableRequirement(IcebergBaseModel): type: str @abstractmethod - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: """Validate the requirement against the base metadata. Args: @@ -761,7 +759,7 @@ class AssertCreate(ValidatableTableRequirement): type: Literal["assert-create"] = Field(default="assert-create") - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: if base_metadata is not None: raise CommitFailedException("Table already exists") @@ -772,7 +770,7 @@ class AssertTableUUID(ValidatableTableRequirement): type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") uuid: uuid.UUID - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") elif self.uuid != base_metadata.table_uuid: @@ -787,7 +785,7 @@ class AssertRefSnapshotId(ValidatableTableRequirement): type: Literal["assert-ref-snapshot-id"] = Field(default="assert-ref-snapshot-id") ref: str = Field(...) - snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id") + snapshot_id: int | None = Field(default=None, alias="snapshot-id") @model_serializer(mode="wrap") def serialize_model(self, handler: ModelWrapSerializerWithoutInfo) -> dict[str, Any]: @@ -795,7 +793,7 @@ def serialize_model(self, handler: ModelWrapSerializerWithoutInfo) -> dict[str, # Ensure "snapshot-id" is always present, even if value is None return {**partial_result, "snapshot-id": self.snapshot_id} - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") elif len(base_metadata.snapshots) == 0 and self.ref != MAIN_BRANCH: @@ -820,7 +818,7 @@ class AssertLastAssignedFieldId(ValidatableTableRequirement): type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") last_assigned_field_id: int = Field(..., alias="last-assigned-field-id") - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") elif base_metadata.last_column_id != self.last_assigned_field_id: @@ -835,7 +833,7 @@ class AssertCurrentSchemaId(ValidatableTableRequirement): type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") current_schema_id: int = Field(..., alias="current-schema-id") - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") elif self.current_schema_id != base_metadata.current_schema_id: @@ -848,9 +846,9 @@ class AssertLastAssignedPartitionId(ValidatableTableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") - last_assigned_partition_id: Optional[int] = Field(..., alias="last-assigned-partition-id") + last_assigned_partition_id: int | None = Field(..., alias="last-assigned-partition-id") - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") elif base_metadata.last_partition_id != self.last_assigned_partition_id: @@ -865,7 +863,7 @@ class AssertDefaultSpecId(ValidatableTableRequirement): type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") default_spec_id: int = Field(..., alias="default-spec-id") - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") elif self.default_spec_id != base_metadata.default_spec_id: @@ -880,7 +878,7 @@ class AssertDefaultSortOrderId(ValidatableTableRequirement): type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") default_sort_order_id: int = Field(..., alias="default-sort-order-id") - def validate(self, base_metadata: Optional[TableMetadata]) -> None: + def validate(self, base_metadata: TableMetadata | None) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") elif self.default_sort_order_id != base_metadata.default_sort_order_id: @@ -890,16 +888,14 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: TableRequirement = Annotated[ - Union[ - AssertCreate, - AssertTableUUID, - AssertRefSnapshotId, - AssertLastAssignedFieldId, - AssertCurrentSchemaId, - AssertLastAssignedPartitionId, - AssertDefaultSpecId, - AssertDefaultSortOrderId, - ], + AssertCreate + | AssertTableUUID + | AssertRefSnapshotId + | AssertLastAssignedFieldId + | AssertCurrentSchemaId + | AssertLastAssignedPartitionId + | AssertDefaultSpecId + | AssertDefaultSortOrderId, Field(discriminator="type"), ] diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index b7ed7c3351..f28e0aa2ac 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -20,7 +20,7 @@ from copy import copy from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple from pyiceberg.exceptions import ResolveError, ValidationError from pyiceberg.expressions import literal # type: ignore @@ -70,7 +70,7 @@ class _Move: field_id: int full_name: str op: _MoveOperation - other_field_id: Optional[int] = None + other_field_id: int | None = None class UpdateSchema(UpdateTableMetadata["UpdateSchema"]): @@ -94,8 +94,8 @@ def __init__( transaction: Transaction, allow_incompatible_changes: bool = False, case_sensitive: bool = True, - schema: Optional[Schema] = None, - name_mapping: Optional[NameMapping] = None, + schema: Schema | None = None, + name_mapping: NameMapping | None = None, ) -> None: super().__init__(transaction) @@ -145,7 +145,7 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema: def union_by_name( # TODO: Move TableProperties.DEFAULT_FORMAT_VERSION to separate file and set that as format_version default. self, - new_schema: Union[Schema, "pa.Schema"], + new_schema: Schema | "pa.Schema", format_version: TableVersion = 2, ) -> UpdateSchema: from pyiceberg.catalog import Catalog @@ -161,11 +161,11 @@ def union_by_name( def add_column( self, - path: Union[str, Tuple[str, ...]], + path: str | Tuple[str, ...], field_type: IcebergType, - doc: Optional[str] = None, + doc: str | None = None, required: bool = False, - default_value: Optional[L] = None, + default_value: L | None = None, ) -> UpdateSchema: """Add a new column to a nested struct or Add a new top-level column. @@ -257,7 +257,7 @@ def add_column( return self - def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: + def delete_column(self, path: str | Tuple[str, ...]) -> UpdateSchema: """Delete a column from a table. Args: @@ -280,7 +280,7 @@ def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: return self - def set_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Optional[L]) -> UpdateSchema: + def set_default_value(self, path: str | Tuple[str, ...], default_value: L | None) -> UpdateSchema: """Set the default value of a column. Args: @@ -293,7 +293,7 @@ def set_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Op return self - def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -> UpdateSchema: + def rename_column(self, path_from: str | Tuple[str, ...], new_name: str) -> UpdateSchema: """Update the name of a column. Args: @@ -339,7 +339,7 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) - return self - def make_column_optional(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: + def make_column_optional(self, path: str | Tuple[str, ...]) -> UpdateSchema: """Make a column optional. Args: @@ -354,7 +354,7 @@ def make_column_optional(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchem def set_identifier_fields(self, *fields: str) -> None: self._identifier_field_names = set(fields) - def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: bool) -> None: + def _set_column_requirement(self, path: str | Tuple[str, ...], required: bool) -> None: path = (path,) if isinstance(path, str) else path name = ".".join(path) @@ -391,7 +391,7 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b write_default=field.write_default, ) - def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Any) -> None: + def _set_column_default_value(self, path: str | Tuple[str, ...], default_value: Any) -> None: path = (path,) if isinstance(path, str) else path name = ".".join(path) @@ -437,10 +437,10 @@ def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_v def update_column( self, - path: Union[str, Tuple[str, ...]], - field_type: Optional[IcebergType] = None, - required: Optional[bool] = None, - doc: Optional[str] = None, + path: str | Tuple[str, ...], + field_type: IcebergType | None = None, + required: bool | None = None, + doc: str | None = None, ) -> UpdateSchema: """Update the type of column. @@ -501,7 +501,7 @@ def update_column( return self - def _find_for_move(self, name: str) -> Optional[int]: + def _find_for_move(self, name: str) -> int | None: try: return self._schema.find_field(name, self._case_sensitive).field_id except ValueError: @@ -534,7 +534,7 @@ def _move(self, move: _Move) -> None: self._moves[TABLE_ROOT_ID] = self._moves.get(TABLE_ROOT_ID, []) + [move] - def move_first(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: + def move_first(self, path: str | Tuple[str, ...]) -> UpdateSchema: """Move the field to the first position of the parent struct. Args: @@ -554,7 +554,7 @@ def move_first(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: return self - def move_before(self, path: Union[str, Tuple[str, ...]], before_path: Union[str, Tuple[str, ...]]) -> UpdateSchema: + def move_before(self, path: str | Tuple[str, ...], before_path: str | Tuple[str, ...]) -> UpdateSchema: """Move the field to before another field. Args: @@ -588,7 +588,7 @@ def move_before(self, path: Union[str, Tuple[str, ...]], before_path: Union[str, return self - def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, Tuple[str, ...]]) -> UpdateSchema: + def move_after(self, path: str | Tuple[str, ...], after_name: str | Tuple[str, ...]) -> UpdateSchema: """Move the field to after another field. Args: @@ -693,7 +693,7 @@ def assign_new_column_id(self) -> int: return next(self._last_column_id) -class _ApplyChanges(SchemaVisitor[Optional[IcebergType]]): +class _ApplyChanges(SchemaVisitor[IcebergType | None]): _adds: Dict[int, List[NestedField]] _updates: Dict[int, NestedField] _deletes: Set[int] @@ -711,7 +711,7 @@ def __init__( self._deletes = deletes self._moves = moves - def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]: + def schema(self, schema: Schema, struct_result: IcebergType | None) -> IcebergType | None: added = self._adds.get(TABLE_ROOT_ID) moves = self._moves.get(TABLE_ROOT_ID) @@ -724,7 +724,7 @@ def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Option return struct_result - def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]: + def struct(self, struct: StructType, field_results: List[IcebergType | None]) -> IcebergType | None: has_changes = False new_fields = [] @@ -777,7 +777,7 @@ def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) return struct - def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]: + def field(self, field: NestedField, field_result: IcebergType | None) -> IcebergType | None: # the API validates deletes, updates, and additions don't conflict handle deletes if field.field_id in self._deletes: return None @@ -799,16 +799,14 @@ def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Opti return field_result - def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: + def list(self, list_type: ListType, element_result: IcebergType | None) -> IcebergType | None: element_type = self.field(list_type.element_field, element_result) if element_type is None: raise ValueError(f"Cannot delete element type from list: {element_result}") return ListType(element_id=list_type.element_id, element=element_type, element_required=list_type.element_required) - def map( - self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType] - ) -> Optional[IcebergType]: + def map(self, map_type: MapType, key_result: IcebergType | None, value_result: IcebergType | None) -> IcebergType | None: key_id: int = map_type.key_field.field_id if key_id in self._deletes: @@ -836,7 +834,7 @@ def map( value_required=map_type.value_required, ) - def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: + def primitive(self, primitive: PrimitiveType) -> IcebergType | None: return primitive @@ -850,10 +848,10 @@ def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, case_se self.existing_schema = existing_schema self.case_sensitive = case_sensitive - def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool: + def schema(self, schema: Schema, partner_id: int | None, struct_result: bool) -> bool: return struct_result - def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool: + def struct(self, struct: StructType, partner_id: int | None, missing_positions: List[bool]) -> bool: if partner_id is None: return True @@ -908,10 +906,10 @@ def _find_field_type(self, field_id: int) -> IcebergType: else: return self.existing_schema.find_field(field_id).field_type - def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool: + def field(self, field: NestedField, partner_id: int | None, field_result: bool) -> bool: return partner_id is None - def list(self, list_type: ListType, list_partner_id: Optional[int], element_missing: bool) -> bool: + def list(self, list_type: ListType, list_partner_id: int | None, element_missing: bool) -> bool: if list_partner_id is None: return True @@ -926,7 +924,7 @@ def list(self, list_type: ListType, list_partner_id: Optional[int], element_miss return False - def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: bool, value_missing: bool) -> bool: + def map(self, map_type: MapType, map_partner_id: int | None, key_missing: bool, value_missing: bool) -> bool: if map_partner_id is None: return True @@ -945,7 +943,7 @@ def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: boo return False - def primitive(self, primitive: PrimitiveType, primitive_partner_id: Optional[int]) -> bool: + def primitive(self, primitive: PrimitiveType, primitive_partner_id: int | None) -> bool: return primitive_partner_id is None @@ -957,10 +955,10 @@ def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None: self.partner_schema = partner_schema self.case_sensitive = case_sensitive - def schema_partner(self, partner: Optional[int]) -> Optional[int]: + def schema_partner(self, partner: int | None) -> int | None: return -1 - def field_partner(self, partner_field_id: Optional[int], field_id: int, field_name: str) -> Optional[int]: + def field_partner(self, partner_field_id: int | None, field_id: int, field_name: str) -> int | None: if partner_field_id is not None: if partner_field_id == -1: struct = self.partner_schema.as_struct() @@ -974,7 +972,7 @@ def field_partner(self, partner_field_id: Optional[int], field_id: int, field_na return None - def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]: + def list_element_partner(self, partner_list_id: int | None) -> int | None: if partner_list_id is not None and (field := self.partner_schema.find_field(partner_list_id)): if not isinstance(field.field_type, ListType): raise ValueError(f"Expected ListType: {field}") @@ -982,7 +980,7 @@ def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]: else: return None - def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]: + def map_key_partner(self, partner_map_id: int | None) -> int | None: if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)): if not isinstance(field.field_type, MapType): raise ValueError(f"Expected MapType: {field}") @@ -990,7 +988,7 @@ def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]: else: return None - def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]: + def map_value_partner(self, partner_map_id: int | None) -> int | None: if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)): if not isinstance(field.field_type, MapType): raise ValueError(f"Expected MapType: {field}") @@ -999,7 +997,7 @@ def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]: return None -def _add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> Tuple[NestedField, ...]: +def _add_fields(fields: Tuple[NestedField, ...], adds: List[NestedField] | None) -> Tuple[NestedField, ...]: adds = adds or [] return fields + tuple(adds) @@ -1029,7 +1027,7 @@ def _move_fields(fields: Tuple[NestedField, ...], moves: List[_Move]) -> Tuple[N def _add_and_move_fields( fields: Tuple[NestedField, ...], adds: List[NestedField], moves: List[_Move] -) -> Optional[Tuple[NestedField, ...]]: +) -> Tuple[NestedField, ...] | None: if len(adds) > 0: # always apply adds first so that added fields can be moved added = _add_fields(fields, adds) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index a73961b56c..191e4a9bff 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -24,7 +24,7 @@ from concurrent.futures import Future from datetime import datetime from functools import cached_property -from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Set, Tuple from sortedcontainers import SortedList @@ -105,21 +105,21 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): _io: FileIO _operation: Operation _snapshot_id: int - _parent_snapshot_id: Optional[int] + _parent_snapshot_id: int | None _added_data_files: List[DataFile] _manifest_num_counter: itertools.count[int] _deleted_data_files: Set[DataFile] _compression: AvroCompressionCodec - _target_branch: Optional[str] + _target_branch: str | None def __init__( self, operation: Operation, transaction: Transaction, io: FileIO, - commit_uuid: Optional[uuid.UUID] = None, + commit_uuid: uuid.UUID | None = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, ) -> None: super().__init__(transaction) self.commit_uuid = commit_uuid or uuid.uuid4() @@ -140,7 +140,7 @@ def __init__( snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None ) - def _validate_target_branch(self, branch: Optional[str]) -> Optional[str]: + def _validate_target_branch(self, branch: str | None) -> str | None: # if branch is none, write will be written into a staging snapshot if branch is not None: if branch in self._transaction.table_metadata.refs: @@ -298,7 +298,7 @@ def _commit(self) -> UpdatesAndRequirements: ) as writer: writer.add_manifests(new_manifests) - first_row_id: Optional[int] = None + first_row_id: int | None = None if self._transaction.table_metadata.format_version >= 3: first_row_id = self._transaction.table_metadata.next_row_id @@ -386,8 +386,8 @@ def __init__( operation: Operation, transaction: Transaction, io: FileIO, - branch: Optional[str] = MAIN_BRANCH, - commit_uuid: Optional[uuid.UUID] = None, + branch: str | None = MAIN_BRANCH, + commit_uuid: uuid.UUID | None = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ): super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) @@ -557,8 +557,8 @@ def __init__( operation: Operation, transaction: Transaction, io: FileIO, - branch: Optional[str] = MAIN_BRANCH, - commit_uuid: Optional[uuid.UUID] = None, + branch: str | None = MAIN_BRANCH, + commit_uuid: uuid.UUID | None = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: from pyiceberg.table import TableProperties @@ -678,14 +678,14 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: class UpdateSnapshot: _transaction: Transaction _io: FileIO - _branch: Optional[str] + _branch: str | None _snapshot_properties: Dict[str, str] def __init__( self, transaction: Transaction, io: FileIO, - branch: Optional[str] = MAIN_BRANCH, + branch: str | None = MAIN_BRANCH, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: self._transaction = transaction @@ -711,7 +711,7 @@ def merge_append(self) -> _MergeAppendFiles: snapshot_properties=self._snapshot_properties, ) - def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: + def overwrite(self, commit_uuid: uuid.UUID | None = None) -> _OverwriteFiles: return _OverwriteFiles( commit_uuid=commit_uuid, operation=Operation.OVERWRITE @@ -864,7 +864,7 @@ def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots: self._requirements += requirements return self - def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: Optional[int] = None) -> ManageSnapshots: + def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: int | None = None) -> ManageSnapshots: """ Create a new tag pointing to the given snapshot id. @@ -901,9 +901,9 @@ def create_branch( self, snapshot_id: int, branch_name: str, - max_ref_age_ms: Optional[int] = None, - max_snapshot_age_ms: Optional[int] = None, - min_snapshots_to_keep: Optional[int] = None, + max_ref_age_ms: int | None = None, + max_snapshot_age_ms: int | None = None, + min_snapshots_to_keep: int | None = None, ) -> ManageSnapshots: """ Create a new branch pointing to the given snapshot id. diff --git a/pyiceberg/table/update/sorting.py b/pyiceberg/table/update/sorting.py index a356229f91..7e931b1a33 100644 --- a/pyiceberg/table/update/sorting.py +++ b/pyiceberg/table/update/sorting.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Tuple from pyiceberg.table.sorting import INITIAL_SORT_ORDER_ID, UNSORTED_SORT_ORDER, NullOrder, SortDirection, SortField, SortOrder from pyiceberg.table.update import ( @@ -36,7 +36,7 @@ class UpdateSortOrder(UpdateTableMetadata["UpdateSortOrder"]): _transaction: Transaction - _last_assigned_order_id: Optional[int] + _last_assigned_order_id: int | None _case_sensitive: bool _fields: List[SortField] @@ -44,7 +44,7 @@ def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> Non super().__init__(transaction) self._fields: List[SortField] = [] self._case_sensitive: bool = case_sensitive - self._last_assigned_order_id: Optional[int] = None + self._last_assigned_order_id: int | None = None def _column_name_to_id(self, column_name: str) -> int: """Map the column name to the column field id.""" diff --git a/pyiceberg/table/update/spec.py b/pyiceberg/table/update/spec.py index f52cd3ba81..b1f5f83d8f 100644 --- a/pyiceberg/table/update/spec.py +++ b/pyiceberg/table/update/spec.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple from pyiceberg.expressions import ( Reference, @@ -78,8 +78,8 @@ def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> Non def add_field( self, source_column_name: str, - transform: Union[str, Transform[Any, Any]], - partition_field_name: Optional[str] = None, + transform: str | Transform[Any, Any], + partition_field_name: str | None = None, ) -> UpdateSpec: ref = Reference(source_column_name) bound_ref = ref.bind(self._transaction.table_metadata.schema(), self._case_sensitive) @@ -267,7 +267,7 @@ def _add_new_field( new_spec_id = spec.spec_id + 1 return PartitionSpec(*partition_fields, spec_id=new_spec_id) - def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: Optional[str]) -> PartitionField: + def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: str | None) -> PartitionField: if self._transaction.table_metadata.format_version == 2: source_id, transform = transform_key historical_fields = [] diff --git a/pyiceberg/table/update/validate.py b/pyiceberg/table/update/validate.py index b49c4abe07..4ef3bcf160 100644 --- a/pyiceberg/table/update/validate.py +++ b/pyiceberg/table/update/validate.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Iterator, Optional, Set +from typing import Iterator, Set from pyiceberg.exceptions import ValidationException from pyiceberg.expressions import BooleanExpression @@ -82,9 +82,9 @@ def _validation_history( def _filter_manifest_entries( entry: ManifestEntry, snapshot_ids: set[int], - data_filter: Optional[BooleanExpression], - partition_set: Optional[dict[int, set[Record]]], - entry_status: Optional[ManifestEntryStatus], + data_filter: BooleanExpression | None, + partition_set: dict[int, set[Record]] | None, + entry_status: ManifestEntryStatus | None, schema: Schema, ) -> bool: """Filter manifest entries based on data filter and partition set. @@ -123,9 +123,9 @@ def _filter_manifest_entries( def _deleted_data_files( table: Table, starting_snapshot: Snapshot, - data_filter: Optional[BooleanExpression], - partition_set: Optional[dict[int, set[Record]]], - parent_snapshot: Optional[Snapshot], + data_filter: BooleanExpression | None, + partition_set: dict[int, set[Record]] | None, + parent_snapshot: Snapshot | None, ) -> Iterator[ManifestEntry]: """Find deleted data files matching a filter since a starting snapshot. @@ -162,7 +162,7 @@ def _deleted_data_files( def _validate_deleted_data_files( table: Table, starting_snapshot: Snapshot, - data_filter: Optional[BooleanExpression], + data_filter: BooleanExpression | None, parent_snapshot: Snapshot, ) -> None: """Validate that no files matching a filter have been deleted from the table since a starting snapshot. @@ -183,9 +183,9 @@ def _validate_deleted_data_files( def _added_data_files( table: Table, starting_snapshot: Snapshot, - data_filter: Optional[BooleanExpression], - partition_set: Optional[dict[int, set[Record]]], - parent_snapshot: Optional[Snapshot], + data_filter: BooleanExpression | None, + partition_set: dict[int, set[Record]] | None, + parent_snapshot: Snapshot | None, ) -> Iterator[ManifestEntry]: """Return manifest entries for data files added between the starting snapshot and parent snapshot. @@ -219,8 +219,8 @@ def _added_data_files( def _validate_added_data_files( table: Table, starting_snapshot: Snapshot, - data_filter: Optional[BooleanExpression], - parent_snapshot: Optional[Snapshot], + data_filter: BooleanExpression | None, + parent_snapshot: Snapshot | None, ) -> None: """Validate that no files matching a filter have been added to the table since a starting snapshot. diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 4069a95330..98cfac1146 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -109,7 +109,7 @@ TRUNCATE_PARSER = ParseNumberFromBrackets(TRUNCATE) -def _try_import(module_name: str, extras_name: Optional[str] = None) -> types.ModuleType: +def _try_import(module_name: str, extras_name: str | None = None) -> types.ModuleType: try: return importlib.import_module(module_name) except ImportError: @@ -165,7 +165,7 @@ class Transform(IcebergRootModel[str], ABC, Generic[S, T]): root: str = Field() @abstractmethod - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[T]]: ... + def transform(self, source: IcebergType) -> Callable[[S | None], T | None]: ... @abstractmethod def can_transform(self, source: IcebergType) -> bool: @@ -183,10 +183,10 @@ def result_type(self, source: IcebergType) -> IcebergType: ... @abstractmethod - def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: ... + def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: ... @abstractmethod - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]: ... + def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: ... @property def preserves_order(self) -> bool: @@ -195,7 +195,7 @@ def preserves_order(self) -> bool: def satisfies_order_of(self, other: Any) -> bool: return self == other - def to_human_string(self, _: IcebergType, value: Optional[S]) -> str: + def to_human_string(self, _: IcebergType, value: S | None) -> str: return str(value) if value is not None else "null" @property @@ -263,13 +263,13 @@ def num_buckets(self) -> int: def hash(self, value: S) -> int: raise NotImplementedError() - def apply(self, value: Optional[S]) -> Optional[int]: + def apply(self, value: S | None) -> int | None: return (self.hash(value) & IntegerType.max) % self._num_buckets if value else None def result_type(self, source: IcebergType) -> IcebergType: return IntegerType() - def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: + def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): @@ -286,7 +286,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica # For example, (x > 0) and (x < 3) can be turned into in({1, 2}) and projected. return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]: + def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): @@ -321,7 +321,7 @@ def can_transform(self, source: IcebergType) -> bool: ), ) - def transform(self, source: IcebergType, bucket: bool = True) -> Callable[[Optional[Any]], Optional[int]]: + def transform(self, source: IcebergType, bucket: bool = True) -> Callable[[Any | None], int | None]: if isinstance(source, TimeType): def hash_func(v: Any) -> int: @@ -418,9 +418,9 @@ def result_type(self, source: IcebergType) -> IntegerType: return IntegerType() @abstractmethod - def transform(self, source: IcebergType) -> Callable[[Optional[Any]], Optional[int]]: ... + def transform(self, source: IcebergType) -> Callable[[Any | None], int | None]: ... - def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: + def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) @@ -433,7 +433,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica else: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]: + def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: transformer = self.transform(pred.term.ref().field.field_type) if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) @@ -466,7 +466,7 @@ class YearTransform(TimeTransform[S]): root: LiteralType["year"] = Field(default="year") # noqa: F821 - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[int]]: + def transform(self, source: IcebergType) -> Callable[[S | None], int | None]: if isinstance(source, DateType): def year_func(v: Any) -> int: @@ -502,7 +502,7 @@ def can_transform(self, source: IcebergType) -> bool: def granularity(self) -> TimeResolution: return TimeResolution.YEAR - def to_human_string(self, _: IcebergType, value: Optional[S]) -> str: + def to_human_string(self, _: IcebergType, value: S | None) -> str: return datetime.to_human_year(value) if isinstance(value, int) else "null" def __repr__(self) -> str: @@ -526,7 +526,7 @@ class MonthTransform(TimeTransform[S]): root: LiteralType["month"] = Field(default="month") # noqa: F821 - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[int]]: + def transform(self, source: IcebergType) -> Callable[[S | None], int | None]: if isinstance(source, DateType): def month_func(v: Any) -> int: @@ -562,7 +562,7 @@ def can_transform(self, source: IcebergType) -> bool: def granularity(self) -> TimeResolution: return TimeResolution.MONTH - def to_human_string(self, _: IcebergType, value: Optional[S]) -> str: + def to_human_string(self, _: IcebergType, value: S | None) -> str: return datetime.to_human_month(value) if isinstance(value, int) else "null" def __repr__(self) -> str: @@ -587,7 +587,7 @@ class DayTransform(TimeTransform[S]): root: LiteralType["day"] = Field(default="day") # noqa: F821 - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[int]]: + def transform(self, source: IcebergType) -> Callable[[S | None], int | None]: if isinstance(source, DateType): def day_func(v: Any) -> int: @@ -631,7 +631,7 @@ def result_type(self, source: IcebergType) -> IcebergType: def granularity(self) -> TimeResolution: return TimeResolution.DAY - def to_human_string(self, _: IcebergType, value: Optional[S]) -> str: + def to_human_string(self, _: IcebergType, value: S | None) -> str: return datetime.to_human_day(value) if isinstance(value, int) else "null" def __repr__(self) -> str: @@ -656,7 +656,7 @@ class HourTransform(TimeTransform[S]): root: LiteralType["hour"] = Field(default="hour") # noqa: F821 - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[int]]: + def transform(self, source: IcebergType) -> Callable[[S | None], int | None]: if isinstance(source, (TimestampType, TimestamptzType)): def hour_func(v: Any) -> int: @@ -684,7 +684,7 @@ def can_transform(self, source: IcebergType) -> bool: def granularity(self) -> TimeResolution: return TimeResolution.HOUR - def to_human_string(self, _: IcebergType, value: Optional[S]) -> str: + def to_human_string(self, _: IcebergType, value: S | None) -> str: return datetime.to_human_hour(value) if isinstance(value, int) else "null" def __repr__(self) -> str: @@ -716,7 +716,7 @@ class IdentityTransform(Transform[S, S]): def __init__(self) -> None: super().__init__("identity") - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[S]]: + def transform(self, source: IcebergType) -> Callable[[S | None], S | None]: return lambda v: v def can_transform(self, source: IcebergType) -> bool: @@ -725,7 +725,7 @@ def can_transform(self, source: IcebergType) -> bool: def result_type(self, source: IcebergType) -> IcebergType: return source - def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: + def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: if isinstance(pred.term, BoundTransform): return _project_transform_predicate(self, name, pred) elif isinstance(pred, BoundUnaryPredicate): @@ -737,7 +737,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica else: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]: + def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: if isinstance(pred, BoundUnaryPredicate): return pred.as_unbound(Reference(name)) elif isinstance(pred, BoundLiteralPredicate): @@ -755,7 +755,7 @@ def satisfies_order_of(self, other: Transform[S, T]) -> bool: """Ordering by value is the same as long as the other preserves order.""" return other.preserves_order - def to_human_string(self, source_type: IcebergType, value: Optional[S]) -> str: + def to_human_string(self, source_type: IcebergType, value: S | None) -> str: return _human_string(value, source_type) if value is not None else "null" def __str__(self) -> str: @@ -801,7 +801,7 @@ def preserves_order(self) -> bool: def source_type(self) -> IcebergType: return self._source_type - def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: + def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: field_type = pred.term.ref().field.field_type if isinstance(pred.term, BoundTransform): @@ -819,7 +819,7 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica return _truncate_array(name, pred, self.transform(field_type)) return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]: + def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: field_type = pred.term.ref().field.field_type if isinstance(pred.term, BoundTransform): @@ -865,7 +865,7 @@ def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[Unbou def width(self) -> int: return self._width - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[S]]: + def transform(self, source: IcebergType) -> Callable[[S | None], S | None]: if isinstance(source, (IntegerType, LongType)): def truncate_func(v: Any) -> Any: @@ -898,7 +898,7 @@ def satisfies_order_of(self, other: Transform[S, T]) -> bool: return False - def to_human_string(self, _: IcebergType, value: Optional[S]) -> str: + def to_human_string(self, _: IcebergType, value: S | None) -> str: if value is None: return "null" elif isinstance(value, bytes): @@ -978,7 +978,7 @@ def __init__(self, transform: str, **data: Any): super().__init__(**data) self._transform = transform - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[T]]: + def transform(self, source: IcebergType) -> Callable[[S | None], T | None]: raise AttributeError(f"Cannot apply unsupported transform: {self}") def can_transform(self, source: IcebergType) -> bool: @@ -987,10 +987,10 @@ def can_transform(self, source: IcebergType) -> bool: def result_type(self, source: IcebergType) -> StringType: return StringType() - def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: + def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: return None - def strict_project(self, name: str, pred: BoundPredicate[Any]) -> Optional[UnboundPredicate[Any]]: + def strict_project(self, name: str, pred: BoundPredicate[Any]) -> UnboundPredicate[Any] | None: return None def __repr__(self) -> str: @@ -1006,7 +1006,7 @@ class VoidTransform(Transform[S, None], Singleton): root: str = "void" - def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[T]]: + def transform(self, source: IcebergType) -> Callable[[S | None], T | None]: return lambda v: None def can_transform(self, _: IcebergType) -> bool: @@ -1015,13 +1015,13 @@ def can_transform(self, _: IcebergType) -> bool: def result_type(self, source: IcebergType) -> IcebergType: return source - def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: + def project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: return None - def strict_project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredicate[Any]]: + def strict_project(self, name: str, pred: BoundPredicate[L]) -> UnboundPredicate[Any] | None: return None - def to_human_string(self, _: IcebergType, value: Optional[S]) -> str: + def to_human_string(self, _: IcebergType, value: S | None) -> str: return "null" def __repr__(self) -> str: @@ -1038,8 +1038,8 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr def _truncate_number( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]] -) -> Optional[UnboundPredicate[Any]]: + name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] +) -> UnboundPredicate[Any] | None: boundary = pred.literal if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, TimestampLiteral)): @@ -1060,8 +1060,8 @@ def _truncate_number( def _truncate_number_strict( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]] -) -> Optional[UnboundPredicate[Any]]: + name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] +) -> UnboundPredicate[Any] | None: boundary = pred.literal if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, TimestampLiteral)): @@ -1086,8 +1086,8 @@ def _truncate_number_strict( def _truncate_array_strict( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]] -) -> Optional[UnboundPredicate[Any]]: + name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] +) -> UnboundPredicate[Any] | None: boundary = pred.literal if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)): @@ -1101,8 +1101,8 @@ def _truncate_array_strict( def _truncate_array( - name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]] -) -> Optional[UnboundPredicate[Any]]: + name: str, pred: BoundLiteralPredicate[L], transform: Callable[[L | None], L | None] +) -> UnboundPredicate[Any] | None: boundary = pred.literal if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)): @@ -1121,7 +1121,7 @@ def _truncate_array( def _project_transform_predicate( transform: Transform[Any, Any], partition_name: str, pred: BoundPredicate[L] -) -> Optional[UnboundPredicate[Any]]: +) -> UnboundPredicate[Any] | None: term = pred.term if isinstance(term, BoundTransform) and transform == term.transform: return _remove_transform(partition_name, pred) diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py index d9ace9d971..3480adcb32 100644 --- a/pyiceberg/typedef.py +++ b/pyiceberg/typedef.py @@ -27,7 +27,6 @@ Generic, List, Literal, - Optional, Protocol, Set, Tuple, @@ -127,7 +126,7 @@ class IcebergBaseModel(BaseModel): model_config = ConfigDict(populate_by_name=True, frozen=True) - def _exclude_private_properties(self, exclude: Optional[Set[str]] = None) -> Set[str]: + def _exclude_private_properties(self, exclude: Set[str] | None = None) -> Set[str]: # A small trick to exclude private properties. Properties are serialized by pydantic, # regardless if they start with an underscore. # This will look at the dict, and find the fields and exclude them @@ -136,14 +135,14 @@ def _exclude_private_properties(self, exclude: Optional[Set[str]] = None) -> Set ) def model_dump( - self, exclude_none: bool = True, exclude: Optional[Set[str]] = None, by_alias: bool = True, **kwargs: Any + self, exclude_none: bool = True, exclude: Set[str] | None = None, by_alias: bool = True, **kwargs: Any ) -> Dict[str, Any]: return super().model_dump( exclude_none=exclude_none, exclude=self._exclude_private_properties(exclude), by_alias=by_alias, **kwargs ) def model_dump_json( - self, exclude_none: bool = True, exclude: Optional[Set[str]] = None, by_alias: bool = True, **kwargs: Any + self, exclude_none: bool = True, exclude: Set[str] | None = None, by_alias: bool = True, **kwargs: Any ) -> str: return super().model_dump_json( exclude_none=exclude_none, exclude=self._exclude_private_properties(exclude), by_alias=by_alias, **kwargs diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 6872663f84..c22bee092f 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -40,7 +40,6 @@ ClassVar, Dict, Literal, - Optional, Tuple, ) @@ -339,9 +338,9 @@ class NestedField(IcebergType): name: str = Field() field_type: SerializeAsAny[IcebergType] = Field(alias="type") required: bool = Field(default=False) - doc: Optional[str] = Field(default=None, repr=False) - initial_default: Optional[DefaultValue] = Field(alias="initial-default", default=None, repr=True) # type: ignore - write_default: Optional[DefaultValue] = Field(alias="write-default", default=None, repr=True) # type: ignore + doc: str | None = Field(default=None, repr=False) + initial_default: DefaultValue | None = Field(alias="initial-default", default=None, repr=True) # type: ignore + write_default: DefaultValue | None = Field(alias="write-default", default=None, repr=True) # type: ignore @field_validator("field_type", mode="before") def convert_field_type(cls, v: Any) -> IcebergType: @@ -355,13 +354,13 @@ def convert_field_type(cls, v: Any) -> IcebergType: def __init__( self, - field_id: Optional[int] = None, - name: Optional[str] = None, - field_type: Optional[IcebergType | str] = None, + field_id: int | None = None, + name: str | None = None, + field_type: IcebergType | str | None = None, required: bool = False, - doc: Optional[str] = None, - initial_default: Optional[Any] = None, - write_default: Optional[L] = None, + doc: str | None = None, + initial_default: Any | None = None, + write_default: L | None = None, **data: Any, ): # We need an init when we want to use positional arguments, but @@ -416,7 +415,7 @@ def __repr__(self) -> str: return f"NestedField({', '.join(parts)})" - def __getnewargs__(self) -> Tuple[int, str, IcebergType, bool, Optional[str]]: + def __getnewargs__(self) -> Tuple[int, str, IcebergType, bool, str | None]: """Pickle the NestedField class.""" return (self.field_id, self.name, self.field_type, self.required, self.doc) @@ -447,13 +446,13 @@ def __init__(self, *fields: NestedField, **data: Any): super().__init__(**data) self._hash = hash(self.fields) - def field(self, field_id: int) -> Optional[NestedField]: + def field(self, field_id: int) -> NestedField | None: for field in self.fields: if field.field_id == field_id: return field return None - def field_by_name(self, name: str, case_sensitive: bool = True) -> Optional[NestedField]: + def field_by_name(self, name: str, case_sensitive: bool = True) -> NestedField | None: if case_sensitive: for field in self.fields: if field.name == name: @@ -506,7 +505,7 @@ class ListType(IcebergType): _hash: int = PrivateAttr() def __init__( - self, element_id: Optional[int] = None, element: Optional[IcebergType] = None, element_required: bool = True, **data: Any + self, element_id: int | None = None, element: IcebergType | None = None, element_required: bool = True, **data: Any ): data["element-id"] = data["element-id"] if "element-id" in data else element_id data["element"] = element or data["element_type"] @@ -558,10 +557,10 @@ class MapType(IcebergType): def __init__( self, - key_id: Optional[int] = None, - key_type: Optional[IcebergType] = None, - value_id: Optional[int] = None, - value_type: Optional[IcebergType] = None, + key_id: int | None = None, + key_type: IcebergType | None = None, + value_id: int | None = None, + value_type: IcebergType | None = None, value_required: bool = True, **data: Any, ): diff --git a/pyiceberg/utils/bin_packing.py b/pyiceberg/utils/bin_packing.py index 0291619685..825420d8b7 100644 --- a/pyiceberg/utils/bin_packing.py +++ b/pyiceberg/utils/bin_packing.py @@ -21,7 +21,6 @@ Generic, Iterable, List, - Optional, TypeVar, ) @@ -91,7 +90,7 @@ def __next__(self) -> List[T]: return self.remove_bin().items - def find_bin(self, weight: int) -> Optional[Bin[T]]: + def find_bin(self, weight: int) -> Bin[T] | None: for bin_ in self.bins: if bin_.can_add(weight): return bin_ diff --git a/pyiceberg/utils/concurrent.py b/pyiceberg/utils/concurrent.py index 54e99dc0ba..a0fccb8131 100644 --- a/pyiceberg/utils/concurrent.py +++ b/pyiceberg/utils/concurrent.py @@ -18,17 +18,16 @@ import os from concurrent.futures import Executor, ThreadPoolExecutor -from typing import Optional from pyiceberg.utils.config import Config class ExecutorFactory: - _instance: Optional[Executor] = None - _instance_pid: Optional[int] = None + _instance: Executor | None = None + _instance_pid: int | None = None @staticmethod - def max_workers() -> Optional[int]: + def max_workers() -> int | None: """Return the max number of workers configured.""" return Config().get_int("max-workers") diff --git a/pyiceberg/utils/config.py b/pyiceberg/utils/config.py index 78f121a402..98fb292369 100644 --- a/pyiceberg/utils/config.py +++ b/pyiceberg/utils/config.py @@ -16,7 +16,7 @@ # under the License. import logging import os -from typing import List, Optional +from typing import List import strictyaml @@ -66,14 +66,14 @@ def __init__(self) -> None: self.config = FrozenDict(**config) @staticmethod - def _from_configuration_files() -> Optional[RecursiveDict]: + def _from_configuration_files() -> RecursiveDict | None: """Load the first configuration file that its finds. Will first look in the PYICEBERG_HOME env variable, and then in the home directory. """ - def _load_yaml(directory: Optional[str]) -> Optional[RecursiveDict]: + def _load_yaml(directory: str | None) -> RecursiveDict | None: if directory: path = os.path.join(directory, PYICEBERG_YML) if os.path.isfile(path): @@ -146,7 +146,7 @@ def get_default_catalog_name(self) -> str: return default_catalog_name return DEFAULT - def get_catalog_config(self, catalog_name: str) -> Optional[RecursiveDict]: + def get_catalog_config(self, catalog_name: str) -> RecursiveDict | None: if CATALOG in self.config: catalog_name_lower = catalog_name.lower() catalogs = self.config[CATALOG] @@ -165,7 +165,7 @@ def get_known_catalogs(self) -> List[str]: raise ValueError("Catalog configurations needs to be an object") return list(catalogs.keys()) - def get_int(self, key: str) -> Optional[int]: + def get_int(self, key: str) -> int | None: if (val := self.config.get(key)) is not None: try: return int(val) # type: ignore @@ -173,7 +173,7 @@ def get_int(self, key: str) -> Optional[int]: raise ValueError(f"{key} should be an integer or left unset. Current value: {val}") from err return None - def get_bool(self, key: str) -> Optional[bool]: + def get_bool(self, key: str) -> bool | None: if (val := self.config.get(key)) is not None: try: return strtobool(val) # type: ignore diff --git a/pyiceberg/utils/decimal.py b/pyiceberg/utils/decimal.py index 99638d2a00..5ef82640d9 100644 --- a/pyiceberg/utils/decimal.py +++ b/pyiceberg/utils/decimal.py @@ -19,7 +19,6 @@ import math from decimal import Decimal -from typing import Optional, Union def decimal_to_unscaled(value: Decimal) -> int: @@ -49,7 +48,7 @@ def unscaled_to_decimal(unscaled: int, scale: int) -> Decimal: return Decimal((sign, digits, -scale)) -def bytes_required(value: Union[int, Decimal]) -> int: +def bytes_required(value: int | Decimal) -> int: """Return the minimum number of bytes needed to serialize a decimal or unscaled value. Args: @@ -66,7 +65,7 @@ def bytes_required(value: Union[int, Decimal]) -> int: raise ValueError(f"Unsupported value: {value}") -def decimal_to_bytes(value: Decimal, byte_length: Optional[int] = None) -> bytes: +def decimal_to_bytes(value: Decimal, byte_length: int | None = None) -> bytes: """Return a byte representation of a decimal. Args: diff --git a/pyiceberg/utils/deprecated.py b/pyiceberg/utils/deprecated.py index b196f47ec6..accbd9d5fe 100644 --- a/pyiceberg/utils/deprecated.py +++ b/pyiceberg/utils/deprecated.py @@ -16,10 +16,10 @@ # under the License. import functools import warnings -from typing import Any, Callable, Optional +from typing import Any, Callable -def deprecated(deprecated_in: str, removed_in: str, help_message: Optional[str] = None) -> Callable: # type: ignore +def deprecated(deprecated_in: str, removed_in: str, help_message: str | None = None) -> Callable: # type: ignore """Mark functions as deprecated. Adding this will result in a warning being emitted when the function is used. @@ -41,12 +41,12 @@ def new_func(*args: Any, **kwargs: Any) -> Any: return decorator -def deprecation_notice(deprecated_in: str, removed_in: str, help_message: Optional[str]) -> str: +def deprecation_notice(deprecated_in: str, removed_in: str, help_message: str | None) -> str: """Return a deprecation notice.""" return f"Deprecated in {deprecated_in}, will be removed in {removed_in}. {help_message}" -def deprecation_message(deprecated_in: str, removed_in: str, help_message: Optional[str]) -> None: +def deprecation_message(deprecated_in: str, removed_in: str, help_message: str | None) -> None: """Mark properties or behaviors as deprecated. Adding this will result in a warning being emitted. diff --git a/pyiceberg/utils/lazydict.py b/pyiceberg/utils/lazydict.py index ea70c78c7a..db5c1f82c5 100644 --- a/pyiceberg/utils/lazydict.py +++ b/pyiceberg/utils/lazydict.py @@ -19,10 +19,8 @@ Dict, Iterator, Mapping, - Optional, Sequence, TypeVar, - Union, cast, ) @@ -41,9 +39,9 @@ class LazyDict(Mapping[K, V]): # # Rather than spending the runtime cost of checking the type of each item, we presume # that the developer has correctly used the class and that the contents are valid. - def __init__(self, contents: Sequence[Sequence[Union[K, V]]]): + def __init__(self, contents: Sequence[Sequence[K | V]]): self._contents = contents - self._dict: Optional[Dict[K, V]] = None + self._dict: Dict[K, V] | None = None def _build_dict(self) -> Dict[K, V]: self._dict = {} diff --git a/pyiceberg/utils/properties.py b/pyiceberg/utils/properties.py index 2b228f6e41..11241e485c 100644 --- a/pyiceberg/utils/properties.py +++ b/pyiceberg/utils/properties.py @@ -18,7 +18,6 @@ from typing import ( Any, Dict, - Optional, ) from pyiceberg.typedef import Properties @@ -30,8 +29,8 @@ def property_as_int( properties: Dict[str, str], property_name: str, - default: Optional[int] = None, -) -> Optional[int]: + default: int | None = None, +) -> int | None: if value := properties.get(property_name): try: return int(value) @@ -44,8 +43,8 @@ def property_as_int( def property_as_float( properties: Dict[str, str], property_name: str, - default: Optional[float] = None, -) -> Optional[float]: + default: float | None = None, +) -> float | None: if value := properties.get(property_name): try: return float(value) @@ -71,7 +70,7 @@ def property_as_bool( def get_first_property_value( properties: Properties, *property_names: str, -) -> Optional[Any]: +) -> Any | None: for property_name in property_names: if property_value := properties.get(property_name): return property_value diff --git a/pyiceberg/utils/schema_conversion.py b/pyiceberg/utils/schema_conversion.py index 551fa40156..0ec8dce084 100644 --- a/pyiceberg/utils/schema_conversion.py +++ b/pyiceberg/utils/schema_conversion.py @@ -21,9 +21,7 @@ Any, Dict, List, - Optional, Tuple, - Union, ) from pyiceberg.schema import ( @@ -81,7 +79,7 @@ ("uuid", "string"): UUIDType(), } -AvroType = Union[str, Any] +AvroType = str | Any class AvroSchemaConversion: @@ -130,13 +128,11 @@ def avro_to_iceberg(self, avro_schema: Dict[str, Any]) -> Schema: """ return Schema(*[self._convert_field(field) for field in avro_schema["fields"]], schema_id=1) - def iceberg_to_avro(self, schema: Schema, schema_name: Optional[str] = None) -> AvroType: + def iceberg_to_avro(self, schema: Schema, schema_name: str | None = None) -> AvroType: """Convert an Iceberg schema into an Avro dictionary that can be serialized to JSON.""" return visit(schema, ConvertSchemaToAvro(schema_name)) - def _resolve_union( - self, type_union: Union[Dict[str, str], List[Union[str, Dict[str, str]]], str] - ) -> Tuple[Union[str, Dict[str, Any]], bool]: + def _resolve_union(self, type_union: Dict[str, str] | List[str | Dict[str, str]] | str) -> Tuple[str | Dict[str, Any], bool]: """ Convert Unions into their type and resolves if the field is required. @@ -159,7 +155,7 @@ def _resolve_union( Raises: TypeError: In the case non-optional union types are encountered. """ - avro_types: Union[Dict[str, str], List[Union[Dict[str, str], str]]] + avro_types: Dict[str, str] | List[Dict[str, str] | str] if isinstance(type_union, str): # It is a primitive and required return type_union, True @@ -185,7 +181,7 @@ def _resolve_union( # Filter the null value and return the type return list(filter(lambda t: t != "null", avro_types))[0], False - def _convert_schema(self, avro_type: Union[str, Dict[str, Any]]) -> IcebergType: + def _convert_schema(self, avro_type: str | Dict[str, Any]) -> IcebergType: """ Resolve the Avro type. @@ -496,12 +492,12 @@ def _convert_fixed_type(self, avro_type: Dict[str, Any]) -> FixedType: class ConvertSchemaToAvro(SchemaVisitorPerPrimitiveType[AvroType]): """Convert an Iceberg schema to an Avro schema.""" - schema_name: Optional[str] + schema_name: str | None last_list_field_id: int last_map_key_field_id: int last_map_value_field_id: int - def __init__(self, schema_name: Optional[str]) -> None: + def __init__(self, schema_name: str | None) -> None: """Convert an Iceberg schema to an Avro schema. Args: diff --git a/pyiceberg/utils/truncate.py b/pyiceberg/utils/truncate.py index 4ddb2401c4..feaa7d342d 100644 --- a/pyiceberg/utils/truncate.py +++ b/pyiceberg/utils/truncate.py @@ -14,10 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional -def truncate_upper_bound_text_string(value: str, trunc_length: Optional[int]) -> Optional[str]: +def truncate_upper_bound_text_string(value: str, trunc_length: int | None) -> str | None: result = value[:trunc_length] if result != value: chars = [*result] @@ -35,7 +34,7 @@ def truncate_upper_bound_text_string(value: str, trunc_length: Optional[int]) -> return result -def truncate_upper_bound_binary_string(value: bytes, trunc_length: Optional[int]) -> Optional[bytes]: +def truncate_upper_bound_binary_string(value: bytes, trunc_length: int | None) -> bytes | None: result = value[:trunc_length] if result != value: _bytes = [*result] diff --git a/ruff.toml b/ruff.toml index b7bc461cf6..7fb76404c7 100644 --- a/ruff.toml +++ b/ruff.toml @@ -57,7 +57,7 @@ select = [ "I", # isort "UP", # pyupgrade ] -ignore = ["E501","E203","B024","B028","UP037", "UP035", "UP006"] +ignore = ["E501","E203","B024","B028","UP037", "UP035", "UP006", "B905"] # Allow autofix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] diff --git a/tests/avro/test_decoder.py b/tests/avro/test_decoder.py index 608e6ae2d5..c7c64ea096 100644 --- a/tests/avro/test_decoder.py +++ b/tests/avro/test_decoder.py @@ -20,7 +20,7 @@ import struct from io import SEEK_SET from types import TracebackType -from typing import Callable, Optional, Type +from typing import Callable, Type from unittest.mock import MagicMock, patch import pytest @@ -129,9 +129,7 @@ def close(self) -> None: def __enter__(self) -> OneByteAtATimeInputStream: return self - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ) -> None: + def __exit__(self, exctype: Type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None) -> None: self.close() diff --git a/tests/avro/test_resolver.py b/tests/avro/test_resolver.py index 26b44e8e23..f0742e946d 100644 --- a/tests/avro/test_resolver.py +++ b/tests/avro/test_resolver.py @@ -16,7 +16,6 @@ # under the License. from tempfile import TemporaryDirectory -from typing import Optional import pytest from pydantic import Field @@ -287,7 +286,7 @@ def test_column_assignment() -> None: class Ints(Record): c: int = Field() - d: Optional[int] = Field() + d: int | None = Field() ints_schema = Schema( NestedField(3, "c", IntegerType(), required=True), diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 80c01f70fa..42702c8c2b 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -18,7 +18,6 @@ from pathlib import PosixPath -from typing import Union import pyarrow as pa import pytest @@ -201,7 +200,7 @@ def test_create_table_removes_trailing_slash_from_location(catalog: InMemoryCata ], ) def test_convert_schema_if_needed( - schema: Union[Schema, pa.Schema], + schema: Schema | pa.Schema, expected: Schema, catalog: InMemoryCatalog, ) -> None: diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index 649e2545ba..1a3978a045 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -22,7 +22,6 @@ import uuid from collections.abc import Generator from copy import deepcopy -from typing import Optional from unittest.mock import MagicMock, call, patch import pytest @@ -230,7 +229,7 @@ def run(self) -> None: pass @property - def port(self) -> Optional[int]: + def port(self) -> int | None: self._port_bound.wait() return self._port diff --git a/tests/conftest.py b/tests/conftest.py index 6734932993..9ac033d1da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,6 @@ Dict, Generator, List, - Optional, ) import boto3 @@ -786,8 +785,8 @@ def example_table_metadata_no_snapshot_v1() -> Dict[str, Any]: def example_table_metadata_v2_with_extensive_snapshots() -> Dict[str, Any]: def generate_snapshot( snapshot_id: int, - parent_snapshot_id: Optional[int] = None, - timestamp_ms: Optional[int] = None, + parent_snapshot_id: int | None = None, + timestamp_ms: int | None = None, sequence_number: int = 0, ) -> Dict[str, Any]: return { @@ -2368,12 +2367,12 @@ def get_gcs_bucket_name() -> str: return bucket_name -def get_glue_endpoint() -> Optional[str]: +def get_glue_endpoint() -> str | None: """Set the optional environment variable AWS_TEST_GLUE_ENDPOINT for a glue endpoint to test.""" return os.getenv("AWS_TEST_GLUE_ENDPOINT") -def get_s3_path(bucket_name: str, database_name: Optional[str] = None, table_name: Optional[str] = None) -> str: +def get_s3_path(bucket_name: str, database_name: str | None = None, table_name: str | None = None) -> str: result_path = f"s3://{bucket_name}" if database_name is not None: result_path += f"/{database_name}.db" @@ -2383,7 +2382,7 @@ def get_s3_path(bucket_name: str, database_name: Optional[str] = None, table_nam return result_path -def get_gcs_path(bucket_name: str, database_name: Optional[str] = None, table_name: Optional[str] = None) -> str: +def get_gcs_path(bucket_name: str, database_name: str | None = None, table_name: str | None = None) -> str: result_path = f"gcs://{bucket_name}" if database_name is not None: result_path += f"/{database_name}.db" diff --git a/tests/integration/test_inspect_table.py b/tests/integration/test_inspect_table.py index 8998d7d0bc..7a9617a995 100644 --- a/tests/integration/test_inspect_table.py +++ b/tests/integration/test_inspect_table.py @@ -18,7 +18,6 @@ import math from datetime import date, datetime -from typing import Union import pyarrow as pa import pytest @@ -643,7 +642,7 @@ def test_inspect_partitions_partitioned_with_filter(spark: SparkSession, session tbl = session_catalog.load_table(identifier) for snapshot in tbl.metadata.snapshots: - test_cases: list[tuple[Union[str, BooleanExpression], str]] = [ + test_cases: list[tuple[str | BooleanExpression, str]] = [ ("dt >= '2021-01-01'", "partition.dt >= '2021-01-01'"), (GreaterThanOrEqual("dt", "2021-01-01"), "partition.dt >= '2021-01-01'"), ("dt >= '2021-01-01' and dt < '2021-03-01'", "partition.dt >= '2021-01-01' AND partition.dt < '2021-03-01'"), diff --git a/tests/integration/test_writes/utils.py b/tests/integration/test_writes/utils.py index 9f1f6df043..ce30c19477 100644 --- a/tests/integration/test_writes/utils.py +++ b/tests/integration/test_writes/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name -from typing import List, Optional, Union +from typing import List, Union import pyarrow as pa @@ -63,7 +63,7 @@ def _create_table( session_catalog: Catalog, identifier: str, properties: Properties = EMPTY_DICT, - data: Optional[List[pa.Table]] = None, + data: List[pa.Table] | None = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, schema: Union[Schema, "pa.Schema"] = TABLE_SCHEMA, ) -> Table: diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index cd3f9c8034..dbd88a77c8 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -22,7 +22,7 @@ import warnings from datetime import date, datetime, timezone from pathlib import Path -from typing import Any, List, Optional +from typing import Any, List from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -1014,7 +1014,7 @@ def file_map(schema_map: Schema, tmpdir: str) -> str: def project( - schema: Schema, files: List[str], expr: Optional[BooleanExpression] = None, table_schema: Optional[Schema] = None + schema: Schema, files: List[str], expr: BooleanExpression | None = None, table_schema: Schema | None = None ) -> pa.Table: def _set_spec_id(datafile: DataFile) -> DataFile: datafile.spec_id = 0 diff --git a/tests/io/test_pyarrow_stats.py b/tests/io/test_pyarrow_stats.py index 513497a338..fd175cae60 100644 --- a/tests/io/test_pyarrow_stats.py +++ b/tests/io/test_pyarrow_stats.py @@ -32,9 +32,7 @@ Any, Dict, List, - Optional, Tuple, - Union, ) import pyarrow as pa @@ -78,13 +76,13 @@ @dataclass(frozen=True) class TestStruct: __test__ = False - x: Optional[int] - y: Optional[float] + x: int | None + y: float | None def construct_test_table( - write_statistics: Union[bool, List[str]] = True, -) -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]: + write_statistics: bool | List[str] = True, +) -> Tuple[pq.FileMetaData, TableMetadataV1 | TableMetadataV2]: table_metadata = { "format-version": 2, "location": "s3://bucket/test/location", @@ -145,7 +143,7 @@ def construct_test_table( _list = [[1, 2, 3], [4, 5, 6], None, [7, 8, 9]] - _maps: List[Optional[Dict[int, int]]] = [ + _maps: List[Dict[int, int] | None] = [ {1: 2, 3: 4}, None, {5: 6}, @@ -424,7 +422,7 @@ def test_column_metrics_mode() -> None: assert 1 not in datafile.upper_bounds -def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]: +def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, TableMetadataV1 | TableMetadataV2]: table_metadata = { "format-version": 2, "location": "s3://bucket/test/location", @@ -578,7 +576,7 @@ def test_metrics_primitive_types() -> None: assert not any(key in datafile.upper_bounds.keys() for key in [16, 17, 18]) -def construct_test_table_invalid_upper_bound() -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]: +def construct_test_table_invalid_upper_bound() -> Tuple[pq.FileMetaData, TableMetadataV1 | TableMetadataV2]: table_metadata = { "format-version": 2, "location": "s3://bucket/test/location", diff --git a/tests/table/test_locations.py b/tests/table/test_locations.py index 4efa64326a..3634151de7 100644 --- a/tests/table/test_locations.py +++ b/tests/table/test_locations.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Optional +from typing import Any import pytest @@ -35,7 +35,7 @@ class CustomLocationProvider(LocationProvider): - def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str: + def new_data_location(self, data_file_name: str, partition_key: PartitionKey | None = None) -> str: return f"custom_location_provider/{data_file_name}" @@ -107,7 +107,7 @@ def test_object_storage_with_partition() -> None: # NB: We test here with None partition key too because disabling partitioned paths still replaces final / with - even in # paths of un-partitioned files. This matches the behaviour of the Java implementation. @pytest.mark.parametrize("partition_key", [PARTITION_KEY, None]) -def test_object_storage_partitioned_paths_disabled(partition_key: Optional[PartitionKey]) -> None: +def test_object_storage_partitioned_paths_disabled(partition_key: PartitionKey | None) -> None: provider = load_location_provider( table_location="table_location", table_properties={ diff --git a/tests/test_conversions.py b/tests/test_conversions.py index 2ee0ba3dd9..b366551cf2 100644 --- a/tests/test_conversions.py +++ b/tests/test_conversions.py @@ -85,7 +85,7 @@ timezone, ) from decimal import Decimal -from typing import Any, Union +from typing import Any import pytest @@ -549,7 +549,7 @@ def test_raise_on_incorrect_precision_or_scale(primitive_type: DecimalType, valu (TimeType(), time(12, 30, 45, 500000), b"`\xc8\xeb|\n\x00\x00\x00"), ], ) -def test_datetime_obj_to_bytes(primitive_type: PrimitiveType, value: Union[datetime, date, time], expected_bytes: bytes) -> None: +def test_datetime_obj_to_bytes(primitive_type: PrimitiveType, value: datetime | date | time, expected_bytes: bytes) -> None: bytes_from_value = conversions.to_bytes(primitive_type, value) assert bytes_from_value == expected_bytes diff --git a/tests/test_transforms.py b/tests/test_transforms.py index deaf5d52b6..3d9bfcb555 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -18,7 +18,7 @@ # pylint: disable=eval-used,protected-access,redefined-outer-name from datetime import date from decimal import Decimal -from typing import Annotated, Any, Callable, Optional, Union +from typing import Annotated, Any, Callable from uuid import UUID import mmh3 as mmh3 @@ -1030,7 +1030,7 @@ def test_projection_truncate_string_not_starts_with(bound_reference_str: BoundRe ) == NotStartsWith(term="name", literal=literal("he")) -def _test_projection(lhs: Optional[UnboundPredicate[L]], rhs: Optional[UnboundPredicate[L]]) -> None: +def _test_projection(lhs: UnboundPredicate[L] | None, rhs: UnboundPredicate[L] | None) -> None: assert type(lhs) is type(lhs), f"Different classes: {type(lhs)} != {type(rhs)}" if lhs is None and rhs is None: # Both null @@ -1050,7 +1050,7 @@ def _assert_projection_strict( pred: BooleanExpression, transform: Transform[S, T], expected_type: type[BooleanExpression], - expected_human_str: Optional[str] = None, + expected_human_str: str | None = None, ) -> None: result = transform.strict_project(name="name", pred=pred) @@ -1647,8 +1647,8 @@ def test_ymd_pyarrow_transforms( ) def test_bucket_pyarrow_transforms( source_type: PrimitiveType, - input_arr: Union[pa.Array, pa.ChunkedArray], - expected: Union[pa.Array, pa.ChunkedArray], + input_arr: pa.Array | pa.ChunkedArray, + expected: pa.Array | pa.ChunkedArray, num_buckets: int, ) -> None: transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) @@ -1670,8 +1670,8 @@ def test_bucket_pyarrow_void_transform() -> None: ) def test_truncate_pyarrow_transforms( source_type: PrimitiveType, - input_arr: Union[pa.Array, pa.ChunkedArray], - expected: Union[pa.Array, pa.ChunkedArray], + input_arr: pa.Array | pa.ChunkedArray, + expected: pa.Array | pa.ChunkedArray, width: int, ) -> None: transform: Transform[Any, Any] = TruncateTransform(width=width) diff --git a/tests/utils/test_concurrent.py b/tests/utils/test_concurrent.py index c703f764af..48039e0c24 100644 --- a/tests/utils/test_concurrent.py +++ b/tests/utils/test_concurrent.py @@ -18,14 +18,14 @@ import multiprocessing import os from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from typing import Dict, Generator, Optional +from typing import Dict, Generator from unittest import mock import pytest from pyiceberg.utils.concurrent import ExecutorFactory -EMPTY_ENV: Dict[str, Optional[str]] = {} +EMPTY_ENV: Dict[str, str | None] = {} VALID_ENV = {"PYICEBERG_MAX_WORKERS": "5"} INVALID_ENV = {"PYICEBERG_MAX_WORKERS": "invalid"} diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index d70b0345f8..8953754103 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import os -from typing import Any, Dict, Optional +from typing import Any, Dict from unittest import mock import pytest @@ -149,13 +149,13 @@ def test_config_lookup_order( monkeypatch: pytest.MonkeyPatch, tmp_path_factory: pytest.TempPathFactory, config_setup: Dict[str, Any], - expected_result: Optional[str], + expected_result: str | None, ) -> None: """ Test that the configuration lookup prioritizes PYICEBERG_HOME, then home (~), then cwd. """ - def create_config_file(path: str, uri: Optional[str]) -> None: + def create_config_file(path: str, uri: str | None) -> None: if uri: config_file_path = os.path.join(path, ".pyiceberg.yaml") content = {"catalog": {"default": {"uri": uri}}} diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index a5d5a6fefb..51cbf06bc0 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=redefined-outer-name,arguments-renamed,fixme from tempfile import TemporaryDirectory -from typing import Dict, Optional +from typing import Dict from unittest.mock import patch import fastavro @@ -532,7 +532,7 @@ def test_write_manifest_list( generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion, - parent_snapshot_id: Optional[int], + parent_snapshot_id: int | None, compression: AvroCompressionCodec, ) -> None: io = load_file_io()