Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 315 additions & 1 deletion flink-python/pyflink/datastream/connectors/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from pyflink.util.java_utils import to_jarray, get_field, get_field_value

__all__ = [
'DynamicKafkaSource',
'DynamicKafkaSourceBuilder',
'KafkaSource',
'KafkaSourceBuilder',
'KafkaSink',
Expand All @@ -38,9 +40,151 @@
'KafkaOffsetResetStrategy',
'KafkaRecordSerializationSchema',
'KafkaRecordSerializationSchemaBuilder',
'KafkaTopicSelector'
'KafkaTopicSelector',
'KafkaStream',
'ClusterMetadata',
'KafkaMetadataService'
]

# ---- DynamicKafkaSource ----


class DynamicKafkaSource(Source):
"""Python wrapper of the Java DynamicKafkaSource.

A DynamicKafkaSource enables consuming records from dynamically discovered Kafka streams
(topics that may span multiple clusters) without restarting the Flink job.

Use :py:meth:`builder` to construct an instance, for example::

>>> source = DynamicKafkaSource.builder() \\
... .set_stream_ids({"stream-a", "stream-b"}) \\
... .set_kafka_metadata_service(metadata_service) \\
... .set_value_only_deserializer(SimpleStringSchema()) \\
... .set_property("group.id", "my_group") \\
... .build()

The builder methods closely mirror their Java counterparts defined on
``org.apache.flink.connector.kafka.dynamic.source.DynamicKafkaSourceBuilder``.
"""

def __init__(self, j_dynamic_kafka_source: JavaObject):
super().__init__(j_dynamic_kafka_source)

@staticmethod
def builder() -> 'DynamicKafkaSourceBuilder':
"""Create and return a new :class:`DynamicKafkaSourceBuilder`."""
return DynamicKafkaSourceBuilder()


class KafkaMetadataService(object):
pass


def single_cluster_topic_metadata_service(kafka_cluster_id: str, properties: Dict[str, str]) -> JavaObject:
gateway = get_gateway()
return gateway.jvm.org.apache.flink.connector.kafka.dynamic.metadata.SingleClusterTopicMetadataService(kafka_cluster_id, properties)

class DynamicKafkaSourceBuilder(object):
"""Builder for :class:`DynamicKafkaSource`.

The builder is a thin Python wrapper delegating to the underlying Java builder defined in
``org.apache.flink.connector.kafka.dynamic.source.DynamicKafkaSourceBuilder``.
"""

def __init__(self):
self._gateway = get_gateway()
self._j_builder = (
self._gateway.jvm.org.apache.flink.connector.kafka.dynamic.source.DynamicKafkaSource
.builder()
)

def build(self) -> 'DynamicKafkaSource':
"""Finalize the configuration and return a :class:`DynamicKafkaSource`."""
return DynamicKafkaSource(self._j_builder.build())

def set_stream_ids(self, stream_ids: Set[str]) -> 'DynamicKafkaSourceBuilder':
"""Subscribe to a fixed set of stream IDs.

:param stream_ids: A Python ``set`` of stream IDs.
"""
j_set = self._gateway.jvm.java.util.HashSet()
for stream_id in stream_ids:
j_set.add(stream_id)
self._j_builder.setStreamIds(j_set)
return self

def set_stream_pattern(self, stream_pattern: str) -> 'DynamicKafkaSourceBuilder':
"""Subscribe to streams whose IDs match the given Java regex pattern."""
j_pattern = self._gateway.jvm.java.util.regex.Pattern.compile(stream_pattern)
self._j_builder.setStreamPattern(j_pattern)
return self

def set_kafka_stream_subscriber(self, kafka_stream_subscriber: JavaObject) -> \
'DynamicKafkaSourceBuilder':
"""Use a custom ``KafkaStreamSubscriber`` implementation."""
self._j_builder.setKafkaStreamSubscriber(kafka_stream_subscriber)
return self

def set_kafka_metadata_service(self, kafka_metadata_service: JavaObject | KafkaMetadataService) -> \
'DynamicKafkaSourceBuilder':
"""Specify the :class:`KafkaMetadataService` that resolves stream IDs to clusters."""
if isinstance(kafka_metadata_service, KafkaMetadataService):
kafka_metadata_service = kafka_metadata_service._j_kafka_metadata_service
return self
else:
self._j_builder.setKafkaMetadataService(kafka_metadata_service)
return self

def set_starting_offsets(self, starting_offsets_initializer: 'KafkaOffsetsInitializer') -> \
'DynamicKafkaSourceBuilder':
self._j_builder.setStartingOffsets(starting_offsets_initializer._j_initializer)
return self

def set_bounded(self, stopping_offsets_initializer: 'KafkaOffsetsInitializer') -> \
'DynamicKafkaSourceBuilder':
self._j_builder.setBounded(stopping_offsets_initializer._j_initializer)
return self

def set_deserializer(self, kafka_record_deserializer: JavaObject) -> \
'DynamicKafkaSourceBuilder':
"""Set a custom Java ``KafkaRecordDeserializationSchema`` instance."""
self._j_builder.setDeserializer(kafka_record_deserializer)
return self

def set_value_only_deserializer(self, deserialization_schema: DeserializationSchema) -> \
'DynamicKafkaSourceBuilder':
"""Convenience method to deserialize the *value* of each Kafka record using the provided
:class:`~pyflink.common.serialization.DeserializationSchema`. Other fields (key, headers,
etc.) are ignored.
"""
j_schema = deserialization_schema._j_deserialization_schema
j_value_only_wrapper = (
self._gateway.jvm.org.apache.flink.connector.kafka.source.reader.deserializer
.KafkaRecordDeserializationSchema.valueOnly(j_schema)
)
self._j_builder.setDeserializer(j_value_only_wrapper)
return self

def set_properties(self, props: Dict[str, str]) -> 'DynamicKafkaSourceBuilder':
j_props = self._gateway.jvm.java.util.Properties()
for k, v in props.items():
j_props.setProperty(k, v)
self._j_builder.setProperties(j_props)
return self

def set_property(self, key: str, value: str) -> 'DynamicKafkaSourceBuilder':
self._j_builder.setProperty(key, value)
return self

def set_group_id(self, group_id: str) -> 'DynamicKafkaSourceBuilder':
self._j_builder.setGroupId(group_id)
return self

def set_client_id_prefix(self, prefix: str) -> 'DynamicKafkaSourceBuilder':
self._j_builder.setClientIdPrefix(prefix)
return self


# ---- KafkaSource ----

Expand Down Expand Up @@ -390,6 +534,176 @@ def __hash__(self):
return 31 * (31 + self._partition) + hash(self._topic)


class ClusterMetadata(object):
"""
Corresponding to Java `org.apache.flink.connector.kafka.dynamic.metadata.ClusterMetadata` class.
"""

def __init__(self, topics: Set[str], properties: Dict[str, str]):
self._topics = topics
self._properties = properties

def _to_j_cluster_metadata(self):
gateway = get_gateway()
j_topics = gateway.jvm.java.util.HashSet()
for t in self._topics:
j_topics.add(t)
j_props = gateway.jvm.java.util.Properties()
for k, v in self._properties.items():
j_props.setProperty(k, v)
return gateway.jvm.org.apache.flink.connector.kafka.dynamic.metadata.ClusterMetadata(
j_topics, j_props)

@staticmethod
def _from_j_cluster_metadata(j_cluster_metadata: JavaObject):
# Extract topics
j_topics = j_cluster_metadata.getTopics()
topics = set()
for t in j_topics:
topics.add(t)
# Extract properties
j_props = j_cluster_metadata.getProperties()
props = {}
for name in j_props.stringPropertyNames():
props[name] = j_props.getProperty(name)
return ClusterMetadata(topics, props)

def __eq__(self, other):
if not isinstance(other, ClusterMetadata):
return False
return self._topics == other._topics and self._properties == other._properties

def __hash__(self):
return hash(frozenset(self._topics)) ^ hash(frozenset(self._properties.items()))


class KafkaStream(object):
"""
Corresponding to Java `org.apache.flink.connector.kafka.dynamic.metadata.KafkaStream` class.
"""

def __init__(self, stream_id: str, cluster_metadata_map: Dict[str, ClusterMetadata]):
self._stream_id = stream_id
self._cluster_metadata_map = cluster_metadata_map

def _to_j_kafka_stream(self):
gateway = get_gateway()
j_map = gateway.jvm.java.util.HashMap()
for cluster_id, cm in self._cluster_metadata_map.items():
j_map.put(cluster_id, cm._to_j_cluster_metadata())
return gateway.jvm.org.apache.flink.connector.kafka.dynamic.metadata.KafkaStream(
self._stream_id, j_map)

@staticmethod
def _from_j_kafka_stream(j_kafka_stream: JavaObject):
stream_id = j_kafka_stream.getStreamId()
j_map = j_kafka_stream.getClusterMetadataMap()
cluster_map = {}
for entry in j_map.entrySet():
cid = entry.getKey()
cm = ClusterMetadata._from_j_cluster_metadata(entry.getValue())
cluster_map[cid] = cm
return KafkaStream(stream_id, cluster_map)

def __eq__(self, other):
if not isinstance(other, KafkaStream):
return False
return (self._stream_id == other._stream_id and
self._cluster_metadata_map == other._cluster_metadata_map)

def __hash__(self):
return hash(self._stream_id) ^ hash(frozenset(self._cluster_metadata_map.items()))


class KafkaMetadataService(object):
"""
Python wrapper for Java `org.apache.flink.connector.kafka.dynamic.metadata.KafkaMetadataService` interface.
"""

def __init__(self, j_service: JavaObject):
self._j_service = j_service

@staticmethod
def single_cluster_service(kafka_cluster_id: str,
properties: Dict[str, str]) -> 'KafkaMetadataService':
"""
Create a KafkaMetadataService scoped to a single cluster.

:param kafka_cluster_id: the id of the Kafka cluster.
:param properties: the properties to connect to the cluster.
"""
gateway = get_gateway()
j_props = gateway.jvm.java.util.Properties()
for k, v in properties.items():
j_props.setProperty(k, v)
j_service = gateway.jvm.org.apache.flink.connector.kafka.dynamic.metadata.\
SingleClusterTopicMetadataService(kafka_cluster_id, j_props)
return KafkaMetadataService(j_service)

@staticmethod
def yaml_file_service(metadata_file_path: str,
refresh_interval: Any) -> 'KafkaMetadataService':
"""
Create a KafkaMetadataService backed by a YAML metadata file, with periodic refresh.

:param metadata_file_path: path to the metadata YAML file.
:param refresh_interval: java.time.Duration, or Python timedelta/int/float (seconds).
"""
from datetime import timedelta

gateway = get_gateway()
# Convert Python types to Java Duration if necessary
if isinstance(refresh_interval, (int, float)):
j_duration = gateway.jvm.java.time.Duration.ofMillis(int(refresh_interval * 1000))
elif isinstance(refresh_interval, timedelta):
j_duration = gateway.jvm.java.time.Duration.ofMillis(
int(refresh_interval.total_seconds() * 1000))
else:
j_duration = refresh_interval
j_service = gateway.jvm.org.apache.flink.connector.kafka.testutils.\
YamlFileMetadataService(metadata_file_path, j_duration)
return KafkaMetadataService(j_service)

def get_all_streams(self) -> Set[KafkaStream]:
"""
Get current metadata for all streams.
"""
result = set()
for j_stream in self._j_service.getAllStreams():
result.add(KafkaStream._from_j_kafka_stream(j_stream))
return result

def describe_streams(self, stream_ids: Set[str]) -> Dict[str, KafkaStream]:
"""
Get current metadata for queried streams.

:param stream_ids: a set of stream full names.
"""
gateway = get_gateway()
j_set = gateway.jvm.java.util.HashSet()
for sid in stream_ids:
j_set.add(sid)
j_map = self._j_service.describeStreams(j_set)
result = {}
for entry in j_map.entrySet():
key = entry.getKey()
val = KafkaStream._from_j_kafka_stream(entry.getValue())
result[key] = val
return result

def is_cluster_active(self, kafka_cluster_id: str) -> bool:
"""
Check if the given cluster is active.
"""
return self._j_service.isClusterActive(kafka_cluster_id)

def close(self):
"""
Close the metadata service.
"""
self._j_service.close()


class KafkaOffsetResetStrategy(Enum):
"""
Corresponding to Java ``org.apache.kafka.client.consumer.OffsetResetStrategy`` class.
Expand Down
Loading