Skip to content

Commit 83916b3

Browse files
authored
Merge pull request #2 from geekwhocodes/dev
tidy up
2 parents 007e15a + 9c46aec commit 83916b3

File tree

5 files changed

+18
-64
lines changed

5 files changed

+18
-64
lines changed

.vscode/settings.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,10 @@
44
"editor.formatOnSave": true,
55
"editor.codeActionsOnSave": {
66
"source.organizeImports": "always"
7-
}
7+
},
8+
"python.testing.pytestArgs": [
9+
"tests"
10+
],
11+
"python.testing.unittestEnabled": false,
12+
"python.testing.pytestEnabled": true
813
}

src/source_msgraph/async_interator.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,63 +3,10 @@
33
import asyncio
44
from typing import AsyncGenerator, Iterator, Any
55

6-
class AsyncToSyncIterator:
7-
"""
8-
Converts an async generator into a synchronous iterator while ensuring proper event loop handling.
9-
"""
10-
11-
def __init__(self, async_gen: AsyncGenerator[Any, None]):
12-
"""
13-
Initializes the iterator by consuming an async generator synchronously.
14-
15-
Args:
16-
async_gen (AsyncGenerator): The async generator yielding results.
17-
"""
18-
self.async_gen = async_gen
19-
self.loop = self._get_event_loop()
20-
self.iterator = self._to_iterator()
21-
22-
def _get_event_loop(self) -> asyncio.AbstractEventLoop:
23-
"""Returns the currently running event loop or creates a new one if none exists."""
24-
try:
25-
loop = asyncio.get_running_loop()
26-
if loop.is_running():
27-
return None # Indicate an already running loop (handled in `_to_iterator()`)
28-
except RuntimeError:
29-
loop = asyncio.new_event_loop()
30-
asyncio.set_event_loop(loop)
31-
return loop
32-
33-
def _to_iterator(self) -> Iterator:
34-
"""
35-
Ensures that the async generator is consumed using the correct event loop.
36-
Uses streaming (does not load all results into memory).
37-
"""
38-
if self.loop:
39-
return iter(self.loop.run_until_complete(self._stream_results()))
40-
else:
41-
return iter(asyncio.run(self._stream_results())) # Safe for Jupyter, PySpark
42-
43-
# Caution : prone to OOM errors
44-
async def _stream_results(self):
45-
# """Streams async generator results without collecting all in memory."""
46-
# page_count = 0
47-
# async for item in self.async_gen:
48-
# if page_count >= self.max_pages:
49-
# raise RuntimeError("Pagination limit reached, possible infinite loop detected!")
50-
# yield item
51-
# page_count += 1 # Track pages to prevent infinite loops
52-
return [item async for item in self.async_gen]
53-
54-
def __iter__(self) -> Iterator:
55-
"""Returns the synchronous iterator."""
56-
return self.iterator
57-
58-
596
import asyncio
607
from typing import AsyncGenerator, Iterator, Any
618

62-
class AsyncToSyncIteratorV2:
9+
class AsyncToSyncIterator:
6310
"""
6411
Converts an async generator into a synchronous iterator while ensuring proper event loop handling.
6512
"""

src/source_msgraph/client.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from kiota_abstractions.base_request_configuration import RequestConfiguration
33
from msgraph.generated.models.o_data_errors.o_data_error import ODataError
44
from azure.identity import ClientSecretCredential
5-
from source_msgraph.async_interator import AsyncToSyncIterator, AsyncToSyncIteratorV2
5+
from source_msgraph.async_interator import AsyncToSyncIterator
66
from source_msgraph.models import ConnectorOptions
77
from source_msgraph.utils import get_python_schema, to_json, to_pyspark_schema
8+
from typing import Dict, Any
89

910
class GraphClient:
1011
def __init__(self, options: ConnectorOptions):
@@ -48,7 +49,6 @@ async def fetch_data(self):
4849
builder = self.options.resource.get_request_builder_cls()(self.graph_client.request_adapter, self.options.resource.resource_params)
4950
items = await builder.get(request_configuration=request_configuration)
5051
while True:
51-
print("Page fetched....")
5252
for item in items.value:
5353
yield item
5454
if not items.odata_next_link:
@@ -72,9 +72,7 @@ def iter_records(options: ConnectorOptions):
7272
async_gen = fetcher.fetch_data()
7373
return AsyncToSyncIterator(async_gen)
7474

75-
import json
76-
from typing import Dict, Any
77-
from dataclasses import asdict
75+
7876

7977
def get_resource_schema(options: ConnectorOptions) -> Dict[str, Any]:
8078
"""
@@ -89,7 +87,7 @@ def get_resource_schema(options: ConnectorOptions) -> Dict[str, Any]:
8987
async_gen = fetcher.fetch_data()
9088

9189
try:
92-
record = next(AsyncToSyncIteratorV2(async_gen), None)
90+
record = next(AsyncToSyncIterator(async_gen), None)
9391
if not record:
9492
raise ValueError(f"No records found for resource: {options.resource.resource_name}")
9593
record = to_json(record)

src/source_msgraph/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
from typing import Any, Dict
66
from source_msgraph.constants import MSGRAPH_SDK_PACKAGE
7-
from urllib.parse import unquote, quote
7+
from urllib.parse import unquote
88
from kiota_abstractions.base_request_builder import BaseRequestBuilder
99

1010
@dataclass
@@ -133,8 +133,9 @@ def map_options_to_params(self, options: Dict[str, Any]) -> 'BaseResource':
133133
if missing_params:
134134
raise ValueError(f"Missing required resource parameters: {', '.join(missing_params)}")
135135

136+
# TODO: add max $top value validation.
137+
136138
mapped_query_params = {"%24"+k: v for k, v in options.items() if k in self.query_params}
137-
138139
mapped_resource_params = {k.replace("-", "%2D"): v for k, v in options.items() if k in self.resource_params}
139140

140141
invalid_params = {k: v for k, v in options.items() if k not in self.query_params and k not in self.resource_params}

src/source_msgraph/source.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Any, Dict, Union
23
from pyspark.sql.datasource import DataSource, DataSourceReader
34
from pyspark.sql.types import StructType
@@ -7,6 +8,7 @@
78
from source_msgraph.resources import get_resource
89
# Reference https://learn.microsoft.com/en-us/azure/databricks/pyspark/datasources
910

11+
logger = logging.getLogger(__name__)
1012

1113
class MSGraphDataSource(DataSource):
1214
"""
@@ -37,8 +39,9 @@ def name(cls):
3739
return "msgraph"
3840

3941
def schema(self):
40-
print("getting aschema")
42+
logger.info("Schema not provided, infering from the source.")
4143
_, schema = get_resource_schema(self.connector_options)
44+
logger.debug(f"Infered schema : {schema}")
4245
return schema
4346

4447
def reader(self, schema: StructType):

0 commit comments

Comments
 (0)