22import logging
33import inspect
44import re
5+ import collections
6+ import itertools
57from .connection import conn
6- from .diagram import Diagram
78from .settings import config
89from .errors import DataJointError , AccessError
910from .jobs import JobTable
1011from .external import ExternalMapping
1112from .heading import Heading
1213from .utils import user_choice , to_camel_case
13- from .user_tables import Part , Computed , Imported , Manual , Lookup
14- from .table import lookup_class_name , Log
14+ from .user_tables import Part , Computed , Imported , Manual , Lookup , _get_tier
15+ from .table import lookup_class_name , Log , FreeTable
1516import types
1617
1718logger = logging .getLogger (__name__ .split ("." )[0 ])
@@ -399,6 +400,76 @@ def jobs(self):
399400 self ._jobs = JobTable (self .connection , self .database )
400401 return self ._jobs
401402
403+ @property
404+ def code (self ):
405+ self ._assert_exists ()
406+ return self .save ()
407+
408+ def save (self , python_filename = None ):
409+ """
410+ Generate the code for a module that recreates the schema.
411+ This method is in preparation for a future release and is not officially supported.
412+
413+ :return: a string containing the body of a complete Python module defining this schema.
414+ """
415+ self ._assert_exists ()
416+ module_count = itertools .count ()
417+ # add virtual modules for referenced modules with names vmod0, vmod1, ...
418+ module_lookup = collections .defaultdict (
419+ lambda : "vmod" + str (next (module_count ))
420+ )
421+ db = self .database
422+
423+ def make_class_definition (table ):
424+ tier = _get_tier (table ).__name__
425+ class_name = table .split ("." )[1 ].strip ("`" )
426+ indent = ""
427+ if tier == "Part" :
428+ class_name = class_name .split ("__" )[- 1 ]
429+ indent += " "
430+ class_name = to_camel_case (class_name )
431+
432+ def replace (s ):
433+ d , tabs = s .group (1 ), s .group (2 )
434+ return ("" if d == db else (module_lookup [d ] + "." )) + "." .join (
435+ to_camel_case (tab ) for tab in tabs .lstrip ("__" ).split ("__" )
436+ )
437+
438+ return ("" if tier == "Part" else "\n @schema\n " ) + (
439+ "{indent}class {class_name}(dj.{tier}):\n "
440+ '{indent} definition = """\n '
441+ '{indent} {defi}"""'
442+ ).format (
443+ class_name = class_name ,
444+ indent = indent ,
445+ tier = tier ,
446+ defi = re .sub (
447+ r"`([^`]+)`.`([^`]+)`" ,
448+ replace ,
449+ FreeTable (self .connection , table ).describe (),
450+ ).replace ("\n " , "\n " + indent ),
451+ )
452+
453+ tables = self .connection .dependencies .topo_sort ()
454+ body = "\n \n " .join (make_class_definition (table ) for table in tables )
455+ python_code = "\n \n " .join (
456+ (
457+ '"""This module was auto-generated by datajoint from an existing schema"""' ,
458+ "import datajoint as dj\n \n schema = dj.Schema('{db}')" .format (db = db ),
459+ "\n " .join (
460+ "{module} = dj.VirtualModule('{module}', '{schema_name}')" .format (
461+ module = v , schema_name = k
462+ )
463+ for k , v in module_lookup .items ()
464+ ),
465+ body ,
466+ )
467+ )
468+ if python_filename is None :
469+ return python_code
470+ with open (python_filename , "wt" ) as f :
471+ f .write (python_code )
472+
402473 def list_tables (self ):
403474 """
404475 Return a list of all tables in the schema except tables with ~ in first character such
@@ -410,7 +481,7 @@ def list_tables(self):
410481 t
411482 for d , t in (
412483 full_t .replace ("`" , "" ).split ("." )
413- for full_t in Diagram (self ).topological_sort ()
484+ for full_t in Diagram (self ).topo_sort ()
414485 )
415486 if d == self .database
416487 ]
0 commit comments