Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions litellm/proxy/management_endpoints/scim/scim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
NewTeamRequest,
NewUserRequest,
NewUserResponse,
ProxyErrorTypes,
ProxyException,
TeamMemberAddRequest,
TeamMemberDeleteRequest,
UserAPIKeyAuth,
Expand Down Expand Up @@ -797,6 +799,9 @@ async def patch_team_membership(
) -> bool:
"""
Add or remove user from teams

Handles duplicate membership gracefully (idempotent operation).
If a user is already in a team, that's fine - we don't treat it as an error.
"""
for _team_id in teams_ids_to_add_user_to:
try:
Expand All @@ -809,6 +814,16 @@ async def patch_team_membership(
user_role=LitellmUserRoles.PROXY_ADMIN
),
)
except ProxyException as e:
# Handle duplicate membership gracefully - this is idempotent
if e.type == ProxyErrorTypes.team_member_already_in_team:
verbose_proxy_logger.debug(
f"User {user_id} is already in team {_team_id}, skipping add"
)
else:
verbose_proxy_logger.exception(
f"Error adding user to team {_team_id}: {e}"
)
except Exception as e:
verbose_proxy_logger.exception(f"Error adding user to team {_team_id}: {e}")

Expand Down Expand Up @@ -1302,20 +1317,42 @@ async def patch_group(
patch_ops, existing_team, prisma_client
)

# Track current members for comparison
# Track current members BEFORE update for comparison
current_members = set(await _get_team_member_user_ids_from_team(existing_team))

# Apply updates to the database
updated_team = await _apply_group_patch_updates(
group_id, update_data, final_members, prisma_client
)

# Refresh team data from database to get the latest state after concurrent updates
# This prevents race conditions when multiple PATCH requests come in simultaneously
refreshed_team = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": group_id}
)
if refreshed_team:
# Re-read current members from refreshed team to account for concurrent updates
refreshed_current_members = set(
await _get_team_member_user_ids_from_team(
LiteLLM_TeamTable(**refreshed_team.model_dump())
)
)
# Use the refreshed members for comparison
current_members = refreshed_current_members

# Handle user-team relationship changes
await _handle_group_membership_changes(group_id, current_members, final_members)

# Refresh team one more time to get final state after membership changes
final_team = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": group_id}
)
if final_team:
updated_team = final_team

# Convert to SCIM format and return
scim_group = await ScimTransformations.transform_litellm_team_to_scim_group(
updated_team
LiteLLM_TeamTable(**updated_team.model_dump())
)
return scim_group

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
create_group,
create_user,
get_service_provider_config,
patch_group,
patch_user,
update_group,
update_user,
Expand Down Expand Up @@ -1189,4 +1190,204 @@ def mock_new_user_side_effect(data):
# Verify response
assert result.id == group_id
assert result.displayName == "Updated Group Name"
assert len(result.members) == 3
assert len(result.members) == 3


@pytest.mark.asyncio
async def test_patch_group_refreshes_team_data_to_prevent_race_conditions(mocker):
"""
Test that patch_group refreshes team data from database:
1. After applying updates (to get latest state before membership changes)
2. After membership changes (to get final state for response)

This prevents race conditions when multiple PATCH requests come in simultaneously.
"""
from litellm.proxy._types import LiteLLM_TeamTable, Member

group_id = "test-group-123"

# Mock existing team
existing_team = LiteLLM_TeamTable(
team_id=group_id,
team_alias="Original Team",
members=["user1", "user2"],
members_with_roles=[
Member(user_id="user1", role="user"),
Member(user_id="user2", role="user")
],
metadata={}
)

# Mock team after applying updates (simulating what _apply_group_patch_updates returns)
updated_team_after_patch = LiteLLM_TeamTable(
team_id=group_id,
team_alias="Updated Team",
members=["user1", "user2", "user3"], # user3 added in patch
members_with_roles=[
Member(user_id="user1", role="user"),
Member(user_id="user2", role="user"),
Member(user_id="user3", role="user")
],
metadata={}
)

# Mock refreshed team (simulating concurrent update - user4 was added by another request)
refreshed_team_before_membership = LiteLLM_TeamTable(
team_id=group_id,
team_alias="Updated Team",
members=["user1", "user2", "user3", "user4"], # user4 added concurrently
members_with_roles=[
Member(user_id="user1", role="user"),
Member(user_id="user2", role="user"),
Member(user_id="user3", role="user"),
Member(user_id="user4", role="user") # Concurrent addition
],
metadata={}
)

# Mock final refreshed team after membership changes
final_refreshed_team = LiteLLM_TeamTable(
team_id=group_id,
team_alias="Updated Team",
members=["user1", "user2", "user3", "user4", "user5"], # user5 added via membership change
members_with_roles=[
Member(user_id="user1", role="user"),
Member(user_id="user2", role="user"),
Member(user_id="user3", role="user"),
Member(user_id="user4", role="user"),
Member(user_id="user5", role="user") # Added via membership change
],
metadata={}
)

# Mock SCIM patch operations - adding user3 and user5
patch_ops = SCIMPatchOp(
schemas=["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
Operations=[
SCIMPatchOperation(op="add", path="members", value=[{"value": "user3"}, {"value": "user5"}])
]
)

# Mock prisma client
mock_prisma_client = mocker.MagicMock()
mock_prisma_client.db = mocker.MagicMock()
mock_prisma_client.db.litellm_teamtable = mocker.MagicMock()
mock_prisma_client.db.litellm_usertable = mocker.MagicMock()

# Mock user lookups (all users exist)
mock_user = mocker.MagicMock()
mock_user.user_id = "test-user"
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=mock_user)

# Mock dependencies
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2._get_prisma_client_or_raise_exception",
AsyncMock(return_value=mock_prisma_client)
)
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2._check_team_exists",
AsyncMock(return_value=existing_team)
)

# Mock _process_group_patch_operations
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2._process_group_patch_operations",
AsyncMock(return_value=(
{"team_alias": "Updated Team"},
{"user1", "user2", "user3", "user5"} # final_members after processing patch
))
)

# Mock _apply_group_patch_updates to return updated_team_after_patch
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2._apply_group_patch_updates",
AsyncMock(return_value=updated_team_after_patch)
)

# Mock find_unique calls for refresh operations
# First refresh (after applying updates) - returns team with concurrent update (user4)
# Second refresh (after membership changes) - returns final team (with user5)
# Need to add model_dump() method to mock Prisma model objects
mock_refreshed_team_before_membership = mocker.MagicMock()
# model_dump() should return a dict that can be used to construct LiteLLM_TeamTable
mock_refreshed_team_before_membership.model_dump = mocker.Mock(return_value={
"team_id": refreshed_team_before_membership.team_id,
"team_alias": refreshed_team_before_membership.team_alias,
"members": refreshed_team_before_membership.members,
"members_with_roles": refreshed_team_before_membership.members_with_roles,
"metadata": refreshed_team_before_membership.metadata,
})

mock_final_refreshed_team = mocker.MagicMock()
mock_final_refreshed_team.model_dump = mocker.Mock(return_value={
"team_id": final_refreshed_team.team_id,
"team_alias": final_refreshed_team.team_alias,
"members": final_refreshed_team.members,
"members_with_roles": final_refreshed_team.members_with_roles,
"metadata": final_refreshed_team.metadata,
})

refresh_calls = [mock_refreshed_team_before_membership, mock_final_refreshed_team]
mock_prisma_client.db.litellm_teamtable.find_unique = AsyncMock(side_effect=refresh_calls)

# Mock _handle_group_membership_changes
mock_handle_group_membership_changes = mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2._handle_group_membership_changes",
AsyncMock()
)

# Mock SCIM transformation
expected_scim_response = SCIMGroup(
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
id=group_id,
displayName="Updated Team",
members=[
SCIMMember(value="user1", display="user1"),
SCIMMember(value="user2", display="user2"),
SCIMMember(value="user3", display="user3"),
SCIMMember(value="user4", display="user4"),
SCIMMember(value="user5", display="user5")
]
)
mocker.patch(
"litellm.proxy.management_endpoints.scim.scim_v2.ScimTransformations.transform_litellm_team_to_scim_group",
AsyncMock(return_value=expected_scim_response)
)

# Execute patch_group
result = await patch_group(group_id=group_id, patch_ops=patch_ops)

# Verify that find_unique was called twice (for the two refreshes)
assert mock_prisma_client.db.litellm_teamtable.find_unique.call_count == 2

# Verify first refresh was called after applying updates
first_refresh_call = mock_prisma_client.db.litellm_teamtable.find_unique.call_args_list[0]
assert first_refresh_call[1]["where"]["team_id"] == group_id

# Verify that _handle_group_membership_changes was called with refreshed members
# It should use refreshed_current_members (user1, user2, user3, user4) not updated_team_after_patch members
mock_handle_group_membership_changes.assert_called_once()
membership_call = mock_handle_group_membership_changes.call_args
# _handle_group_membership_changes is called with positional arguments: (group_id, current_members, final_members)
assert membership_call[0][0] == group_id
# current_members should be from refreshed_team_before_membership (includes user4 from concurrent update)
assert membership_call[0][1] == {"user1", "user2", "user3", "user4"}
# final_members should be from patch operations (user1, user2, user3, user5)
assert membership_call[0][2] == {"user1", "user2", "user3", "user5"}

# Verify second refresh was called after membership changes
second_refresh_call = mock_prisma_client.db.litellm_teamtable.find_unique.call_args_list[1]
assert second_refresh_call[1]["where"]["team_id"] == group_id

# Verify SCIM transformation was called with final_refreshed_team (not updated_team_after_patch)
from litellm.proxy.management_endpoints.scim.scim_v2 import ScimTransformations
ScimTransformations.transform_litellm_team_to_scim_group.assert_called_once()
transform_call = ScimTransformations.transform_litellm_team_to_scim_group.call_args[0][0]
# Verify it was called with final_refreshed_team (has user5)
assert isinstance(transform_call, LiteLLM_TeamTable)
member_ids = {member.user_id for member in transform_call.members_with_roles}
assert member_ids == {"user1", "user2", "user3", "user4", "user5"}

# Verify response
assert result.id == group_id
assert result.displayName == "Updated Team"
Loading