diff --git a/jupyterlab_tinyapp/handlers.py b/jupyterlab_tinyapp/handlers.py index 72d4b93..b822c49 100644 --- a/jupyterlab_tinyapp/handlers.py +++ b/jupyterlab_tinyapp/handlers.py @@ -41,6 +41,8 @@ from .generation.generator import MockStreamingGenerator, OpenAIStreamingGenerator from .generation.streaming import StreamParser import ldap3 +import requests +from urllib.parse import urlencode app = Application.instance() logger = logging.getLogger(app.log.name) @@ -133,6 +135,15 @@ class TinyAppServerEndpoint(Enum): LDAP_BIND_PASSWORD = os.getenv('LDAP_BIND_PASSWORD') LDAP_USER_SEARCH_FILTER = os.getenv('LDAP_USER_SEARCH_FILTER', '(|(givenName=*{search}*)(sn=*{search}*))') LDAP_USER_ATTRIBUTES = os.getenv('LDAP_USER_ATTRIBUTES', 'cn,uid,displayName,mail,givenName,sn').split(',') +LDAP_MAX_RESULTS = int(os.getenv('LDAP_MAX_RESULTS', '1000')) + +# SCIM Configuration +USER_PROVIDER = os.getenv('USER_PROVIDER', 'ldap') # 'ldap' or 'scim' +SCIM_ENDPOINT = os.getenv('SCIM_ENDPOINT') # e.g., https://dev-12345.okta.com/api/v1/users +SCIM_TOKEN = os.getenv('SCIM_TOKEN') # API token for SCIM endpoint +SCIM_SEARCH_FILTER = os.getenv('SCIM_SEARCH_FILTER', 'profile.firstName sw "{search}" or profile.lastName sw "{search}" or profile.login sw "{search}"') +SCIM_MAX_RESULTS = int(os.getenv('SCIM_MAX_RESULTS', '100')) +SCIM_TIMEOUT = int(os.getenv('SCIM_TIMEOUT', '30')) # VOLUME_CLAIM_NAME will be mounted on published app container. It is assumed that # BASE_DIR and VOLUME_CLAIM_NAME refer to same file system - otherwise files @@ -875,83 +886,189 @@ async def post(self): })) +def get_ldap_attr(entry, attr_name): + """Safely extract LDAP attribute value, handling multi-valued attributes""" + if hasattr(entry, attr_name): + attr_value = getattr(entry, attr_name) + if attr_value: + # Handle multi-valued attributes by taking the first value + if isinstance(attr_value, list) and len(attr_value) > 0: + first_value = attr_value[0] + else: + first_value = attr_value + + # Convert bytes to string if necessary + if isinstance(first_value, bytes): + first_value = first_value.decode('utf-8') + + # Return stripped string + return str(first_value).strip() + return '' + + +async def search_users_ldap(search_query): + """Search users using LDAP""" + if not LDAP_ADDR or not LDAP_BASE_DN: + logger.warning('LDAP not configured') + return None, 'ldap is not configured: missing LDAP_ADDR or LDAP_BASE_DN' + + # Create LDAP server object + try: + server = ldap3.Server(LDAP_ADDR, get_info=ldap3.ALL) + # Connect and bind to LDAP server + if LDAP_BIND_DN and LDAP_BIND_PASSWORD: + conn = ldap3.Connection(server, LDAP_BIND_DN, LDAP_BIND_PASSWORD, auto_bind=True) + else: + conn = ldap3.Connection(server, auto_bind=True) + except Exception as e: + logger.error(f'Error connecting to LDAP server: {str(e)}') + return None, 'Unable to connect to LDAP server' + + # Search for users + search_filter = LDAP_USER_SEARCH_FILTER.format(search=ldap3.utils.conv.escape_filter_chars(search_query)) + + try: + success = conn.search( + search_base=LDAP_BASE_DN, + search_filter=search_filter, + search_scope=ldap3.SUBTREE, + attributes=LDAP_USER_ATTRIBUTES, + size_limit=LDAP_MAX_RESULTS + ) + + if not success: + conn.unbind() + logger.info('No LDAP results found') + return [], None + except Exception as e: + conn.unbind() + logger.error(f'Error during LDAP search: {str(e)}') + return None, 'error searching ldap directory' + + # Process search results + users = [] + for entry in conn.entries: + user_data = { + 'uid': get_ldap_attr(entry, 'uid'), + 'cn': get_ldap_attr(entry, 'cn'), + } + + user_data['label'] = user_data['cn'] + user_data['value'] = user_data['uid'] + + if user_data['value']: # Only include users with a valid identifier + users.append(user_data) + + conn.unbind() + logger.info(f'Found {len(users)} LDAP users for query: {search_query}') + return users, None + + +async def search_users_scim(search_query): + """Search users using SCIM API""" + if not SCIM_ENDPOINT or not SCIM_TOKEN: + logger.warning('SCIM not configured') + return None, 'SCIM is not configured: missing SCIM_ENDPOINT or SCIM_TOKEN' + + try: + # Build SCIM query parameters + filter_expr = SCIM_SEARCH_FILTER.format(search=search_query) + params = { + 'filter': filter_expr, + 'count': SCIM_MAX_RESULTS, + 'startIndex': 1 + } + + headers = { + 'Authorization': f'SSWS {SCIM_TOKEN}', + 'Accept': 'application/scim+json', + 'Content-Type': 'application/scim+json' + } + + # Make SCIM API request + response = requests.get( + SCIM_ENDPOINT, + params=params, + headers=headers, + timeout=SCIM_TIMEOUT + ) + + if response.status_code != 200: + logger.error(f'SCIM API error: {response.status_code} {response.text}') + return None, f'SCIM API error: {response.status_code}' + + scim_data = response.json() + + # Process SCIM results + users = [] + resources = scim_data.get('Resources', []) + + for user in resources: + # Extract user data from SCIM response + user_data = { + 'uid': user.get('userName', ''), + 'cn': user.get('displayName', ''), + } + + # Fallback to profile data if needed + if not user_data['cn'] and 'profile' in user: + profile = user['profile'] + first_name = profile.get('firstName', '') + last_name = profile.get('lastName', '') + if first_name or last_name: + user_data['cn'] = f"{first_name} {last_name}".strip() + + # Set label and value for UI + user_data['label'] = user_data['cn'] or user_data['uid'] + user_data['value'] = user_data['uid'] + + if user_data['value']: # Only include users with a valid identifier + users.append(user_data) + + logger.info(f'Found {len(users)} SCIM users for query: {search_query}') + return users, None + + except requests.exceptions.RequestException as e: + logger.error(f'SCIM request error: {str(e)}') + return None, 'Error connecting to SCIM endpoint' + except Exception as e: + logger.error(f'SCIM processing error: {str(e)}') + return None, 'Error processing SCIM response' + + class SearchUsersHandler(CustomAPIHandler): @tornado.web.authenticated async def get(self): - logger.info('Received request to SearchUsersHandler') + logger.info(f'Received request to SearchUsersHandler (provider: {USER_PROVIDER})') # Get search query parameter search_query = self.get_argument('query', '') if not search_query or len(search_query.strip()) < 2: logger.info('Invalid query parameter: must be at least 2 characters') self._return_error(400, 'query parameter must be at least 2 characters') - - if not LDAP_ADDR or not LDAP_BASE_DN: - logger.warning('LDAP not configured') - self._return_error(500, 'ldap is not configured: missing LDAP_ADDR or LDAP_BASE_DN') return - # Create LDAP server object - try: - server = ldap3.Server(LDAP_ADDR, get_info=ldap3.ALL) - # Connect and bind to LDAP server - if LDAP_BIND_DN and LDAP_BIND_PASSWORD: - conn = ldap3.Connection(server, LDAP_BIND_DN, LDAP_BIND_PASSWORD, auto_bind=True) - else: - conn = ldap3.Connection(server, auto_bind=True) - except Exception as e: - logger.error(f'Error connecting to LDAP server: {str(e)}') - self._return_error(500, 'Unable to connect to LDAP server') + # Route to appropriate search function based on configuration + if USER_PROVIDER.lower() == 'scim': + users, error = await search_users_scim(search_query) + elif USER_PROVIDER.lower() == 'ldap': + users, error = await search_users_ldap(search_query) + else: + logger.error(f'Unknown user provider: {USER_PROVIDER}') + self._return_error(500, f'Unknown user provider: {USER_PROVIDER}. Must be "ldap" or "scim"') return - # Search for users - search_filter = LDAP_USER_SEARCH_FILTER.format(search=ldap3.utils.conv.escape_filter_chars(search_query)) - - try: - success = conn.search( - search_base=LDAP_BASE_DN, - search_filter=search_filter, - search_scope=ldap3.SUBTREE, - attributes=LDAP_USER_ATTRIBUTES, - size_limit=1000 - ) - - if not success: - conn.unbind() - logger.error('No results found or error during LDAP search') - self._return_error(500, 'No results found or error during LDAP search') - return - except Exception as e: - conn.unbind() - logger.error(f'Error during LDAP search: {str(e)}') - self._return_error(500, 'error searching ldap directory') + # Handle errors + if error: + self._return_error(500, error) return - - # Process search results - users = [] - for entry in conn.entries: - user_data = { - 'uid': str(entry.uid) if hasattr(entry, 'uid') and entry.uid else '', - 'cn': str(entry.cn) if hasattr(entry, 'cn') and entry.cn else '', - 'displayName': str(entry.displayName) if hasattr(entry, 'displayName') and entry.displayName else '', - 'mail': str(entry.mail) if hasattr(entry, 'mail') and entry.mail else '' - } - - user_data['label'] = user_data['cn'] - user_data['value'] = user_data['uid'] - - if user_data['value']: # Only include users with a valid identifier - users.append(user_data) - - logger.info(f'Found {len(users)} users for query: {search_query}') + # Return results self.finish(json.dumps({ 'data': { 'users': users } })) - - conn.unbind() class PingHandler(CustomAPIHandler): diff --git a/pyproject.toml b/pyproject.toml index a2c206a..d2ebdad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "openai==2.6.1", "aiofiles==23.2.1", "ldap3>=2.9.1", + "requests>=2.25.0", ] dynamic = ["version", "description", "authors", "urls", "keywords"]