diff --git a/django_cassandra_engine/base.py b/django_cassandra_engine/base.py index 0664aa9..e3fd67f 100644 --- a/django_cassandra_engine/base.py +++ b/django_cassandra_engine/base.py @@ -149,12 +149,26 @@ def __init__(self, *args, **kwargs): self.connected = False self.autocommit = True + self.original_db_name = self.settings_dict['NAME'] + self.updated_keyspace = False + del self.connection def connect(self): if not self.connected or self.connection is None: settings = self.settings_dict self.connection = CassandraConnection(**settings) + + # Support django-nose's REUSE_DB flag. + if (self.original_db_name != settings['NAME'] + and not self.updated_keyspace): + self.connection.keyspace = self.original_db_name + for models in self.introspection.cql_models.values(): + for model in models: + model.__keyspace__ = settings['NAME'] + self.connection.keyspace = settings['NAME'] + self.updated_keyspace = True + connection_created.send(sender=self.__class__, connection=self) self.connected = True diff --git a/django_cassandra_engine/creation.py b/django_cassandra_engine/creation.py index a0dfca2..9a334b0 100644 --- a/django_cassandra_engine/creation.py +++ b/django_cassandra_engine/creation.py @@ -58,6 +58,8 @@ def create_test_db(self, verbosity=1, autoclobber=False, **kwargs): # Set all models keyspace to the test keyspace self.set_models_keyspace(test_database_name) + self.connection.connection.keyspace = test_database_name + if verbosity >= 1: test_db_repr = '' if verbosity >= 2: