diff --git a/flink-python/pyflink/datastream/connectors/kafka.py b/flink-python/pyflink/datastream/connectors/kafka.py index 4a3fbacc0..23e14d6cb 100644 --- a/flink-python/pyflink/datastream/connectors/kafka.py +++ b/flink-python/pyflink/datastream/connectors/kafka.py @@ -29,6 +29,8 @@ from pyflink.util.java_utils import to_jarray, get_field, get_field_value __all__ = [ + 'DynamicKafkaSource', + 'DynamicKafkaSourceBuilder', 'KafkaSource', 'KafkaSourceBuilder', 'KafkaSink', @@ -38,9 +40,151 @@ 'KafkaOffsetResetStrategy', 'KafkaRecordSerializationSchema', 'KafkaRecordSerializationSchemaBuilder', - 'KafkaTopicSelector' + 'KafkaTopicSelector', + 'KafkaStream', + 'ClusterMetadata', + 'KafkaMetadataService' ] +# ---- DynamicKafkaSource ---- + + +class DynamicKafkaSource(Source): + """Python wrapper of the Java DynamicKafkaSource. + + A DynamicKafkaSource enables consuming records from dynamically discovered Kafka streams + (topics that may span multiple clusters) without restarting the Flink job. + + Use :py:meth:`builder` to construct an instance, for example:: + + >>> source = DynamicKafkaSource.builder() \\ + ... .set_stream_ids({"stream-a", "stream-b"}) \\ + ... .set_kafka_metadata_service(metadata_service) \\ + ... .set_value_only_deserializer(SimpleStringSchema()) \\ + ... .set_property("group.id", "my_group") \\ + ... .build() + + The builder methods closely mirror their Java counterparts defined on + ``org.apache.flink.connector.kafka.dynamic.source.DynamicKafkaSourceBuilder``. + """ + + def __init__(self, j_dynamic_kafka_source: JavaObject): + super().__init__(j_dynamic_kafka_source) + + @staticmethod + def builder() -> 'DynamicKafkaSourceBuilder': + """Create and return a new :class:`DynamicKafkaSourceBuilder`.""" + return DynamicKafkaSourceBuilder() + + +class KafkaMetadataService(object): + pass + + +def single_cluster_topic_metadata_service(kafka_cluster_id: str, properties: Dict[str, str]) -> JavaObject: + gateway = get_gateway() + return gateway.jvm.org.apache.flink.connector.kafka.dynamic.metadata.SingleClusterTopicMetadataService(kafka_cluster_id, properties) + +class DynamicKafkaSourceBuilder(object): + """Builder for :class:`DynamicKafkaSource`. + + The builder is a thin Python wrapper delegating to the underlying Java builder defined in + ``org.apache.flink.connector.kafka.dynamic.source.DynamicKafkaSourceBuilder``. + """ + + def __init__(self): + self._gateway = get_gateway() + self._j_builder = ( + self._gateway.jvm.org.apache.flink.connector.kafka.dynamic.source.DynamicKafkaSource + .builder() + ) + + def build(self) -> 'DynamicKafkaSource': + """Finalize the configuration and return a :class:`DynamicKafkaSource`.""" + return DynamicKafkaSource(self._j_builder.build()) + + def set_stream_ids(self, stream_ids: Set[str]) -> 'DynamicKafkaSourceBuilder': + """Subscribe to a fixed set of stream IDs. + + :param stream_ids: A Python ``set`` of stream IDs. + """ + j_set = self._gateway.jvm.java.util.HashSet() + for stream_id in stream_ids: + j_set.add(stream_id) + self._j_builder.setStreamIds(j_set) + return self + + def set_stream_pattern(self, stream_pattern: str) -> 'DynamicKafkaSourceBuilder': + """Subscribe to streams whose IDs match the given Java regex pattern.""" + j_pattern = self._gateway.jvm.java.util.regex.Pattern.compile(stream_pattern) + self._j_builder.setStreamPattern(j_pattern) + return self + + def set_kafka_stream_subscriber(self, kafka_stream_subscriber: JavaObject) -> \ + 'DynamicKafkaSourceBuilder': + """Use a custom ``KafkaStreamSubscriber`` implementation.""" + self._j_builder.setKafkaStreamSubscriber(kafka_stream_subscriber) + return self + + def set_kafka_metadata_service(self, kafka_metadata_service: JavaObject | KafkaMetadataService) -> \ + 'DynamicKafkaSourceBuilder': + """Specify the :class:`KafkaMetadataService` that resolves stream IDs to clusters.""" + if isinstance(kafka_metadata_service, KafkaMetadataService): + kafka_metadata_service = kafka_metadata_service._j_kafka_metadata_service + return self + else: + self._j_builder.setKafkaMetadataService(kafka_metadata_service) + return self + + def set_starting_offsets(self, starting_offsets_initializer: 'KafkaOffsetsInitializer') -> \ + 'DynamicKafkaSourceBuilder': + self._j_builder.setStartingOffsets(starting_offsets_initializer._j_initializer) + return self + + def set_bounded(self, stopping_offsets_initializer: 'KafkaOffsetsInitializer') -> \ + 'DynamicKafkaSourceBuilder': + self._j_builder.setBounded(stopping_offsets_initializer._j_initializer) + return self + + def set_deserializer(self, kafka_record_deserializer: JavaObject) -> \ + 'DynamicKafkaSourceBuilder': + """Set a custom Java ``KafkaRecordDeserializationSchema`` instance.""" + self._j_builder.setDeserializer(kafka_record_deserializer) + return self + + def set_value_only_deserializer(self, deserialization_schema: DeserializationSchema) -> \ + 'DynamicKafkaSourceBuilder': + """Convenience method to deserialize the *value* of each Kafka record using the provided + :class:`~pyflink.common.serialization.DeserializationSchema`. Other fields (key, headers, + etc.) are ignored. + """ + j_schema = deserialization_schema._j_deserialization_schema + j_value_only_wrapper = ( + self._gateway.jvm.org.apache.flink.connector.kafka.source.reader.deserializer + .KafkaRecordDeserializationSchema.valueOnly(j_schema) + ) + self._j_builder.setDeserializer(j_value_only_wrapper) + return self + + def set_properties(self, props: Dict[str, str]) -> 'DynamicKafkaSourceBuilder': + j_props = self._gateway.jvm.java.util.Properties() + for k, v in props.items(): + j_props.setProperty(k, v) + self._j_builder.setProperties(j_props) + return self + + def set_property(self, key: str, value: str) -> 'DynamicKafkaSourceBuilder': + self._j_builder.setProperty(key, value) + return self + + def set_group_id(self, group_id: str) -> 'DynamicKafkaSourceBuilder': + self._j_builder.setGroupId(group_id) + return self + + def set_client_id_prefix(self, prefix: str) -> 'DynamicKafkaSourceBuilder': + self._j_builder.setClientIdPrefix(prefix) + return self + # ---- KafkaSource ---- @@ -390,6 +534,176 @@ def __hash__(self): return 31 * (31 + self._partition) + hash(self._topic) +class ClusterMetadata(object): + """ + Corresponding to Java `org.apache.flink.connector.kafka.dynamic.metadata.ClusterMetadata` class. + """ + + def __init__(self, topics: Set[str], properties: Dict[str, str]): + self._topics = topics + self._properties = properties + + def _to_j_cluster_metadata(self): + gateway = get_gateway() + j_topics = gateway.jvm.java.util.HashSet() + for t in self._topics: + j_topics.add(t) + j_props = gateway.jvm.java.util.Properties() + for k, v in self._properties.items(): + j_props.setProperty(k, v) + return gateway.jvm.org.apache.flink.connector.kafka.dynamic.metadata.ClusterMetadata( + j_topics, j_props) + + @staticmethod + def _from_j_cluster_metadata(j_cluster_metadata: JavaObject): + # Extract topics + j_topics = j_cluster_metadata.getTopics() + topics = set() + for t in j_topics: + topics.add(t) + # Extract properties + j_props = j_cluster_metadata.getProperties() + props = {} + for name in j_props.stringPropertyNames(): + props[name] = j_props.getProperty(name) + return ClusterMetadata(topics, props) + + def __eq__(self, other): + if not isinstance(other, ClusterMetadata): + return False + return self._topics == other._topics and self._properties == other._properties + + def __hash__(self): + return hash(frozenset(self._topics)) ^ hash(frozenset(self._properties.items())) + + +class KafkaStream(object): + """ + Corresponding to Java `org.apache.flink.connector.kafka.dynamic.metadata.KafkaStream` class. + """ + + def __init__(self, stream_id: str, cluster_metadata_map: Dict[str, ClusterMetadata]): + self._stream_id = stream_id + self._cluster_metadata_map = cluster_metadata_map + + def _to_j_kafka_stream(self): + gateway = get_gateway() + j_map = gateway.jvm.java.util.HashMap() + for cluster_id, cm in self._cluster_metadata_map.items(): + j_map.put(cluster_id, cm._to_j_cluster_metadata()) + return gateway.jvm.org.apache.flink.connector.kafka.dynamic.metadata.KafkaStream( + self._stream_id, j_map) + + @staticmethod + def _from_j_kafka_stream(j_kafka_stream: JavaObject): + stream_id = j_kafka_stream.getStreamId() + j_map = j_kafka_stream.getClusterMetadataMap() + cluster_map = {} + for entry in j_map.entrySet(): + cid = entry.getKey() + cm = ClusterMetadata._from_j_cluster_metadata(entry.getValue()) + cluster_map[cid] = cm + return KafkaStream(stream_id, cluster_map) + + def __eq__(self, other): + if not isinstance(other, KafkaStream): + return False + return (self._stream_id == other._stream_id and + self._cluster_metadata_map == other._cluster_metadata_map) + + def __hash__(self): + return hash(self._stream_id) ^ hash(frozenset(self._cluster_metadata_map.items())) + + +class KafkaMetadataService(object): + """ + Python wrapper for Java `org.apache.flink.connector.kafka.dynamic.metadata.KafkaMetadataService` interface. + """ + + def __init__(self, j_service: JavaObject): + self._j_service = j_service + + @staticmethod + def single_cluster_service(kafka_cluster_id: str, + properties: Dict[str, str]) -> 'KafkaMetadataService': + """ + Create a KafkaMetadataService scoped to a single cluster. + + :param kafka_cluster_id: the id of the Kafka cluster. + :param properties: the properties to connect to the cluster. + """ + gateway = get_gateway() + j_props = gateway.jvm.java.util.Properties() + for k, v in properties.items(): + j_props.setProperty(k, v) + j_service = gateway.jvm.org.apache.flink.connector.kafka.dynamic.metadata.\ + SingleClusterTopicMetadataService(kafka_cluster_id, j_props) + return KafkaMetadataService(j_service) + + @staticmethod + def yaml_file_service(metadata_file_path: str, + refresh_interval: Any) -> 'KafkaMetadataService': + """ + Create a KafkaMetadataService backed by a YAML metadata file, with periodic refresh. + + :param metadata_file_path: path to the metadata YAML file. + :param refresh_interval: java.time.Duration, or Python timedelta/int/float (seconds). + """ + from datetime import timedelta + + gateway = get_gateway() + # Convert Python types to Java Duration if necessary + if isinstance(refresh_interval, (int, float)): + j_duration = gateway.jvm.java.time.Duration.ofMillis(int(refresh_interval * 1000)) + elif isinstance(refresh_interval, timedelta): + j_duration = gateway.jvm.java.time.Duration.ofMillis( + int(refresh_interval.total_seconds() * 1000)) + else: + j_duration = refresh_interval + j_service = gateway.jvm.org.apache.flink.connector.kafka.testutils.\ + YamlFileMetadataService(metadata_file_path, j_duration) + return KafkaMetadataService(j_service) + + def get_all_streams(self) -> Set[KafkaStream]: + """ + Get current metadata for all streams. + """ + result = set() + for j_stream in self._j_service.getAllStreams(): + result.add(KafkaStream._from_j_kafka_stream(j_stream)) + return result + + def describe_streams(self, stream_ids: Set[str]) -> Dict[str, KafkaStream]: + """ + Get current metadata for queried streams. + + :param stream_ids: a set of stream full names. + """ + gateway = get_gateway() + j_set = gateway.jvm.java.util.HashSet() + for sid in stream_ids: + j_set.add(sid) + j_map = self._j_service.describeStreams(j_set) + result = {} + for entry in j_map.entrySet(): + key = entry.getKey() + val = KafkaStream._from_j_kafka_stream(entry.getValue()) + result[key] = val + return result + + def is_cluster_active(self, kafka_cluster_id: str) -> bool: + """ + Check if the given cluster is active. + """ + return self._j_service.isClusterActive(kafka_cluster_id) + + def close(self): + """ + Close the metadata service. + """ + self._j_service.close() + + class KafkaOffsetResetStrategy(Enum): """ Corresponding to Java ``org.apache.kafka.client.consumer.OffsetResetStrategy`` class. diff --git a/flink-python/pyflink/datastream/connectors/tests/test_kafka.py b/flink-python/pyflink/datastream/connectors/tests/test_kafka.py index c34a9ffc5..b054ac28f 100644 --- a/flink-python/pyflink/datastream/connectors/tests/test_kafka.py +++ b/flink-python/pyflink/datastream/connectors/tests/test_kafka.py @@ -28,6 +28,8 @@ from pyflink.datastream.connectors.base import DeliveryGuarantee from pyflink.datastream.connectors.kafka import KafkaSource, KafkaTopicPartition, \ KafkaOffsetsInitializer, KafkaOffsetResetStrategy, KafkaRecordSerializationSchema, KafkaSink +# Import newly added DynamicKafkaSource +from pyflink.datastream.connectors.kafka import DynamicKafkaSource from pyflink.datastream.formats.avro import AvroRowDeserializationSchema, AvroRowSerializationSchema from pyflink.datastream.formats.csv import CsvRowDeserializationSchema, CsvRowSerializationSchema from pyflink.datastream.formats.json import JsonRowDeserializationSchema, JsonRowSerializationSchema @@ -380,6 +382,116 @@ def _get_kafka_source_configuration(source: KafkaSource): return Configuration(j_configuration=j_configuration) +class DynamicKafkaSourceTests(PyFlinkStreamingTestCase): + + def _create_metadata_service(self): + """Create a simple metadata service pointing to a dummy cluster.""" + jvm = get_gateway().jvm + props = jvm.java.util.Properties() + # Minimal bootstrap servers property to satisfy Kafka client validation. + props.setProperty("bootstrap.servers", "localhost:9092") + return jvm.org.apache.flink.connector.kafka.dynamic.metadata.SingleClusterTopicMetadataService( + "dummy-cluster", props) + + def _build_base_source(self): + return DynamicKafkaSource.builder() \ + .set_stream_ids({"stream_test"}) \ + .set_kafka_metadata_service(self._create_metadata_service()) \ + .set_value_only_deserializer(SimpleStringSchema()) \ + .set_group_id("test_group") \ + .build() + + def test_compiling(self): + source = self._build_base_source() + ds = self.env.from_source( + source=source, + watermark_strategy=WatermarkStrategy.for_monotonous_timestamps(), + source_name="dynamic kafka source") + ds.print() + plan = json.loads(self.env.get_execution_plan()) + self.assertEqual('Source: dynamic kafka source', plan['nodes'][0]['type']) + + def test_set_properties(self): + source = DynamicKafkaSource.builder() \ + .set_stream_ids({"stream_test"}) \ + .set_kafka_metadata_service(self._create_metadata_service()) \ + .set_group_id("test_group_id") \ + .set_client_id_prefix("test_client_id_prefix") \ + .set_property("test_property", "test_value") \ + .set_value_only_deserializer(SimpleStringSchema()) \ + .build() + + # Extract the internal properties field for verification. + props = get_field_value(source.get_java_function(), 'properties') + self.assertEqual(props.getProperty('group.id'), 'test_group_id') + self.assertEqual(props.getProperty('client.id.prefix'), 'test_client_id_prefix') + self.assertEqual(props.getProperty('test_property'), 'test_value') + + def test_set_stream_ids(self): + stream_ids = {"stream_a", "stream_b"} + source = DynamicKafkaSource.builder() \ + .set_stream_ids(stream_ids) \ + .set_kafka_metadata_service(self._create_metadata_service()) \ + .set_value_only_deserializer(SimpleStringSchema()) \ + .build() + + subscriber = get_field_value(source.get_java_function(), 'kafkaStreamSubscriber') + self.assertEqual( + subscriber.getClass().getCanonicalName(), + 'org.apache.flink.connector.kafka.dynamic.source.enumerator.subscriber.KafkaStreamSetSubscriber' + ) + + subscribed_ids = get_field_value(subscriber, 'streamIds') + self.assertTrue(is_instance_of(subscribed_ids, get_gateway().jvm.java.util.Set)) + self.assertEqual(subscribed_ids.size(), len(stream_ids)) + for s in stream_ids: + self.assertTrue(subscribed_ids.contains(s)) + + def test_set_stream_pattern(self): + pattern = 'stream_*' + source = DynamicKafkaSource.builder() \ + .set_stream_pattern(pattern) \ + .set_kafka_metadata_service(self._create_metadata_service()) \ + .set_value_only_deserializer(SimpleStringSchema()) \ + .build() + + subscriber = get_field_value(source.get_java_function(), 'kafkaStreamSubscriber') + self.assertEqual( + subscriber.getClass().getCanonicalName(), + 'org.apache.flink.connector.kafka.dynamic.source.enumerator.subscriber.StreamPatternSubscriber' + ) + j_pattern = get_field_value(subscriber, 'streamPattern') + self.assertTrue(is_instance_of(j_pattern, get_gateway().jvm.java.util.regex.Pattern)) + self.assertEqual(j_pattern.toString(), pattern) + + def test_bounded(self): + source = DynamicKafkaSource.builder() \ + .set_stream_ids({"stream_test"}) \ + .set_kafka_metadata_service(self._create_metadata_service()) \ + .set_value_only_deserializer(SimpleStringSchema()) \ + .set_bounded(KafkaOffsetsInitializer.latest()) \ + .build() + + self.assertEqual( + get_field_value(source.get_java_function(), 'boundedness').toString(), 'BOUNDED' + ) + + def test_starting_offsets(self): + source = DynamicKafkaSource.builder() \ + .set_stream_ids({"stream_test"}) \ + .set_kafka_metadata_service(self._create_metadata_service()) \ + .set_value_only_deserializer(SimpleStringSchema()) \ + .set_starting_offsets(KafkaOffsetsInitializer.latest()) \ + .build() + + initializer = get_field_value(source.get_java_function(), 'startingOffsetsInitializer') + self.assertEqual( + initializer.getClass().getCanonicalName(), + 'org.apache.flink.connector.kafka.source.enumerator.initializer.LatestOffsetsInitializer' + ) + + + class KafkaSinkTests(PyFlinkStreamingTestCase): def test_compile(self):