diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 25e98b5..0000000 --- a/setup.cfg +++ /dev/null @@ -1,13 +0,0 @@ -[metadata] -description-file = README.md - -[pep8] -max-line-length = 120 -exclude = .tox - -[tool:pytest] -pep8maxlinelength = 120 -addopts = --tb=short - -[bdist_wheel] -universal=1 diff --git a/sqlacodegen/codegen.py b/sqlacodegen/codegen.py index 2949822..f39c418 100644 --- a/sqlacodegen/codegen.py +++ b/sqlacodegen/codegen.py @@ -32,6 +32,7 @@ _flask_prepend = 'db.' _dataclass = False +_sqla_orm = False class _DummyInflectEngine(object): @@ -168,7 +169,11 @@ def _render_column(column, show_name): server_default = 'server_default=' + _flask_prepend + 'FetchedValue()' comment = getattr(column, 'comment', None) - return _flask_prepend + 'Column({0})'.format(', '.join( + if _sqla_orm: + column_string = 'mapped_column' + else: + column_string = 'Column' + return _flask_prepend + '{0}({1})'.format(column_string, ', '.join( ([repr(column.name)] if show_name else []) + ([_render_column_type(column.type)] if render_coltype else []) + [_render_constraint(x) for x in dedicated_fks] + @@ -227,7 +232,8 @@ def _render_index(index): class ImportCollector(OrderedDict): def add_import(self, obj): type_ = type(obj) if not isinstance(obj, type) else obj - pkgname = 'sqlalchemy' if type_.__name__ in sqlalchemy.__all__ else type_.__module__ # @UndefinedVariable + pkgname = 'sqlalchemy' if hasattr(sqlalchemy, type_.__name__) else type_.__module__ # @UndefinedVariable + # pkgname = 'sqlalchemy' if type_.__name__ in sqlalchemy.__all__ else type_.__module__ # @UndefinedVariable self.add_literal_import(pkgname, type_.__name__) def add_literal_import(self, pkgname, name): @@ -326,7 +332,7 @@ def __init__(self, table, association_tables, inflect_engine, detect_joined, col if _dataclass: if column.type.python_type.__module__ != 'builtins': collector.add_literal_import(column.type.python_type.__module__, column.type.python_type.__name__) - + # Add many-to-one relationships pk_column_names = set(col.name for col in table.primary_key.columns) @@ -373,13 +379,12 @@ def add_imports(self, collector): child.add_imports(collector) def render(self): - global _dataclass - + global _dataclass text = 'class {0}({1}):\n'.format(self.name, self.parent_name) - + if _dataclass: text = '@dataclass\n' + text - + text += ' __tablename__ = {0!r}\n'.format(self.table.name) # Render constraints and indexes as __table_args__ @@ -397,7 +402,6 @@ def render(self): table_kwargs = {} if self.schema: table_kwargs['schema'] = self.schema - kwargs_items = ', '.join('{0!r}: {1!r}'.format(key, table_kwargs[key]) for key in table_kwargs) kwargs_items = '{{{0}}}'.format(kwargs_items) if kwargs_items else None if table_kwargs and not table_args: @@ -414,9 +418,12 @@ def render(self): for attr, column in self.attributes.items(): if isinstance(column, Column): show_name = attr != column.name - if _dataclass: - text += ' ' + attr + ' : ' + column.type.python_type.__name__ + '\n' - + if _dataclass: + if _sqla_orm: + text += ' ' + attr + ' : ' + 'Mapped[{0}]\n'.format(column.type.python_type.__name__) + else: + text += ' ' + attr + ' : ' + column.type.python_type.__name__ + '\n' + text += ' {0} = {1}\n'.format(attr, _render_column(column, show_name)) # Render relationships @@ -452,7 +459,7 @@ def render(self): delimiter, end = ', ', ')' args.extend([key + '=' + value for key, value in self.kwargs.items()]) - + return _re_invalid_relationship.sub('_', text + delimiter.join(args) + end) def make_backref(self, relationships, classes): @@ -509,7 +516,7 @@ def __init__(self, source_cls, target_cls, constraint, inflect_engine): # common_fk_constraints = _get_common_fk_constraints(constraint.table, constraint.elements[0].column.table) # if len(common_fk_constraints) > 1: # self.kwargs['primaryjoin'] = "'{0}.{1} == {2}.{3}'".format(source_cls, constraint.columns[0], target_cls, constraint.elements[0].column.name) - if len(constraint.elements) > 1: # or + if len(constraint.elements) > 1: # or self.kwargs['primaryjoin'] = "'and_({0})'".format(', '.join(['{0}.{1} == {2}.{3}'.format(source_cls, k.parent.name, target_cls, k.column.name) for k in constraint.elements])) else: @@ -550,9 +557,8 @@ class CodeGenerator(object): def __init__(self, metadata, noindexes=False, noconstraints=False, nojoined=False, noinflect=False, nobackrefs=False, - flask=False, ignore_cols=None, noclasses=False, nocomments=False, notables=False, dataclass=False): + flask=False, ignore_cols=None, noclasses=False, nocomments=False, notables=False, dataclass=False, sqla_orm=False): super(CodeGenerator, self).__init__() - if noinflect: inflect_engine = _DummyInflectEngine() else: @@ -561,19 +567,23 @@ def __init__(self, metadata, noindexes=False, noconstraints=False, # exclude these column names from consideration when generating association tables _ignore_columns = ignore_cols or [] - + self.flask = flask if not self.flask: global _flask_prepend _flask_prepend = '' self.nocomments = nocomments - + self.dataclass = dataclass if self.dataclass: global _dataclass _dataclass = True + self.sqla_orm = sqla_orm + global _sqla_orm + _sqla_orm = sqla_orm + # Pick association tables from the metadata into their own set, don't process them normally links = defaultdict(lambda: []) association_tables = set() @@ -671,15 +681,20 @@ def __init__(self, metadata, noindexes=False, noconstraints=False, if model.parent_name == 'Base': model.parent_name = parent_name else: - self.collector.add_literal_import('sqlalchemy.ext.declarative', 'declarative_base') - self.collector.add_literal_import('sqlalchemy', 'MetaData') - - + if self.sqla_orm: + self.collector.add_literal_import('sqlalchemy.orm', 'DeclarativeBase') + self.collector.add_literal_import('sqlalchemy.orm', 'Mapped') + self.collector.add_literal_import('sqlalchemy.orm', 'mapped_column') + else: + self.collector.add_literal_import('sqlalchemy.ext.declarative', 'declarative_base') + self.collector.add_literal_import('sqlalchemy', 'MetaData') + + if self.dataclass: self.collector.add_literal_import('dataclasses', 'dataclass') def render(self, outfile=sys.stdout): - + print(self.header, file=outfile) # Render the collected imports @@ -689,7 +704,10 @@ def render(self, outfile=sys.stdout): print('db = SQLAlchemy()', file=outfile) else: if any(isinstance(model, ModelClass) for model in self.models): - print('Base = declarative_base()\nmetadata = Base.metadata', file=outfile) + if self.sqla_orm: + print('class Base(DeclarativeBase):\n pass', file=outfile) + else: + print('Base = declarative_base()\nmetadata = Base.metadata', file=outfile) else: print('metadata = MetaData()', file=outfile) diff --git a/sqlacodegen/main.py b/sqlacodegen/main.py index d8646e9..c97f328 100644 --- a/sqlacodegen/main.py +++ b/sqlacodegen/main.py @@ -8,7 +8,8 @@ from sqlalchemy.engine import create_engine from sqlalchemy.schema import MetaData -from sqlacodegen.codegen import CodeGenerator +# from sqlacodegen.codegen import CodeGenerator +from codegen import CodeGenerator import sqlacodegen import sqlacodegen.dialects @@ -25,7 +26,7 @@ def main(): parser = argparse.ArgumentParser(description='Generates SQLAlchemy model code from an existing database.') parser.add_argument('url', nargs='?', help='SQLAlchemy url to the database') parser.add_argument('--version', action='store_true', help="print the version number and exit") - parser.add_argument('--schema', help='load tables from an alternate schema') + parser.add_argument('--schema', help='alternate schemas to load in addition to local schema (comma-separated)') parser.add_argument('--default-schema', help='default schema name for local schema object') parser.add_argument('--tables', help='tables to process (comma-separated, default: all)') parser.add_argument('--noviews', action='store_true', help="ignore views") @@ -41,6 +42,8 @@ def main(): parser.add_argument('--ignore-cols', help="Don't check foreign key constraints on specified columns (comma-separated)") parser.add_argument('--nocomments', action='store_true', help="don't render column comments") parser.add_argument('--dataclass', action='store_true', help="add dataclass decorators for JSON serialization") + parser.add_argument('--sqlalchemyorm', action='store_true', help="use SQLAlchemy.orm module") + args = parser.parse_args() if args.version: @@ -52,20 +55,23 @@ def main(): return default_schema = args.default_schema if not default_schema: - default_schema = None + default_schema = None engine = create_engine(args.url) import_dialect_specificities(engine) - metadata = MetaData() + metadata = MetaData(schema=default_schema) tables = args.tables.split(',') if args.tables else None ignore_cols = args.ignore_cols.split(',') if args.ignore_cols else None - metadata.reflect(engine, args.schema, not args.noviews, tables) + metadata.reflect(engine, views=not args.noviews, only=tables) + for schema in args.schema.split(','): + metadata.reflect(engine, schema, not args.noviews, tables) + outfile = codecs.open(args.outfile, 'w', encoding='utf-8') if args.outfile else sys.stdout generator = CodeGenerator(metadata, args.noindexes, args.noconstraints, args.nojoined, args.noinflect, args.nobackrefs, - args.flask, ignore_cols, args.noclasses, args.nocomments, args.notables, args.dataclass) + args.flask, ignore_cols, args.noclasses, args.nocomments, args.notables, args.dataclass, args.sqlalchemyorm) generator.render(outfile) if __name__ == '__main__': - main() \ No newline at end of file + main()