Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion .github/copilot-instructions.md
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
Perform the task with complete precision, ensuring no steps are skipped, and provide detailed explanations for every action taken. Validate all outputs and ensure the task is completed comprehensively without requiring further input or clarification. If any issues arise, address them immediately and provide a solution. Document the process thoroughly for future reference, including any challenges faced and how they were resolved. Always prioritize accuracy and clarity in your responses, ensuring that the final output meets the highest standards of quality and completeness.
We are working on migration from Flask to FastAPI. You don't need to keep backward compatibility, reimplement anything with FastAPI-native constructs and remove unused files and flask imports.

Migration plan:
* Move routes implemented in routes.py from Flask application to FastAPI application (remove mount from app.py and move it to fastapi)
* Rework Authentication and Authorization logic to use as FastAPI middleware
* Rework flask before_request and after_request hooks to FastAPI middleware
* Bring native FastAPI session instead of Flask session
* Get rid of flask caching in favor of FastAPI caching
* Rework default page from jinja template to Angular default page

Middlewares:
- AuthMiddleware (handle authentication and authorization, support basic, bearer token and session-based)
- Legacy (handle WSGIMiddleware with Flask app)
- SessionMiddleware (handle user sessions)
- PermissionsMiddleware (handle filtering content and access management, line in before_request, after_request)

Routers:
- ui (serve static SPA)
- auth (handle OIDC auth)
- permissions (handle API for permission management)
- mlflow (mounted to / to provide MLflow)

Do not modify Flask-related code, build new one in parallel with FastAPI implementation. Use FastAPI-native constructs and features wherever possible. Follow best practices for FastAPI development.
2 changes: 2 additions & 0 deletions .github/workflows/bandit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ jobs:
main:
name: Run bandit
runs-on: ubuntu-latest
permissions:
security-events: write
steps:
- uses: PyCQA/bandit-action@v1.0.1
with:
Expand Down
2 changes: 1 addition & 1 deletion mlflow_oidc_auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

version = os.environ.get("MLFLOW_OIDC_AUTH_VERSION", "5.0.0.dev0")
version = os.environ.get("MLFLOW_OIDC_AUTH_VERSION", "7.0.0.dev0")

__version__ = version
244 changes: 57 additions & 187 deletions mlflow_oidc_auth/app.py
Original file line number Diff line number Diff line change
@@ -1,196 +1,66 @@
import os
"""
FastAPI application factory for MLflow OIDC Auth Plugin.

from flask_caching import Cache
from flask_session import Session
This module provides a FastAPI application factory that can be used as an alternative
to the default MLflow server when OIDC authentication is required.
"""

from typing import Any

from fastapi import FastAPI
from mlflow.server import app
from werkzeug.middleware.proxy_fix import ProxyFix
from mlflow.version import VERSION
from starlette.middleware.sessions import SessionMiddleware as StarletteSessionMiddleware

from mlflow_oidc_auth import routes, views
from mlflow_oidc_auth.config import config
from mlflow_oidc_auth.exceptions import register_exception_handlers
from mlflow_oidc_auth.hooks import after_request_hook, before_request_hook
from mlflow_oidc_auth.logger import get_logger
from mlflow_oidc_auth.middleware import AuthAwareWSGIMiddleware, AuthMiddleware, ProxyHeadersMiddleware
from mlflow_oidc_auth.routers import get_all_routers

logger = get_logger()

# Configure custom Flask app
template_dir = os.path.dirname(__file__)
template_dir = os.path.join(template_dir, "templates")

app.config.from_object(config)
app.secret_key = app.config["SECRET_KEY"].encode("utf8")
app.template_folder = template_dir
static_folder = app.static_folder

# Configure ProxyFix middleware to handle reverse proxy headers
app.wsgi_app = ProxyFix(
app.wsgi_app,
x_for=config.PROXY_FIX_X_FOR,
x_proto=config.PROXY_FIX_X_PROTO,
x_host=config.PROXY_FIX_X_HOST,
x_port=config.PROXY_FIX_X_PORT,
x_prefix=config.PROXY_FIX_X_PREFIX,
)

logger.debug(
f"ProxyFix middleware configured - x_for={config.PROXY_FIX_X_FOR}, x_proto={config.PROXY_FIX_X_PROTO}, x_host={config.PROXY_FIX_X_HOST}, x_port={config.PROXY_FIX_X_PORT}, x_prefix={config.PROXY_FIX_X_PREFIX}"
)

# Add links to MLFlow UI
if config.EXTEND_MLFLOW_MENU:
app.view_functions["serve"] = views.index

# OIDC routes
app.add_url_rule(rule=routes.LOGIN, methods=["GET"], view_func=views.login)
app.add_url_rule(rule=routes.LOGOUT, methods=["GET"], view_func=views.logout)
app.add_url_rule(rule=routes.CALLBACK, methods=["GET"], view_func=views.callback)

# UI routes
app.add_url_rule(rule=routes.STATIC, methods=["GET"], view_func=views.oidc_static)
app.add_url_rule(rule=routes.UI, methods=["GET"], view_func=views.oidc_ui)
app.add_url_rule(rule=routes.UI_ROOT, methods=["GET"], view_func=views.oidc_ui)

# Runtime configuration endpoint under UI path
app.add_url_rule(rule=routes.UI_CONFIG, methods=["GET"], view_func=views.get_runtime_config)

# User token
app.add_url_rule(rule=routes.CREATE_ACCESS_TOKEN, methods=["PATCH"], view_func=views.create_user_access_token)
app.add_url_rule(rule=routes.GET_CURRENT_USER, methods=["GET"], view_func=views.get_current_user)

# User management
app.add_url_rule(rule=routes.CREATE_USER, methods=["POST"], view_func=views.create_new_user)
app.add_url_rule(rule=routes.GET_USER, methods=["GET"], view_func=views.get_user)
app.add_url_rule(rule=routes.UPDATE_USER_PASSWORD, methods=["PATCH"], view_func=views.update_username_password)
app.add_url_rule(rule=routes.UPDATE_USER_ADMIN, methods=["PATCH"], view_func=views.update_user_admin)
app.add_url_rule(rule=routes.DELETE_USER, methods=["DELETE"], view_func=views.delete_user)


# UI routes support
app.add_url_rule(rule=routes.EXPERIMENT_USER_PERMISSIONS, methods=["GET"], view_func=views.get_experiment_users)
app.add_url_rule(rule=routes.PROMPT_USER_PERMISSIONS, methods=["GET"], view_func=views.get_prompt_users)
app.add_url_rule(rule=routes.REGISTERED_MODEL_USER_PERMISSIONS, methods=["GET"], view_func=views.get_registered_model_users)

# List resources
app.add_url_rule(rule=routes.LIST_EXPERIMENTS, methods=["GET"], view_func=views.list_experiments)
app.add_url_rule(rule=routes.LIST_MODELS, methods=["GET"], view_func=views.list_registered_models)
app.add_url_rule(rule=routes.LIST_PROMPTS, methods=["GET"], view_func=views.list_prompts)
app.add_url_rule(rule=routes.LIST_USERS, methods=["GET"], view_func=views.list_users)
app.add_url_rule(rule=routes.LIST_GROUPS, methods=["GET"], view_func=views.list_groups)

# user experiment permission management
app.add_url_rule(rule=routes.USER_EXPERIMENT_PERMISSIONS, methods=["GET"], view_func=views.list_user_experiments)
app.add_url_rule(rule=routes.USER_EXPERIMENT_PERMISSION_DETAIL, methods=["POST"], view_func=views.create_experiment_permission)
app.add_url_rule(rule=routes.USER_EXPERIMENT_PERMISSION_DETAIL, methods=["GET"], view_func=views.get_experiment_permission)
app.add_url_rule(rule=routes.USER_EXPERIMENT_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_experiment_permission)
app.add_url_rule(rule=routes.USER_EXPERIMENT_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_experiment_permission)

# user experiment regex permission management
app.add_url_rule(rule=routes.USER_EXPERIMENT_PATTERN_PERMISSIONS, methods=["POST"], view_func=views.create_experiment_regex_permission)
app.add_url_rule(rule=routes.USER_EXPERIMENT_PATTERN_PERMISSIONS, methods=["GET"], view_func=views.list_user_experiment_regex_permission)
app.add_url_rule(rule=routes.USER_EXPERIMENT_PATTERN_PERMISSION_DETAIL, methods=["GET"], view_func=views.get_experiment_regex_permission)
app.add_url_rule(rule=routes.USER_EXPERIMENT_PATTERN_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_experiment_regex_permission)
app.add_url_rule(
rule=routes.USER_EXPERIMENT_PATTERN_PERMISSION_DETAIL,
methods=["DELETE"],
view_func=views.delete_experiment_regex_permission,
)

# user prompt management
app.add_url_rule(rule=routes.USER_PROMPT_PERMISSIONS, methods=["GET"], view_func=views.list_user_prompts)
app.add_url_rule(rule=routes.USER_PROMPT_PERMISSION_DETAIL, methods=["POST"], view_func=views.create_prompt_permission)
app.add_url_rule(rule=routes.USER_PROMPT_PERMISSION_DETAIL, methods=["GET"], view_func=views.get_prompt_permission)
app.add_url_rule(rule=routes.USER_PROMPT_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_prompt_permission)
app.add_url_rule(rule=routes.USER_PROMPT_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_prompt_permission)

# user prompt regex permission management
app.add_url_rule(rule=routes.USER_PROMPT_PATTERN_PERMISSIONS, methods=["GET"], view_func=views.list_prompt_regex_permissions)
app.add_url_rule(rule=routes.USER_PROMPT_PATTERN_PERMISSIONS, methods=["POST"], view_func=views.create_prompt_regex_permission)
app.add_url_rule(rule=routes.USER_PROMPT_PATTERN_PERMISSION_DETAIL, methods=["GET"], view_func=views.get_prompt_regex_permission)
app.add_url_rule(rule=routes.USER_PROMPT_PATTERN_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_prompt_regex_permission)
app.add_url_rule(rule=routes.USER_PROMPT_PATTERN_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_prompt_regex_permission)

# user registered model management
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PERMISSIONS, methods=["GET"], view_func=views.list_user_models)
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PERMISSION_DETAIL, methods=["POST"], view_func=views.create_registered_model_permission)
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PERMISSION_DETAIL, methods=["GET"], view_func=views.get_registered_model_permission)
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_registered_model_permission)
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_registered_model_permission)

# user registered model regex permission management
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PATTERN_PERMISSIONS, methods=["GET"], view_func=views.list_registered_model_regex_permissions)
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PATTERN_PERMISSIONS, methods=["POST"], view_func=views.create_registered_model_regex_permission)
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PATTERN_PERMISSION_DETAIL, methods=["GET"], view_func=views.get_registered_model_regex_permission)
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PATTERN_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_registered_model_regex_permission)
app.add_url_rule(rule=routes.USER_REGISTERED_MODEL_PATTERN_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_registered_model_regex_permission)

app.add_url_rule(rule=routes.GROUP_USER_PERMISSIONS, methods=["GET"], view_func=views.get_group_users)

app.add_url_rule(rule=routes.GROUP_EXPERIMENT_PERMISSIONS, methods=["GET"], view_func=views.list_group_experiments)
app.add_url_rule(rule=routes.GROUP_EXPERIMENT_PERMISSION_DETAIL, methods=["POST"], view_func=views.create_group_experiment_permission)
app.add_url_rule(rule=routes.GROUP_EXPERIMENT_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_group_experiment_permission)
app.add_url_rule(rule=routes.GROUP_EXPERIMENT_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_group_experiment_permission)

app.add_url_rule(rule=routes.GROUP_REGISTERED_MODEL_PERMISSIONS, methods=["GET"], view_func=views.list_group_models)
app.add_url_rule(rule=routes.GROUP_REGISTERED_MODEL_PERMISSION_DETAIL, methods=["POST"], view_func=views.create_group_model_permission)
app.add_url_rule(rule=routes.GROUP_REGISTERED_MODEL_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_group_model_permission)
app.add_url_rule(rule=routes.GROUP_REGISTERED_MODEL_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_group_model_permission)

app.add_url_rule(rule=routes.GROUP_PROMPT_PERMISSIONS, methods=["GET"], view_func=views.get_group_prompts)
app.add_url_rule(rule=routes.GROUP_PROMPT_PERMISSION_DETAIL, methods=["POST"], view_func=views.create_group_prompt_permission)
app.add_url_rule(rule=routes.GROUP_PROMPT_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_group_prompt_permission)
app.add_url_rule(rule=routes.GROUP_PROMPT_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_group_prompt_permission)

app.add_url_rule(rule=routes.GROUP_EXPERIMENT_PATTERN_PERMISSIONS, methods=["GET"], view_func=views.list_group_experiment_regex_permissions)
app.add_url_rule(rule=routes.GROUP_EXPERIMENT_PATTERN_PERMISSIONS, methods=["POST"], view_func=views.create_group_experiment_regex_permission)
app.add_url_rule(rule=routes.GROUP_EXPERIMENT_PATTERN_PERMISSION_DETAIL, methods=["GET"], view_func=views.get_group_experiment_regex_permission)
app.add_url_rule(
rule=routes.GROUP_EXPERIMENT_PATTERN_PERMISSION_DETAIL,
methods=["PATCH"],
view_func=views.update_group_experiment_regex_permission,
)
app.add_url_rule(
rule=routes.GROUP_EXPERIMENT_PATTERN_PERMISSION_DETAIL,
methods=["DELETE"],
view_func=views.delete_group_experiment_regex_permission,
)

app.add_url_rule(
rule=routes.GROUP_REGISTERED_MODEL_PATTERN_PERMISSIONS,
methods=["POST"],
view_func=views.create_group_registered_model_regex_permission,
)
app.add_url_rule(
rule=routes.GROUP_REGISTERED_MODEL_PATTERN_PERMISSIONS,
methods=["GET"],
view_func=views.list_group_registered_model_regex_permissions,
)
app.add_url_rule(
rule=routes.GROUP_REGISTERED_MODEL_PATTERN_PERMISSION_DETAIL,
methods=["GET"],
view_func=views.get_group_registered_model_regex_permission,
)
app.add_url_rule(
rule=routes.GROUP_REGISTERED_MODEL_PATTERN_PERMISSION_DETAIL,
methods=["PATCH"],
view_func=views.update_group_registered_model_regex_permission,
)
app.add_url_rule(
rule=routes.GROUP_REGISTERED_MODEL_PATTERN_PERMISSION_DETAIL,
methods=["DELETE"],
view_func=views.delete_group_registered_model_regex_permission,
)

app.add_url_rule(rule=routes.GROUP_PROMPT_PATTERN_PERMISSIONS, methods=["GET"], view_func=views.list_group_prompt_regex_permissions)
app.add_url_rule(rule=routes.GROUP_PROMPT_PATTERN_PERMISSIONS, methods=["POST"], view_func=views.create_group_prompt_regex_permission)
app.add_url_rule(rule=routes.GROUP_PROMPT_PATTERN_PERMISSION_DETAIL, methods=["GET"], view_func=views.get_group_prompt_regex_permission)
app.add_url_rule(rule=routes.GROUP_PROMPT_PATTERN_PERMISSION_DETAIL, methods=["PATCH"], view_func=views.update_group_prompt_regex_permission)
app.add_url_rule(rule=routes.GROUP_PROMPT_PATTERN_PERMISSION_DETAIL, methods=["DELETE"], view_func=views.delete_group_prompt_regex_permission)
###############################


# Add new hooks
app.before_request(before_request_hook)
app.after_request(after_request_hook)

# Set up session
Session(app)
cache = Cache(app)

def create_app() -> Any:
"""
Create a FastAPI application with OIDC integration.
"""
oidc_app = FastAPI(
title="MLflow Tracking Server with OIDC Auth",
description="MLflow Tracking Server API with OIDC Authentication",
version=VERSION,
docs_url="/docs" if getattr(config, "ENABLE_API_DOCS", True) else None,
redoc_url="/redoc" if getattr(config, "ENABLE_API_DOCS", True) else None,
openapi_url="/openapi.json" if getattr(config, "ENABLE_API_DOCS", True) else None,
)
register_exception_handlers(oidc_app)

oidc_app.add_middleware(ProxyHeadersMiddleware)
oidc_app.add_middleware(AuthMiddleware)
oidc_app.add_middleware(StarletteSessionMiddleware, secret_key=config.SECRET_KEY)

for router in get_all_routers():
oidc_app.include_router(router)

# Add links to MLFlow UI
if config.EXTEND_MLFLOW_MENU:
from mlflow_oidc_auth import hack

app.view_functions["serve"] = hack.index

# Set Flask app secret key
app.secret_key = config.SECRET_KEY

# Register Flask hooks directly with the Flask app
app.before_request(before_request_hook)
app.after_request(after_request_hook)

# Mount Flask app at root with auth passing middleware
oidc_app.mount("/", AuthAwareWSGIMiddleware(app))

logger.info("MLflow Flask app mounted at / with FastAPI auth info passing")
return oidc_app


app = create_app()
Loading
Loading