From c5fcf4b8b74f772edf38a7982834539f18e62672 Mon Sep 17 00:00:00 2001 From: James O'Brien Date: Fri, 18 May 2018 09:57:09 -0700 Subject: [PATCH 1/3] Initial attempt to use authlib --- app/extensions/api/namespace.py | 3 +- app/extensions/auth/__init__.py | 2 +- app/extensions/auth/oauth22.py | 156 +++++++++++++++ app/modules/auth/__init__.py | 23 ++- app/modules/auth/models2.py | 58 ++++++ app/modules/auth/resources.py | 2 +- app/modules/auth/schemas.py | 5 +- app/modules/auth/views.py | 56 ++++-- app/modules/users/models.py | 3 + app/requirements.txt | 41 ++-- config.py | 9 + migrations/__init__.py | 0 migrations/initial_development_data.py | 13 +- migrations/script.py.mako | 2 + migrations/versions/15f27bc43bd_.py | 89 --------- migrations/versions/2b5af066bb9_.py | 39 ---- migrations/versions/2e9d99288cd_.py | 38 ---- migrations/versions/357c2809db4_.py | 28 --- migrations/versions/36954739c63_.py | 52 ----- migrations/versions/4754e1427ac_.py | 26 --- migrations/versions/522430fd0601_.py | 124 ++++++++++++ .../5e2954a2af18_refactored-auth-oauth2.py | 180 ------------------ .../81ce4ac01c45_migrate_static_roles.py | 49 ----- ..._altered-OAuth2Token-token_type-to-Enum.py | 42 ---- migrations/versions/8c8b2d23a5_.py | 24 --- .../beb065460c24_fixed-password-type.py | 66 ------- run_app.py | 10 + tasks/app/db.py | 17 ++ tasks/app/run.py | 1 + tasks/app/users.py | 2 +- tasks/requirements.txt | 1 + tests/modules/auth/conftest.py | 4 +- .../resources/test_creating_oauth2client.py | 2 +- tests/modules/auth/resources/test_token.py | 4 +- tests/utils.py | 2 +- 35 files changed, 468 insertions(+), 705 deletions(-) create mode 100644 app/extensions/auth/oauth22.py create mode 100644 app/modules/auth/models2.py delete mode 100644 migrations/__init__.py delete mode 100644 migrations/versions/15f27bc43bd_.py delete mode 100644 migrations/versions/2b5af066bb9_.py delete mode 100644 migrations/versions/2e9d99288cd_.py delete mode 100644 migrations/versions/357c2809db4_.py delete mode 100644 migrations/versions/36954739c63_.py delete mode 100644 migrations/versions/4754e1427ac_.py create mode 100644 migrations/versions/522430fd0601_.py delete mode 100644 migrations/versions/5e2954a2af18_refactored-auth-oauth2.py delete mode 100644 migrations/versions/81ce4ac01c45_migrate_static_roles.py delete mode 100644 migrations/versions/82184d7d1e88_altered-OAuth2Token-token_type-to-Enum.py delete mode 100644 migrations/versions/8c8b2d23a5_.py delete mode 100644 migrations/versions/beb065460c24_fixed-password-type.py create mode 100644 run_app.py diff --git a/app/extensions/api/namespace.py b/app/extensions/api/namespace.py index c375892c..23ad81ab 100644 --- a/app/extensions/api/namespace.py +++ b/app/extensions/api/namespace.py @@ -154,7 +154,8 @@ def decorator(func_or_class): else: _oauth_scopes = oauth_scopes - oauth_protection_decorator = oauth2.require_oauth(*_oauth_scopes, locations=locations) + # oauth_protection_decorator = oauth2.require_oauth(*_oauth_scopes, locations=locations) + oauth_protection_decorator = oauth2.require_oauth( *_oauth_scopes ) self._register_access_restriction_decorator(protected_func, oauth_protection_decorator) oauth_protected_func = oauth_protection_decorator(protected_func) diff --git a/app/extensions/auth/__init__.py b/app/extensions/auth/__init__.py index 0cd52410..8e255d47 100644 --- a/app/extensions/auth/__init__.py +++ b/app/extensions/auth/__init__.py @@ -4,4 +4,4 @@ ============== """ -from .oauth2 import OAuth2Provider +from .oauth22 import OAuth2Provider diff --git a/app/extensions/auth/oauth22.py b/app/extensions/auth/oauth22.py new file mode 100644 index 00000000..a26dd98d --- /dev/null +++ b/app/extensions/auth/oauth22.py @@ -0,0 +1,156 @@ +import functools, logging +from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector +from authlib.flask.oauth2.sqla import ( + create_query_client_func, + create_save_token_func, + create_revocation_endpoint, + create_bearer_token_validator, +) +from authlib.specs.rfc6749 import grants +from werkzeug.security import gen_salt +from app.extensions import api, db +from app.modules.users.models import User +from app.modules.auth.models2 import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token +from flask_restplus_patched._http import HTTPStatus +from authlib.specs.rfc6750 import BearerTokenValidator as _BearerTokenValidator + +log = logging.getLogger(__name__) + + +def api_invalid_response(req): + """ + This is a default handler for OAuth2Provider, which raises abort exception + with error message in JSON format. + """ + # pylint: disable=unused-argument + api.abort(code=HTTPStatus.UNAUTHORIZED.value) + + +class BearerTokenValidator(_BearerTokenValidator): + def authenticate_token(self, token_string): + return OAuth2Token.query.filter_by(access_token=token_string).first() + + def request_invalid(self, request): + return False + + def token_revoked(self, token): + # TODO: return token.revoked + return token.revoked + +class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): + def create_authorization_code(self, client, grant_user, request): + code = gen_salt(48) + item = OAuth2AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + user_id=grant_user.id, + ) + db.session.add(item) + db.session.commit() + return code + + def parse_authorization_code(self, code, client): + item = OAuth2AuthorizationCode.query.filter_by( + code=code, client_id=client.client_id).first() + if item and not item.is_expired(): + return item + + def delete_authorization_code(self, authorization_code): + db.session.delete(authorization_code) + db.session.commit() + + def authenticate_user(self, authorization_code): + return User.query.get(authorization_code.user_id) + + +class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): + def authenticate_user(self, username, password): + return User.find_with_password(username, password) + + +class RefreshTokenGrant(grants.RefreshTokenGrant): + def authenticate_refresh_token(self, refresh_token): + item = OAuth2Token.query.filter_by(refresh_token=refresh_token).first() + if item and not item.is_refresh_token_expired(): + return item + + def authenticate_user(self, credential): + return User.query.get(credential.user_id) + + +class OAuth2ResourceProtector(ResourceProtector): + def __init__( self ): + super().__init__() + + +class OAuth2Provider(AuthorizationServer): + def __init__(self): + super().__init__() + self._require_oauth = None + + def init_app( self, app, query_client=None, save_token=None ): + if query_client is None: + query_client = create_query_client_func(db.session, OAuth2Client) + if save_token is None: + save_token = create_save_token_func(db.session, OAuth2Token) + + super().init_app( + app, query_client=query_client, save_token=save_token) + + # support all grants + self.register_grant(grants.ImplicitGrant) + self.register_grant(grants.ClientCredentialsGrant) + self.register_grant(AuthorizationCodeGrant) + self.register_grant(PasswordGrant) + self.register_grant(RefreshTokenGrant) + + # support revocation + revocation_cls = create_revocation_endpoint(db.session, OAuth2Token) + self.register_endpoint(revocation_cls) + + # protect resource + bearer_cls = create_bearer_token_validator(db.session, OAuth2Token) + OAuth2ResourceProtector.register_token_validator(bearer_cls()) + self._require_oauth = OAuth2ResourceProtector() + + def require_oauth(self, *args, **kwargs): + # pylint: disable=arguments-differ + """ + A decorator to protect a resource with specified scopes. Access Token + can be fetched from the specified locations (``headers`` or ``form``). + + Arguments: + locations (list): a list of locations (``headers``, ``form``) where + the access token should be looked up. + + Returns: + function: a decorator. + """ + locations = kwargs.get('locations', ('cookies',)) # don't want to pop - original decorator may need + origin_decorator = self._require_oauth(*args, **kwargs) + + def decorator(func): + # pylint: disable=missing-docstring + from flask import request + + origin_decorated_func = origin_decorator(func) + + @functools.wraps(origin_decorated_func) + def wrapper(*args, **kwargs): + # pylint: disable=missing-docstring + if 'headers' not in locations: + # Invalidate authorization if developer specifically + # disables the lookup in the headers. (this may or may not be worth all the hassle) + request.authorization = '!' + # don't think we need below lines because bearer validator already registered + # if 'form' in locations: + # if 'access_token' in request.form: + # request.authorization = 'Bearer %s' % request.form['access_token'] + + return origin_decorated_func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/app/modules/auth/__init__.py b/app/modules/auth/__init__.py index e68e3e7f..0e775644 100644 --- a/app/modules/auth/__init__.py +++ b/app/modules/auth/__init__.py @@ -3,7 +3,6 @@ Auth module =========== """ -from app.extensions import login_manager, oauth2 from app.extensions.api import api_v1 @@ -11,14 +10,16 @@ def load_user_from_request(request): """ Load user from OAuth2 Authentication header. """ - user = None - if hasattr(request, 'oauth'): - user = request.oauth.user - else: - is_valid, oauth = oauth2.verify_request(scopes=[]) - if is_valid: - user = oauth.user - return user + from authlib.flask.oauth2 import current_token + from app.modules.users.models import User + if current_token: + user_id = current_token.user_id + if user_id: + return User.query.get(user_id) + elif current_token.user: + return current_token.user + return None + def init_app(app, **kwargs): # pylint: disable=unused-argument @@ -26,6 +27,8 @@ def init_app(app, **kwargs): Init auth module. """ # Bind Flask-Login for current_user + from app.extensions import login_manager + login_manager.request_loader(load_user_from_request) # Register OAuth scopes @@ -33,7 +36,7 @@ def init_app(app, **kwargs): api_v1.add_oauth_scope('auth:write', "Provide write access to auth details") # Touch underlying modules - from . import models, views, resources # pylint: disable=unused-variable + from . import models2, views, resources # pylint: disable=unused-variable # Mount authentication routes app.register_blueprint(views.auth_blueprint) diff --git a/app/modules/auth/models2.py b/app/modules/auth/models2.py new file mode 100644 index 00000000..97992132 --- /dev/null +++ b/app/modules/auth/models2.py @@ -0,0 +1,58 @@ +import time, enum +from authlib.flask.oauth2.sqla import ( + OAuth2ClientMixin, + OAuth2AuthorizationCodeMixin, + OAuth2TokenMixin, +) +from sqlalchemy_utils.types import ScalarListType + +from app.extensions import db +from app.modules.users.models import User + + +class OAuth2Client(db.Model, OAuth2ClientMixin): + __tablename__ = 'oauth2_client' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) + class ClientTypes(str, enum.Enum): + public = 'public' + confidential = 'confidential' + + client_type = db.Column(db.Enum(ClientTypes), default=ClientTypes.public, nullable=False) + default_scopes = db.Column(ScalarListType(separator=' '), nullable=False) + + user = db.relationship('User') + + @property + def default_redirect_uri(self): + return self.get_default_redirect_uri() + + @classmethod + def find(cls, client_id): + if not client_id: + return + return cls.query.get(client_id) + + +class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): + __tablename__ = 'oauth2_code' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) + user = db.relationship('User') + + +class OAuth2Token(db.Model, OAuth2TokenMixin): + __tablename__ = 'oauth2_token' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) + user = db.relationship('User') + + def is_refresh_token_expired(self): + expires_at = self.issued_at + self.expires_in * 2 + return expires_at < time.time() diff --git a/app/modules/auth/resources.py b/app/modules/auth/resources.py index 5ecb7573..bcc6df2a 100644 --- a/app/modules/auth/resources.py +++ b/app/modules/auth/resources.py @@ -15,7 +15,7 @@ from app.extensions.api import Namespace from . import schemas, parameters -from .models import db, OAuth2Client +from .models2 import db, OAuth2Client log = logging.getLogger(__name__) diff --git a/app/modules/auth/schemas.py b/app/modules/auth/schemas.py index fc8ee6ae..f1212ffe 100644 --- a/app/modules/auth/schemas.py +++ b/app/modules/auth/schemas.py @@ -8,7 +8,7 @@ from flask_marshmallow import base_fields from flask_restplus_patched import ModelSchema -from .models import OAuth2Client +from .models2 import OAuth2Client class BaseOAuth2ClientSchema(ModelSchema): @@ -16,7 +16,6 @@ class BaseOAuth2ClientSchema(ModelSchema): Base OAuth2 client schema exposes only the most general fields. """ default_scopes = base_fields.List(base_fields.String, required=True) - redirect_uris = base_fields.List(base_fields.String, required=True) class Meta: # pylint: disable=missing-docstring @@ -26,7 +25,7 @@ class Meta: OAuth2Client.client_id.key, OAuth2Client.client_type.key, OAuth2Client.default_scopes.key, - OAuth2Client.redirect_uris.key, + # OAuth2Client.redirect_uris.fget.__name__, ) dump_only = ( OAuth2Client.user_id.key, diff --git a/app/modules/auth/views.py b/app/modules/auth/views.py index 4ea4a742..03826b18 100644 --- a/app/modules/auth/views.py +++ b/app/modules/auth/views.py @@ -10,41 +10,51 @@ * http://lepture.com/en/2013/create-oauth-server """ -from flask import Blueprint, request, render_template +from flask import Blueprint, request, render_template, session from flask_login import current_user from flask_restplus_patched._http import HTTPStatus - +from authlib.flask.oauth2 import current_token +from authlib.specs.rfc6749 import OAuth2Error from app.extensions import api, oauth2 -from .models import OAuth2Client +from app.modules.users.models import User +from .models2 import OAuth2Client auth_blueprint = Blueprint('auth', __name__, url_prefix='/auth') # pylint: disable=invalid-name +@auth_blueprint.route('/oauth2/invalid_request', methods=['GET']) +def api_invalid_response(req): + """ + This is a default handler for OAuth2Provider, which raises abort exception + with error message in JSON format. + """ + # pylint: disable=unused-argument + api.abort(code=HTTPStatus.UNAUTHORIZED.value) + + @auth_blueprint.route('/oauth2/token', methods=['GET', 'POST']) -@oauth2.token_handler def access_token(*args, **kwargs): # pylint: disable=unused-argument """ This endpoint is for exchanging/refreshing an access token. Returns: - response (dict): a dictionary or None as the extra credentials for - creating the token response. + token response """ - return None + return oauth2.create_token_response() + @auth_blueprint.route('/oauth2/revoke', methods=['POST']) -@oauth2.revoke_handler def revoke_token(): """ This endpoint allows a user to revoke their access token. """ - pass + return oauth2.create_endpoint_response('revocation') + @auth_blueprint.route('/oauth2/authorize', methods=['GET', 'POST']) -@oauth2.authorize_handler def authorize(*args, **kwargs): # pylint: disable=unused-argument """ @@ -57,16 +67,20 @@ def authorize(*args, **kwargs): # can implement a login page and store cookies with a session id. # ALTERNATIVELY, authorize page can be implemented as SPA (single page # application) - if not current_user.is_authenticated: - return api.abort(code=HTTPStatus.UNAUTHORIZED) + user = current_user() if request.method == 'GET': - client_id = kwargs.get('client_id') - oauth2_client = OAuth2Client.query.get_or_404(client_id=client_id) - kwargs['client'] = oauth2_client - kwargs['user'] = current_user - # TODO: improve template design - return render_template('authorize.html', **kwargs) - - confirm = request.form.get('confirm', 'no') - return confirm == 'yes' + try: + grant = oauth2.validate_consent_request(end_user=user) + except OAuth2Error as error: + return error.error + return render_template('authorize.html', user=user, grant=grant) + if not user and 'username' in request.form: + username = request.form.get('username') + user = User.query.filter_by(username=username).first() + if request.form['confirm']: + grant_user = user + else: + grant_user = None + + return oauth2.create_authorization_response(grant_user) diff --git a/app/modules/users/models.py b/app/modules/users/models.py index 13e1b262..2a874ecf 100644 --- a/app/modules/users/models.py +++ b/app/modules/users/models.py @@ -96,6 +96,9 @@ def __repr__(self): ) ) + def get_user_id( self ): + return self.id + def has_static_role(self, role): return (self.static_roles & role.mask) != 0 diff --git a/app/requirements.txt b/app/requirements.txt index 2c919982..f7baa559 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -1,28 +1,25 @@ -Flask>=1.0,<2.0 +Flask -flask-restplus>=0.10.1 +flask-restplus -Flask-Cors==3.0.2 +Flask-Cors -SQLAlchemy==1.1.5 -SQLAlchemy-Utils==0.32.12 -Flask-SQLAlchemy==2.2 -Alembic==0.8.10 -werkzeug>=0.14.1,<0.15 +SQLAlchemy +sqlalchemy-utils +Flask-SQLAlchemy +Alembic +werkzeug -marshmallow>=2.13.5 -flask-marshmallow==0.7.0 -marshmallow-sqlalchemy==0.12.0 -webargs>=1.4.0 -apispec>=0.20.0 +marshmallow +flask-marshmallow +marshmallow-sqlalchemy +webargs +apispec -bcrypt==3.1.3 -passlib==1.7.1 -Flask-OAuthlib>=0.9.4 -Flask-Login==0.4.0 -permission==0.4.1 +bcrypt +passlib +authlib +Flask-Login +permission -arrow==0.8.0 - -six -enum34; python_version < '3.4' +arrow diff --git a/config.py b/config.py index a7fcd8fa..ab931c36 100644 --- a/config.py +++ b/config.py @@ -29,6 +29,8 @@ class BaseConfig(object): REVERSE_PROXY_SETUP = os.getenv('EXAMPLE_API_REVERSE_PROXY_SETUP', False) + SQLALCHEMY_ECHO=True + AUTHORIZATIONS = { 'oauth2_password': { 'type': 'oauth2', @@ -54,6 +56,13 @@ class BaseConfig(object): 'api', ) + OAUTH2_ERROR_URIS = [('invalid_request', '/oauth2/invalid_request')] + OAUTH2_EXPIRES_IN = { + 'authorization_code':864000, + 'implicit': 3600, + 'password': 864000, + 'client_credentials':864000 + } STATIC_ROOT = os.path.join(PROJECT_ROOT, 'static') SWAGGER_UI_JSONEDITOR = True diff --git a/migrations/__init__.py b/migrations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/migrations/initial_development_data.py b/migrations/initial_development_data.py index 780ae009..74b1b4e3 100644 --- a/migrations/initial_development_data.py +++ b/migrations/initial_development_data.py @@ -8,7 +8,7 @@ from app.extensions import db, api from app.modules.users.models import User -from app.modules.auth.models import OAuth2Client +from app.modules.auth.models2 import OAuth2Client def init_users(): @@ -55,9 +55,9 @@ def init_auth(docs_user): client_id='documentation', client_secret='KQ()SWK)SQK)QWSKQW(SKQ)S(QWSQW(SJ*HQ&HQW*SQ*^SSQWSGQSG', user_id=docs_user.id, - redirect_uris=[], default_scopes=api.api_v1.authorizations['oauth2_password']['scopes'] ) + oauth2_client.redirect_uris = [] db.session.add(oauth2_client) return oauth2_client @@ -69,8 +69,9 @@ def init(): OAuth2Client.default_scopes: api.api_v1.authorizations['oauth2_password']['scopes'], }) - assert User.query.count() == 0, \ - "Database is not empty. You should not re-apply fixtures! Aborted." - - root_user, docs_user, regular_user = init_users() # pylint: disable=unused-variable + # assert User.query.count() == 0, \ + # "Database is not empty. You should not re-apply fixtures! Aborted." + # + # root_user, docs_user, regular_user = init_users() # pylint: disable=unused-variable + docs_user = User.query.filter_by(username='documentation').first() init_auth(docs_user) diff --git a/migrations/script.py.mako b/migrations/script.py.mako index 95702017..b5afa321 100644 --- a/migrations/script.py.mako +++ b/migrations/script.py.mako @@ -12,6 +12,8 @@ down_revision = ${repr(down_revision)} from alembic import op import sqlalchemy as sa +import sqlalchemy_utils + ${imports if imports else ""} def upgrade(): diff --git a/migrations/versions/15f27bc43bd_.py b/migrations/versions/15f27bc43bd_.py deleted file mode 100644 index a6935fa9..00000000 --- a/migrations/versions/15f27bc43bd_.py +++ /dev/null @@ -1,89 +0,0 @@ -"""empty message - -Revision ID: 15f27bc43bd -Revises: None -Create Date: 2015-11-10 18:41:49.419188 - -""" - -# revision identifiers, used by Alembic. -revision = '15f27bc43bd' -down_revision = None - -from alembic import op -import sqlalchemy as sa - - -def upgrade(): - ### commands auto generated by Alembic - please adjust! ### - op.create_table('user', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('username', sa.String(length=80), nullable=False), - sa.Column('password', sa.String(length=128), nullable=False), - sa.Column('email', sa.String(length=120), nullable=False), - sa.Column('first_name', sa.String(length=30), nullable=False), - sa.Column('middle_name', sa.String(length=30), nullable=False), - sa.Column('last_name', sa.String(length=30), nullable=False), - sa.Column('static_roles', sa.Integer(), nullable=False), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('email'), - sa.UniqueConstraint('username') - ) - op.create_table('oauth2_client', - sa.Column('client_id', sa.String(length=40), nullable=False), - sa.Column('client_secret', sa.String(length=55), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('_redirect_uris', sa.Text(), nullable=False), - sa.Column('_default_scopes', sa.Text(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('client_id') - ) - op.create_index(op.f('ix_oauth2_client_user_id'), 'oauth2_client', ['user_id'], unique=False) - op.create_table('oauth2_grant', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('client_id', sa.String(length=40), nullable=False), - sa.Column('code', sa.String(length=255), nullable=False), - sa.Column('redirect_uri', sa.String(length=255), nullable=False), - sa.Column('expires', sa.DateTime(), nullable=False), - sa.Column('_scopes', sa.Text(), nullable=False), - sa.ForeignKeyConstraint(['client_id'], ['oauth2_client.client_id'], ), - sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) - op.create_index(op.f('ix_oauth2_grant_client_id'), 'oauth2_grant', ['client_id'], unique=False) - op.create_index(op.f('ix_oauth2_grant_code'), 'oauth2_grant', ['code'], unique=False) - op.create_index(op.f('ix_oauth2_grant_user_id'), 'oauth2_grant', ['user_id'], unique=False) - op.create_table('oauth2_token', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('client_id', sa.String(length=40), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('token_type', sa.String(length=40), nullable=False), - sa.Column('access_token', sa.String(length=255), nullable=False), - sa.Column('refresh_token', sa.String(length=255), nullable=True), - sa.Column('expires', sa.DateTime(), nullable=False), - sa.Column('_scopes', sa.Text(), nullable=False), - sa.ForeignKeyConstraint(['client_id'], ['oauth2_client.client_id'], ), - sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('access_token'), - sa.UniqueConstraint('refresh_token') - ) - op.create_index(op.f('ix_oauth2_token_client_id'), 'oauth2_token', ['client_id'], unique=False) - op.create_index(op.f('ix_oauth2_token_user_id'), 'oauth2_token', ['user_id'], unique=False) - ### end Alembic commands ### - - -def downgrade(): - ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_oauth2_token_user_id'), table_name='oauth2_token') - op.drop_index(op.f('ix_oauth2_token_client_id'), table_name='oauth2_token') - op.drop_table('oauth2_token') - op.drop_index(op.f('ix_oauth2_grant_user_id'), table_name='oauth2_grant') - op.drop_index(op.f('ix_oauth2_grant_code'), table_name='oauth2_grant') - op.drop_index(op.f('ix_oauth2_grant_client_id'), table_name='oauth2_grant') - op.drop_table('oauth2_grant') - op.drop_index(op.f('ix_oauth2_client_user_id'), table_name='oauth2_client') - op.drop_table('oauth2_client') - op.drop_table('user') - ### end Alembic commands ### diff --git a/migrations/versions/2b5af066bb9_.py b/migrations/versions/2b5af066bb9_.py deleted file mode 100644 index 67f317b3..00000000 --- a/migrations/versions/2b5af066bb9_.py +++ /dev/null @@ -1,39 +0,0 @@ -"""empty message - -Revision ID: 2b5af066bb9 -Revises: 2e9d99288cd -Create Date: 2015-11-25 22:16:31.864584 - -""" - -# revision identifiers, used by Alembic. -revision = '2b5af066bb9' -down_revision = '2e9d99288cd' - -from alembic import op -import sqlalchemy as sa - - -def upgrade(): - ### commands auto generated by Alembic - please adjust! ### - op.create_table('team', - sa.Column('created', sa.DateTime(), nullable=False), - sa.Column('updated', sa.DateTime(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('title', sa.String(length=50), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('team_members', - sa.Column('team_id', sa.Integer(), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['team_id'], ['team.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['user.id'], ) - ) - ### end Alembic commands ### - - -def downgrade(): - ### commands auto generated by Alembic - please adjust! ### - op.drop_table('team_members') - op.drop_table('team') - ### end Alembic commands ### diff --git a/migrations/versions/2e9d99288cd_.py b/migrations/versions/2e9d99288cd_.py deleted file mode 100644 index 8157954e..00000000 --- a/migrations/versions/2e9d99288cd_.py +++ /dev/null @@ -1,38 +0,0 @@ -"""empty message - -Revision ID: 2e9d99288cd -Revises: 36954739c63 -Create Date: 2015-11-23 21:16:54.103342 - -""" - -# revision identifiers, used by Alembic. -revision = '2e9d99288cd' -down_revision = '36954739c63' - -from alembic import op -import sqlalchemy as sa - - -def upgrade(): - ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('user') as batch_op: - batch_op.alter_column('created', - existing_type=sa.DATETIME(), - nullable=False) - batch_op.alter_column('updated', - existing_type=sa.DATETIME(), - nullable=False) - ### end Alembic commands ### - - -def downgrade(): - ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('user') as batch_op: - batch_op.alter_column('updated', - existing_type=sa.DATETIME(), - nullable=True) - batch_op.alter_column('created', - existing_type=sa.DATETIME(), - nullable=True) - ### end Alembic commands ### diff --git a/migrations/versions/357c2809db4_.py b/migrations/versions/357c2809db4_.py deleted file mode 100644 index fd281ecc..00000000 --- a/migrations/versions/357c2809db4_.py +++ /dev/null @@ -1,28 +0,0 @@ -"""empty message - -Revision ID: 357c2809db4 -Revises: 4754e1427ac -Create Date: 2015-11-27 20:22:12.644342 - -""" - -# revision identifiers, used by Alembic. -revision = '357c2809db4' -down_revision = '4754e1427ac' - -from alembic import op -import sqlalchemy as sa - - -def upgrade(): - with op.batch_alter_table('team_member') as batch_op: - batch_op.alter_column('is_leader', - existing_type=sa.BOOLEAN(), - nullable=False) - - -def downgrade(): - with op.batch_alter_table('team_member') as batch_op: - batch_op.alter_column('is_leader', - existing_type=sa.BOOLEAN(), - nullable=True) diff --git a/migrations/versions/36954739c63_.py b/migrations/versions/36954739c63_.py deleted file mode 100644 index 859ac63f..00000000 --- a/migrations/versions/36954739c63_.py +++ /dev/null @@ -1,52 +0,0 @@ -"""empty message - -Revision ID: 36954739c63 -Revises: 15f27bc43bd -Create Date: 2015-11-23 21:00:24.105026 - -""" - -# revision identifiers, used by Alembic. -revision = '36954739c63' -down_revision = '15f27bc43bd' - -from datetime import datetime - -from alembic import op -import sqlalchemy as sa -import sqlalchemy_utils - - -def upgrade(): - ### commands auto generated by Alembic - please adjust! ### - op.add_column('user', sa.Column('created', sa.DateTime(), nullable=True)) - op.add_column('user', sa.Column('updated', sa.DateTime(), nullable=True)) - with op.batch_alter_table('user') as batch_op: - batch_op.alter_column('password', - existing_type=sa.VARCHAR(length=128), - type_=sqlalchemy_utils.types.password.PasswordType(max_length=128), - existing_nullable=False, - postgresql_using='password::bytea') - ### end Alembic commands ### - - user = sa.Table('user', - sa.MetaData(), - sa.Column('created', sa.DateTime()), - sa.Column('updated', sa.DateTime()), - ) - - op.execute( - user.update().values({'created': datetime.now(), 'updated': datetime.now()}) - ) - - -def downgrade(): - ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('user') as batch_op: - batch_op.alter_column('password', - existing_type=sqlalchemy_utils.types.password.PasswordType(max_length=128), - type_=sa.VARCHAR(length=128), - existing_nullable=False) - batch_op.drop_column('updated') - batch_op.drop_column('created') - ### end Alembic commands ### diff --git a/migrations/versions/4754e1427ac_.py b/migrations/versions/4754e1427ac_.py deleted file mode 100644 index a62f3f5d..00000000 --- a/migrations/versions/4754e1427ac_.py +++ /dev/null @@ -1,26 +0,0 @@ -"""empty message - -Revision ID: 4754e1427ac -Revises: 2b5af066bb9 -Create Date: 2015-11-27 19:43:31.118013 - -""" - -# revision identifiers, used by Alembic. -revision = '4754e1427ac' -down_revision = '2b5af066bb9' - -from alembic import op -import sqlalchemy as sa -import sqlalchemy_utils - - -def upgrade(): - op.rename_table('team_members', 'team_member') - op.add_column('team_member', sa.Column('is_leader', sa.Boolean(), nullable=True)) - - -def downgrade(): - with op.batch_alter_table('team_member') as batch_op: - batch_op.drop_column('is_leader') - op.rename_table('team_member', 'team_members') diff --git a/migrations/versions/522430fd0601_.py b/migrations/versions/522430fd0601_.py new file mode 100644 index 00000000..d3aa66e8 --- /dev/null +++ b/migrations/versions/522430fd0601_.py @@ -0,0 +1,124 @@ +"""empty message + +Revision ID: 522430fd0601 +Revises: None +Create Date: 2018-05-17 17:39:42.990467 + +""" + +# revision identifiers, used by Alembic. +revision = '522430fd0601' +down_revision = None + +from alembic import op +import sqlalchemy as sa +import sqlalchemy_utils + + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('team', + sa.Column('created', sa.DateTime(), nullable=False), + sa.Column('updated', sa.DateTime(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('title', sa.String(length=50), nullable=False), + sa.PrimaryKeyConstraint('id', name=op.f('pk_team')) + ) + op.create_table('user', + sa.Column('created', sa.DateTime(), nullable=False), + sa.Column('updated', sa.DateTime(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('username', sa.String(length=80), nullable=False), + sa.Column('password', sa.String(length=128), nullable=False), + sa.Column('email', sa.String(length=120), nullable=False), + sa.Column('first_name', sa.String(length=30), nullable=False), + sa.Column('middle_name', sa.String(length=30), nullable=False), + sa.Column('last_name', sa.String(length=30), nullable=False), + sa.Column('static_roles', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('id', name=op.f('pk_user')), + sa.UniqueConstraint('email', name=op.f('uq_user_email')), + sa.UniqueConstraint('username', name=op.f('uq_user_username')) + ) + op.create_table('oauth2_client', + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('client_secret', sa.String(length=120), nullable=False), + sa.Column('issued_at', sa.Integer(), nullable=False), + sa.Column('expires_at', sa.Integer(), nullable=False), + sa.Column('redirect_uri', sa.Text(), nullable=False), + sa.Column('token_endpoint_auth_method', sa.String(length=48), nullable=True), + sa.Column('grant_type', sa.Text(), nullable=False), + sa.Column('response_type', sa.Text(), nullable=False), + sa.Column('scope', sa.Text(), nullable=False), + sa.Column('client_name', sa.String(length=100), nullable=True), + sa.Column('client_uri', sa.Text(), nullable=True), + sa.Column('logo_uri', sa.Text(), nullable=True), + sa.Column('contact', sa.Text(), nullable=True), + sa.Column('tos_uri', sa.Text(), nullable=True), + sa.Column('policy_uri', sa.Text(), nullable=True), + sa.Column('jwks_uri', sa.Text(), nullable=True), + sa.Column('jwks_text', sa.Text(), nullable=True), + sa.Column('i18n_metadata', sa.Text(), nullable=True), + sa.Column('software_id', sa.String(length=36), nullable=True), + sa.Column('software_version', sa.String(length=48), nullable=True), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('client_type', sa.Enum('public', 'confidential', name='clienttypes'), nullable=False), + sa.Column('default_scopes', sqlalchemy_utils.types.scalar_list.ScalarListType(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_oauth2_client_user_id_user'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_client')) + ) + op.create_index(op.f('ix_oauth2_client_client_id'), 'oauth2_client', ['client_id'], unique=False) + op.create_table('oauth2_code', + sa.Column('code', sa.String(length=120), nullable=False), + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('redirect_uri', sa.Text(), nullable=True), + sa.Column('response_type', sa.Text(), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('nonce', sa.Text(), nullable=True), + sa.Column('auth_time', sa.Integer(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_oauth2_code_user_id_user'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_code')), + sa.UniqueConstraint('code', name=op.f('uq_oauth2_code_code')) + ) + op.create_table('oauth2_token', + sa.Column('client_id', sa.String(length=48), nullable=True), + sa.Column('token_type', sa.String(length=40), nullable=True), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('refresh_token', sa.String(length=255), nullable=True), + sa.Column('scope', sa.Text(), nullable=True), + sa.Column('revoked', sa.Boolean(name='revoked'), nullable=True), + sa.Column('issued_at', sa.Integer(), nullable=False), + sa.Column('expires_in', sa.Integer(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_oauth2_token_user_id_user'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2_token')), + sa.UniqueConstraint('access_token', name=op.f('uq_oauth2_token_access_token')) + ) + op.create_index(op.f('ix_oauth2_token_refresh_token'), 'oauth2_token', ['refresh_token'], unique=False) + op.create_table('team_member', + sa.Column('team_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('is_leader', sa.Boolean(name='is_leader'), nullable=False), + sa.ForeignKeyConstraint(['team_id'], ['team.id'], name=op.f('fk_team_member_team_id_team')), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], name=op.f('fk_team_member_user_id_user')), + sa.PrimaryKeyConstraint('team_id', 'user_id', name=op.f('pk_team_member')), + sa.UniqueConstraint('team_id', 'user_id', name=op.f('uq_team_member_team_id')) + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('team_member') + op.drop_index(op.f('ix_oauth2_token_refresh_token'), table_name='oauth2_token') + op.drop_table('oauth2_token') + op.drop_table('oauth2_code') + op.drop_index(op.f('ix_oauth2_client_client_id'), table_name='oauth2_client') + op.drop_table('oauth2_client') + op.drop_table('user') + op.drop_table('team') + # ### end Alembic commands ### diff --git a/migrations/versions/5e2954a2af18_refactored-auth-oauth2.py b/migrations/versions/5e2954a2af18_refactored-auth-oauth2.py deleted file mode 100644 index 19891efb..00000000 --- a/migrations/versions/5e2954a2af18_refactored-auth-oauth2.py +++ /dev/null @@ -1,180 +0,0 @@ -"""Refactored auth.OAuth2 models - -Revision ID: 5e2954a2af18 -Revises: 81ce4ac01c45 -Create Date: 2016-11-10 16:45:41.153837 - -""" - -# revision identifiers, used by Alembic. -revision = '5e2954a2af18' -down_revision = '81ce4ac01c45' - -import enum - -from alembic import op -import sqlalchemy as sa -import sqlalchemy_utils - - -OAuth2Client = sa.Table( - 'oauth2_client', - sa.MetaData(), - sa.Column('default_scopes', sa.String), - sa.Column('_default_scopes', sa.String), - sa.Column('redirect_uris', sa.String), - sa.Column('_redirect_uris', sa.String), -) - -OAuth2Grant = sa.Table( - 'oauth2_grant', - sa.MetaData(), - sa.Column('scopes', sa.String), - sa.Column('_scopes', sa.String), -) - -OAuth2Token = sa.Table( - 'oauth2_token', - sa.MetaData(), - sa.Column('scopes', sa.String), - sa.Column('_scopes', sa.String), -) - - -def upgrade(): - connection = op.get_bind() - - clienttypes = sa.dialects.postgresql.ENUM('public', 'confidential', name='clienttypes') - clienttypes.create(connection) - - with op.batch_alter_table('oauth2_client') as batch_op: - batch_op.add_column( - sa.Column( - 'client_type', - sa.Enum('public', 'confidential', name='clienttypes'), - server_default='public', - nullable=False - ) - ) - batch_op.add_column( - sa.Column( - 'default_scopes', - sqlalchemy_utils.types.scalar_list.ScalarListType(), - server_default='', - nullable=False - ) - ) - batch_op.add_column( - sa.Column( - 'redirect_uris', - sqlalchemy_utils.types.scalar_list.ScalarListType(), - server_default='', - nullable=False - ) - ) - - connection.execute( - OAuth2Client.update().values(default_scopes=OAuth2Client.c._default_scopes) - ) - connection.execute( - OAuth2Client.update().values(redirect_uris=OAuth2Client.c._redirect_uris) - ) - - with op.batch_alter_table('oauth2_client') as batch_op: - batch_op.drop_column('_redirect_uris') - batch_op.drop_column('_default_scopes') - batch_op.alter_column('redirect_uris', server_default=None) - batch_op.alter_column('default_scopes', server_default=None) - - with op.batch_alter_table('oauth2_grant') as batch_op: - batch_op.add_column( - sa.Column( - 'scopes', - sqlalchemy_utils.types.scalar_list.ScalarListType(), - server_default='', - nullable=False - ) - ) - - connection.execute( - OAuth2Grant.update().values(scopes=OAuth2Grant.c._scopes) - ) - - with op.batch_alter_table('oauth2_grant') as batch_op: - batch_op.drop_column('_scopes') - batch_op.alter_column('scopes', server_default=None) - - with op.batch_alter_table('oauth2_token') as batch_op: - batch_op.add_column( - sa.Column( - 'scopes', - sqlalchemy_utils.types.scalar_list.ScalarListType(), - server_default='', - nullable=False - ) - ) - - connection.execute( - OAuth2Token.update().values(scopes=OAuth2Token.c._scopes) - ) - - with op.batch_alter_table('oauth2_token') as batch_op: - batch_op.drop_column('_scopes') - batch_op.alter_column('scopes', server_default=None) - - -def downgrade(): - connection = op.get_bind() - - with op.batch_alter_table('oauth2_token') as batch_op: - batch_op.add_column(sa.Column('_scopes', sa.TEXT(), server_default='', nullable=False)) - - connection.execute( - OAuth2Token.update().values(_scopes=OAuth2Token.c.scopes) - ) - - with op.batch_alter_table('oauth2_token') as batch_op: - batch_op.drop_column('scopes') - - with op.batch_alter_table('oauth2_grant') as batch_op: - batch_op.add_column(sa.Column('_scopes', sa.TEXT(), server_default='', nullable=False)) - - connection.execute( - OAuth2Grant.update().values(_scopes=OAuth2Grant.c.scopes) - ) - - with op.batch_alter_table('oauth2_grant') as batch_op: - batch_op.drop_column('scopes') - - with op.batch_alter_table('oauth2_client') as batch_op: - batch_op.add_column( - sa.Column( - '_default_scopes', - sa.TEXT(), - server_default='', - nullable=False - ) - ) - batch_op.add_column( - sa.Column( - '_redirect_uris', - sa.TEXT(), - server_default='', - nullable=False - ) - ) - - connection.execute( - OAuth2Client.update().values(_default_scopes=OAuth2Client.c.default_scopes) - ) - connection.execute( - OAuth2Client.update().values(_redirect_uris=OAuth2Client.c.redirect_uris) - ) - - with op.batch_alter_table('oauth2_client') as batch_op: - batch_op.drop_column('redirect_uris') - batch_op.drop_column('default_scopes') - batch_op.drop_column('client_type') - - clienttypes = sa.dialects.postgresql.ENUM('public', 'confidential', name='clienttypes') - clienttypes.drop(connection) diff --git a/migrations/versions/81ce4ac01c45_migrate_static_roles.py b/migrations/versions/81ce4ac01c45_migrate_static_roles.py deleted file mode 100644 index c1ae927b..00000000 --- a/migrations/versions/81ce4ac01c45_migrate_static_roles.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Migrate static roles (new "internal" role type requires data migration) - -Revision ID: 81ce4ac01c45 -Revises: beb065460c24 -Create Date: 2016-11-08 15:58:55.932297 - -""" - -# revision identifiers, used by Alembic. -revision = '81ce4ac01c45' -down_revision = 'beb065460c24' - -from alembic import op -import sqlalchemy as sa - -UserHelper = sa.Table( - 'user', - sa.MetaData(), - sa.Column('id', sa.Integer, primary_key=True), - sa.Column('static_roles', sa.Integer), -) - -def upgrade(): - connection = op.get_bind() - for user in connection.execute(UserHelper.select()): - if user.static_roles & 0x1000: - continue - new_static_roles = user.static_roles >> 1 - connection.execute( - UserHelper.update().where( - UserHelper.c.id == user.id - ).values( - static_roles=new_static_roles - ) - ) - -def downgrade(): - connection = op.get_bind() - for user in connection.execute(UserHelper.select()): - if not user.static_roles & 0x1000: - continue - new_static_roles = user.static_roles << 1 - connection.execute( - UserHelper.update().where( - UserHelper.c.id == user.id - ).values( - static_roles=new_static_roles - ) - ) diff --git a/migrations/versions/82184d7d1e88_altered-OAuth2Token-token_type-to-Enum.py b/migrations/versions/82184d7d1e88_altered-OAuth2Token-token_type-to-Enum.py deleted file mode 100644 index dea7b5b4..00000000 --- a/migrations/versions/82184d7d1e88_altered-OAuth2Token-token_type-to-Enum.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Alter OAuth2Token.token_type to Enum - -Revision ID: 82184d7d1e88 -Revises: 5e2954a2af18 -Create Date: 2016-11-10 21:14:33.787194 - -""" - -# revision identifiers, used by Alembic. -revision = '82184d7d1e88' -down_revision = '5e2954a2af18' - -from alembic import op -import sqlalchemy as sa - - -def upgrade(): - connection = op.get_bind() - - - with op.batch_alter_table('oauth2_token') as batch_op: - tokentypes = sa.dialects.postgresql.ENUM('Bearer', name='tokentypes') - tokentypes.create(connection) - - batch_op.alter_column('token_type', - existing_type=sa.VARCHAR(length=40), - type_=sa.Enum('Bearer', name='tokentypes'), - existing_nullable=False, - postgresql_using='token_type::tokentypes') - - -def downgrade(): - connection = op.get_bind() - - with op.batch_alter_table('oauth2_token') as batch_op: - batch_op.alter_column('token_type', - existing_type=sa.Enum('Bearer', name='tokentypes'), - type_=sa.VARCHAR(length=40), - existing_nullable=False) - - tokentypes = sa.dialects.postgresql.ENUM('Bearer', name='tokentypes') - tokentypes.drop(connection) diff --git a/migrations/versions/8c8b2d23a5_.py b/migrations/versions/8c8b2d23a5_.py deleted file mode 100644 index 6b3d95f7..00000000 --- a/migrations/versions/8c8b2d23a5_.py +++ /dev/null @@ -1,24 +0,0 @@ -"""empty message - -Revision ID: 8c8b2d23a5 -Revises: 357c2809db4 -Create Date: 2015-11-27 20:43:11.241948 - -""" - -# revision identifiers, used by Alembic. -revision = '8c8b2d23a5' -down_revision = '357c2809db4' - -from alembic import op -import sqlalchemy as sa - - -def upgrade(): - with op.batch_alter_table('team_member') as batch_op: - batch_op.create_unique_constraint('_team_user_uc', ['team_id', 'user_id']) - - -def downgrade(): - with op.batch_alter_table('team_member') as batch_op: - batch_op.drop_constraint('_team_user_uc', type_='unique') diff --git a/migrations/versions/beb065460c24_fixed-password-type.py b/migrations/versions/beb065460c24_fixed-password-type.py deleted file mode 100644 index 53ddd899..00000000 --- a/migrations/versions/beb065460c24_fixed-password-type.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Upgraded to the correct PasswordType implementation -https://github.com/kvesteri/sqlalchemy-utils/pull/254 - -Revision ID: beb065460c24 -Revises: 8c8b2d23a5 -Create Date: 2016-11-09 09:10:40.630496 - -""" - -# revision identifiers, used by Alembic. -revision = 'beb065460c24' -down_revision = '8c8b2d23a5' - -from alembic import op -import sqlalchemy as sa -import sqlalchemy_utils - - -UserHelper = sa.Table( - 'user', - sa.MetaData(), - sa.Column('id', sa.Integer, primary_key=True), - sa.Column('password', sa.String), - sa.Column('_password', sa.String), -) - -def upgrade(): - connection = op.get_bind() - if connection.engine.name != 'sqlite': - return - - with op.batch_alter_table('user') as batch_op: - batch_op.add_column(sa.Column('_password', - sqlalchemy_utils.types.password.PasswordType(max_length=128), - server_default='', - nullable=False - )) - - connection.execute( - UserHelper.update().values(_password=UserHelper.c.password) - ) - - with op.batch_alter_table('user') as batch_op: - batch_op.drop_column('password') - batch_op.alter_column('_password', server_default=None, new_column_name='password') - - -def downgrade(): - connection = op.get_bind() - if connection.engine.name != 'sqlite': - return - - with op.batch_alter_table('user') as batch_op: - batch_op.add_column(sa.Column('_password', - type_=sa.NUMERIC(precision=128), - server_default='', - nullable=False - )) - - connection.execute( - UserHelper.update().values(_password=UserHelper.c.password) - ) - - with op.batch_alter_table('user') as batch_op: - batch_op.drop_column('password') - batch_op.alter_column('_password', server_default=None, new_column_name='password') diff --git a/run_app.py b/run_app.py new file mode 100644 index 00000000..c9aa456f --- /dev/null +++ b/run_app.py @@ -0,0 +1,10 @@ +import os +from app import create_app + +def run_app(host='127.0.0.1', port=5000): + flask_config = os.environ[ 'FLASK_CONFIG' ] or 'development' + app = create_app(flask_config) + return app.run( host=host, port=port) + +if __name__ == "__main__": + run_app() diff --git a/tasks/app/db.py b/tasks/app/db.py index 80303b61..b3321ce8 100644 --- a/tasks/app/db.py +++ b/tasks/app/db.py @@ -147,6 +147,23 @@ def merge(context, directory='migrations', revisions='', message=None, branch_la else: raise RuntimeError('Alembic 0.7.0 or greater is required') + +@app_context_task( + help={ + 'tag': "Arbitrary 'tag' name - can be used by custom env.py scripts", + 'sql': "Don't emit SQL to database - dump to standard output instead", + 'revision': "revision identifier", + 'directory': "migration script directory", + 'x_arg': "Additional arguments consumed by custom env.py scripts", + } +) +def droptables(context, directory='migrations', revision='head', sql=False, tag=None, x_arg=None, + app=None): + """Upgrade to a later version""" + + db.drop_all() + + @app_context_task( help={ 'tag': "Arbitrary 'tag' name - can be used by custom env.py scripts", diff --git a/tasks/app/run.py b/tasks/app/run.py index 3703b417..0178e288 100644 --- a/tasks/app/run.py +++ b/tasks/app/run.py @@ -45,6 +45,7 @@ def run( if upgrade_db: # After the installed dependencies the app.db.* tasks might need to be # reloaded to import all necessary dependencies. + import sqlalchemy_utils from . import db as db_tasks reload(db_tasks) diff --git a/tasks/app/users.py b/tasks/app/users.py index 916dfd76..6df89d66 100644 --- a/tasks/app/users.py +++ b/tasks/app/users.py @@ -51,7 +51,7 @@ def create_oauth2_client( Create a new OAuth2 Client associated with a given user (username). """ from app.modules.users.models import User - from app.modules.auth.models import OAuth2Client + from app.modules.auth.models2 import OAuth2Client user = User.query.filter(User.username == username).first() if not user: diff --git a/tasks/requirements.txt b/tasks/requirements.txt index 492bdc14..763d960f 100644 --- a/tasks/requirements.txt +++ b/tasks/requirements.txt @@ -2,3 +2,4 @@ invoke colorlog lockfile requests +sqlalchemy-utils diff --git a/tests/modules/auth/conftest.py b/tests/modules/auth/conftest.py index c55221b7..5dd98421 100644 --- a/tests/modules/auth/conftest.py +++ b/tests/modules/auth/conftest.py @@ -6,7 +6,7 @@ @pytest.yield_fixture() def regular_user_oauth2_client(regular_user, temp_db_instance_helper): # pylint: disable=invalid-name,unused-argument - from app.modules.auth.models import OAuth2Client + from app.modules.auth.models2 import OAuth2Client for _ in temp_db_instance_helper( OAuth2Client( @@ -22,7 +22,7 @@ def regular_user_oauth2_client(regular_user, temp_db_instance_helper): @pytest.yield_fixture() def regular_user_oauth2_token(regular_user_oauth2_client, temp_db_instance_helper): - from app.modules.auth.models import OAuth2Token + from app.modules.auth.models2 import OAuth2Token for _ in temp_db_instance_helper( OAuth2Token( diff --git a/tests/modules/auth/resources/test_creating_oauth2client.py b/tests/modules/auth/resources/test_creating_oauth2client.py index aa9965df..f5bdf6aa 100644 --- a/tests/modules/auth/resources/test_creating_oauth2client.py +++ b/tests/modules/auth/resources/test_creating_oauth2client.py @@ -37,7 +37,7 @@ def test_creating_oauth2_client( assert isinstance(response.json['redirect_uris'], list) # Cleanup - from app.modules.auth.models import OAuth2Client + from app.modules.auth.models2 import OAuth2Client oauth2_client_instance = OAuth2Client.query.get(response.json['client_id']) assert oauth2_client_instance.client_secret == response.json['client_secret'] diff --git a/tests/modules/auth/resources/test_token.py b/tests/modules/auth/resources/test_token.py index 7d032545..a90ab594 100644 --- a/tests/modules/auth/resources/test_token.py +++ b/tests/modules/auth/resources/test_token.py @@ -31,7 +31,7 @@ def test_regular_user_can_retrieve_token( } # Clean up - from app.modules.auth.models import OAuth2Token + from app.modules.auth.models2 import OAuth2Token with db.session.begin(): OAuth2Token.query.filter(OAuth2Token.access_token == response.json['access_token']).delete() @@ -110,7 +110,7 @@ def test_regular_user_can_refresh_token( } # Clean up - from app.modules.auth.models import OAuth2Token + from app.modules.auth.models2 import OAuth2Token with db.session.begin(): OAuth2Token.query.filter( OAuth2Token.access_token == refresh_token_response.json['access_token'] diff --git a/tests/utils.py b/tests/utils.py index 40ff2915..f22a9e78 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -39,7 +39,7 @@ def login(self, user, auth_scopes=None): def open(self, *args, **kwargs): if self._user is not None: from app.extensions import db - from app.modules.auth.models import OAuth2Client, OAuth2Token + from app.modules.auth.models2 import OAuth2Client, OAuth2Token oauth2_client = OAuth2Client( client_id='OAUTH2_%s' % self._user.username, From 95acfb6f237555c94994bbddb7e9c6b94f0a5c0f Mon Sep 17 00:00:00 2001 From: James O'Brien Date: Mon, 21 May 2018 22:43:06 -0700 Subject: [PATCH 2/3] Fixed --- .gitignore | 1 - app/extensions/__init__.py | 2 +- app/extensions/auth/oauth22.py | 21 ++++++++++- app/extensions/flask_sqlalchemy/__init__.py | 6 +-- app/modules/auth/__init__.py | 17 --------- app/modules/auth/models2.py | 14 ++++++- app/modules/auth/resources.py | 1 + app/modules/auth/views.py | 25 ++++++++---- app/requirements.txt | 4 +- app/templates/authorize.html | 41 +++++++++----------- app/templates/create_client.html | 42 +++++++++++++++++++++ app/templates/home.html | 38 +++++++++---------- migrations/initial_development_data.py | 22 ++++++----- 13 files changed, 147 insertions(+), 87 deletions(-) create mode 100644 app/templates/create_client.html diff --git a/.gitignore b/.gitignore index 6b1e44a1..b843c7a7 100644 --- a/.gitignore +++ b/.gitignore @@ -49,7 +49,6 @@ docs/_build/ *.bak local_config.py static/ -example.db .idea/ clients/*/swagger.json clients/*/dist diff --git a/app/extensions/__init__.py b/app/extensions/__init__.py index 0eee3ab4..bfce34cb 100644 --- a/app/extensions/__init__.py +++ b/app/extensions/__init__.py @@ -22,7 +22,7 @@ force_auto_coercion() force_instant_defaults() -from flask_login import LoginManager +from .login import LoginManager login_manager = LoginManager() from flask_marshmallow import Marshmallow diff --git a/app/extensions/auth/oauth22.py b/app/extensions/auth/oauth22.py index a26dd98d..795968ae 100644 --- a/app/extensions/auth/oauth22.py +++ b/app/extensions/auth/oauth22.py @@ -1,5 +1,5 @@ import functools, logging -from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector +from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector, current_token from authlib.flask.oauth2.sqla import ( create_query_client_func, create_save_token_func, @@ -8,7 +8,7 @@ ) from authlib.specs.rfc6749 import grants from werkzeug.security import gen_salt -from app.extensions import api, db +from app.extensions import api, login_manager from app.modules.users.models import User from app.modules.auth.models2 import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token from flask_restplus_patched._http import HTTPStatus @@ -17,6 +17,20 @@ log = logging.getLogger(__name__) +@login_manager.request_loader +def load_user_from_request(request): + """ + Load user from OAuth2 Authentication header. + """ + from app.modules.users.models import User + if current_token: + user_id = current_token.user.id + if user_id: + return User.query.get(user_id) + elif current_token.user: + return current_token.user + return None + def api_invalid_response(req): """ This is a default handler for OAuth2Provider, which raises abort exception @@ -39,6 +53,7 @@ def token_revoked(self, token): class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): def create_authorization_code(self, client, grant_user, request): + from app.extensions import db code = gen_salt(48) item = OAuth2AuthorizationCode( code=code, @@ -91,6 +106,8 @@ def __init__(self): self._require_oauth = None def init_app( self, app, query_client=None, save_token=None ): + from app.extensions import db + db.init_app(app) if query_client is None: query_client = create_query_client_func(db.session, OAuth2Client) if save_token is None: diff --git a/app/extensions/flask_sqlalchemy/__init__.py b/app/extensions/flask_sqlalchemy/__init__.py index 753706c3..40ec8863 100644 --- a/app/extensions/flask_sqlalchemy/__init__.py +++ b/app/extensions/flask_sqlalchemy/__init__.py @@ -45,9 +45,9 @@ class SQLAlchemy(BaseSQLAlchemy): """ def __init__(self, *args, **kwargs): - if 'session_options' not in kwargs: - kwargs['session_options'] = {} - kwargs['session_options']['autocommit'] = True + # if 'session_options' not in kwargs: + # kwargs['session_options'] = {} + # # kwargs['session_options']['autocommit'] = True # Configure Constraint Naming Conventions: # http://docs.sqlalchemy.org/en/latest/core/constraints.html#constraint-naming-conventions kwargs['metadata'] = MetaData( diff --git a/app/modules/auth/__init__.py b/app/modules/auth/__init__.py index 0e775644..37e4f3c4 100644 --- a/app/modules/auth/__init__.py +++ b/app/modules/auth/__init__.py @@ -6,20 +6,6 @@ from app.extensions.api import api_v1 -def load_user_from_request(request): - """ - Load user from OAuth2 Authentication header. - """ - from authlib.flask.oauth2 import current_token - from app.modules.users.models import User - if current_token: - user_id = current_token.user_id - if user_id: - return User.query.get(user_id) - elif current_token.user: - return current_token.user - return None - def init_app(app, **kwargs): # pylint: disable=unused-argument @@ -27,9 +13,6 @@ def init_app(app, **kwargs): Init auth module. """ # Bind Flask-Login for current_user - from app.extensions import login_manager - - login_manager.request_loader(load_user_from_request) # Register OAuth scopes api_v1.add_oauth_scope('auth:read', "Provide access to auth details") diff --git a/app/modules/auth/models2.py b/app/modules/auth/models2.py index 97992132..206e255e 100644 --- a/app/modules/auth/models2.py +++ b/app/modules/auth/models2.py @@ -5,12 +5,21 @@ OAuth2TokenMixin, ) from sqlalchemy_utils.types import ScalarListType +from sqlalchemy.ext.hybrid import hybrid_property from app.extensions import db -from app.modules.users.models import User -class OAuth2Client(db.Model, OAuth2ClientMixin): +class MyOAuth2ClientMixin(OAuth2ClientMixin): + def check_requested_scopes(self, scopes): + if type(self.scope) == str: + allowed = set(self.scope.split()) + elif type(self.scope) == list: + allowed = set(self.scope) + + return allowed.issuperset(set(scopes)) + +class OAuth2Client(db.Model, MyOAuth2ClientMixin): __tablename__ = 'oauth2_client' id = db.Column(db.Integer, primary_key=True) @@ -22,6 +31,7 @@ class ClientTypes(str, enum.Enum): client_type = db.Column(db.Enum(ClientTypes), default=ClientTypes.public, nullable=False) default_scopes = db.Column(ScalarListType(separator=' '), nullable=False) + scope = db.Column(ScalarListType(separator=' '), nullable=False) user = db.relationship('User') diff --git a/app/modules/auth/resources.py b/app/modules/auth/resources.py index bcc6df2a..3888cf65 100644 --- a/app/modules/auth/resources.py +++ b/app/modules/auth/resources.py @@ -62,6 +62,7 @@ def post(self, args): db.session, default_error_message="Failed to create a new OAuth2 client." ): + # TODO: reconsider using gen_salt new_oauth2_client = OAuth2Client( user_id=current_user.id, diff --git a/app/modules/auth/views.py b/app/modules/auth/views.py index 03826b18..e7bb3f0a 100644 --- a/app/modules/auth/views.py +++ b/app/modules/auth/views.py @@ -11,18 +11,22 @@ """ from flask import Blueprint, request, render_template, session -from flask_login import current_user +# from flask_login import current_user from flask_restplus_patched._http import HTTPStatus from authlib.flask.oauth2 import current_token from authlib.specs.rfc6749 import OAuth2Error -from app.extensions import api, oauth2 +from app.extensions import api, oauth2, db from app.modules.users.models import User from .models2 import OAuth2Client - auth_blueprint = Blueprint('auth', __name__, url_prefix='/auth') # pylint: disable=invalid-name +def current_user(): + if 'id' in session: + uid = session['id'] + return User.query.get(uid) + return None @auth_blueprint.route('/oauth2/invalid_request', methods=['GET']) def api_invalid_response(req): @@ -43,7 +47,10 @@ def access_token(*args, **kwargs): Returns: token response """ - return oauth2.create_token_response() + # with db.session.begin(): + response = oauth2.create_token_response() + + return response @auth_blueprint.route('/oauth2/revoke', methods=['POST']) @@ -51,7 +58,9 @@ def revoke_token(): """ This endpoint allows a user to revoke their access token. """ - return oauth2.create_endpoint_response('revocation') + with db.session.begin(): + response = oauth2.create_endpoint_response('revocation') + return response @auth_blueprint.route('/oauth2/authorize', methods=['GET', 'POST']) @@ -67,8 +76,8 @@ def authorize(*args, **kwargs): # can implement a login page and store cookies with a session id. # ALTERNATIVELY, authorize page can be implemented as SPA (single page # application) - user = current_user() + user = current_user() if request.method == 'GET': try: grant = oauth2.validate_consent_request(end_user=user) @@ -82,5 +91,7 @@ def authorize(*args, **kwargs): grant_user = user else: grant_user = None + with db.session.begin(): + response = oauth2.create_authorization_response(grant_user=grant_user) - return oauth2.create_authorization_response(grant_user) + return response or None diff --git a/app/requirements.txt b/app/requirements.txt index f7baa559..9b219a17 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -1,6 +1,6 @@ Flask -flask-restplus +flask-restplus==0.10.1 Flask-Cors @@ -18,8 +18,8 @@ apispec bcrypt passlib -authlib Flask-Login permission +Authlib>=0.6 arrow diff --git a/app/templates/authorize.html b/app/templates/authorize.html index a9078453..4e2a6e7d 100644 --- a/app/templates/authorize.html +++ b/app/templates/authorize.html @@ -1,23 +1,18 @@ - - - - - Authorization - - -

Client: {{ client.client_id }}

-

User: {{ user.username }}

-
-

Allow access?

- - - - - {% if state %} - - {% endif %} - - -
- - +

{{grant.client.client_name}} is requesting: +{{ grant.request.scope }} +

+ +
+ + {% if not user %} +

You haven't logged in. Log in with:

+
+ +
+ {% endif %} +
+ +
diff --git a/app/templates/create_client.html b/app/templates/create_client.html new file mode 100644 index 00000000..20b3ed32 --- /dev/null +++ b/app/templates/create_client.html @@ -0,0 +1,42 @@ + + +Home + +
+ + + + + + + + +
diff --git a/app/templates/home.html b/app/templates/home.html index 6532d470..57714520 100644 --- a/app/templates/home.html +++ b/app/templates/home.html @@ -1,20 +1,20 @@ - - - - - - - - {% if user %} -

You are {{ user.username }}

- {% else %} -

You are not authenticated

- {% endif %} +{% if user %} + +
Logged in as {{user}} (Log Out)
-

Type any username:

-
- - -
- - +{% for client in clients %} +
+{{ client.client_info|tojson }}
+{{ client.client_metadata|tojson }}
+
+
+{% endfor %} + +
Create Client + +{% else %} +
+ + +
+{% endif %} diff --git a/migrations/initial_development_data.py b/migrations/initial_development_data.py index 74b1b4e3..aee0fd0b 100644 --- a/migrations/initial_development_data.py +++ b/migrations/initial_development_data.py @@ -55,6 +55,7 @@ def init_auth(docs_user): client_id='documentation', client_secret='KQ()SWK)SQK)QWSKQW(SKQ)S(QWSQW(SJ*HQ&HQW*SQ*^SSQWSGQSG', user_id=docs_user.id, + scope=api.api_v1.authorizations['oauth2_password']['scopes'], default_scopes=api.api_v1.authorizations['oauth2_password']['scopes'] ) oauth2_client.redirect_uris = [] @@ -64,14 +65,15 @@ def init_auth(docs_user): def init(): # Automatically update `default_scopes` for `documentation` OAuth2 Client, # as it is nice to have an ability to evaluate all available API calls. - with db.session.begin(): - OAuth2Client.query.filter(OAuth2Client.client_id == 'documentation').update({ - OAuth2Client.default_scopes: api.api_v1.authorizations['oauth2_password']['scopes'], - }) - # assert User.query.count() == 0, \ - # "Database is not empty. You should not re-apply fixtures! Aborted." - # - # root_user, docs_user, regular_user = init_users() # pylint: disable=unused-variable - docs_user = User.query.filter_by(username='documentation').first() - init_auth(docs_user) + if User.query.count()==0: + root_user, docs_user, regular_user = init_users() # pylint: disable=unused-variable + init_auth( root_user ) + # with db.session.begin(): + root_user = User.query.filter(User.username == 'root').first() + client = OAuth2Client.query.filter(OAuth2Client.user_id == root_user.id).first() + client.default_scopes = api.api_v1.authorizations['oauth2_password']['scopes'] + client.scope = api.api_v1.authorizations['oauth2_password']['scopes'] + client.grant_types = ['authorization_code', 'password'] + db.session.add(client) + db.session.commit() From 1562c0d2142077b0fe5b8b815345205101d0df80 Mon Sep 17 00:00:00 2001 From: James O'Brien Date: Tue, 22 May 2018 11:36:22 -0700 Subject: [PATCH 3/3] Working pretty well - moved oauth2 models and extension to proper filenames. Tests are mostly failing due to the problems I was having with sessions, but also because there is no OAuth2Grant, and also Flask-login anonymous users seem to be handled differently. --- .gitignore | 4 + app/extensions/api/namespace.py | 29 ++- app/extensions/auth/__init__.py | 2 +- app/extensions/auth/oauth2.py | 232 ++++++++++-------- app/extensions/auth/oauth22.py | 173 ------------- app/extensions/flask_sqlalchemy/__init__.py | 2 +- app/extensions/login/__init__.py | 25 ++ app/modules/auth/__init__.py | 9 +- app/modules/auth/models.py | 132 +++------- app/modules/auth/models2.py | 68 ----- app/modules/auth/parameters.py | 4 +- app/modules/auth/resources.py | 12 +- app/modules/auth/schemas.py | 2 +- app/modules/auth/views.py | 24 +- app/modules/users/models.py | 15 +- app/templates/authorize.html | 7 +- migrations/initial_development_data.py | 16 +- tasks/app/users.py | 2 +- tests/modules/auth/conftest.py | 4 +- .../resources/test_creating_oauth2client.py | 2 +- tests/modules/auth/resources/test_token.py | 4 +- .../auth/test_login_manager_integration.py | 5 +- tests/utils.py | 2 +- 23 files changed, 276 insertions(+), 499 deletions(-) delete mode 100644 app/extensions/auth/oauth22.py create mode 100644 app/extensions/login/__init__.py delete mode 100644 app/modules/auth/models2.py diff --git a/.gitignore b/.gitignore index b843c7a7..62646140 100644 --- a/.gitignore +++ b/.gitignore @@ -4,14 +4,17 @@ __pycache__/ # C extensions *.so +.pytest_cache/ # Distribution / packaging .Python +/.venv/ /env/ /build/ *.egg-info/ .installed.cfg *.egg +example.db # Installer logs pip-log.txt @@ -34,6 +37,7 @@ coverage.xml .project .pydevproject + # Rope .ropeproject diff --git a/app/extensions/api/namespace.py b/app/extensions/api/namespace.py index 23ad81ab..aa818b3c 100644 --- a/app/extensions/api/namespace.py +++ b/app/extensions/api/namespace.py @@ -296,21 +296,26 @@ def commit_or_abort(self, session, default_error_message="The operation failed t session: db.session instance default_error_message: Custom error message - Exampple: + Example: >>> with api.commit_or_abort(db.session): ... team = Team(**args) ... db.session.add(team) ... return team """ + from werkzeug.exceptions import HTTPException try: - with session.begin(): - yield - except ValueError as exception: - log.info("Database transaction was rolled back due to: %r", exception) - http_exceptions.abort(code=HTTPStatus.CONFLICT, message=str(exception)) - except sqlalchemy.exc.IntegrityError as exception: - log.info("Database transaction was rolled back due to: %r", exception) - http_exceptions.abort( - code=HTTPStatus.CONFLICT, - message=default_error_message - ) + try: + yield session + session.commit() + except ValueError as exception: + log.info( "Database transaction was rolled back due to: %r", exception ) + http_exceptions.abort( code=HTTPStatus.CONFLICT, message=str( exception ) ) + except sqlalchemy.exc.IntegrityError as exception: + log.info( "Database transaction was rolled back due to: %r", exception ) + http_exceptions.abort( + code=HTTPStatus.CONFLICT, + message=default_error_message + ) + except HTTPException: + session.rollback() + raise diff --git a/app/extensions/auth/__init__.py b/app/extensions/auth/__init__.py index 8e255d47..0cd52410 100644 --- a/app/extensions/auth/__init__.py +++ b/app/extensions/auth/__init__.py @@ -4,4 +4,4 @@ ============== """ -from .oauth22 import OAuth2Provider +from .oauth2 import OAuth2Provider diff --git a/app/extensions/auth/oauth2.py b/app/extensions/auth/oauth2.py index 31236004..10009b7e 100644 --- a/app/extensions/auth/oauth2.py +++ b/app/extensions/auth/oauth2.py @@ -1,102 +1,36 @@ -# encoding: utf-8 -# pylint: disable=no-self-use -""" -OAuth2 provider setup. - -It is based on the code from the example: -https://github.com/lepture/example-oauth2-server - -More details are available here: -* http://flask-oauthlib.readthedocs.org/en/latest/oauth2.html -* http://lepture.com/en/2013/create-oauth-server -""" - -from datetime import datetime, timedelta -import functools -import logging - -from flask_login import current_user -from flask_oauthlib import provider +import functools, logging +from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector, current_token +from authlib.flask.oauth2.sqla import ( + create_query_client_func, + create_save_token_func, + create_revocation_endpoint, + create_bearer_token_validator, +) +from authlib.specs.rfc6749 import grants +from werkzeug.security import gen_salt +from app.extensions import api, login_manager +from app.modules.users.models import User +from app.modules.auth.models import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token from flask_restplus_patched._http import HTTPStatus -import sqlalchemy - -from app.extensions import api, db - +from authlib.specs.rfc6750 import BearerTokenValidator as _BearerTokenValidator log = logging.getLogger(__name__) -class OAuth2RequestValidator(provider.OAuth2RequestValidator): - # pylint: disable=abstract-method +@login_manager.request_loader +def load_user_from_request(request): """ - A project-specific implementation of OAuth2RequestValidator, which connects - our User and OAuth2* implementations together. + Load user from OAuth2 Authentication header. """ - - def __init__(self): - from app.modules.auth.models import OAuth2Client, OAuth2Grant, OAuth2Token - self._client_class = OAuth2Client - self._grant_class = OAuth2Grant - self._token_class = OAuth2Token - super(OAuth2RequestValidator, self).__init__( - usergetter=self._usergetter, - clientgetter=self._client_class.find, - tokengetter=self._token_class.find, - grantgetter=self._grant_class.find, - tokensetter=self._tokensetter, - grantsetter=self._grantsetter, - ) - - def _usergetter(self, username, password, client, request): - # pylint: disable=method-hidden,unused-argument - # Avoid circular dependencies - from app.modules.users.models import User - return User.find_with_password(username, password) - - def _tokensetter(self, token, request, *args, **kwargs): - # pylint: disable=method-hidden,unused-argument - # TODO: review expiration time - expires_in = token['expires_in'] - expires = datetime.utcnow() + timedelta(seconds=expires_in) - - try: - with db.session.begin(): - token_instance = self._token_class( - access_token=token['access_token'], - refresh_token=token.get('refresh_token'), - token_type=token['token_type'], - scopes=[scope for scope in token['scope'].split(' ') if scope], - expires=expires, - client_id=request.client.client_id, - user_id=request.user.id, - ) - db.session.add(token_instance) - except sqlalchemy.exc.IntegrityError: - log.exception("Token-setter has failed.") - return None - return token_instance - - def _grantsetter(self, client_id, code, request, *args, **kwargs): - # pylint: disable=method-hidden,unused-argument - # TODO: review expiration time - # decide the expires time yourself - expires = datetime.utcnow() + timedelta(seconds=100) - try: - with db.session.begin(): - grant_instance = self._grant_class( - client_id=client_id, - code=code['code'], - redirect_uri=request.redirect_uri, - scopes=request.scopes, - user=current_user, - expires=expires - ) - db.session.add(grant_instance) - except sqlalchemy.exc.IntegrityError: - log.exception("Grant-setter has failed.") - return None - return grant_instance - + from app.modules.users.models import User + if current_token: + user = current_token.user + if user: + return user + user_id = current_token.user.id + if user_id: + return User.query.get(user_id) + return None def api_invalid_response(req): """ @@ -107,19 +41,97 @@ def api_invalid_response(req): api.abort(code=HTTPStatus.UNAUTHORIZED.value) -class OAuth2Provider(provider.OAuth2Provider): - """ - A helper class which connects OAuth2RequestValidator with OAuth2Provider. - """ +class BearerTokenValidator(_BearerTokenValidator): + def authenticate_token(self, token_string): + return OAuth2Token.query.filter_by(access_token=token_string).first() + + def request_invalid(self, request): + return False - def __init__(self, *args, **kwargs): - super(OAuth2Provider, self).__init__(*args, **kwargs) - self.invalid_response(api_invalid_response) + def token_revoked(self, token): + # TODO: return token.revoked + return token.revoked + +class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): + def create_authorization_code(self, client, grant_user, request): + from app.extensions import db + code = gen_salt(48) + item = OAuth2AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + user_id=grant_user.id, + ) + db.session.add(item) + db.session.commit() + return code + + def parse_authorization_code(self, code, client): + item = OAuth2AuthorizationCode.query.filter_by( + code=code, client_id=client.client_id).first() + if item and not item.is_expired(): + return item + + def delete_authorization_code(self, authorization_code): + from app.extensions import db + db.session.delete(authorization_code) + db.session.commit() + + def authenticate_user(self, authorization_code): + return User.query.get(authorization_code.user_id) + + +class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): + def authenticate_user(self, username, password): + return User.find_with_password(username, password) - def init_app(self, app): - assert app.config['SECRET_KEY'], "SECRET_KEY must be configured!" - super(OAuth2Provider, self).init_app(app) - self._validator = OAuth2RequestValidator() + +class RefreshTokenGrant(grants.RefreshTokenGrant): + def authenticate_refresh_token(self, refresh_token): + item = OAuth2Token.query.filter_by(refresh_token=refresh_token).first() + if item and not item.is_refresh_token_expired(): + return item + + def authenticate_user(self, credential): + return User.query.get(credential.user_id) + + +class OAuth2ResourceProtector(ResourceProtector): + def __init__( self ): + super().__init__() + + +class OAuth2Provider(AuthorizationServer): + def __init__(self): + super().__init__() + self._require_oauth = None + + def init_app( self, app, query_client=None, save_token=None ): + from app.extensions import db + if query_client is None: + query_client = create_query_client_func(db.session, OAuth2Client) + if save_token is None: + save_token = create_save_token_func(db.session, OAuth2Token) + + super().init_app( + app, query_client=query_client, save_token=save_token) + + # support all grants + self.register_grant(grants.ImplicitGrant) + self.register_grant(grants.ClientCredentialsGrant) + self.register_grant(AuthorizationCodeGrant) + self.register_grant(PasswordGrant) + self.register_grant(RefreshTokenGrant) + + # support revocation + revocation_cls = create_revocation_endpoint(db.session, OAuth2Token) + self.register_endpoint(revocation_cls) + + # protect resource + bearer_cls = create_bearer_token_validator(db.session, OAuth2Token) + OAuth2ResourceProtector.register_token_validator(bearer_cls()) + self._require_oauth = OAuth2ResourceProtector() def require_oauth(self, *args, **kwargs): # pylint: disable=arguments-differ @@ -134,8 +146,8 @@ def require_oauth(self, *args, **kwargs): Returns: function: a decorator. """ - locations = kwargs.pop('locations', ('cookies',)) - origin_decorator = super(OAuth2Provider, self).require_oauth(*args, **kwargs) + locations = kwargs.get('locations', ('cookies',)) # don't want to pop - original decorator may need + origin_decorator = self._require_oauth(*args, **kwargs) def decorator(func): # pylint: disable=missing-docstring @@ -148,11 +160,13 @@ def wrapper(*args, **kwargs): # pylint: disable=missing-docstring if 'headers' not in locations: # Invalidate authorization if developer specifically - # disables the lookup in the headers. + # disables the lookup in the headers. (this may or may not be worth all the hassle) request.authorization = '!' - if 'form' in locations: - if 'access_token' in request.form: - request.authorization = 'Bearer %s' % request.form['access_token'] + # don't think we need below lines because bearer validator already registered + # if 'form' in locations: + # if 'access_token' in request.form: + # request.authorization = 'Bearer %s' % request.form['access_token'] + return origin_decorated_func(*args, **kwargs) return wrapper diff --git a/app/extensions/auth/oauth22.py b/app/extensions/auth/oauth22.py deleted file mode 100644 index 795968ae..00000000 --- a/app/extensions/auth/oauth22.py +++ /dev/null @@ -1,173 +0,0 @@ -import functools, logging -from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector, current_token -from authlib.flask.oauth2.sqla import ( - create_query_client_func, - create_save_token_func, - create_revocation_endpoint, - create_bearer_token_validator, -) -from authlib.specs.rfc6749 import grants -from werkzeug.security import gen_salt -from app.extensions import api, login_manager -from app.modules.users.models import User -from app.modules.auth.models2 import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token -from flask_restplus_patched._http import HTTPStatus -from authlib.specs.rfc6750 import BearerTokenValidator as _BearerTokenValidator - -log = logging.getLogger(__name__) - - -@login_manager.request_loader -def load_user_from_request(request): - """ - Load user from OAuth2 Authentication header. - """ - from app.modules.users.models import User - if current_token: - user_id = current_token.user.id - if user_id: - return User.query.get(user_id) - elif current_token.user: - return current_token.user - return None - -def api_invalid_response(req): - """ - This is a default handler for OAuth2Provider, which raises abort exception - with error message in JSON format. - """ - # pylint: disable=unused-argument - api.abort(code=HTTPStatus.UNAUTHORIZED.value) - - -class BearerTokenValidator(_BearerTokenValidator): - def authenticate_token(self, token_string): - return OAuth2Token.query.filter_by(access_token=token_string).first() - - def request_invalid(self, request): - return False - - def token_revoked(self, token): - # TODO: return token.revoked - return token.revoked - -class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): - def create_authorization_code(self, client, grant_user, request): - from app.extensions import db - code = gen_salt(48) - item = OAuth2AuthorizationCode( - code=code, - client_id=client.client_id, - redirect_uri=request.redirect_uri, - scope=request.scope, - user_id=grant_user.id, - ) - db.session.add(item) - db.session.commit() - return code - - def parse_authorization_code(self, code, client): - item = OAuth2AuthorizationCode.query.filter_by( - code=code, client_id=client.client_id).first() - if item and not item.is_expired(): - return item - - def delete_authorization_code(self, authorization_code): - db.session.delete(authorization_code) - db.session.commit() - - def authenticate_user(self, authorization_code): - return User.query.get(authorization_code.user_id) - - -class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): - def authenticate_user(self, username, password): - return User.find_with_password(username, password) - - -class RefreshTokenGrant(grants.RefreshTokenGrant): - def authenticate_refresh_token(self, refresh_token): - item = OAuth2Token.query.filter_by(refresh_token=refresh_token).first() - if item and not item.is_refresh_token_expired(): - return item - - def authenticate_user(self, credential): - return User.query.get(credential.user_id) - - -class OAuth2ResourceProtector(ResourceProtector): - def __init__( self ): - super().__init__() - - -class OAuth2Provider(AuthorizationServer): - def __init__(self): - super().__init__() - self._require_oauth = None - - def init_app( self, app, query_client=None, save_token=None ): - from app.extensions import db - db.init_app(app) - if query_client is None: - query_client = create_query_client_func(db.session, OAuth2Client) - if save_token is None: - save_token = create_save_token_func(db.session, OAuth2Token) - - super().init_app( - app, query_client=query_client, save_token=save_token) - - # support all grants - self.register_grant(grants.ImplicitGrant) - self.register_grant(grants.ClientCredentialsGrant) - self.register_grant(AuthorizationCodeGrant) - self.register_grant(PasswordGrant) - self.register_grant(RefreshTokenGrant) - - # support revocation - revocation_cls = create_revocation_endpoint(db.session, OAuth2Token) - self.register_endpoint(revocation_cls) - - # protect resource - bearer_cls = create_bearer_token_validator(db.session, OAuth2Token) - OAuth2ResourceProtector.register_token_validator(bearer_cls()) - self._require_oauth = OAuth2ResourceProtector() - - def require_oauth(self, *args, **kwargs): - # pylint: disable=arguments-differ - """ - A decorator to protect a resource with specified scopes. Access Token - can be fetched from the specified locations (``headers`` or ``form``). - - Arguments: - locations (list): a list of locations (``headers``, ``form``) where - the access token should be looked up. - - Returns: - function: a decorator. - """ - locations = kwargs.get('locations', ('cookies',)) # don't want to pop - original decorator may need - origin_decorator = self._require_oauth(*args, **kwargs) - - def decorator(func): - # pylint: disable=missing-docstring - from flask import request - - origin_decorated_func = origin_decorator(func) - - @functools.wraps(origin_decorated_func) - def wrapper(*args, **kwargs): - # pylint: disable=missing-docstring - if 'headers' not in locations: - # Invalidate authorization if developer specifically - # disables the lookup in the headers. (this may or may not be worth all the hassle) - request.authorization = '!' - # don't think we need below lines because bearer validator already registered - # if 'form' in locations: - # if 'access_token' in request.form: - # request.authorization = 'Bearer %s' % request.form['access_token'] - - return origin_decorated_func(*args, **kwargs) - - return wrapper - - return decorator diff --git a/app/extensions/flask_sqlalchemy/__init__.py b/app/extensions/flask_sqlalchemy/__init__.py index 40ec8863..6f796837 100644 --- a/app/extensions/flask_sqlalchemy/__init__.py +++ b/app/extensions/flask_sqlalchemy/__init__.py @@ -47,7 +47,7 @@ class SQLAlchemy(BaseSQLAlchemy): def __init__(self, *args, **kwargs): # if 'session_options' not in kwargs: # kwargs['session_options'] = {} - # # kwargs['session_options']['autocommit'] = True + # kwargs['session_options']['autocommit'] = True # Configure Constraint Naming Conventions: # http://docs.sqlalchemy.org/en/latest/core/constraints.html#constraint-naming-conventions kwargs['metadata'] = MetaData( diff --git a/app/extensions/login/__init__.py b/app/extensions/login/__init__.py new file mode 100644 index 00000000..f4d210a1 --- /dev/null +++ b/app/extensions/login/__init__.py @@ -0,0 +1,25 @@ +from flask import g +from flask.sessions import SecureCookieSessionInterface +from flask_login import user_loaded_from_header +from flask_login import LoginManager as OriginalLoginManager + +class CustomSessionInterface(SecureCookieSessionInterface): + """Prevent creating session from API requests.""" + def save_session(self, *args, **kwargs): + if g.get('login_via_header'): + return + return super(CustomSessionInterface, self).save_session(*args, + **kwargs) + + +@user_loaded_from_header.connect +def user_loaded_from_header(self, user=None): + g.login_via_header = True + + +class LoginManager(OriginalLoginManager): + def init_app(self, app): + app.session_interface = CustomSessionInterface() + super().init_app(app) + + diff --git a/app/modules/auth/__init__.py b/app/modules/auth/__init__.py index 37e4f3c4..f1e7d630 100644 --- a/app/modules/auth/__init__.py +++ b/app/modules/auth/__init__.py @@ -3,9 +3,16 @@ Auth module =========== """ +from flask_login import current_user from app.extensions.api import api_v1 +def load_user_from_request(request): + """ + Load user from OAuth2 Authentication header. + """ + user = current_user + return user def init_app(app, **kwargs): # pylint: disable=unused-argument @@ -19,7 +26,7 @@ def init_app(app, **kwargs): api_v1.add_oauth_scope('auth:write', "Provide write access to auth details") # Touch underlying modules - from . import models2, views, resources # pylint: disable=unused-variable + from . import models, views, resources # pylint: disable=unused-variable # Mount authentication routes app.register_blueprint(views.auth_blueprint) diff --git a/app/modules/auth/models.py b/app/modules/auth/models.py index ef9c80ec..8b208b00 100644 --- a/app/modules/auth/models.py +++ b/app/modules/auth/models.py @@ -1,49 +1,42 @@ -# encoding: utf-8 -""" -OAuth2 provider models. - -It is based on the code from the example: -https://github.com/lepture/example-oauth2-server - -More details are available here: -* http://flask-oauthlib.readthedocs.org/en/latest/oauth2.html -* http://lepture.com/en/2013/create-oauth-server -""" -import enum - +import time, enum +from authlib.flask.oauth2.sqla import ( + OAuth2ClientMixin, + OAuth2AuthorizationCodeMixin, + OAuth2TokenMixin, +) from sqlalchemy_utils.types import ScalarListType from app.extensions import db -from app.modules.users.models import User -class OAuth2Client(db.Model): - """ - Model that binds OAuth2 Client ID and Secret to a specific User. - """ +class MyOAuth2ClientMixin(OAuth2ClientMixin): + def check_requested_scopes(self, scopes): + if type(self.scope) == str: + allowed = set(self.scope.split()) + elif type(self.scope) == list: + allowed = set(self.scope) - __tablename__ = 'oauth2_client' - - client_id = db.Column(db.String(length=40), primary_key=True) - client_secret = db.Column(db.String(length=55), nullable=False) + return allowed.issuperset(set(scopes)) - user_id = db.Column(db.ForeignKey('user.id', ondelete='CASCADE'), index=True, nullable=False) - user = db.relationship(User) +class OAuth2Client(db.Model, MyOAuth2ClientMixin): + __tablename__ = 'oauth2_client' + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) class ClientTypes(str, enum.Enum): public = 'public' confidential = 'confidential' client_type = db.Column(db.Enum(ClientTypes), default=ClientTypes.public, nullable=False) - redirect_uris = db.Column(ScalarListType(separator=' '), default=[], nullable=False) default_scopes = db.Column(ScalarListType(separator=' '), nullable=False) + scope = db.Column(ScalarListType(separator=' '), nullable=False) + + user = db.relationship('User') @property def default_redirect_uri(self): - redirect_uris = self.redirect_uris - if redirect_uris: - return redirect_uris[0] - return None + return self.get_default_redirect_uri() @classmethod def find(cls, client_id): @@ -52,79 +45,24 @@ def find(cls, client_id): return cls.query.get(client_id) -class OAuth2Grant(db.Model): - """ - Intermediate temporary helper for OAuth2 Grants. - """ - - __tablename__ = 'oauth2_grant' +class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): + __tablename__ = 'oauth2_code' - id = db.Column(db.Integer, primary_key=True) # pylint: disable=invalid-name - - user_id = db.Column(db.ForeignKey('user.id', ondelete='CASCADE'), index=True, nullable=False) + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) user = db.relationship('User') - client_id = db.Column( - db.String(length=40), - db.ForeignKey('oauth2_client.client_id'), - index=True, - nullable=False, - ) - client = db.relationship('OAuth2Client') - - code = db.Column(db.String(length=255), index=True, nullable=False) - - redirect_uri = db.Column(db.String(length=255), nullable=False) - expires = db.Column(db.DateTime, nullable=False) - - scopes = db.Column(ScalarListType(separator=' '), nullable=False) - - def delete(self): - db.session.delete(self) - db.session.commit() - return self - - @classmethod - def find(cls, client_id, code): - return cls.query.filter_by(client_id=client_id, code=code).first() - - -class OAuth2Token(db.Model): - """ - OAuth2 Access Tokens storage model. - """ +class OAuth2Token(db.Model, OAuth2TokenMixin): __tablename__ = 'oauth2_token' - id = db.Column(db.Integer, primary_key=True) # pylint: disable=invalid-name - client_id = db.Column( - db.String(length=40), - db.ForeignKey('oauth2_client.client_id'), - index=True, - nullable=False, - ) - client = db.relationship('OAuth2Client') - - user_id = db.Column(db.ForeignKey('user.id', ondelete='CASCADE'), index=True, nullable=False) + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column( + db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) user = db.relationship('User') + revoked = db.Column(db.Boolean(name='revoked'), default=False) # override because of bug in alembic - class TokenTypes(str, enum.Enum): - # currently only bearer is supported - Bearer = 'Bearer' - token_type = db.Column(db.Enum(TokenTypes), nullable=False) - - access_token = db.Column(db.String(length=255), unique=True, nullable=False) - refresh_token = db.Column(db.String(length=255), unique=True, nullable=True) - expires = db.Column(db.DateTime, nullable=False) - scopes = db.Column(ScalarListType(separator=' '), nullable=False) - - @classmethod - def find(cls, access_token=None, refresh_token=None): - if access_token: - return cls.query.filter_by(access_token=access_token).first() - elif refresh_token: - return cls.query.filter_by(refresh_token=refresh_token).first() - - def delete(self): - with db.session.begin(): - db.session.delete(self) + def is_refresh_token_expired(self): + expires_at = self.issued_at + self.expires_in * 2 + return expires_at < time.time() diff --git a/app/modules/auth/models2.py b/app/modules/auth/models2.py deleted file mode 100644 index 206e255e..00000000 --- a/app/modules/auth/models2.py +++ /dev/null @@ -1,68 +0,0 @@ -import time, enum -from authlib.flask.oauth2.sqla import ( - OAuth2ClientMixin, - OAuth2AuthorizationCodeMixin, - OAuth2TokenMixin, -) -from sqlalchemy_utils.types import ScalarListType -from sqlalchemy.ext.hybrid import hybrid_property - -from app.extensions import db - - -class MyOAuth2ClientMixin(OAuth2ClientMixin): - def check_requested_scopes(self, scopes): - if type(self.scope) == str: - allowed = set(self.scope.split()) - elif type(self.scope) == list: - allowed = set(self.scope) - - return allowed.issuperset(set(scopes)) - -class OAuth2Client(db.Model, MyOAuth2ClientMixin): - __tablename__ = 'oauth2_client' - - id = db.Column(db.Integer, primary_key=True) - user_id = db.Column( - db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) - class ClientTypes(str, enum.Enum): - public = 'public' - confidential = 'confidential' - - client_type = db.Column(db.Enum(ClientTypes), default=ClientTypes.public, nullable=False) - default_scopes = db.Column(ScalarListType(separator=' '), nullable=False) - scope = db.Column(ScalarListType(separator=' '), nullable=False) - - user = db.relationship('User') - - @property - def default_redirect_uri(self): - return self.get_default_redirect_uri() - - @classmethod - def find(cls, client_id): - if not client_id: - return - return cls.query.get(client_id) - - -class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): - __tablename__ = 'oauth2_code' - - id = db.Column(db.Integer, primary_key=True) - user_id = db.Column( - db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) - user = db.relationship('User') - - -class OAuth2Token(db.Model, OAuth2TokenMixin): - __tablename__ = 'oauth2_token' - - id = db.Column(db.Integer, primary_key=True) - user_id = db.Column( - db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')) - user = db.relationship('User') - - def is_refresh_token_expired(self): - expires_at = self.issued_at + self.expires_in * 2 - return expires_at < time.time() diff --git a/app/modules/auth/parameters.py b/app/modules/auth/parameters.py index 50e5eea4..3908f012 100644 --- a/app/modules/auth/parameters.py +++ b/app/modules/auth/parameters.py @@ -24,9 +24,9 @@ def validate_user_id(self, data): class CreateOAuth2ClientParameters(PostFormParameters): redirect_uris = base_fields.List(base_fields.String, required=False) - default_scopes = base_fields.List(base_fields.String, required=True) + scopes = base_fields.List(base_fields.String, required=True) - @validates('default_scopes') + @validates('scopes') def validate_default_scopes(self, data): unknown_scopes = set(data) - set(api.api_v1.authorizations['oauth2_password']['scopes']) if unknown_scopes: diff --git a/app/modules/auth/resources.py b/app/modules/auth/resources.py index 3888cf65..ef8805e5 100644 --- a/app/modules/auth/resources.py +++ b/app/modules/auth/resources.py @@ -15,7 +15,7 @@ from app.extensions.api import Namespace from . import schemas, parameters -from .models2 import db, OAuth2Client +from .models import db, OAuth2Client log = logging.getLogger(__name__) @@ -64,11 +64,17 @@ def post(self, args): ): # TODO: reconsider using gen_salt + user = current_user new_oauth2_client = OAuth2Client( - user_id=current_user.id, + user_id=user.id, client_id=security.gen_salt(40), - client_secret=security.gen_salt(50), **args ) + + if new_oauth2_client.token_endpoint_auth_method=='none': + new_oauth2_client.client_secret = '' + else: + new_oauth2_client.client_secret = security.gen_salt( 48 ) + db.session.add(new_oauth2_client) return new_oauth2_client diff --git a/app/modules/auth/schemas.py b/app/modules/auth/schemas.py index f1212ffe..b8add53b 100644 --- a/app/modules/auth/schemas.py +++ b/app/modules/auth/schemas.py @@ -8,7 +8,7 @@ from flask_marshmallow import base_fields from flask_restplus_patched import ModelSchema -from .models2 import OAuth2Client +from .models import OAuth2Client class BaseOAuth2ClientSchema(ModelSchema): diff --git a/app/modules/auth/views.py b/app/modules/auth/views.py index e7bb3f0a..7e1c8fb1 100644 --- a/app/modules/auth/views.py +++ b/app/modules/auth/views.py @@ -11,22 +11,21 @@ """ from flask import Blueprint, request, render_template, session -# from flask_login import current_user +from flask_login import current_user from flask_restplus_patched._http import HTTPStatus -from authlib.flask.oauth2 import current_token from authlib.specs.rfc6749 import OAuth2Error from app.extensions import api, oauth2, db from app.modules.users.models import User -from .models2 import OAuth2Client +from .models import OAuth2Client auth_blueprint = Blueprint('auth', __name__, url_prefix='/auth') # pylint: disable=invalid-name - -def current_user(): - if 'id' in session: - uid = session['id'] - return User.query.get(uid) - return None +# +# def current_user(): +# if 'id' in session: +# uid = session['id'] +# return User.query.get(uid) +# return None @auth_blueprint.route('/oauth2/invalid_request', methods=['GET']) def api_invalid_response(req): @@ -76,6 +75,7 @@ def authorize(*args, **kwargs): # can implement a login page and store cookies with a session id. # ALTERNATIVELY, authorize page can be implemented as SPA (single page # application) + from flask_login import login_user user = current_user() if request.method == 'GET': @@ -86,7 +86,11 @@ def authorize(*args, **kwargs): return render_template('authorize.html', user=user, grant=grant) if not user and 'username' in request.form: username = request.form.get('username') - user = User.query.filter_by(username=username).first() + password = request.form.get('password') + user = User.find_with_password(username, password) + if user: + login_user(user) + if request.form['confirm']: grant_user = user else: diff --git a/app/modules/users/models.py b/app/modules/users/models.py index 2a874ecf..98601ca6 100644 --- a/app/modules/users/models.py +++ b/app/modules/users/models.py @@ -6,6 +6,7 @@ import enum from sqlalchemy_utils import types as column_types, Timestamp +from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method from app.extensions import db @@ -42,7 +43,6 @@ class User(db.Model, Timestamp): """ User database model. """ - id = db.Column(db.Integer, primary_key=True) # pylint: disable=invalid-name username = db.Column(db.String(length=80), unique=True, nullable=False) password = db.Column( @@ -96,6 +96,10 @@ def __repr__(self): ) ) + def __init__( self, **kwargs ): + super().__init__( **kwargs ) + self._authenticated = False + def get_user_id( self ): return self.id @@ -115,9 +119,13 @@ def unset_static_role(self, role): def check_owner(self, user): return self == user - @property + @hybrid_property def is_authenticated(self): - return True + return self._authenticated + + @is_authenticated.setter + def is_authenticated( self, value): + self._authenticated = value @property def is_anonymous(self): @@ -138,5 +146,6 @@ def find_with_password(cls, username, password): if not user: return None if user.password == password: + user.is_authenticated = True return user return None diff --git a/app/templates/authorize.html b/app/templates/authorize.html index 4e2a6e7d..b128906d 100644 --- a/app/templates/authorize.html +++ b/app/templates/authorize.html @@ -7,10 +7,13 @@ Consent? - {% if not user %} + {% if not user or not user.is_authenticated %}

You haven't logged in. Log in with:

- + Username: +
+
+ Password:
{% endif %}
diff --git a/migrations/initial_development_data.py b/migrations/initial_development_data.py index aee0fd0b..8a99669d 100644 --- a/migrations/initial_development_data.py +++ b/migrations/initial_development_data.py @@ -8,7 +8,7 @@ from app.extensions import db, api from app.modules.users.models import User -from app.modules.auth.models2 import OAuth2Client +from app.modules.auth.models import OAuth2Client def init_users(): @@ -70,10 +70,10 @@ def init(): root_user, docs_user, regular_user = init_users() # pylint: disable=unused-variable init_auth( root_user ) # with db.session.begin(): - root_user = User.query.filter(User.username == 'root').first() - client = OAuth2Client.query.filter(OAuth2Client.user_id == root_user.id).first() - client.default_scopes = api.api_v1.authorizations['oauth2_password']['scopes'] - client.scope = api.api_v1.authorizations['oauth2_password']['scopes'] - client.grant_types = ['authorization_code', 'password'] - db.session.add(client) - db.session.commit() + # root_user = User.query.filter(User.username == 'root').first() + # client = OAuth2Client.query.filter(OAuth2Client.user_id == root_user.id).first() + # client.default_scopes = api.api_v1.authorizations['oauth2_password']['scopes'] + # client.scope = api.api_v1.authorizations['oauth2_password']['scopes'] + # client.grant_types = ['authorization_code', 'password'] + # db.session.add(client) + # db.session.commit() diff --git a/tasks/app/users.py b/tasks/app/users.py index 6df89d66..916dfd76 100644 --- a/tasks/app/users.py +++ b/tasks/app/users.py @@ -51,7 +51,7 @@ def create_oauth2_client( Create a new OAuth2 Client associated with a given user (username). """ from app.modules.users.models import User - from app.modules.auth.models2 import OAuth2Client + from app.modules.auth.models import OAuth2Client user = User.query.filter(User.username == username).first() if not user: diff --git a/tests/modules/auth/conftest.py b/tests/modules/auth/conftest.py index 5dd98421..c55221b7 100644 --- a/tests/modules/auth/conftest.py +++ b/tests/modules/auth/conftest.py @@ -6,7 +6,7 @@ @pytest.yield_fixture() def regular_user_oauth2_client(regular_user, temp_db_instance_helper): # pylint: disable=invalid-name,unused-argument - from app.modules.auth.models2 import OAuth2Client + from app.modules.auth.models import OAuth2Client for _ in temp_db_instance_helper( OAuth2Client( @@ -22,7 +22,7 @@ def regular_user_oauth2_client(regular_user, temp_db_instance_helper): @pytest.yield_fixture() def regular_user_oauth2_token(regular_user_oauth2_client, temp_db_instance_helper): - from app.modules.auth.models2 import OAuth2Token + from app.modules.auth.models import OAuth2Token for _ in temp_db_instance_helper( OAuth2Token( diff --git a/tests/modules/auth/resources/test_creating_oauth2client.py b/tests/modules/auth/resources/test_creating_oauth2client.py index f5bdf6aa..aa9965df 100644 --- a/tests/modules/auth/resources/test_creating_oauth2client.py +++ b/tests/modules/auth/resources/test_creating_oauth2client.py @@ -37,7 +37,7 @@ def test_creating_oauth2_client( assert isinstance(response.json['redirect_uris'], list) # Cleanup - from app.modules.auth.models2 import OAuth2Client + from app.modules.auth.models import OAuth2Client oauth2_client_instance = OAuth2Client.query.get(response.json['client_id']) assert oauth2_client_instance.client_secret == response.json['client_secret'] diff --git a/tests/modules/auth/resources/test_token.py b/tests/modules/auth/resources/test_token.py index a90ab594..7d032545 100644 --- a/tests/modules/auth/resources/test_token.py +++ b/tests/modules/auth/resources/test_token.py @@ -31,7 +31,7 @@ def test_regular_user_can_retrieve_token( } # Clean up - from app.modules.auth.models2 import OAuth2Token + from app.modules.auth.models import OAuth2Token with db.session.begin(): OAuth2Token.query.filter(OAuth2Token.access_token == response.json['access_token']).delete() @@ -110,7 +110,7 @@ def test_regular_user_can_refresh_token( } # Clean up - from app.modules.auth.models2 import OAuth2Token + from app.modules.auth.models import OAuth2Token with db.session.begin(): OAuth2Token.query.filter( OAuth2Token.access_token == refresh_token_response.json['access_token'] diff --git a/tests/modules/auth/test_login_manager_integration.py b/tests/modules/auth/test_login_manager_integration.py index 631dd04d..f5ad2e28 100644 --- a/tests/modules/auth/test_login_manager_integration.py +++ b/tests/modules/auth/test_login_manager_integration.py @@ -9,7 +9,10 @@ def test_loading_user_from_anonymous_request(flask_app): with flask_app.test_request_context('/'): - assert auth.load_user_from_request(request) is None + user = auth.load_user_from_request(request) + assert user.is_authenticated == False + assert user.is_active == False + assert user.get_id() == None def test_loading_user_from_request_with_oauth_user_cached(flask_app): mock_user = Mock() diff --git a/tests/utils.py b/tests/utils.py index f22a9e78..40ff2915 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -39,7 +39,7 @@ def login(self, user, auth_scopes=None): def open(self, *args, **kwargs): if self._user is not None: from app.extensions import db - from app.modules.auth.models2 import OAuth2Client, OAuth2Token + from app.modules.auth.models import OAuth2Client, OAuth2Token oauth2_client = OAuth2Client( client_id='OAUTH2_%s' % self._user.username,