Skip to content

Commit 65cf0d9

Browse files
committed
Fix changing default database by USE in session mode
1 parent 6c9b2c6 commit 65cf0d9

File tree

3 files changed

+102
-0
lines changed

3 files changed

+102
-0
lines changed

programs/local/LocalServer.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
#include <boost/program_options/options_description.hpp>
5959
#include <base/argsToConfig.h>
6060
#include <filesystem>
61+
#include <fstream>
6162

6263
#include "config.h"
6364

@@ -801,6 +802,42 @@ void LocalServer::processConfig()
801802
global_context->getUserDefinedSQLObjectsLoader().loadObjects();
802803

803804
LOG_DEBUG(log, "Loaded metadata.");
805+
806+
/** Set default database if it is specified in default_database file.
807+
* NOTE: We do it after loading metadata to let the '_local' and 'system' database initialization
808+
* to be done correctly.
809+
* I have tried to set default database during parsing '--database' option, but it leads to
810+
* "Code: 82. DB::Exception: Database db_xxx already exists.: while loading database `db_xxx`
811+
* from path .state_tmp_auxten_usedb_/metadata/db_xxx. (DATABASE_ALREADY_EXISTS)"
812+
* This will also happen if we call:
813+
* `clickhouse local --database=db_xxx --path=.state_tmp_auxten_usedb_ --query="select * FROM log_table_xxx"`
814+
* with existing `db_xxx` database in the '.state_tmp_auxten_usedb_' directory.
815+
*/
816+
auto default_database_path = fs::path(path) / "default_database";
817+
if (std::filesystem::exists(default_database_path))
818+
{
819+
std::ifstream ifs(default_database_path);
820+
std::string user_default_database;
821+
if (ifs.is_open())
822+
{
823+
ifs >> user_default_database;
824+
// strip default_database
825+
user_default_database.erase(
826+
std::remove_if(
827+
user_default_database.begin(), user_default_database.end(), [](unsigned char x) { return std::isspace(x); }),
828+
user_default_database.end());
829+
if (!user_default_database.empty())
830+
{
831+
global_context->setCurrentDatabase(user_default_database);
832+
LOG_DEBUG(log, "Set default database to {} recorded in {}", user_default_database, default_database_path);
833+
}
834+
ifs.close();
835+
}
836+
else
837+
{
838+
LOG_ERROR(log, "Cannot read default database from {}", default_database_path);
839+
}
840+
}
804841
}
805842
else if (!config().has("no-system-tables"))
806843
{

src/Interpreters/InterpreterUseQuery.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,40 @@
22
#include <Interpreters/Context.h>
33
#include <Interpreters/InterpreterUseQuery.h>
44
#include <Access/Common/AccessFlags.h>
5+
#include <Common/Exception.h>
56
#include <Common/typeid_cast.h>
67

8+
#include <fstream>
79

810
namespace DB
911
{
1012

13+
namespace ErrorCodes
14+
{
15+
extern const int CANNOT_OPEN_FILE;
16+
}
17+
1118
BlockIO InterpreterUseQuery::execute()
1219
{
1320
const String & new_database = query_ptr->as<ASTUseQuery &>().getDatabase();
1421
getContext()->checkAccess(AccessType::SHOW_DATABASES, new_database);
1522
getContext()->getSessionContext()->setCurrentDatabase(new_database);
23+
24+
// Save the current using database in default_database stored in getPath()
25+
// for the case when the database is changed in chDB session.
26+
// The default_database content is used in the LocalServer::processConfig() method.
27+
auto default_database_path = fs::path(getContext()->getPath()) / "default_database";
28+
std::ofstream tmp_path_fs(default_database_path, std::ofstream::out | std::ofstream::trunc);
29+
if (tmp_path_fs && tmp_path_fs.is_open())
30+
{
31+
tmp_path_fs << new_database;
32+
tmp_path_fs.close();
33+
}
34+
else
35+
{
36+
throw Exception(ErrorCodes::CANNOT_OPEN_FILE, "Cannot open file {} for writing", default_database_path.string());
37+
}
38+
1639
return {};
1740
}
1841

tests/test_usedb.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!python3
2+
3+
import unittest
4+
import shutil
5+
import psutil
6+
from chdb import session
7+
8+
9+
test_state_dir = ".state_tmp_auxten_usedb_"
10+
current_process = psutil.Process()
11+
check_thread_count = False
12+
13+
14+
class TestStateful(unittest.TestCase):
15+
def setUp(self) -> None:
16+
shutil.rmtree(test_state_dir, ignore_errors=True)
17+
return super().setUp()
18+
19+
def tearDown(self) -> None:
20+
shutil.rmtree(test_state_dir, ignore_errors=True)
21+
return super().tearDown()
22+
23+
def test_path(self):
24+
sess = session.Session(test_state_dir)
25+
26+
sess.query("CREATE DATABASE IF NOT EXISTS db_xxx ENGINE = Atomic", "CSV")
27+
ret = sess.query("SHOW DATABASES", "CSV")
28+
self.assertIn("db_xxx", str(ret))
29+
30+
sess.query("CREATE TABLE IF NOT EXISTS db_xxx.log_table_xxx (x UInt8) ENGINE = Log;")
31+
sess.query("INSERT INTO db_xxx.log_table_xxx VALUES (1), (2), (3), (4);")
32+
33+
ret = sess.query("USE db_xxx; SELECT * FROM log_table_xxx", "Debug")
34+
self.assertEqual(str(ret), "1\n2\n3\n4\n")
35+
36+
sess.query("USE db_xxx")
37+
ret = sess.query("SELECT * FROM log_table_xxx", "Debug")
38+
self.assertEqual(str(ret), "1\n2\n3\n4\n")
39+
40+
41+
if __name__ == '__main__':
42+
unittest.main()

0 commit comments

Comments
 (0)