Skip to content

Commit 29a7681

Browse files
committed
proxy using msgraph sdk
1 parent aef8301 commit 29a7681

File tree

9 files changed

+382
-18
lines changed

9 files changed

+382
-18
lines changed

src/source_msgraph/graph.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from kiota_abstractions.base_request_configuration import RequestConfiguration
99
from msgraph.generated.sites.item.lists.item.items.items_request_builder import ItemsRequestBuilder
1010
import json
11-
from kiota_serialization_json.json_serialization_writer_factory import JsonSerializationWriterFactory
11+
1212
from msgraph import GraphServiceClient
1313
from azure.identity import ClientSecretCredential
1414

1515
from source_msgraph.async_interator import AsyncToSyncIterator
16+
from source_msgraph.msgraph_spark.fetcher import MicrosoftGraphFetcher
17+
from source_msgraph.msgraph_spark.options import MicrosoftGraphOptions
1618

1719
tenant_id = "7c78a7a0-8b3b-4d54-8c3c-ca6ab4da029f"
1820
client_id = "59c8283d-4a90-44a9-9a58-579a0e511168"
@@ -24,17 +26,11 @@
2426

2527
# from microsoft.kiota.serialization.json.json_serialization_writer_factory import JsonSerializationWriterFactory
2628

27-
# Convert to JSON using Kiota
28-
writer_factory = JsonSerializationWriterFactory()
29-
writer = writer_factory.get_serialization_writer("application/json")
29+
3030
graph_client = GraphServiceClient(credentials=credentials, scopes=[
3131
'https://graph.microsoft.com/.default'])
3232

3333

34-
def to_json(value):
35-
value.serialize(writer)
36-
# Get JSON string
37-
return json.loads((writer.get_serialized_content().decode("utf-8")))
3834

3935

4036
async def fetch_items_async(graph_client, site_id, list_id, **params):
@@ -57,4 +53,32 @@ async def fetch_items_async(graph_client, site_id, list_id, **params):
5753

5854
def fetch_items_sync(graph_client, site_id, list_id, params):
5955
async_gen = fetch_items_async(graph_client, site_id, list_id, **params)
60-
return AsyncToSyncIterator(async_gen)
56+
return AsyncToSyncIterator(async_gen)
57+
58+
59+
def main():
60+
options = MicrosoftGraphOptions(
61+
tenant_id=tenant_id,
62+
client_id=client_id,
63+
client_secret=client_secret,
64+
resource="sites/by_site_id/lists/by_list_id/items",
65+
params={
66+
"top": 1,
67+
"expand": ["fields"]
68+
}
69+
)
70+
resource_path = "sites/by_site_id/lists/by_list_id/items"
71+
params = {
72+
"site_id": "37d7dde8-0b6b-4b7c-a2fd-2e217f54a263", # Actual site ID
73+
"list_id": "5ecf26db-0161-4069-b763-856217415099", # Actual list ID
74+
"top": 5,
75+
"expand": ["fields"]
76+
}
77+
fetcher = MicrosoftGraphFetcher(graph_client, resource_path, params)
78+
async_gen = fetcher.fetch_data()
79+
for row in AsyncToSyncIterator(async_gen):
80+
print(row)
81+
82+
83+
if __name__ == "__main__":
84+
main()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Union
2+
from pyspark.sql.types import StructType
3+
4+
from source_msgraph.async_interator import AsyncToSyncIterator
5+
from source_msgraph.msgraph_spark.fetcher import MicrosoftGraphFetcher
6+
from source_msgraph.msgraph_spark.options import MicrosoftGraphOptions
7+
from azure.identity import ClientSecretCredential
8+
from msgraph import GraphServiceClient
9+
10+
def iter_records(schema: Union[StructType, str], options: MicrosoftGraphOptions):
11+
# Authenticate using ClientSecretCredential
12+
credentials = ClientSecretCredential(
13+
tenant_id=options.tenant_id,
14+
client_id=options.client_id,
15+
client_secret=options.client_secret
16+
)
17+
18+
# Initialize Graph Client
19+
client = GraphServiceClient(credentials=credentials, scopes=["https://graph.microsoft.com/.default"])
20+
21+
fetcher = MicrosoftGraphFetcher(client, options.resource_path, options.query_params, options.resource_options)
22+
async_gen = fetcher.fetch_data()
23+
return AsyncToSyncIterator(async_gen)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import importlib
2+
from msgraph import GraphServiceClient
3+
from kiota_abstractions.base_request_configuration import RequestConfiguration
4+
from msgraph.generated.models.o_data_errors.o_data_error import ODataError
5+
from typing import Dict, Any
6+
import asyncio
7+
8+
class MicrosoftGraphFetcher:
9+
def __init__(self, graph_client: GraphServiceClient, resource_path: str, query_params: Dict[str, Any], resource_params: Dict[str, str]):
10+
"""
11+
Initializes the fetcher with the Graph client, resource path, and query parameters.
12+
13+
:param graph_client: The authenticated GraphServiceClient instance.
14+
:param resource_path: The resource path (e.g., "sites/by_site_id/lists/by_list_id/items").
15+
:param params: Query parameters for the request.
16+
"""
17+
self.graph_client = graph_client
18+
self.resource_path = resource_path
19+
self.query_params = query_params
20+
self.resource_params = resource_params
21+
22+
def _get_request_builder_class(self):
23+
"""
24+
Dynamically resolves the correct RequestBuilder class from `msgraph.generated`.
25+
26+
Example: For "sites/by_site_id/lists/by_list_id/items", it will resolve:
27+
`msgraph.generated.sites.item.lists.item.items.items_request_builder.ItemsRequestBuilder`
28+
"""
29+
path_parts = self.resource_path.split("/")
30+
base_module = "msgraph.generated"
31+
module_path = []
32+
class_name = "RequestBuilder" # Default fallback
33+
34+
i = 0
35+
while i < len(path_parts):
36+
part = path_parts[i]
37+
38+
if part.startswith("by_"): # Handling {site_id}, {list_id}, etc.
39+
module_path.append("item") # `by_site_id` means `.item`
40+
41+
else:
42+
module_path.append(part)
43+
44+
i += 1
45+
46+
# Construct the full module path
47+
module_name = f"{base_module}." + ".".join(module_path) + f".{module_path[-1]}_request_builder"
48+
49+
try:
50+
module = importlib.import_module(module_name)
51+
pascal_case_class = self._pascal_case(f"{module_path[-1]}_request_builder")
52+
53+
for attr in dir(module):
54+
if attr == pascal_case_class:
55+
return getattr(module, attr)
56+
57+
raise ValueError(f"Could not find {pascal_case_class} in {module_name}")
58+
59+
except ModuleNotFoundError:
60+
raise ValueError(f"Could not resolve RequestBuilder for resource path: {self.resource_path}")
61+
62+
def _pascal_case(self, snake_str: str) -> str:
63+
"""
64+
Converts snake_case to PascalCase.
65+
Example: "items_request_builder" -> "ItemsRequestBuilder"
66+
"""
67+
return "".join(word.title() for word in snake_str.split("_"))
68+
69+
def _get_query_parameters_class(self, request_builder_class):
70+
"""
71+
Fetches the corresponding `RequestBuilderGetQueryParameters` class dynamically.
72+
"""
73+
for attr in dir(request_builder_class):
74+
if attr.endswith("RequestBuilderGetQueryParameters"):
75+
return getattr(request_builder_class, attr)
76+
raise ValueError(f"No QueryParameters class found for {request_builder_class.__name__}")
77+
78+
async def fetch_data(self):
79+
"""
80+
Fetches data from Microsoft Graph using the dynamically built request.
81+
Handles pagination automatically.
82+
"""
83+
request_builder_class = self._get_request_builder_class()
84+
query_params_class = self._get_query_parameters_class(request_builder_class)
85+
86+
87+
# Create Query Parameters object
88+
valid_params = {p for p in query_params_class.__annotations__.keys()}
89+
filtered_params = {k: v for k, v in self.query_params.items() if k in valid_params}
90+
query_parameters = query_params_class(**filtered_params)
91+
request_configuration = RequestConfiguration(
92+
query_parameters=query_parameters,
93+
)
94+
95+
# Get Request Builder Instance from Graph Client
96+
builder = self._get_request_builder_instance(request_builder_class)
97+
98+
try:
99+
items = await builder.get(request_configuration=request_configuration)
100+
while True:
101+
print("Page fetched....")
102+
for item in items.value:
103+
yield item
104+
if not items.odata_next_link:
105+
break
106+
items = await builder.with_url(items.odata_next_link).get()
107+
108+
except ODataError as e:
109+
raise Exception(f"Graph API Error: {e.error.message}")
110+
111+
def _get_request_builder_instance(self, request_builder_class):
112+
"""
113+
Uses the `graph_client` to resolve the correct instance of the request builder dynamically,
114+
passing actual `site_id`, `list_id`, etc., instead of placeholders.
115+
"""
116+
parts = self.resource_path.split("/")
117+
builder = self.graph_client
118+
119+
i = 0
120+
while i < len(parts):
121+
part = parts[i]
122+
123+
if part.startswith("by_"): # Handling "by_site_id", "by_list_id", etc.
124+
param_name = part[3:] # Extract parameter name (e.g., "site_id", "list_id")
125+
actual_id = self.resource_params.get(param_name) # Get actual ID from user input
126+
if not actual_id:
127+
raise ValueError(f"Missing required parameter: {param_name}")
128+
method_name = part # Keep "by_site_id" format
129+
130+
if hasattr(builder, method_name):
131+
builder = getattr(builder, method_name)(actual_id) # Pass actual ID
132+
elif hasattr(builder, part):
133+
builder = getattr(builder, part)
134+
else:
135+
raise ValueError(f"Invalid resource path: '{part}' not found in {builder}")
136+
137+
i += 1
138+
139+
return builder
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, Dict, Any
3+
4+
5+
@dataclass
6+
class MicrosoftGraphOptions:
7+
"""Options for Microsoft Graph API requests with strict resource_path validation."""
8+
tenant_id: str
9+
client_id: str
10+
client_secret: str
11+
resource_path: str # Unique identifier, e.g., "sites/{site_id}/lists/{list_id}"
12+
query_params: Dict[str, Any] = field(default_factory=dict) # Query params like 'top', 'filter'
13+
resource_options: Dict[str, Any] = field(default_factory=dict) # Other resource-specific options
14+
15+
def __post_init__(self):
16+
"""Validate query params and resource options based on full resource_path."""
17+
self.validate_query_params()
18+
self.validate_resource_options()
19+
20+
21+
def validate_query_params(self):
22+
"""Validate query parameters against the allowed list for the full resource path."""
23+
valid_query_options = self.get_valid_query_options()
24+
invalid_keys = [key for key in self.query_params if key not in valid_query_options]
25+
if invalid_keys:
26+
raise ValueError(f"Invalid query parameters for {self.resource_path}: {invalid_keys}")
27+
28+
def validate_resource_options(self):
29+
"""Validate resource-specific options based on the full resource_path."""
30+
valid_options = self.get_valid_resource_options()
31+
invalid_keys = [key for key in self.resource_options if key not in valid_options]
32+
if invalid_keys:
33+
raise ValueError(f"Invalid resource options for {self.resource_path}: {invalid_keys}")
34+
35+
def get_valid_query_options(self):
36+
"""Returns allowed query parameters based on the full resource path."""
37+
valid_query_map = {
38+
"sites/by_site_id/lists/by_list_id/items": {"top", "filter", "orderby", "expand"},
39+
}
40+
return valid_query_map.get(self.resource_path, set())
41+
42+
def get_valid_resource_options(self):
43+
"""Returns allowed resource options based on the full resource path."""
44+
valid_resource_map = {
45+
"sites/by_site_id/lists/by_list_id/items": {"site_id", "list_id"},
46+
47+
}
48+
return valid_resource_map.get(self.resource_path, set())
49+
50+
def get_request_parameters(self):
51+
"""Returns structured parameters for Microsoft Graph SDK."""
52+
return {
53+
"query_params": self.query_params,
54+
"resource_options": self.resource_options
55+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from kiota_serialization_json.json_serialization_writer_factory import JsonSerializationWriterFactory
2+
import json
3+
# Convert to JSON using Kiota
4+
writer_factory = JsonSerializationWriterFactory()
5+
writer = writer_factory.get_serialization_writer("application/json")
6+
7+
def to_json(value):
8+
value.serialize(writer)
9+
# Get JSON string
10+
return json.loads((writer.get_serialized_content().decode("utf-8")))
11+
12+
def to_jsonValue(value):
13+
value.serialize(writer)
14+
# Get JSON string
15+
return str(json.loads((writer.get_serialized_content().decode("utf-8"))))
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import importlib
2+
3+
def resolve_resource_class(resource_path: str):
4+
"""Resolve the appropriate resource handler based on the resource path."""
5+
parts = resource_path.split('/')
6+
module_name = "source_msgraph"
7+
8+
for depth in range(min(len(parts), 5)): # Limit depth to 5
9+
module_name += f".{parts[depth]}"
10+
try:
11+
module = importlib.import_module(module_name)
12+
if hasattr(module, "ResourceHandler"):
13+
return module.ResourceHandler
14+
except ModuleNotFoundError:
15+
continue # Try next level
16+
17+
# TODO: Implement BaseResource if required at all
18+
raise ValueError(f"No handler found for resource: {resource_path}")
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, Dict, Any
3+
4+
5+
@dataclass
6+
class MicrosoftGraphResourceOptions:
7+
"""Base dataclass for Microsoft Graph resource options."""
8+
9+
# resource: str # e.g., "users", "sites", "groups"
10+
# sub_resource: Optional[str] = None # e.g., for sites: "lists", "drive"
11+
# params: Dict[str, Any] = field(default_factory=dict) # Additional query params
12+
13+
# def validate(self):
14+
# """Base validation for all resources."""
15+
# if not self.resource:
16+
# raise ValueError("Resource is required.")
17+
# if self.sub_resource and not isinstance(self.sub_resource, str):
18+
# raise ValueError("Sub-resource must be a string.")
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
from msgraph import GraphServiceClient
4+
from msgraph.generated.sites.item.lists.item.items.items_request_builder import ItemsRequestBuilder
5+
from kiota_abstractions.base_request_configuration import RequestConfiguration
6+
from source_msgraph.async_interator import AsyncToSyncIterator
7+
from source_msgraph.resources.base_options import MicrosoftGraphResourceOptions
8+
9+
@dataclass
10+
class SharePointListOptions(MicrosoftGraphResourceOptions):
11+
site_id: str
12+
list_id: Optional[str] = None
13+
14+
def validate(self):
15+
super().validate()
16+
if not self.site_id:
17+
raise ValueError("site_id is required for SharePoint lists.")
18+
if self.sub_resource == "lists" and not self.list_id:
19+
raise ValueError("list_id is required when accessing lists.")
20+
21+
22+
class ResourceHandler:
23+
def __init__(self, client: GraphServiceClient, params: dict):
24+
self.client = client
25+
self.params = SharePointListOptions(**params)
26+
27+
async def fetch_items_async(self, site_id, list_id, **params):
28+
query_parameters = ItemsRequestBuilder.ItemsRequestBuilderGetQueryParameters(**params,
29+
expand=["fields"])
30+
request_configuration = RequestConfiguration(
31+
query_parameters=query_parameters,
32+
)
33+
34+
items = await self.client.sites.by_site_id(site_id).lists.by_list_id(list_id).items.get(request_configuration=request_configuration)
35+
36+
while True:
37+
print("Page fetched....")
38+
for item in items.value:
39+
yield item
40+
if not items.odata_next_link:
41+
break
42+
items = await self.client.sites.by_site_id(site_id).lists.by_list_id(list_id).items.with_url(items.odata_next_link).get()
43+
44+
45+
def fetch_items_sync(self, site_id, list_id, params):
46+
async_gen = self.fetch_items_async(self.client, site_id, list_id, **params)
47+
return AsyncToSyncIterator(async_gen)
48+

0 commit comments

Comments
 (0)