diff --git a/.github/workflows/build_linux_arm64_wheels-gh.yml b/.github/workflows/build_linux_arm64_wheels-gh.yml index 5fd5c76fd83..e84d8d8d3c0 100644 --- a/.github/workflows/build_linux_arm64_wheels-gh.yml +++ b/.github/workflows/build_linux_arm64_wheels-gh.yml @@ -29,7 +29,7 @@ jobs: run: | sudo apt-get update sudo apt-get install -y make build-essential libssl-dev zlib1g-dev \ - libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \ + libbz2-dev libreadline-dev wget curl llvm \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev - name: Scan SQLite vulnerabilities with grype diff --git a/.github/workflows/build_linux_x86_wheels.yml b/.github/workflows/build_linux_x86_wheels.yml index ccb0180fbcc..2a1ee8047f7 100644 --- a/.github/workflows/build_linux_x86_wheels.yml +++ b/.github/workflows/build_linux_x86_wheels.yml @@ -29,7 +29,7 @@ jobs: run: | sudo apt-get update sudo apt-get install -y make build-essential libssl-dev zlib1g-dev \ - libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \ + libbz2-dev libreadline-dev wget curl llvm \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev - name: Scan SQLite vulnerabilities with grype @@ -62,7 +62,7 @@ jobs: else echo "✅ No SQLite vulnerabilities found" fi - continue-on-error: true + continue-on-error: true - name: Setup pyenv run: | curl https://pyenv.run | bash diff --git a/chdb/session/state.py b/chdb/session/state.py index 6beae880b36..090a4c5e7ad 100644 --- a/chdb/session/state.py +++ b/chdb/session/state.py @@ -4,9 +4,6 @@ from ..state import sqlitelike as chdb_stateful from ..state.sqlitelike import StreamingResult -g_session = None -g_session_path = None - class Session: """ @@ -35,21 +32,15 @@ class Session: - "mode=ro" would be "--readonly=1" for clickhouse (read-only mode) Important: - - There can be only one session at a time. If you want to create a new session, you need to close the existing one. - - Creating a new session will close the existing one. + - Multiple sessions can coexist. Each session has its own connection and database context. + - Sessions are thread-safe: Multiple threads can safely use the same session concurrently. + - Internal mutexes protect concurrent access to the underlying connection and client. + - For optimal performance in multi-threaded scenarios, consider creating a separate session for each thread + to avoid lock contention, though sharing a session across threads is safe. """ def __init__(self, path=None): self._conn = None - global g_session, g_session_path - if g_session is not None: - warnings.warn( - "There is already an active session. Creating a new session will close the existing one. " - "It is recommended to close the existing session before creating a new one. " - f"Closing the existing session {g_session_path}" - ) - g_session.close() - g_session_path = None if path is None: self._path = ":memory:" else: @@ -68,8 +59,6 @@ def __init__(self, path=None): self._udf_path = "" self._conn_str = f"{self._path}" self._conn = chdb_stateful.Connection(self._conn_str) - g_session = self - g_session_path = self._path def __del__(self): self.close() @@ -102,9 +91,6 @@ def close(self): if self._conn is not None: self._conn.close() self._conn = None - global g_session, g_session_path - g_session = None - g_session_path = None def cleanup(self): """Cleanup session resources with exception handling. diff --git a/contrib/corrosion-cmake/CMakeLists.txt b/contrib/corrosion-cmake/CMakeLists.txt index 7c82987dfdf..c33ef609a95 100644 --- a/contrib/corrosion-cmake/CMakeLists.txt +++ b/contrib/corrosion-cmake/CMakeLists.txt @@ -192,7 +192,7 @@ endfunction() function(clickhouse_config_crate_flags target_name) corrosion_set_env_vars(${target_name} "CFLAGS=${RUST_CFLAGS}") corrosion_set_env_vars(${target_name} "CXXFLAGS=${RUST_CXXFLAGS}") - corrosion_set_env_vars(${target_name} "RUSTFLAGS=${RUSTFLAGS}") + corrosion_set_env_vars(${target_name} "RUSTFLAGS=${RUSTFLAGS} --cfg osslconf=\"OPENSSL_NO_DEPRECATED_3_0\"") corrosion_set_env_vars(${target_name} "RUSTDOCFLAGS=${RUSTFLAGS}") if (CMAKE_OSX_SYSROOT) corrosion_set_env_vars(${target_name} "SDKROOT=${CMAKE_OSX_SYSROOT}") diff --git a/programs/local/CMakeLists.txt b/programs/local/CMakeLists.txt index f84770e6392..891043c98de 100644 --- a/programs/local/CMakeLists.txt +++ b/programs/local/CMakeLists.txt @@ -4,6 +4,8 @@ set (CLICKHOUSE_LOCAL_SOURCES ArrowStreamWrapper.cpp ArrowTableReader.cpp LocalServer.cpp + EmbeddedServer.cpp + ChdbClient.cpp ) if (NOT USE_PYTHON) @@ -25,10 +27,11 @@ endif() if (USE_PYTHON) set (CHDB_SOURCES chdb.cpp - FormatHelper.cpp ListScan.cpp LocalChdb.cpp LocalServer.cpp + EmbeddedServer.cpp + ChdbClient.cpp NumpyType.cpp PandasAnalyzer.cpp PandasDataFrame.cpp diff --git a/programs/local/ChdbClient.cpp b/programs/local/ChdbClient.cpp new file mode 100644 index 00000000000..4ed706b8b09 --- /dev/null +++ b/programs/local/ChdbClient.cpp @@ -0,0 +1,394 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if USE_PYTHON +#include +#endif + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int LOGICAL_ERROR; +} + +ChdbClient::ChdbClient(EmbeddedServerPtr server_ptr) + : ClientBase() + , server(server_ptr) +{ + if (!server) + throw Exception(ErrorCodes::LOGICAL_ERROR, "EmbeddedServer pointer is null"); + + configuration = ConfigHelper::createEmpty(); + layered_configuration = new Poco::Util::LayeredConfiguration(); + layered_configuration->addWriteable(configuration, 0); + session = std::make_unique(server->getGlobalContext(), ClientInfo::Interface::LOCAL); +#if USE_PYTHON + python_table_cache = std::make_shared(); + session->setPythonTableCache(python_table_cache); +#endif + session->authenticate("default", "", Poco::Net::SocketAddress{}); + global_context = session->makeSessionContext(); + global_context->setCurrentDatabase("default"); + global_context->setApplicationType(Context::ApplicationType::LOCAL); + initClientContext(global_context); + server_display_name = "chDB-embedded"; + query_processing_stage = QueryProcessingStage::Enum::Complete; + is_interactive = false; + ignore_error = false; + echo_queries = false; + print_stack_trace = false; +} + +std::unique_ptr ChdbClient::create(EmbeddedServerPtr server_ptr) +{ + if (!server_ptr) + { + server_ptr = EmbeddedServer::getInstance(); + } + return std::make_unique(server_ptr); +} + +ChdbClient::~ChdbClient() +{ + std::lock_guard lock(client_mutex); + cleanup(); + resetQueryOutputVector(); +} + +void ChdbClient::cleanup() +{ + try + { + if (streaming_query_context && streaming_query_context->streaming_result) + CHDB::cancelStreamQuery(this, streaming_query_context->streaming_result); + streaming_query_context.reset(); + connection.reset(); + client_context.reset(); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +void ChdbClient::connect() +{ + connection_parameters = ConnectionParameters::createForEmbedded( + session->sessionContext()->getUserName(), + "default"); + connection = LocalConnection::createConnection( + connection_parameters, + std::move(session), + std_in.get(), + false, + false, + server_display_name); + connection->setDefaultDatabase("default"); +} + +Poco::Util::LayeredConfiguration & ChdbClient::getClientConfiguration() +{ + chassert(layered_configuration); + return *layered_configuration; +} + +void ChdbClient::processError(std::string_view) const +{ + if (server_exception) + server_exception->rethrow(); + if (client_exception) + client_exception->rethrow(); +} + +bool ChdbClient::hasStreamingQuery() const +{ + std::lock_guard lock(client_mutex); + return streaming_query_context != nullptr; +} + +size_t ChdbClient::getStorageRowsRead() const +{ + if (connection) + { + auto * local_connection = static_cast(connection.get()); + return local_connection->getCHDBProgress().read_rows; + } + return 0; +} + +size_t ChdbClient::getStorageBytesRead() const +{ + if (connection) + { + auto * local_connection = static_cast(connection.get()); + return local_connection->getCHDBProgress().read_bytes; + } + return 0; +} + +#if USE_PYTHON +void ChdbClient::findQueryableObjFromPyCache(const String & query_str) const +{ + python_table_cache->findQueryableObjFromQuery(query_str); +} + +#endif + +#if USE_PYTHON +static bool isJSONSupported(const char * format, size_t format_len) +{ + if (format) + { + String lower_format{format, format_len}; + std::transform(lower_format.begin(), lower_format.end(), lower_format.begin(), ::tolower); + + return !( + lower_format == "arrow" || lower_format == "parquet" || lower_format == "arrowstream" || lower_format == "protobuf" + || lower_format == "protobuflist" || lower_format == "protobufsingle"); + } + + return true; +} +#endif + +bool ChdbClient::parseQueryTextWithOutputFormat(const String & query, const String & format) +{ + if (!format.empty()) + { + client_context->setDefaultFormat(format); + setDefaultFormat(format); + } + + if (!connection || !connection->checkConnected(connection_parameters.timeouts)) + connect(); +#if USE_PYTHON + (static_cast(connection.get()))->getSession().setJSONSupport(isJSONSupported(format.c_str(), format.size())); +#endif + return processQueryText(query); +} + +CHDB::QueryResultPtr ChdbClient::executeMaterializedQuery( + const char * query, size_t query_len, + const char * format, size_t format_len) +{ + std::lock_guard lock(client_mutex); + + String query_str(query, query_len); + String format_str(format, format_len); + + try + { + DB::ThreadStatus thread_status; + if (!parseQueryTextWithOutputFormat(query_str, format_str)) + { + return std::make_unique(getErrorMsg()); + } + auto * local_connection = static_cast(connection.get()); + size_t storage_rows_read = local_connection->getCHDBProgress().read_rows; + size_t storage_bytes_read = local_connection->getCHDBProgress().read_bytes; + auto res = std::make_unique( + CHDB::ResultBuffer(stealQueryOutputVector()), + getElapsedTime(), + getProcessedRows(), + getProcessedBytes(), + storage_rows_read, + storage_bytes_read); +#if USE_PYTHON + python_table_cache->clear(); +#endif + return res; + } + catch (const Exception & e) + { +#if USE_PYTHON + python_table_cache->clear(); +#endif + return std::make_unique(getExceptionMessage(e, false)); + } + catch (...) + { +#if USE_PYTHON + python_table_cache->clear(); +#endif + return std::make_unique(getCurrentExceptionMessage(true)); + } +} + +CHDB::QueryResultPtr ChdbClient::executeStreamingInit( + const char * query, size_t query_len, + const char * format, size_t format_len) +{ + std::lock_guard lock(client_mutex); + + String query_str(query, query_len); + String format_str(format, format_len); + + try + { + DB::ThreadStatus thread_status; + + streaming_query_context = std::make_shared(); + if (!parseQueryTextWithOutputFormat(query_str, format_str)) + { + streaming_query_context.reset(); + return std::make_unique(getErrorMsg()); + } + streaming_query_context->thread_group = DB::CurrentThread::getGroup(); + auto result = std::make_unique(); + streaming_query_context->streaming_result = result.get(); + return result; + } + catch (const Exception & e) + { + streaming_query_context.reset(); + return std::make_unique(getExceptionMessage(e, false)); + } + catch (...) + { + streaming_query_context.reset(); + return std::make_unique(getCurrentExceptionMessage(true)); + } +} + +CHDB::QueryResultPtr ChdbClient::executeStreamingIterate(void * streaming_result, bool is_canceled) +{ + std::lock_guard lock(client_mutex); + + if (!streaming_query_context) + return std::make_unique("No active streaming query"); + + try + { + DB::ThreadStatus thread_status; + if (streaming_query_context->thread_group) + { + DB::CurrentThread::attachToGroupIfDetached(streaming_query_context->thread_group); + } + auto * local_connection = static_cast(connection.get()); + const auto old_processed_rows = getProcessedRows(); + const auto old_processed_bytes = getProcessedBytes(); + size_t old_storage_rows_read = local_connection->getCHDBProgress().read_rows; + size_t old_storage_bytes_read = local_connection->getCHDBProgress().read_bytes; + const auto old_elapsed_time = getElapsedTime(); + + std::unique_ptr res; + if (!processStreamingQuery(streaming_result, is_canceled)) + { + res = std::make_unique(getErrorMsg()); + } + else + { + const auto processed_rows = getProcessedRows(); + const auto processed_bytes = getProcessedBytes(); + size_t storage_rows_read = local_connection->getCHDBProgress().read_rows; + size_t storage_bytes_read = local_connection->getCHDBProgress().read_bytes; + const auto elapsed_time = getElapsedTime(); + auto * output_vec = stealQueryOutputVector(); + bool has_output_data = output_vec && !output_vec->empty(); + if (has_output_data) + { + res = std::make_unique( + CHDB::ResultBuffer(output_vec), + elapsed_time - old_elapsed_time, + processed_rows - old_processed_rows, + processed_bytes - old_processed_bytes, + storage_rows_read - old_storage_rows_read, + storage_bytes_read - old_storage_bytes_read); + } + else + { + delete output_vec; + res = std::make_unique(nullptr, 0.0, 0, 0, 0, 0); + } + } + + bool is_end = !res->getError().empty() || res->rows_read == 0 || is_canceled; + if (is_end) + { + // End of stream reached or cancelled, cleanup + streaming_query_context.reset(); +#if USE_PYTHON + if (connection) + { + auto * local_connection = static_cast(connection.get()); + local_connection->resetQueryContext(); + local_connection->getSession().getPythonTableCache()->clear(); + } +#endif + } + return res; + } + catch (const Exception & e) + { + streaming_query_context.reset(); +#if USE_PYTHON + if (connection) + { + auto * local_connection = static_cast(connection.get()); + local_connection->resetQueryContext(); + } + python_table_cache->clear(); +#endif + return std::make_unique(getExceptionMessage(e, false)); + } + catch (...) + { + streaming_query_context.reset(); +#if USE_PYTHON + if (connection) + { + auto * local_connection = static_cast(connection.get()); + local_connection->resetQueryContext(); + } + python_table_cache->clear(); +#endif + return std::make_unique(getCurrentExceptionMessage(true)); + } +} + +void ChdbClient::cancelStreamingQuery(void * streaming_result) +{ + std::lock_guard lock(client_mutex); + + if (streaming_query_context) + { + try + { + // Process the cancellation through ClientBase's streaming query method + processStreamingQuery(streaming_result, true); + } + catch (...) + { + // Ignore errors during cancellation + tryLogCurrentException(__PRETTY_FUNCTION__); + } + + // Ensure cleanup happens + streaming_query_context.reset(); +#if USE_PYTHON + if (connection) + { + auto * local_connection = static_cast(connection.get()); + local_connection->resetQueryContext(); + } + python_table_cache->clear(); +#endif + } +} + + +} // namespace DB diff --git a/programs/local/ChdbClient.h b/programs/local/ChdbClient.h new file mode 100644 index 00000000000..2b35fa2e5f6 --- /dev/null +++ b/programs/local/ChdbClient.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include +#include +#include "QueryResult.h" + +#include +#include + +namespace DB +{ +class EmbeddedServer; +using EmbeddedServerPtr = std::shared_ptr; + +/** + * ChdbClient - Client for executing queries in chDB + * + * Designed for chDB's embedded use case and inherits from ClientBase + * to reuse all query execution logic. + * Each client has its own LocalConnection. + * Holds a shared_ptr to EmbeddedServer to ensure it stays alive while client exists. + */ +class ChdbClient : public ClientBase +{ +public: + static std::unique_ptr create(EmbeddedServerPtr server_ptr = nullptr); + + explicit ChdbClient(EmbeddedServerPtr server_ptr); + ~ChdbClient() override; + + CHDB::QueryResultPtr executeMaterializedQuery(const char * query, size_t query_len, const char * format, size_t format_len); + + CHDB::QueryResultPtr executeStreamingInit(const char * query, size_t query_len, const char * format, size_t format_len); + + CHDB::QueryResultPtr executeStreamingIterate(void * streaming_result, bool is_canceled = false); + + void cancelStreamingQuery(void * streaming_result); + + bool hasStreamingQuery() const; + + size_t getStorageRowsRead() const; + size_t getStorageBytesRead() const; + +#if USE_PYTHON + void findQueryableObjFromPyCache(const String & query_str) const; +#endif + +protected: + void connect() override; + Poco::Util::LayeredConfiguration & getClientConfiguration() override; + void processError(std::string_view query) const override; + String getName() const override { return "chdb"; } + bool isEmbeeddedClient() const override { return false; } + + void printHelpMessage(const OptionsDescription &) override {} + void addExtraOptions(OptionsDescription &) override {} + void processOptions(const OptionsDescription &, const CommandLineOptions &, + const std::vector &, const std::vector &) override {} + void processConfig() override {} + void setupSignalHandler() override {} + +private: + void cleanup(); + bool parseQueryTextWithOutputFormat(const String & query, const String & format); + + EmbeddedServerPtr server; + std::unique_ptr session; + ConfigurationPtr configuration; + Poco::AutoPtr layered_configuration; + std::unique_ptr input; +#if USE_PYTHON + std::shared_ptr python_table_cache; +#endif + mutable std::mutex client_mutex; +}; + +} // namespace DB diff --git a/programs/local/EmbeddedServer.cpp b/programs/local/EmbeddedServer.cpp new file mode 100644 index 00000000000..c03e803a680 --- /dev/null +++ b/programs/local/EmbeddedServer.cpp @@ -0,0 +1,962 @@ +#include "EmbeddedServer.h" + +#if USE_PYTHON +# include "StoragePython.h" +# include "TableFunctionPython.h" +#else +# include "StorageArrowStream.h" +# include "TableFunctionArrowStream.h" +#endif +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "config.h" + +#if USE_AZURE_BLOB_STORAGE +# include +#endif + + +namespace fs = std::filesystem; + +namespace CurrentMetrics +{ +extern const Metric MemoryTracking; +} + +namespace DB +{ + +namespace Setting +{ +extern const SettingsBool allow_introspection_functions; +extern const SettingsBool implicit_select; +extern const SettingsLocalFSReadMethod storage_file_read_method; +} + +namespace ServerSetting +{ +extern const ServerSettingsUInt32 allow_feature_tier; +extern const ServerSettingsDouble cache_size_to_ram_max_ratio; +extern const ServerSettingsUInt64 compiled_expression_cache_elements_size; +extern const ServerSettingsUInt64 compiled_expression_cache_size; +extern const ServerSettingsUInt64 database_catalog_drop_table_concurrency; +extern const ServerSettingsString default_database; +extern const ServerSettingsString index_mark_cache_policy; +extern const ServerSettingsUInt64 index_mark_cache_size; +extern const ServerSettingsDouble index_mark_cache_size_ratio; +extern const ServerSettingsString index_uncompressed_cache_policy; +extern const ServerSettingsUInt64 index_uncompressed_cache_size; +extern const ServerSettingsDouble index_uncompressed_cache_size_ratio; +extern const ServerSettingsString vector_similarity_index_cache_policy; +extern const ServerSettingsUInt64 vector_similarity_index_cache_size; +extern const ServerSettingsUInt64 vector_similarity_index_cache_max_entries; +extern const ServerSettingsDouble vector_similarity_index_cache_size_ratio; +extern const ServerSettingsUInt64 io_thread_pool_queue_size; +extern const ServerSettingsString mark_cache_policy; +extern const ServerSettingsUInt64 mark_cache_size; +extern const ServerSettingsDouble mark_cache_size_ratio; +extern const ServerSettingsString iceberg_metadata_files_cache_policy; +extern const ServerSettingsUInt64 iceberg_metadata_files_cache_size; +extern const ServerSettingsUInt64 iceberg_metadata_files_cache_max_entries; +extern const ServerSettingsDouble iceberg_metadata_files_cache_size_ratio; +extern const ServerSettingsUInt64 max_active_parts_loading_thread_pool_size; +extern const ServerSettingsUInt64 max_io_thread_pool_free_size; +extern const ServerSettingsUInt64 max_io_thread_pool_size; +extern const ServerSettingsUInt64 max_outdated_parts_loading_thread_pool_size; +extern const ServerSettingsUInt64 max_parts_cleaning_thread_pool_size; +extern const ServerSettingsUInt64 max_server_memory_usage; +extern const ServerSettingsDouble max_server_memory_usage_to_ram_ratio; +extern const ServerSettingsUInt64 max_thread_pool_free_size; +extern const ServerSettingsUInt64 max_thread_pool_size; +extern const ServerSettingsUInt64 max_unexpected_parts_loading_thread_pool_size; +extern const ServerSettingsUInt64 mmap_cache_size; +extern const ServerSettingsBool show_addresses_in_stack_traces; +extern const ServerSettingsUInt64 thread_pool_queue_size; +extern const ServerSettingsString uncompressed_cache_policy; +extern const ServerSettingsUInt64 uncompressed_cache_size; +extern const ServerSettingsDouble uncompressed_cache_size_ratio; +extern const ServerSettingsString primary_index_cache_policy; +extern const ServerSettingsUInt64 primary_index_cache_size; +extern const ServerSettingsDouble primary_index_cache_size_ratio; +extern const ServerSettingsUInt64 max_prefixes_deserialization_thread_pool_size; +extern const ServerSettingsUInt64 max_prefixes_deserialization_thread_pool_free_size; +extern const ServerSettingsUInt64 prefixes_deserialization_thread_pool_thread_pool_queue_size; +extern const ServerSettingsUInt64 max_format_parsing_thread_pool_size; +extern const ServerSettingsUInt64 max_format_parsing_thread_pool_free_size; +extern const ServerSettingsUInt64 format_parsing_thread_pool_queue_size; +extern const ServerSettingsUInt64 memory_worker_period_ms; +extern const ServerSettingsBool memory_worker_correct_memory_tracker; +} + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +extern const int CANNOT_LOAD_CONFIG; +extern const int FILE_ALREADY_EXISTS; +extern const int UNKNOWN_FORMAT; +} + +static void applySettingsOverridesForLocal(ContextMutablePtr context) +{ + Settings settings = context->getSettingsCopy(); + + settings[Setting::allow_introspection_functions] = true; + settings[Setting::storage_file_read_method] = LocalFSReadMethod::mmap; + settings[Setting::implicit_select] = true; + + context->setSettings(settings); +} + +EmbeddedServer::~EmbeddedServer() +{ + cleanup(); +} + +void EmbeddedServer::initialize(Poco::Util::Application & self) +{ + Poco::Util::Application::initialize(self); + + const char * home_path_cstr = getenv("HOME"); // NOLINT(concurrency-mt-unsafe) + if (home_path_cstr) + home_path = home_path_cstr; + + /// Load config files if exists + std::string config_path; + if (config().has("config-file")) + config_path = config().getString("config-file"); + else if (config_path.empty() && fs::exists("config.xml")) + config_path = "config.xml"; + else if (config_path.empty()) + config_path = getLocalConfigPath(home_path).value_or(""); + + if (fs::exists(config_path)) + { + ConfigProcessor config_processor(config_path); + ConfigProcessor::setConfigPath(fs::path(config_path).parent_path()); + auto loaded_config = config_processor.loadConfig(); + config().add(loaded_config.configuration.duplicate(), PRIO_DEFAULT, false); + } + + server_settings.loadSettingsFromConfig(config()); + + GlobalThreadPool::initialize( + server_settings[ServerSetting::max_thread_pool_size], + server_settings[ServerSetting::max_thread_pool_free_size], + server_settings[ServerSetting::thread_pool_queue_size]); + +#if USE_AZURE_BLOB_STORAGE + /// See the explanation near the same line in Server.cpp + GlobalThreadPool::instance().addOnDestroyCallback([] { Azure::Storage::_internal::XmlGlobalDeinitialize(); }); +#endif + +#if defined(OS_LINUX) + memory_worker = std::make_unique( + server_settings[ServerSetting::memory_worker_period_ms], + server_settings[ServerSetting::memory_worker_correct_memory_tracker], + /* use_cgroup */ true, + nullptr); + memory_worker->start(); +#endif + + getIOThreadPool().initialize( + server_settings[ServerSetting::max_io_thread_pool_size], + server_settings[ServerSetting::max_io_thread_pool_free_size], + server_settings[ServerSetting::io_thread_pool_queue_size]); + + const size_t active_parts_loading_threads = server_settings[ServerSetting::max_active_parts_loading_thread_pool_size]; + getActivePartsLoadingThreadPool().initialize( + active_parts_loading_threads, + 0, // We don't need any threads one all the parts will be loaded + active_parts_loading_threads); + + const size_t outdated_parts_loading_threads = server_settings[ServerSetting::max_outdated_parts_loading_thread_pool_size]; + getOutdatedPartsLoadingThreadPool().initialize( + outdated_parts_loading_threads, + 0, // We don't need any threads one all the parts will be loaded + outdated_parts_loading_threads); + + getOutdatedPartsLoadingThreadPool().setMaxTurboThreads(active_parts_loading_threads); + + const size_t unexpected_parts_loading_threads = server_settings[ServerSetting::max_unexpected_parts_loading_thread_pool_size]; + getUnexpectedPartsLoadingThreadPool().initialize( + unexpected_parts_loading_threads, + 0, // We don't need any threads one all the parts will be loaded + unexpected_parts_loading_threads); + + getUnexpectedPartsLoadingThreadPool().setMaxTurboThreads(active_parts_loading_threads); + + const size_t cleanup_threads = server_settings[ServerSetting::max_parts_cleaning_thread_pool_size]; + getPartsCleaningThreadPool().initialize( + cleanup_threads, + 0, // We don't need any threads one all the parts will be deleted + cleanup_threads); + + getDatabaseCatalogDropTablesThreadPool().initialize( + server_settings[ServerSetting::database_catalog_drop_table_concurrency], + 0, // We don't need any threads if there are no DROP queries. + server_settings[ServerSetting::database_catalog_drop_table_concurrency]); + + getMergeTreePrefixesDeserializationThreadPool().initialize( + server_settings[ServerSetting::max_prefixes_deserialization_thread_pool_size], + server_settings[ServerSetting::max_prefixes_deserialization_thread_pool_free_size], + server_settings[ServerSetting::prefixes_deserialization_thread_pool_thread_pool_queue_size]); + + getFormatParsingThreadPool().initialize( + server_settings[ServerSetting::max_format_parsing_thread_pool_size], + server_settings[ServerSetting::max_format_parsing_thread_pool_free_size], + server_settings[ServerSetting::format_parsing_thread_pool_queue_size]); +} + +static DatabasePtr createMemoryDatabaseIfNotExists(ContextPtr context, const String & database_name) +{ + DatabasePtr system_database = DatabaseCatalog::instance().tryGetDatabase(database_name); + if (!system_database) + { + /// TODO: add attachTableDelayed into DatabaseMemory to speedup loading + system_database = std::make_shared(database_name, context); + DatabaseCatalog::instance().attachDatabase(database_name, system_database); + } + return system_database; +} + +static DatabasePtr createClickHouseLocalDatabaseOverlay(const String & name_, ContextPtr context) +{ + auto overlay = std::make_shared(name_, context); + + UUID default_database_uuid; + + fs::path existing_path_symlink = fs::weakly_canonical(context->getPath()) / "metadata" / "default"; + if (FS::isSymlinkNoThrow(existing_path_symlink)) + default_database_uuid = parse(FS::readSymlink(existing_path_symlink).filename()); + else + default_database_uuid = UUIDHelpers::generateV4(); + + fs::path default_database_metadata_path + = fs::weakly_canonical(context->getPath()) / "store" / DatabaseCatalog::getPathForUUID(default_database_uuid); + + overlay->registerNextDatabase(std::make_shared(name_, default_database_metadata_path, default_database_uuid, context)); + overlay->registerNextDatabase(std::make_shared(name_, "", context)); + return overlay; +} + +/// If path is specified and not empty, will try to setup server environment and load existing metadata +void EmbeddedServer::tryInitPath() +{ + std::string path; + + if (config().has("path")) + { + /// User-supplied path. + path = config().getString("path"); + Poco::trimInPlace(path); + + if (path.empty()) + { + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "Cannot work with empty storage path that is explicitly specified" + " by the --path option. Please check the program options and" + " correct the --path."); + } + } + else + { + /// The user requested to use a temporary path - use a unique path in the system temporary directory + /// (or in the current dir if a temporary doesn't exist) + LoggerRawPtr log = &logger(); + std::filesystem::path parent_folder; + std::filesystem::path default_path; + + try + { + /// Try to guess a tmp folder name, and check if it's a directory (throw an exception otherwise). + parent_folder = std::filesystem::temp_directory_path(); + } + catch (const fs::filesystem_error & e) + { + // The tmp folder doesn't exist? Is it a misconfiguration? Or chroot? + LOG_DEBUG(log, "Can not get temporary folder: {}", e.what()); + parent_folder = std::filesystem::current_path(); + + std::filesystem::is_directory(parent_folder); // that will throw an exception if it's not a directory + LOG_DEBUG(log, "Will create working directory inside current directory: {}", parent_folder.string()); + } + + /// we can have another clickhouse-embedded running simultaneously, even with the same PID (for ex. - several dockers mounting the same folder) + /// or it can be some leftovers from other clickhouse-embedded runs + /// as we can't accurately distinguish those situations we don't touch any existent folders + /// we just try to pick some free name for our working folder + + default_path = parent_folder / fmt::format("clickhouse-embedded-{}", UUIDHelpers::generateV4()); + + if (fs::exists(default_path)) + throw Exception( + ErrorCodes::FILE_ALREADY_EXISTS, + "Unsuccessful attempt to set up the working directory: {} already exists.", + default_path.string()); + + /// The directory can be created lazily during the runtime. + temporary_directory_to_delete = default_path; + + path = default_path.string(); + + LOG_DEBUG(log, "Working directory will be created as needed: {}", path); + + config().setString("path", path); + } + fs::create_directories(path); + + global_context->setPath(fs::path(path) / ""); + DatabaseCatalog::instance().fixPath(global_context->getPath()); + + global_context->setTemporaryStoragePath(fs::path(path) / "tmp" / "", 0); + global_context->setFlagsPath(fs::path(path) / "flags" / ""); + + global_context->setUserFilesPath(""); /// user's files are everywhere + + std::string user_scripts_path = config().getString("user_scripts_path", fs::path(path) / "user_scripts" / ""); + global_context->setUserScriptsPath(user_scripts_path); + + /// Set path for filesystem caches + String filesystem_caches_path(config().getString("filesystem_caches_path", fs::path(path) / "cache" / "")); + if (!filesystem_caches_path.empty()) + global_context->setFilesystemCachesPath(filesystem_caches_path); + + /// top_level_domains_lists + const std::string & top_level_domains_path = config().getString("top_level_domains_path", fs::path(path) / "top_level_domains/"); + if (!top_level_domains_path.empty()) + TLDListsHolder::getInstance().parseConfig(fs::path(top_level_domains_path) / "", config()); +} + + +void EmbeddedServer::cleanup() +{ + try + { + if (global_context) + { + global_context->shutdown(); + global_context.reset(); + } + status.reset(); + + // Delete the temporary directory if needed. + if (temporary_directory_to_delete) + { + LOG_DEBUG(&logger(), "Removing temporary directory: {}", temporary_directory_to_delete->string()); + fs::remove_all(*temporary_directory_to_delete); + temporary_directory_to_delete.reset(); + } + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + +static ConfigurationPtr getConfigurationFromXMLString(const char * xml_data) +{ + std::stringstream ss{std::string{xml_data}}; // STYLE_CHECK_ALLOW_STD_STRING_STREAM + Poco::XML::InputSource input_source{ss}; + return {new Poco::Util::XMLConfiguration{&input_source}}; +} + + +void EmbeddedServer::setupUsers() +{ + static const char * minimal_default_user_xml = "" + " " + " " + " " + " " + " " + " " + " " + " ::/0" + " " + " default" + " default" + " 1" + " " + " " + " " + " " + " " + ""; + + ConfigurationPtr users_config; + auto & access_control = global_context->getAccessControl(); + access_control.setNoPasswordAllowed(config().getBool("allow_no_password", true)); + access_control.setPlaintextPasswordAllowed(config().getBool("allow_plaintext_password", true)); + if (config().has("config-file") || fs::exists("config.xml")) + { + String config_path = config().getString("config-file", ""); + bool has_user_directories = config().has("user_directories"); + const auto config_dir = fs::path{config_path}.remove_filename().string(); + String users_config_path = config().getString("users_config", ""); + + if (users_config_path.empty() && has_user_directories) + { + users_config_path = config().getString("user_directories.users_xml.path"); + if (fs::path(users_config_path).is_relative() && fs::exists(fs::path(config_dir) / users_config_path)) + users_config_path = fs::path(config_dir) / users_config_path; + } + + if (users_config_path.empty()) + users_config = getConfigurationFromXMLString(minimal_default_user_xml); + else + { + ConfigProcessor config_processor(users_config_path); + const auto loaded_config = config_processor.loadConfig(); + users_config = loaded_config.configuration; + } + } + else + users_config = getConfigurationFromXMLString(minimal_default_user_xml); + if (users_config) + { + global_context->setUsersConfig(users_config); + // NamedCollectionUtils::loadIfNot(); + } + else + throw Exception(ErrorCodes::CANNOT_LOAD_CONFIG, "Can't load config for users"); +} + +int EmbeddedServer::main(const std::vector & /*args*/) +try +{ + StackTrace::setShowAddresses(server_settings[ServerSetting::show_addresses_in_stack_traces]); + std::cout << std::fixed << std::setprecision(3); + std::cerr << std::fixed << std::setprecision(3); + + /// Try to increase limit on number of open files. + { + rlimit rlim; + if (getrlimit(RLIMIT_NOFILE, &rlim)) + throw Poco::Exception("Cannot getrlimit"); + + if (rlim.rlim_cur < rlim.rlim_max) + { + rlim.rlim_cur = config().getUInt("max_open_files", static_cast(rlim.rlim_max)); + int rc = setrlimit(RLIMIT_NOFILE, &rlim); + if (rc != 0) + std::cerr << fmt::format( + "Cannot set max number of file descriptors to {}. Try to specify max_open_files according to your system limits. " + "error: {}", + rlim.rlim_cur, + errnoToString()) + << '\n'; + } + } + + std::call_once( + global_register_once_flag, + []() + { + registerInterpreters(); + /// Don't initialize DateLUT + registerFunctions(); + registerAggregateFunctions(); + + registerTableFunctions(); + auto & table_function_factory = TableFunctionFactory::instance(); +#if USE_PYTHON + registerTableFunctionPython(table_function_factory); +#else + registerTableFunctionArrowStream(table_function_factory); +#endif + + registerDatabases(); + registerStorages(); +#if USE_PYTHON + auto & storage_factory = StorageFactory::instance(); + registerStoragePython(storage_factory); +#endif + + registerDictionaries(); + registerDisks(/* global_skip_access_check= */ true); + registerFormats(); + }); + + processConfig(); + /// try to load user defined executable functions, throw on error and die + try + { + global_context->loadOrReloadUserDefinedExecutableFunctions(config()); + } + catch (...) + { + tryLogCurrentException(&logger(), "Caught exception while loading user defined executable functions."); + throw; + } + +#if USE_FUZZING_MODE + runLibFuzzer(); +#endif + + return Application::EXIT_OK; +} +catch (DB::Exception & e) +{ + bool need_print_stack_trace = config().getBool("stacktrace", false); + std::cerr << getExceptionMessageForLogging(e, need_print_stack_trace, true) << std::endl; + auto code = DB::getCurrentExceptionCode(); + return static_cast(code) ? code : 1; +} +catch (...) +{ + error_message_oss << DB::getCurrentExceptionMessage(true) << '\n'; + auto code = DB::getCurrentExceptionCode(); + return static_cast(code) ? code : 1; +} + +void EmbeddedServer::processConfig() +{ + + auto logging + = (config().has("logger.console") || config().has("logger.level") + || config().has("log-level") || config().has("logger.log")); + + auto level = config().getString("log-level", config().getString("send_logs_level", "trace")); + config().setString("logger", "logger"); + config().setString("logger.level", logging ? level : "fatal"); + buildLoggers(config(), logger(), "clickhouse-embedded"); + shared_context = Context::createSharedHolder(); + global_context = Context::createGlobal(shared_context.get()); + global_context->makeGlobalContext(); + global_context->setApplicationType(Context::ApplicationType::LOCAL); + + tryInitPath(); + + LoggerRawPtr log = &logger(); + + /// Maybe useless + if (config().has("macros")) + global_context->setMacros(std::make_unique(config(), "macros", log)); + + /// Sets external authenticators config (LDAP, Kerberos). + global_context->setExternalAuthenticatorsConfig(config()); + + setupUsers(); + + /// Limit on total number of concurrently executing queries. + /// There is no need for concurrent queries, override max_concurrent_queries. + global_context->getProcessList().setMaxSize(0); + + size_t max_server_memory_usage = server_settings[ServerSetting::max_server_memory_usage]; + const double max_server_memory_usage_to_ram_ratio = server_settings[ServerSetting::max_server_memory_usage_to_ram_ratio]; + const size_t physical_server_memory = getMemoryAmount(); + const size_t default_max_server_memory_usage = static_cast(physical_server_memory * max_server_memory_usage_to_ram_ratio); + + if (max_server_memory_usage == 0) + { + max_server_memory_usage = default_max_server_memory_usage; + LOG_INFO( + log, + "Changed setting 'max_server_memory_usage' to {}" + " ({} available memory * {:.2f} max_server_memory_usage_to_ram_ratio)", + formatReadableSizeWithBinarySuffix(max_server_memory_usage), + formatReadableSizeWithBinarySuffix(physical_server_memory), + max_server_memory_usage_to_ram_ratio); + } + else if (max_server_memory_usage > default_max_server_memory_usage) + { + max_server_memory_usage = default_max_server_memory_usage; + LOG_INFO( + log, + "Lowered setting 'max_server_memory_usage' to {}" + " because the system has little few memory. The new value was" + " calculated as {} available memory * {:.2f} max_server_memory_usage_to_ram_ratio", + formatReadableSizeWithBinarySuffix(max_server_memory_usage), + formatReadableSizeWithBinarySuffix(physical_server_memory), + max_server_memory_usage_to_ram_ratio); + } + + total_memory_tracker.setHardLimit(max_server_memory_usage); + total_memory_tracker.setDescription("(total)"); + total_memory_tracker.setMetric(CurrentMetrics::MemoryTracking); + + const double cache_size_to_ram_max_ratio = server_settings[ServerSetting::cache_size_to_ram_max_ratio]; + const size_t max_cache_size = static_cast(physical_server_memory * cache_size_to_ram_max_ratio); + + String uncompressed_cache_policy = server_settings[ServerSetting::uncompressed_cache_policy]; + size_t uncompressed_cache_size = server_settings[ServerSetting::uncompressed_cache_size]; + double uncompressed_cache_size_ratio = server_settings[ServerSetting::uncompressed_cache_size_ratio]; + if (uncompressed_cache_size > max_cache_size) + { + uncompressed_cache_size = max_cache_size; + LOG_DEBUG( + log, + "Lowered uncompressed cache size to {} because the system has limited RAM", + formatReadableSizeWithBinarySuffix(uncompressed_cache_size)); + } + global_context->setUncompressedCache(uncompressed_cache_policy, uncompressed_cache_size, uncompressed_cache_size_ratio); + + String mark_cache_policy = server_settings[ServerSetting::mark_cache_policy]; + size_t mark_cache_size = server_settings[ServerSetting::mark_cache_size]; + double mark_cache_size_ratio = server_settings[ServerSetting::mark_cache_size_ratio]; + if (!mark_cache_size) + LOG_ERROR(log, "Too low mark cache size will lead to severe performance degradation."); + if (mark_cache_size > max_cache_size) + { + mark_cache_size = max_cache_size; + LOG_DEBUG( + log, "Lowered mark cache size to {} because the system has limited RAM", formatReadableSizeWithBinarySuffix(mark_cache_size)); + } + global_context->setMarkCache(mark_cache_policy, mark_cache_size, mark_cache_size_ratio); + + String index_uncompressed_cache_policy = server_settings[ServerSetting::index_uncompressed_cache_policy]; + size_t index_uncompressed_cache_size = server_settings[ServerSetting::index_uncompressed_cache_size]; + double index_uncompressed_cache_size_ratio = server_settings[ServerSetting::index_uncompressed_cache_size_ratio]; + if (index_uncompressed_cache_size > max_cache_size) + { + index_uncompressed_cache_size = max_cache_size; + LOG_INFO( + log, + "Lowered index uncompressed cache size to {} because the system has limited RAM", + formatReadableSizeWithBinarySuffix(index_uncompressed_cache_size)); + } + global_context->setIndexUncompressedCache( + index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio); + + String index_mark_cache_policy = server_settings[ServerSetting::index_mark_cache_policy]; + size_t index_mark_cache_size = server_settings[ServerSetting::index_mark_cache_size]; + double index_mark_cache_size_ratio = server_settings[ServerSetting::index_mark_cache_size_ratio]; + if (index_mark_cache_size > max_cache_size) + { + index_mark_cache_size = max_cache_size; + LOG_INFO( + log, + "Lowered index mark cache size to {} because the system has limited RAM", + formatReadableSizeWithBinarySuffix(index_mark_cache_size)); + } + global_context->setIndexMarkCache(index_mark_cache_policy, index_mark_cache_size, index_mark_cache_size_ratio); + + String primary_index_cache_policy = server_settings[ServerSetting::primary_index_cache_policy]; + size_t primary_index_cache_size = server_settings[ServerSetting::primary_index_cache_size]; + double primary_index_cache_size_ratio = server_settings[ServerSetting::primary_index_cache_size_ratio]; + if (primary_index_cache_size > max_cache_size) + { + primary_index_cache_size = max_cache_size; + LOG_INFO( + log, + "Lowered primary index cache size to {} because the system has limited RAM", + formatReadableSizeWithBinarySuffix(primary_index_cache_size)); + } + global_context->setPrimaryIndexCache(primary_index_cache_policy, primary_index_cache_size, primary_index_cache_size_ratio); + + String vector_similarity_index_cache_policy = server_settings[ServerSetting::vector_similarity_index_cache_policy]; + size_t vector_similarity_index_cache_size = server_settings[ServerSetting::vector_similarity_index_cache_size]; + size_t vector_similarity_index_cache_max_count = server_settings[ServerSetting::vector_similarity_index_cache_max_entries]; + double vector_similarity_index_cache_size_ratio = server_settings[ServerSetting::vector_similarity_index_cache_size_ratio]; + if (vector_similarity_index_cache_size > max_cache_size) + { + vector_similarity_index_cache_size = max_cache_size; + LOG_INFO( + log, + "Lowered vector similarity index cache size to {} because the system has limited RAM", + formatReadableSizeWithBinarySuffix(vector_similarity_index_cache_size)); + } + global_context->setVectorSimilarityIndexCache( + vector_similarity_index_cache_policy, + vector_similarity_index_cache_size, + vector_similarity_index_cache_max_count, + vector_similarity_index_cache_size_ratio); + + size_t mmap_cache_size = server_settings[ServerSetting::mmap_cache_size]; + if (mmap_cache_size > max_cache_size) + { + mmap_cache_size = max_cache_size; + LOG_INFO( + log, + "Lowered mmap file cache size to {} because the system has limited RAM", + formatReadableSizeWithBinarySuffix(mmap_cache_size)); + } + global_context->setMMappedFileCache(mmap_cache_size); + +#if USE_AVRO + String iceberg_metadata_files_cache_policy = server_settings[ServerSetting::iceberg_metadata_files_cache_policy]; + size_t iceberg_metadata_files_cache_size = server_settings[ServerSetting::iceberg_metadata_files_cache_size]; + size_t iceberg_metadata_files_cache_max_entries = server_settings[ServerSetting::iceberg_metadata_files_cache_max_entries]; + double iceberg_metadata_files_cache_size_ratio = server_settings[ServerSetting::iceberg_metadata_files_cache_size_ratio]; + if (iceberg_metadata_files_cache_size > max_cache_size) + { + iceberg_metadata_files_cache_size = max_cache_size; + LOG_INFO( + log, + "Lowered Iceberg metadata cache size to {} because the system has limited RAM", + formatReadableSizeWithBinarySuffix(iceberg_metadata_files_cache_size)); + } + global_context->setIcebergMetadataFilesCache( + iceberg_metadata_files_cache_policy, + iceberg_metadata_files_cache_size, + iceberg_metadata_files_cache_max_entries, + iceberg_metadata_files_cache_size_ratio); +#endif + + /// Initialize a dummy query condition cache. + global_context->setQueryConditionCache(DEFAULT_QUERY_CONDITION_CACHE_POLICY, 0, 0); + + /// Initialize a dummy query result cache. + global_context->setQueryResultCache(0, 0, 0, 0); + + /// Initialize allowed tiers + global_context->getAccessControl().setAllowTierSettings(server_settings[ServerSetting::allow_feature_tier]); + +#if USE_EMBEDDED_COMPILER + size_t compiled_expression_cache_max_size_in_bytes = server_settings[ServerSetting::compiled_expression_cache_size]; + size_t compiled_expression_cache_max_elements = server_settings[ServerSetting::compiled_expression_cache_elements_size]; + CompiledExpressionCacheFactory::instance().init(compiled_expression_cache_max_size_in_bytes, compiled_expression_cache_max_elements); +#endif + + NamedCollectionFactory::instance().loadIfNot(); + FileCacheFactory::instance().loadDefaultCaches(config(), global_context); + applySettingsOverridesForLocal(global_context); + applyCmdOptions(global_context); + + /// Load global settings from default_profile and system_profile. + global_context->setDefaultProfiles(config()); + /// We load temporary database first, because projections need it. + DatabaseCatalog::instance().initializeAndLoadTemporaryDatabase(); + + std::string server_default_database = server_settings[ServerSetting::default_database]; + if (!server_default_database.empty()) + { + DatabasePtr database = createClickHouseLocalDatabaseOverlay(server_default_database, global_context); + if (UUID uuid = database->getUUID(); uuid != UUIDHelpers::Nil) + DatabaseCatalog::instance().addUUIDMapping(uuid); + DatabaseCatalog::instance().attachDatabase(server_default_database, database); + global_context->setCurrentDatabase(server_default_database); + } + + if (config().has("path")) + { + attachInformationSchema(global_context, *createMemoryDatabaseIfNotExists(global_context, DatabaseCatalog::INFORMATION_SCHEMA)); + attachInformationSchema( + global_context, *createMemoryDatabaseIfNotExists(global_context, DatabaseCatalog::INFORMATION_SCHEMA_UPPERCASE)); + + /// Attaching "automatic" tables in the system database is done after attaching the system database. + /// Consequently, it depends on whether we load it from the path. + /// If it is loaded from a user-specified path, we load it as usual. If not, we create it as a memory (ephemeral) database. + bool attached_system_database = false; + + String path = global_context->getPath(); + + /// Lock path directory before read + fs::create_directories(fs::path(path)); + status.emplace(fs::path(path) / "status", StatusFile::write_full_info); + + if (fs::exists(fs::path(path) / "metadata")) + { + LOG_DEBUG(log, "Loading metadata from {}", path); + + if (fs::exists(std::filesystem::path(path) / "metadata" / "system.sql")) + { + LoadTaskPtrs load_system_metadata_tasks = loadMetadataSystem(global_context); + waitLoad(TablesLoaderForegroundPoolId, load_system_metadata_tasks); + + attachSystemTablesServer( + global_context, *DatabaseCatalog::instance().tryGetDatabase(DatabaseCatalog::SYSTEM_DATABASE), false); + attached_system_database = true; + } + + if (!config().has("only-system-tables")) + { + DatabaseCatalog::instance().loadMarkedAsDroppedTables(); + DatabaseCatalog::instance().createBackgroundTasks(); + waitLoad(loadMetadata(global_context)); + DatabaseCatalog::instance().startupBackgroundTasks(); + } + + LOG_DEBUG(log, "Loaded metadata."); + } + + if (!attached_system_database) + attachSystemTablesServer( + global_context, *createMemoryDatabaseIfNotExists(global_context, DatabaseCatalog::SYSTEM_DATABASE), false); + + if (fs::exists(fs::path(path) / "user_defined")) + global_context->getUserDefinedSQLObjectsStorage().loadObjects(); + } + else if (!config().has("no-system-tables")) + { + attachSystemTablesServer(global_context, *createMemoryDatabaseIfNotExists(global_context, DatabaseCatalog::SYSTEM_DATABASE), false); + attachInformationSchema(global_context, *createMemoryDatabaseIfNotExists(global_context, DatabaseCatalog::INFORMATION_SCHEMA)); + attachInformationSchema( + global_context, *createMemoryDatabaseIfNotExists(global_context, DatabaseCatalog::INFORMATION_SCHEMA_UPPERCASE)); + + /// Create background tasks necessary for DDL operations like DROP VIEW SYNC, + /// even in temporary mode (--path not set) without persistent storage + DatabaseCatalog::instance().createBackgroundTasks(); + DatabaseCatalog::instance().startupBackgroundTasks(); + } + else + { + /// Similarly, for other cases, create background tasks for DDL operations like + /// DROP VIEW SYNC in temporaty mode (--path not set) without persistent storage + DatabaseCatalog::instance().createBackgroundTasks(); + DatabaseCatalog::instance().startupBackgroundTasks(); + } + + std::string default_database = config().getString("database", server_default_database); + if (default_database.empty()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "default_database cannot be empty"); + global_context->setCurrentDatabase(default_database); +} + +void EmbeddedServer::applyCmdOptions(ContextMutablePtr context) +{ + context->setDefaultFormat( + config().getString( + "output-format", config().getString("format", "TSV"))); +} + +std::weak_ptr EmbeddedServer::global_instance; +std::mutex EmbeddedServer::instance_mutex; + +std::shared_ptr EmbeddedServer::getInstance(int argc, char ** argv) +{ + std::lock_guard lock(instance_mutex); + + auto instance = global_instance.lock(); + if (instance) + { + if (argc > 0 && argv) + { + std::string path = ":memory:"; // Default path + for (int i = 1; i < argc; i++) + { + if (strncmp(argv[i], "--path=", 7) == 0) + { + path = argv[i] + 7; + break; + } + } + if (!instance->db_path.empty() && instance->db_path != path) + { + throw DB::Exception( + ErrorCodes::BAD_ARGUMENTS, + "EmbeddedServer already initialized with path '{}', cannot connect with different path '{}'", + instance->db_path, + path); + } + } + return instance; + } + + instance = std::make_shared(); + if (argc == 0 || !argv) + { + const char * default_argv[] = {"chdb"}; + instance->initializeWithArgs(1, const_cast(default_argv)); + } + else + { + instance->initializeWithArgs(argc, argv); + } + + global_instance = instance; + return instance; +} + +void EmbeddedServer::initializeWithArgs(int argc, char ** argv) +{ + db_path = ":memory:"; // Default path + for (int i = 1; i < argc; i++) + { + if (strncmp(argv[i], "--path=", 7) == 0) + { + db_path = argv[i] + 7; + break; + } + } + + try + { + std::vector args; + for (int i = 0; i < argc; ++i) + { + args.push_back(argv[i]); + } + + Poco::Util::Application::ArgVec arg_vec; + for (const auto & arg : args) + { + arg_vec.push_back(arg); + } + argsToConfig(arg_vec, config(), 100); + + initialize(*this); + int ret = main(args); + if (ret != 0) + { + auto err_msg = getErrorMsg(); + LOG_ERROR(&logger(), "Error initializing EmbeddedServer: {}", err_msg); + throw DB::Exception(ErrorCodes::BAD_ARGUMENTS, "Error initializing EmbeddedServer: {}", err_msg); + } + } + catch (const std::exception & e) + { + LOG_ERROR(&Poco::Logger::get("EmbeddedServer"), "Failed to initialize EmbeddedServer: {}", e.what()); + throw; + } +} +} // namespace DB diff --git a/programs/local/EmbeddedServer.h b/programs/local/EmbeddedServer.h new file mode 100644 index 00000000000..8297e532238 --- /dev/null +++ b/programs/local/EmbeddedServer.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + + +namespace DB +{ + +/// Lightweight Application for embeeded server +/// No networking, no extra configs and working directories, no pid and status files, no dictionaries, no logging. +/// Quiet mode by default +/// +/// EmbeddedServer is managed via shared_ptr by ChdbClient instances. +/// When the last ChdbClient is destroyed, the EmbeddedServer is automatically destroyed. +/// Only one EmbeddedServer instance can exist globally at a time. +class EmbeddedServer : public Poco::Util::Application, public IHints<2>, public Loggers +{ +public: + EmbeddedServer() = default; + + ~EmbeddedServer() override; + + void initialize(Poco::Util::Application & self) override; + + int main(const std::vector & /*args*/) override; + + std::vector getAllRegisteredNames() const override { return {}; } + + ContextMutablePtr getGlobalContext() { return global_context; } + + std::string getErrorMsg() const { return error_message_oss.str(); } + + static std::shared_ptr getInstance(int argc = 0, char ** argv = nullptr); + + std::string getPath() const { return db_path; } + +private: + void tryInitPath(); + void setupUsers(); + void cleanup(); + void processConfig(); + void applyCmdOptions(ContextMutablePtr context); + void initializeWithArgs(int argc, char ** argv); + static std::weak_ptr global_instance; + static std::mutex instance_mutex; + std::string db_path; + ServerSettings server_settings; + std::optional status; + std::optional temporary_directory_to_delete; + std::unique_ptr memory_worker; + ContextMutablePtr global_context; + String home_path; + std::stringstream error_message_oss; + SharedPtrContextHolder shared_context; +}; +} // namespace DB diff --git a/programs/local/FormatHelper.cpp b/programs/local/FormatHelper.cpp deleted file mode 100644 index 894558a5cb1..00000000000 --- a/programs/local/FormatHelper.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "FormatHelper.h" - -#include -#include -#include - -namespace CHDB { - -static bool is_json_supported = true; - -void SetCurrentFormat(const char * format, size_t format_len) -{ - if (format) - { - String lower_format{format, format_len}; - std::transform(lower_format.begin(), lower_format.end(), lower_format.begin(), ::tolower); - - is_json_supported - = !(lower_format == "arrow" || lower_format == "parquet" || lower_format == "arrowstream" || lower_format == "protobuf" - || lower_format == "protobuflist" || lower_format == "protobufsingle"); - - return; - } - - is_json_supported = true; -} - -bool isJSONSupported() -{ - return is_json_supported; -} - -} // namespace CHDB diff --git a/programs/local/FormatHelper.h b/programs/local/FormatHelper.h deleted file mode 100644 index bd02d1b95c1..00000000000 --- a/programs/local/FormatHelper.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -namespace CHDB { - -void SetCurrentFormat(const char * format, size_t format_len); - -bool isJSONSupported(); - -} // namespace CHDB diff --git a/programs/local/LocalChdb.cpp b/programs/local/LocalChdb.cpp index a35f6669c56..93da76786fb 100644 --- a/programs/local/LocalChdb.cpp +++ b/programs/local/LocalChdb.cpp @@ -21,6 +21,7 @@ extern bool inside_main = true; namespace CHDB { extern chdb_connection * connect_chdb_with_exception(int argc, char ** argv); +extern void cachePythonTablesFromQuery(chdb_conn * conn, const std::string & query_str); } const static char * CURSOR_DEFAULT_FORMAT = "JSONCompactEachRowWithNamesAndTypes"; @@ -265,8 +266,7 @@ void connection_wrapper::commit() query_result * connection_wrapper::query(const std::string & query_str, const std::string & format) { - CHDB::PythonTableCache::findQueryableObjFromQuery(query_str); - + CHDB::cachePythonTablesFromQuery(reinterpret_cast(*conn), query_str); py::gil_scoped_release release; auto * result = chdb_query_n(*conn, query_str.data(), query_str.size(), format.data(), format.size()); if (chdb_result_length(result)) @@ -286,8 +286,7 @@ query_result * connection_wrapper::query(const std::string & query_str, const st streaming_query_result * connection_wrapper::send_query(const std::string & query_str, const std::string & format) { - CHDB::PythonTableCache::findQueryableObjFromQuery(query_str); - + CHDB::cachePythonTablesFromQuery(reinterpret_cast(*conn), query_str); py::gil_scoped_release release; auto * result = chdb_stream_query_n(*conn, query_str.data(), query_str.size(), format.data(), format.size()); auto error_msg = CHDB::chdb_result_error_string(result); @@ -337,8 +336,7 @@ void connection_wrapper::streaming_cancel_query(streaming_query_result * streami void cursor_wrapper::execute(const std::string & query_str) { release_result(); - CHDB::PythonTableCache::findQueryableObjFromQuery(query_str); - + CHDB::cachePythonTablesFromQuery(reinterpret_cast(conn->get_conn()), query_str); // Use JSONCompactEachRowWithNamesAndTypes format for better type support py::gil_scoped_release release; current_result = chdb_query_n(conn->get_conn(), query_str.data(), query_str.size(), CURSOR_DEFAULT_FORMAT, CURSOR_DEFAULT_FORMAT_LEN); @@ -511,13 +509,11 @@ PYBIND11_MODULE(_chdb, m) py::arg("udf_path") = "", "Query chDB and return a query_result object"); - auto destroy_import_cache = []() + auto destroy_import_cache = []() { - CHDB::chdbCleanupConnection(); - CHDB::PythonTableCache::clear(); - CHDB::PythonImporter::destroy(); - }; - m.add_object("_destroy_import_cache", py::capsule(destroy_import_cache)); + CHDB::PythonImporter::destroy(); + }; + m.add_object("_destroy_import_cache", py::capsule(destroy_import_cache)); } # endif // PY_TEST_MAIN diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index af5590c8186..a91b04d05db 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -154,7 +154,7 @@ namespace ErrorCodes extern const int UNKNOWN_FORMAT; } -void applySettingsOverridesForLocal(ContextMutablePtr context) +static void applySettingsOverridesForLocal(ContextMutablePtr context) { Settings settings = context->getSettingsCopy(); @@ -429,8 +429,6 @@ void LocalServer::cleanup() { try { - cleanStreamingQuery(); - connection.reset(); /// Suggestions are loaded async in a separate thread and it can use global context. @@ -647,38 +645,38 @@ try Poco::ErrorHandler::set(&error_handler); } - // run only once - static std::once_flag register_once_flag; - std::call_once(register_once_flag, []() - { - registerInterpreters(); - /// Don't initialize DateLUT - registerFunctions(); - registerAggregateFunctions(); + std::call_once( + global_register_once_flag, + []() + { + registerInterpreters(); + /// Don't initialize DateLUT + registerFunctions(); + registerAggregateFunctions(); - registerTableFunctions(); + registerTableFunctions(); - auto & table_function_factory = TableFunctionFactory::instance(); + auto & table_function_factory = TableFunctionFactory::instance(); #if USE_PYTHON - registerTableFunctionPython(table_function_factory); + registerTableFunctionPython(table_function_factory); #else - registerTableFunctionArrowStream(table_function_factory); + registerTableFunctionArrowStream(table_function_factory); #endif - registerDatabases(); + registerDatabases(); - registerStorages(); - auto & storage_factory = StorageFactory::instance(); + registerStorages(); + auto & storage_factory = StorageFactory::instance(); #if USE_PYTHON - registerStoragePython(storage_factory); + registerStoragePython(storage_factory); #else - registerStorageArrowStream(storage_factory); + registerStorageArrowStream(storage_factory); #endif - registerDictionaries(); - registerDisks(/* global_skip_access_check= */ true); - registerFormats(); - }); + registerDictionaries(); + registerDisks(/* global_skip_access_check= */ true); + registerFormats(); + }); processConfig(); @@ -805,12 +803,10 @@ void LocalServer::processConfig() getClientConfiguration().setString("logger.level", logging ? level : "fatal"); buildLoggers(getClientConfiguration(), logger(), "clickhouse-local"); } - shared_context = Context::createSharedHolder(); global_context = Context::createGlobal(shared_context.get()); global_context->makeGlobalContext(); global_context->setApplicationType(Context::ApplicationType::LOCAL); - tryInitPath(); LoggerRawPtr log = &logger(); @@ -1226,16 +1222,6 @@ void LocalServer::readArguments(int argc, char ** argv, Arguments & common_argum } } } - - -void LocalServer::cleanStreamingQuery() -{ - if (streaming_query_context && streaming_query_context->streaming_result) - CHDB::cancelStreamQuery(this, streaming_query_context->streaming_result); - - streaming_query_context.reset(); -} - } #pragma clang diagnostic ignored "-Wunused-function" diff --git a/programs/local/LocalServer.h b/programs/local/LocalServer.h index 42defd1d3c7..43f2e66f70c 100644 --- a/programs/local/LocalServer.h +++ b/programs/local/LocalServer.h @@ -90,8 +90,6 @@ class LocalServer : public ClientApplicationBase, public Loggers } private: - void cleanStreamingQuery(); - std::unique_ptr memory_worker; }; diff --git a/programs/local/PandasDataFrame.cpp b/programs/local/PandasDataFrame.cpp index c0baa6659dc..64cdf2e7e0a 100644 --- a/programs/local/PandasDataFrame.cpp +++ b/programs/local/PandasDataFrame.cpp @@ -1,5 +1,4 @@ #include "PandasDataFrame.h" -#include "FormatHelper.h" #include "NumpyType.h" #include "PandasAnalyzer.h" #include "PandasCacheItem.h" @@ -69,7 +68,7 @@ static DataTypePtr inferDataTypeFromPandasColumn(PandasBindColumn & column, Cont if (numpy_type.type == NumpyNullableType::OBJECT) { - if (!isJSONSupported()) + if (!context->getQueryContext() || !context->getQueryContext()->isJSONSupported()) { numpy_type.type = NumpyNullableType::STRING; return NumpyToDataType(numpy_type); diff --git a/programs/local/PythonDict.cpp b/programs/local/PythonDict.cpp index f4356f606df..19e4f19736d 100644 --- a/programs/local/PythonDict.cpp +++ b/programs/local/PythonDict.cpp @@ -1,5 +1,4 @@ #include "PythonDict.h" -#include "FormatHelper.h" #include "StoragePython.h" #include @@ -38,7 +37,7 @@ ColumnsDescription PythonDict::getActualTableStructure(const py::object & object { py::handle element = values[0]; - if (isJSONSupported() && py::isinstance(element)) + if (context->getQueryContext() && context->getQueryContext()->isJSONSupported() && py::isinstance(element)) { schema.emplace_back(key, "json"); } @@ -50,7 +49,7 @@ ColumnsDescription PythonDict::getActualTableStructure(const py::object & object } } - return StoragePython::getTableStructureFromData(schema); + return StoragePython::getTableStructureFromData(schema, context); } bool PythonDict::isPythonDict(const py::object & object) diff --git a/programs/local/PythonReader.cpp b/programs/local/PythonReader.cpp index 54bcd2aff16..03ea40fbc06 100644 --- a/programs/local/PythonReader.cpp +++ b/programs/local/PythonReader.cpp @@ -27,7 +27,7 @@ ColumnsDescription PythonReader::getActualTableStructure(const py::object & obje schema = object.attr("get_schema")().cast>>(); - return StoragePython::getTableStructureFromData(schema); + return StoragePython::getTableStructureFromData(schema, context); } bool PythonReader::isPythonReader(const py::object & object) diff --git a/programs/local/PythonTableCache.cpp b/programs/local/PythonTableCache.cpp index acb32bcfd78..7cbf5d28739 100644 --- a/programs/local/PythonTableCache.cpp +++ b/programs/local/PythonTableCache.cpp @@ -6,8 +6,6 @@ namespace CHDB { -std::unordered_map PythonTableCache::py_table_cache; - /// Function to find instance of PyReader, pandas DataFrame, or PyArrow Table, filtered by variable name static py::object findQueryableObj(const String & var_name) { diff --git a/programs/local/PythonTableCache.h b/programs/local/PythonTableCache.h index 55e93a49c34..be09d71bc16 100644 --- a/programs/local/PythonTableCache.h +++ b/programs/local/PythonTableCache.h @@ -9,14 +9,14 @@ namespace CHDB { class PythonTableCache { public: - static void findQueryableObjFromQuery(const String & query_str); + void findQueryableObjFromQuery(const String & query_str); - static py::handle getQueryableObj(const String & table_name); + py::handle getQueryableObj(const String & table_name); - static void clear(); + void clear(); private: - static std::unordered_map py_table_cache; + std::unordered_map py_table_cache; }; } // namespace CHDB diff --git a/programs/local/StoragePython.cpp b/programs/local/StoragePython.cpp index 8f3b4f8002f..938308e3529 100644 --- a/programs/local/StoragePython.cpp +++ b/programs/local/StoragePython.cpp @@ -1,5 +1,4 @@ #include "StoragePython.h" -#include "FormatHelper.h" #include "PybindWrapper.h" #include "PythonSource.h" #include "PyArrowTable.h" @@ -13,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -27,9 +28,8 @@ #include #include #include -#include "PythonUtils.h" #include -#include +#include "PythonUtils.h" #include @@ -147,7 +147,7 @@ void StoragePython::prepareColumnCache(const Names & names, const Columns & colu } } -ColumnsDescription StoragePython::getTableStructureFromData(std::vector> & schema) +ColumnsDescription StoragePython::getTableStructureFromData(std::vector> & schema, const ContextPtr & context) { py::gil_assert(); @@ -187,7 +187,7 @@ dtype\('S|dtype\('O|||| data_type; std::string type_capture, bits, precision, scale; - if (CHDB::isJSONSupported() && RE2::PartialMatch(typeStr, pattern_json)) + if (context->getQueryContext() && context->getQueryContext()->isJSONSupported() && RE2::PartialMatch(typeStr, pattern_json)) { data_type = std::make_shared(DataTypeObject::SchemaFormat::JSON); } diff --git a/programs/local/StoragePython.h b/programs/local/StoragePython.h index 7985256478a..89eee5bac28 100644 --- a/programs/local/StoragePython.h +++ b/programs/local/StoragePython.h @@ -174,7 +174,7 @@ class StoragePython : public IStorage, public WithContext Block prepareSampleBlock(const Names & column_names, const StorageSnapshotPtr & storage_snapshot); - static ColumnsDescription getTableStructureFromData(std::vector> & schema); + static ColumnsDescription getTableStructureFromData(std::vector> & schema, const ContextPtr & context); private: void prepareColumnCache(const Names & names, const Columns & columns, const Block & sample_block); diff --git a/programs/local/TableFunctionPython.cpp b/programs/local/TableFunctionPython.cpp index afa068b0ded..418345bf610 100644 --- a/programs/local/TableFunctionPython.cpp +++ b/programs/local/TableFunctionPython.cpp @@ -7,11 +7,12 @@ #include "PythonTableCache.h" #include "PythonUtils.h" -#include #include #include +#include #include #include +#include #include #include #include @@ -65,8 +66,7 @@ void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr py_reader_arg_str.erase( std::remove_if(py_reader_arg_str.begin(), py_reader_arg_str.end(), [](char c) { return c == '\'' || c == '\"' || c == '`'; }), py_reader_arg_str.end()); - - auto instance = PythonTableCache::getQueryableObj(py_reader_arg_str); + auto instance = context->getQueryContext()->getPythonTableCache()->getQueryableObj(py_reader_arg_str); if (instance == nullptr || instance.is_none()) throw Exception(ErrorCodes::PY_OBJECT_NOT_FOUND, "Python object not found in the Python environment\n" @@ -131,7 +131,7 @@ ColumnsDescription TableFunctionPython::getActualTableStructure(ContextPtr conte return PythonReader::getActualTableStructure(reader, context); auto schema = PyReader::getSchemaFromPyObj(reader); - return StoragePython::getTableStructureFromData(schema); + return StoragePython::getTableStructureFromData(schema, context); } void registerTableFunctionPython(TableFunctionFactory & factory) diff --git a/programs/local/chdb-arrow.cpp b/programs/local/chdb-arrow.cpp index e899e47e0c5..2b7389c3022 100644 --- a/programs/local/chdb-arrow.cpp +++ b/programs/local/chdb-arrow.cpp @@ -2,7 +2,6 @@ #include "chdb-internal.h" #include "ArrowStreamRegistry.h" -#include #include #include @@ -91,8 +90,6 @@ static chdb_state chdb_inner_arrow_scan( chdb_connection conn, const char * table_name, chdb_arrow_stream arrow_stream, bool is_owner) { - std::shared_lock global_lock(global_connection_mutex); - if (!table_name || !arrow_stream) return CHDBError; @@ -165,9 +162,6 @@ chdb_state chdb_arrow_array_scan( chdb_state chdb_arrow_unregister_table(chdb_connection conn, const char * table_name) { ChdbDestructorGuard guard; - - std::shared_lock global_lock(global_connection_mutex); - if (!table_name) return CHDBError; diff --git a/programs/local/chdb-internal.h b/programs/local/chdb-internal.h index 945cf4ba3ae..3e0dfb26bab 100644 --- a/programs/local/chdb-internal.h +++ b/programs/local/chdb-internal.h @@ -3,19 +3,15 @@ #include "chdb.h" #include "QueryResult.h" -#include #include -#include -#include #include #include namespace DB { - class LocalServer; + class ChdbClient; } -extern std::shared_mutex global_connection_mutex; extern thread_local bool chdb_destructor_cleanup_in_progress; /** @@ -37,7 +33,7 @@ class ChdbDestructorGuard /// Connection validity check function inline bool checkConnectionValidity(chdb_conn * connection) { - return connection && connection->connected && connection->queue; + return connection && connection->connected; } namespace CHDB @@ -89,37 +85,13 @@ struct StreamingIterateRequest : QueryRequestBase bool isIteration() const override { return true; } }; -enum class QueryType : uint8_t -{ - TYPE_MATERIALIZED = 0, - TYPE_STREAMING_INIT = 1, - TYPE_STREAMING_ITER = 2 -}; - -struct QueryQueue -{ - std::mutex mutex; - std::condition_variable query_cv; // For query submission - std::condition_variable result_cv; // For query result retrieval - std::unique_ptr current_query; - QueryResultPtr current_result; - bool has_result = false; - bool has_query = false; - bool has_streaming_query = false; - bool shutdown = false; - bool cleanup_done = false; -}; - std::unique_ptr pyEntryClickHouseLocal(int argc, char ** argv); -void chdbCleanupConnection(); - -void cancelStreamQuery(DB::LocalServer * server, void * stream_result); +void cancelStreamQuery(DB::ChdbClient * client, void * stream_result); const std::string & chdb_result_error_string(chdb_result * result); const std::string & chdb_streaming_result_error_string(chdb_streaming_result * result); void chdb_destroy_arrow_stream(ArrowArrayStream * arrow_stream); - } diff --git a/programs/local/chdb.cpp b/programs/local/chdb.cpp index c964fe6b2f9..2eedb585acb 100644 --- a/programs/local/chdb.cpp +++ b/programs/local/chdb.cpp @@ -4,6 +4,11 @@ #include #include +#include +#include +#if USE_PYTHON +# include +#endif #include #include @@ -18,11 +23,6 @@ void chdb_musl_compile_stub(int arg) } #endif -#if USE_PYTHON -#include "FormatHelper.h" -#include "PythonTableCache.h" -#endif - #if USE_JEMALLOC # include #endif @@ -38,7 +38,6 @@ namespace DB #endif extern thread_local bool chdb_destructor_cleanup_in_progress; -std::shared_mutex global_connection_mutex; namespace CHDB { @@ -54,48 +53,12 @@ extern "C" }; #endif +// used only in pyEntryClickHouseLocal static std::mutex CHDB_MUTEX; -chdb_conn * global_conn_ptr = nullptr; -std::string global_db_path; - -static std::unique_ptr bgClickHouseLocal(int argc, char ** argv) -{ - std::unique_ptr app; - try - { - app = std::make_unique(); - app->setBackground(true); - app->init(argc, argv); - int ret = app->run(); - if (ret != 0) - { - auto err_msg = app->getErrorMsg(); - LOG_ERROR(&app->logger(), "Error running bgClickHouseLocal: {}", err_msg); - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Error running bgClickHouseLocal: {}", err_msg); - } - return app; - } - catch (const DB::Exception & e) - { - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "bgClickHouseLocal {}", DB::getExceptionMessage(e, false)); - } - catch (const Poco::Exception & e) - { - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "bgClickHouseLocal {}", e.displayText()); - } - catch (const std::exception & e) - { - throw std::domain_error(e.what()); - } - catch (...) - { - throw std::domain_error(DB::getCurrentExceptionMessage(true)); - } -} static local_result_v2 * convert2LocalResultV2(QueryResult * query_result) { - auto local_result = new local_result_v2(); + auto * local_result = new local_result_v2(); auto * materialized_query_result = static_cast(query_result); if (!materialized_query_result) @@ -131,102 +94,40 @@ static local_result_v2 * convert2LocalResultV2(QueryResult * query_result) static local_result_v2 * createErrorLocalResultV2(const String & error) { - auto local_result = new local_result_v2(); + auto * local_result = new local_result_v2(); local_result->error_message = new char[error.size() + 1]; std::memcpy(local_result->error_message, error.c_str(), error.size() + 1); return local_result; } -static QueryResultPtr createMaterializedLocalQueryResult(DB::LocalServer * server, const CHDB::QueryRequestBase & req) -{ - QueryResultPtr query_result; - const auto & materialized_request = static_cast(req); - - try - { - if (!server->parseQueryTextWithOutputFormat(materialized_request.query, materialized_request.format)) - { - query_result = std::make_unique(server->getErrorMsg()); - } - else - { - query_result = std::make_unique( - ResultBuffer(server->stealQueryOutputVector()), - server->getElapsedTime(), - server->getProcessedRows(), - server->getProcessedBytes(), - server->getStorgaeRowsRead(), - server->getStorageBytesRead()); - } - } - catch (const DB::Exception & e) - { - query_result = std::make_unique(DB::getExceptionMessage(e, false)); - } - catch (...) - { - String error_message = "Unknown error occurred"; - query_result = std::make_unique(error_message); - } - - server->resetQueryOutputVector(); - - return query_result; -} - -static QueryResultPtr createStreamingQueryResult(DB::LocalServer * server, const CHDB::QueryRequestBase & req) -{ - QueryResultPtr query_result; - const auto & streaming_init_request = static_cast(req); - - try - { - if (!server->parseQueryTextWithOutputFormat(streaming_init_request.query, streaming_init_request.format)) - query_result = std::make_unique(server->getErrorMsg()); - else - query_result = std::make_unique(); - } - catch (const DB::Exception& e) - { - query_result = std::make_unique(DB::getExceptionMessage(e, false)); - } - catch (...) - { - String error_message = "Unknown error occurred"; - query_result = std::make_unique(error_message); - } - - return query_result; -} - -static QueryResultPtr createStreamingIterateQueryResult(DB::LocalServer * server, const CHDB::QueryRequestBase & req) +static QueryResultPtr createStreamingIterateQueryResult(DB::ChdbClient * client, const CHDB::QueryRequestBase & req) { QueryResultPtr query_result; const auto & streaming_iter_request = static_cast(req); - const auto old_processed_rows = server->getProcessedRows(); - const auto old_processed_bytes = server->getProcessedBytes(); - const auto old_storage_rows_read = server->getStorgaeRowsRead(); - const auto old_storage_bytes_read = server->getStorageBytesRead(); - const auto old_elapsed_time = server->getElapsedTime(); + const auto old_processed_rows = client->getProcessedRows(); + const auto old_processed_bytes = client->getProcessedBytes(); + const auto old_storage_rows_read = client->getStorageRowsRead(); + const auto old_storage_bytes_read = client->getStorageBytesRead(); + const auto old_elapsed_time = client->getElapsedTime(); try { - if (!server->processStreamingQuery(streaming_iter_request.streaming_result, streaming_iter_request.is_canceled)) + if (!client->processStreamingQuery(streaming_iter_request.streaming_result, streaming_iter_request.is_canceled)) { - query_result = std::make_unique(server->getErrorMsg()); + query_result = std::make_unique(client->getErrorMsg()); } else { - const auto processed_rows = server->getProcessedRows(); - const auto processed_bytes = server->getProcessedBytes(); - const auto storage_rows_read = server->getStorgaeRowsRead(); - const auto storage_bytes_read = server->getStorageBytesRead(); - const auto elapsed_time = server->getElapsedTime(); + const auto processed_rows = client->getProcessedRows(); + const auto processed_bytes = client->getProcessedBytes(); + const auto storage_rows_read = client->getStorageRowsRead(); + const auto storage_bytes_read = client->getStorageBytesRead(); + const auto elapsed_time = client->getElapsedTime(); if (processed_rows <= old_processed_rows) query_result = std::make_unique(nullptr, 0.0, 0, 0, 0, 0); else query_result = std::make_unique( - ResultBuffer(server->stealQueryOutputVector()), + ResultBuffer(client->stealQueryOutputVector()), elapsed_time - old_elapsed_time, processed_rows - old_processed_rows, processed_bytes - old_processed_bytes, @@ -244,173 +145,17 @@ static QueryResultPtr createStreamingIterateQueryResult(DB::LocalServer * server query_result = std::make_unique(error_message); } - server->resetQueryOutputVector(); + client->resetQueryOutputVector(); return query_result; } -static std::pair createQueryResult(DB::LocalServer * server, const CHDB::QueryRequestBase & req) -{ - QueryResultPtr query_result; - bool is_end = false; - - if (!req.isStreaming()) - { - query_result = createMaterializedLocalQueryResult(server, req); - is_end = true; - } - else if (!req.isIteration()) - { - server->streaming_query_context = std::make_shared(); - query_result = createStreamingQueryResult(server, req); - is_end = !query_result->getError().empty(); - - if (!is_end) - server->streaming_query_context->streaming_result = query_result.get(); - } - else - { - query_result = createStreamingIterateQueryResult(server, req); - const auto & streaming_iter_request = static_cast(req); - auto materialized_query_result_ptr = static_cast(query_result.get()); - - is_end = !materialized_query_result_ptr->getError().empty() || materialized_query_result_ptr->rows_read == 0 - || streaming_iter_request.is_canceled; - } - - if (is_end) - { - if (server->streaming_query_context) - { - server->streaming_query_context.reset(); - } -#if USE_PYTHON - if (auto * local_connection = static_cast(server->connection.get())) - { - /// Must clean up Context objects whether the query succeeds or fails. - /// During process exit, if LocalServer destructor triggers while cached PythonStorage - /// objects still exist in Context, their destruction will attempt to acquire GIL. - /// Acquiring GIL during process termination leads to immediate thread termination. - local_connection->resetQueryContext(); - } - - CHDB::PythonTableCache::clear(); -#endif - } - - return std::make_pair(std::move(query_result), is_end); -} - -static QueryResultPtr executeQueryRequest( - CHDB::QueryQueue * queue, - const char * query, - size_t query_len, - const char * format, - size_t format_len, - CHDB::QueryType query_type, - void * streaming_result_ = nullptr, - bool is_canceled = false) -{ - QueryResultPtr query_result; - - try - { - { - std::unique_lock lock(queue->mutex); - // Wait until any ongoing query completes - if (query_type == CHDB::QueryType::TYPE_STREAMING_ITER) - queue->result_cv.wait(lock, [queue]() { return (!queue->has_query && !queue->has_result) || queue->shutdown; }); - else - queue->result_cv.wait( - lock, - [queue]() { return (!queue->has_query && !queue->has_result && !queue->has_streaming_query) || queue->shutdown; }); - - if (queue->shutdown) - { - String error_message = "connection is shutting down"; - if (query_type == CHDB::QueryType::TYPE_STREAMING_INIT) - { - query_result.reset(new StreamQueryResult(error_message)); - } - else - { - query_result.reset(new MaterializedQueryResult(error_message)); - } - return query_result; - } - - if (query_type == CHDB::QueryType::TYPE_STREAMING_INIT) - { - queue->current_query = std::make_unique(query, query_len, format, format_len); -#if USE_PYTHON - CHDB::SetCurrentFormat(format, format_len); -#endif - } - else if (query_type == CHDB::QueryType::TYPE_MATERIALIZED) - { - queue->current_query = std::make_unique(query, query_len, format, format_len); -#if USE_PYTHON - CHDB::SetCurrentFormat(format, format_len); -#endif - } - else - { - auto streaming_iter_req = std::make_unique(); - streaming_iter_req->streaming_result = streaming_result_; - streaming_iter_req->is_canceled = is_canceled; - queue->current_query = std::move(streaming_iter_req); - } - - queue->has_query = true; - queue->current_result.reset(); - queue->has_result = false; - } - queue->query_cv.notify_one(); - - { - std::unique_lock lock(queue->mutex); - queue->result_cv.wait(lock, [queue]() { return queue->has_result || queue->shutdown; }); - - if (!queue->shutdown && queue->has_result) - { - query_result = std::move(queue->current_result); - queue->has_result = false; - queue->has_query = false; - } - } - queue->result_cv.notify_all(); - } - catch (...) - { - // Handle any exceptions during query processing - String error_message = "Error occurred while processing query"; - if (query_type == CHDB::QueryType::TYPE_STREAMING_INIT) - query_result.reset(new StreamQueryResult(error_message)); - else - query_result.reset(new MaterializedQueryResult(error_message)); - } - - return query_result; -} - -void chdbCleanupConnection() -{ - try - { - close_conn(&global_conn_ptr); - } - catch (...) - { - } -} - -void cancelStreamQuery(DB::LocalServer * server, void * stream_result) +void cancelStreamQuery(DB::ChdbClient * client, void * stream_result) { auto streaming_iter_req = std::make_unique(); streaming_iter_req->streaming_result = stream_result; streaming_iter_req->is_canceled = true; - - createStreamingIterateQueryResult(server, *streaming_iter_req); + createStreamingIterateQueryResult(client, *streaming_iter_req); } std::unique_ptr pyEntryClickHouseLocal(int argc, char ** argv) @@ -468,152 +213,46 @@ const std::string & chdb_streaming_result_error_string(chdb_streaming_result * r chdb_connection * connect_chdb_with_exception(int argc, char ** argv) { - std::lock_guard global_lock(global_connection_mutex); - - std::string path = ":memory:"; // Default path - for (int i = 1; i < argc; i++) + try { - if (strncmp(argv[i], "--path=", 7) == 0) + DB::ThreadStatus thread_status; + auto server = DB::EmbeddedServer::getInstance(argc, argv); + auto client = DB::ChdbClient::create(server); + if (!client) { - path = argv[i] + 7; - break; + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Failed to create ChdbClient"); } - } - if (global_conn_ptr != nullptr) + auto * conn = new chdb_conn(); + conn->server = client.release(); + conn->connected = true; + auto ** conn_ptr = new chdb_conn *(conn); + return reinterpret_cast(conn_ptr); + } + catch (const DB::Exception & e) { - if (path == global_db_path) - return reinterpret_cast(&global_conn_ptr); - - throw DB::Exception( - DB::ErrorCodes::BAD_ARGUMENTS, - "Another connection is already active with different path. Old path = {}, new path = {}, " - "please close the existing connection first.", - global_db_path, - path); + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Failed to create connection: {}", DB::getExceptionMessage(e, false)); } - - auto * conn = new chdb_conn(); - auto * q_queue = new CHDB::QueryQueue(); - conn->queue = q_queue; - - std::mutex init_mutex; - std::condition_variable init_cv; - bool init_done = false; - bool init_success = false; - std::exception_ptr init_exception; - - // Start query processing thread - std::thread( - [&]() - { - auto * queue = static_cast(conn->queue); - std::unique_ptr server; - try - { - DB::ThreadStatus thread_status; - server = bgClickHouseLocal(argc, argv); - conn->server = nullptr; - conn->connected = true; - - global_conn_ptr = conn; - global_db_path = path; - - // Signal successful initialization - { - std::lock_guard init_lock(init_mutex); - init_success = true; - init_done = true; - } - init_cv.notify_one(); - while (true) - { - { - std::unique_lock lock(queue->mutex); - queue->query_cv.wait(lock, [queue]() { return queue->has_query || queue->shutdown; }); - - if (queue->shutdown) - { - server.reset(); - queue->cleanup_done = true; - queue->query_cv.notify_all(); - break; - } - } - - CHDB::QueryRequestBase & req = *(queue->current_query); - auto result = createQueryResult(server.get(), req); - bool is_end = result.second; - - { - std::lock_guard lock(queue->mutex); - if (req.isStreaming() && !req.isIteration() && !is_end) - queue->has_streaming_query = true; - - if (req.isStreaming() && req.isIteration() && is_end) - queue->has_streaming_query = false; - - queue->current_result = std::move(result.first); - queue->has_result = true; - queue->has_query = false; - } - queue->result_cv.notify_all(); - } - } - catch (const DB::Exception & e) - { - // Log the error - LOG_ERROR(&Poco::Logger::get("LocalServer"), "Query thread terminated with error: {}", e.what()); - - // Signal thread termination - { - std::lock_guard init_lock(init_mutex); - init_exception = std::current_exception(); - init_done = true; - std::lock_guard lock(queue->mutex); - queue->shutdown = true; - queue->cleanup_done = true; - } - init_cv.notify_one(); - queue->query_cv.notify_all(); - queue->result_cv.notify_all(); - } - catch (...) - { - LOG_ERROR(&Poco::Logger::get("LocalServer"), "Query thread terminated with unknown error"); - - { - std::lock_guard init_lock(init_mutex); - init_exception = std::current_exception(); - init_done = true; - std::lock_guard lock(queue->mutex); - queue->shutdown = true; - queue->cleanup_done = true; - } - init_cv.notify_one(); - queue->query_cv.notify_all(); - queue->result_cv.notify_all(); - } - }) - .detach(); - - // Wait for initialization to complete - { - std::unique_lock init_lock(init_mutex); - init_cv.wait(init_lock, [&init_done]() { return init_done; }); - - if (!init_success) - { - delete q_queue; - delete conn; - if (init_exception) - std::rethrow_exception(init_exception); - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Failed to create connection"); - } + catch (const std::exception & e) + { + throw std::domain_error(std::string("Connection failed: ") + e.what()); } + catch (...) + { + throw std::domain_error(DB::getCurrentExceptionMessage(true)); + } +} - return reinterpret_cast(&global_conn_ptr); +#if USE_PYTHON +void cachePythonTablesFromQuery(chdb_conn * conn, const std::string & query_str) +{ + if (!conn || !conn->server || !conn->connected) + return; + auto * client = reinterpret_cast(conn->server); + client->findQueryableObjFromPyCache(query_str); } +#endif + } // namespace CHDB using namespace CHDB; @@ -663,7 +302,6 @@ local_result_v2 * query_stable_v2(int argc, char ** argv) try { auto query_result = pyEntryClickHouseLocal(argc, argv); - return convert2LocalResultV2(query_result.get()); } catch (const std::exception & e) @@ -708,46 +346,26 @@ chdb_conn ** connect_chdb(int argc, char ** argv) void close_conn(chdb_conn ** conn) { - std::lock_guard global_lock(global_connection_mutex); - if (!conn || !*conn) return; - if ((*conn)->connected) + try { - if ((*conn)->queue) + if ((*conn)->connected && (*conn)->server) { - auto * queue = static_cast((*conn)->queue); - - { - std::unique_lock queue_lock(queue->mutex); - queue->shutdown = true; - queue->query_cv.notify_all(); // Wake up query processing thread - queue->result_cv.notify_all(); // Wake up any waiting result threads - - // Wait for server cleanup - queue->query_cv.wait(queue_lock, [queue] { return queue->cleanup_done; }); - - // Clean up current result if any - queue->current_result.reset(); - queue->has_result = false; - } - - delete queue; - (*conn)->queue = nullptr; + auto * client = static_cast((*conn)->server); + delete client; + (*conn)->server = nullptr; } - // Mark as disconnected BEFORE deleting queue and nulling global pointer (*conn)->connected = false; + delete *conn; + *conn = nullptr; } - // Clear global pointer under lock before queue deletion - if (*conn != global_conn_ptr) + catch (...) { - LOG_ERROR(&Poco::Logger::get("LocalServer"), "Connection mismatch during close_conn"); + DB::tryLogCurrentException(__PRETTY_FUNCTION__); } - global_conn_ptr = nullptr; - delete *conn; - *conn = nullptr; } struct local_result_v2 * query_conn(chdb_conn * conn, const char * query, const char * format) @@ -758,16 +376,23 @@ struct local_result_v2 * query_conn(chdb_conn * conn, const char * query, const struct local_result_v2 * query_conn_n(struct chdb_conn * conn, const char * query, size_t query_len, const char * format, size_t format_len) { ChdbDestructorGuard guard; - - // Add connection validity check under global lock - std::shared_lock global_lock(global_connection_mutex); - if (!checkConnectionValidity(conn)) return createErrorLocalResultV2("Invalid or closed connection"); - auto * queue = static_cast(conn->queue); - auto query_result = executeQueryRequest(queue, query, query_len, format, format_len, CHDB::QueryType::TYPE_MATERIALIZED); - return convert2LocalResultV2(query_result.get()); + try + { + auto * client = static_cast(conn->server); + auto query_result = client->executeMaterializedQuery(query, query_len, format, format_len); + return convert2LocalResultV2(query_result.get()); + } + catch (const std::exception & e) + { + return createErrorLocalResultV2(std::string("Error: ") + e.what()); + } + catch (...) + { + return createErrorLocalResultV2(DB::getCurrentExceptionMessage(true)); + } } chdb_streaming_result * query_conn_streaming(chdb_conn * conn, const char * query, const char * format) @@ -780,25 +405,36 @@ query_conn_streaming_n(struct chdb_conn * conn, const char * query, size_t query { ChdbDestructorGuard guard; - // Add connection validity check under global lock - std::shared_lock global_lock(global_connection_mutex); - if (!checkConnectionValidity(conn)) { auto * result = new StreamQueryResult("Invalid or closed connection"); return reinterpret_cast(result); } - auto * queue = static_cast(conn->queue); - auto query_result = executeQueryRequest(queue, query, query_len, format, format_len, CHDB::QueryType::TYPE_STREAMING_INIT); + try + { + // Use the client from the connection + auto * client = static_cast(conn->server); + auto query_result = client->executeStreamingInit(query, query_len, format, format_len); - if (!query_result) + if (!query_result) + { + auto * result = new StreamQueryResult("Query processing failed"); + return reinterpret_cast(result); + } + + return reinterpret_cast(query_result.release()); + } + catch (const std::exception & e) { - auto * result = new StreamQueryResult("Query processing failed"); + auto * result = new StreamQueryResult(std::string("Error: ") + e.what()); + return reinterpret_cast(result); + } + catch (...) + { + auto * result = new StreamQueryResult(DB::getCurrentExceptionMessage(true)); return reinterpret_cast(result); } - - return reinterpret_cast(query_result.release()); } const char * chdb_streaming_result_error(chdb_streaming_result * result) @@ -806,7 +442,7 @@ const char * chdb_streaming_result_error(chdb_streaming_result * result) if (!result) return nullptr; - auto stream_query_result = reinterpret_cast(result); + auto * stream_query_result = reinterpret_cast(result); const auto & error_message = stream_query_result->getError(); if (!error_message.empty()) @@ -819,32 +455,53 @@ local_result_v2 * chdb_streaming_fetch_result(chdb_conn * conn, chdb_streaming_r { ChdbDestructorGuard guard; - // Add connection validity check under global lock - std::shared_lock global_lock(global_connection_mutex); - if (!checkConnectionValidity(conn)) return createErrorLocalResultV2("Invalid or closed connection"); - auto * queue = static_cast(conn->queue); - auto query_result = executeQueryRequest(queue, nullptr, 0, nullptr, 0, CHDB::QueryType::TYPE_STREAMING_ITER, result); + if (!result) + return createErrorLocalResultV2("Invalid streaming result"); - return convert2LocalResultV2(query_result.get()); + try + { + auto * client = static_cast(conn->server); + if (!client->hasStreamingQuery()) + return createErrorLocalResultV2("No active streaming query"); + auto query_result = client->executeStreamingIterate(result, false); + if (!query_result) + return createErrorLocalResultV2("Failed to fetch streaming results"); + auto * local_result = convert2LocalResultV2(query_result.release()); + return local_result; + } + catch (const std::exception & e) + { + return createErrorLocalResultV2(std::string("Error fetching streaming results: ") + e.what()); + } + catch (...) + { + return createErrorLocalResultV2(std::string("Unknown error fetching streaming results: ") + DB::getCurrentExceptionMessage(true)); + } } void chdb_streaming_cancel_query(chdb_conn * conn, chdb_streaming_result * result) { ChdbDestructorGuard guard; - // Add connection validity check under global lock - std::shared_lock global_lock(global_connection_mutex); - if (!checkConnectionValidity(conn)) return; - auto * queue = static_cast(conn->queue); - auto query_result = executeQueryRequest(queue, nullptr, 0, nullptr, 0, CHDB::QueryType::TYPE_STREAMING_ITER, result, true); + if (!result) + return; - query_result.reset(); + try + { + auto * client = static_cast(conn->server); + client->cancelStreamingQuery(result); + } + catch (...) + { + DB::tryLogCurrentException(__PRETTY_FUNCTION__); + } + // Note: The result object should be freed by chdb_destroy_result(), not here } void chdb_destroy_result(chdb_streaming_result * result) @@ -854,7 +511,7 @@ void chdb_destroy_result(chdb_streaming_result * result) if (!result) return; - auto stream_query_result = reinterpret_cast(result); + auto * stream_query_result = reinterpret_cast(result); delete stream_query_result; } @@ -869,22 +526,23 @@ chdb_connection * chdb_connect(int argc, char ** argv) } catch (const DB::Exception & e) { - LOG_ERROR(&Poco::Logger::get("LocalServer"), "Connection failed with DB::Exception: {}", DB::getExceptionMessage(e, false)); + LOG_ERROR(&Poco::Logger::get("EmbeddedServer"), "Connection failed with DB::Exception: {}", DB::getExceptionMessage(e, false)); return nullptr; } catch (const boost::program_options::error & e) { - LOG_ERROR(&Poco::Logger::get("LocalServer"), "Connection failed with bad arguments: {}", e.what()); + LOG_ERROR(&Poco::Logger::get("EmbeddedServer"), "Connection failed with bad arguments: {}", e.what()); return nullptr; } catch (const Poco::Exception & e) { - LOG_ERROR(&Poco::Logger::get("LocalServer"), "Connection failed with Poco::Exception: {}", e.displayText()); + LOG_ERROR(&Poco::Logger::get("EmbeddedServer"), "Connection failed with Poco::Exception: {}", e.displayText()); return nullptr; } catch (...) { - LOG_ERROR(&Poco::Logger::get("LocalServer"), "Connection failed with unknown exception: {}", DB::getCurrentExceptionMessage(true)); + LOG_ERROR( + &Poco::Logger::get("EmbeddedServer"), "Connection failed with unknown exception: {}", DB::getCurrentExceptionMessage(true)); return nullptr; } } @@ -894,7 +552,7 @@ void chdb_close_conn(chdb_connection * conn) if (!conn || !*conn) return; - auto connection = reinterpret_cast(conn); + auto * connection = reinterpret_cast(conn); close_conn(connection); } @@ -908,11 +566,9 @@ chdb_result * chdb_query_n(chdb_connection conn, const char * query, size_t quer { ChdbDestructorGuard guard; - std::shared_lock global_lock(global_connection_mutex); - if (!conn) { - auto * result = new MaterializedQueryResult("Unexepected null connection"); + auto * result = new MaterializedQueryResult("Unexpected null connection"); return reinterpret_cast(result); } @@ -923,10 +579,24 @@ chdb_result * chdb_query_n(chdb_connection conn, const char * query, size_t quer return reinterpret_cast(result); } - auto * queue = static_cast(connection->queue); - auto query_result = executeQueryRequest(queue, query, query_len, format, format_len, CHDB::QueryType::TYPE_MATERIALIZED); + try + { + // Use the client from the connection + auto * client = static_cast(connection->server); + auto query_result = client->executeMaterializedQuery(query, query_len, format, format_len); - return reinterpret_cast(query_result.release()); + return reinterpret_cast(query_result.release()); + } + catch (const std::exception & e) + { + auto * result = new MaterializedQueryResult(std::string("Error: ") + e.what()); + return reinterpret_cast(result); + } + catch (...) + { + auto * result = new MaterializedQueryResult(DB::getCurrentExceptionMessage(true)); + return reinterpret_cast(result); + } } chdb_result * chdb_query_cmdline(int argc, char ** argv) @@ -961,11 +631,9 @@ chdb_result * chdb_stream_query_n(chdb_connection conn, const char * query, size { ChdbDestructorGuard guard; - std::shared_lock global_lock(global_connection_mutex); - if (!conn) { - auto * result = new StreamQueryResult("Unexepected null connection"); + auto * result = new StreamQueryResult("Unexpected null connection"); return reinterpret_cast(result); } @@ -976,24 +644,34 @@ chdb_result * chdb_stream_query_n(chdb_connection conn, const char * query, size return reinterpret_cast(result); } - auto * queue = static_cast(connection->queue); - auto query_result = executeQueryRequest(queue, query, query_len, format, format_len, CHDB::QueryType::TYPE_STREAMING_INIT); + try + { + auto * client = static_cast(connection->server); + auto query_result = client->executeStreamingInit(query, query_len, format, format_len); + + if (!query_result) + { + auto * result = new StreamQueryResult("Query processing failed"); + return reinterpret_cast(result); + } - if (!query_result) + return reinterpret_cast(query_result.release()); + } + catch (const std::exception & e) { - auto * result = new StreamQueryResult("Query processing failed"); + auto * result = new StreamQueryResult(std::string("Error: ") + e.what()); + return reinterpret_cast(result); + } + catch (...) + { + auto * result = new StreamQueryResult(DB::getCurrentExceptionMessage(true)); return reinterpret_cast(result); } - - return reinterpret_cast(query_result.release()); } chdb_result * chdb_stream_fetch_result(chdb_connection conn, chdb_result * result) { ChdbDestructorGuard guard; - - std::shared_lock global_lock(global_connection_mutex); - if (!conn) { auto * query_result = new MaterializedQueryResult("Unexpected null connection"); @@ -1006,7 +684,6 @@ chdb_result * chdb_stream_fetch_result(chdb_connection conn, chdb_result * resul return reinterpret_cast(query_result); } - auto * connection = reinterpret_cast(conn); if (!checkConnectionValidity(connection)) { @@ -1014,18 +691,34 @@ chdb_result * chdb_stream_fetch_result(chdb_connection conn, chdb_result * resul return reinterpret_cast(query_result); } - auto * queue = static_cast(connection->queue); - auto query_result = executeQueryRequest(queue, nullptr, 0, nullptr, 0, CHDB::QueryType::TYPE_STREAMING_ITER, result); + try + { + auto * client = static_cast(connection->server); + if (!client->hasStreamingQuery()) + return reinterpret_cast(new MaterializedQueryResult("No active streaming query")); + auto * stream_result = reinterpret_cast(result); + auto query_result = client->executeStreamingIterate(stream_result, false); + if (!query_result) + return reinterpret_cast(new MaterializedQueryResult("Failed to fetch streaming results")); - return reinterpret_cast(query_result.release()); + return reinterpret_cast(query_result.release()); + } + catch (const std::exception & e) + { + auto * query_result = new MaterializedQueryResult(std::string("Error: ") + e.what()); + return reinterpret_cast(query_result); + } + catch (...) + { + auto * query_result = new MaterializedQueryResult(DB::getCurrentExceptionMessage(true)); + return reinterpret_cast(query_result); + } } void chdb_stream_cancel_query(chdb_connection conn, chdb_result * result) { ChdbDestructorGuard guard; - std::shared_lock global_lock(global_connection_mutex); - if (!result || !conn) return; @@ -1033,9 +726,17 @@ void chdb_stream_cancel_query(chdb_connection conn, chdb_result * result) if (!checkConnectionValidity(connection)) return; - auto * queue = static_cast(connection->queue); - auto query_result = executeQueryRequest(queue, nullptr, 0, nullptr, 0, CHDB::QueryType::TYPE_STREAMING_ITER, result, true); - query_result.reset(); + try + { + auto * client = static_cast(connection->server); + auto * stream_result = reinterpret_cast(result); + client->cancelStreamingQuery(stream_result); + } + catch (...) + { + DB::tryLogCurrentException(__PRETTY_FUNCTION__); + } + // Note: The result object should be freed by chdb_destroy_query_result(), not here } void chdb_destroy_query_result(chdb_result * result) @@ -1045,7 +746,7 @@ void chdb_destroy_query_result(chdb_result * result) if (!result) return; - auto query_result = reinterpret_cast(result); + auto * query_result = reinterpret_cast(result); delete query_result; } @@ -1054,11 +755,11 @@ char * chdb_result_buffer(chdb_result * result) if (!result) return nullptr; - auto query_result = reinterpret_cast(result); + auto * query_result = reinterpret_cast(result); if (query_result->getType() == QueryResultType::RESULT_TYPE_MATERIALIZED) { - auto materialized_result = reinterpret_cast(result); + auto * materialized_result = reinterpret_cast(result); return materialized_result->result_buffer ? materialized_result->result_buffer->data() : nullptr; } @@ -1070,11 +771,10 @@ size_t chdb_result_length(chdb_result * result) if (!result) return 0; - auto query_result = reinterpret_cast(result); - + auto * query_result = reinterpret_cast(result); if (query_result->getType() == QueryResultType::RESULT_TYPE_MATERIALIZED) { - auto materialized_result = reinterpret_cast(result); + auto * materialized_result = reinterpret_cast(result); return materialized_result->result_buffer ? materialized_result->result_buffer->size() : 0; } @@ -1086,14 +786,13 @@ double chdb_result_elapsed(chdb_result * result) if (!result) return 0.0; - auto query_result = reinterpret_cast(result); + auto * query_result = reinterpret_cast(result); if (query_result->getType() == QueryResultType::RESULT_TYPE_MATERIALIZED) { - auto materialized_result = reinterpret_cast(result); + auto * materialized_result = reinterpret_cast(result); return materialized_result->elapsed; } - return 0.0; } @@ -1102,11 +801,11 @@ uint64_t chdb_result_rows_read(chdb_result * result) if (!result) return 0; - auto query_result = reinterpret_cast(result); + auto * query_result = reinterpret_cast(result); if (query_result->getType() == QueryResultType::RESULT_TYPE_MATERIALIZED) { - auto materialized_result = reinterpret_cast(result); + auto * materialized_result = reinterpret_cast(result); return materialized_result->rows_read; } @@ -1118,11 +817,11 @@ uint64_t chdb_result_bytes_read(chdb_result * result) if (!result) return 0; - auto query_result = reinterpret_cast(result); + auto * query_result = reinterpret_cast(result); if (query_result->getType() == QueryResultType::RESULT_TYPE_MATERIALIZED) { - auto materialized_result = reinterpret_cast(result); + auto * materialized_result = reinterpret_cast(result); return materialized_result->bytes_read; } @@ -1134,11 +833,11 @@ uint64_t chdb_result_storage_rows_read(chdb_result * result) if (!result) return 0; - auto query_result = reinterpret_cast(result); + auto * query_result = reinterpret_cast(result); if (query_result->getType() == QueryResultType::RESULT_TYPE_MATERIALIZED) { - auto materialized_result = reinterpret_cast(result); + auto * materialized_result = reinterpret_cast(result); return materialized_result->storage_rows_read; } @@ -1150,11 +849,11 @@ uint64_t chdb_result_storage_bytes_read(chdb_result * result) if (!result) return 0; - auto query_result = reinterpret_cast(result); + auto * query_result = reinterpret_cast(result); if (query_result->getType() == QueryResultType::RESULT_TYPE_MATERIALIZED) { - auto materialized_result = reinterpret_cast(result); + auto * materialized_result = reinterpret_cast(result); return materialized_result->storage_bytes_read; } @@ -1166,7 +865,7 @@ const char * chdb_result_error(chdb_result * result) if (!result) return nullptr; - auto query_result = reinterpret_cast(result); + auto * query_result = reinterpret_cast(result); if (query_result->getError().empty()) return nullptr; diff --git a/programs/local/chdb.h b/programs/local/chdb.h index d16f5172f43..35d017d3e28 100644 --- a/programs/local/chdb.h +++ b/programs/local/chdb.h @@ -50,13 +50,12 @@ struct local_result_v2 /** * Connection structure for chDB - * Contains server instance, connection state, and query processing queue + * Contains ChdbClient instance and connection state */ struct chdb_conn { - void * server; /* ClickHouse LocalServer instance */ + void * server; /* ChdbClient instance */ bool connected; /* Connection state flag */ - void * queue; /* Query processing queue */ }; typedef struct diff --git a/src/Client/ClientBase.h b/src/Client/ClientBase.h index 8e08e60e541..fccbcdddc35 100644 --- a/src/Client/ClientBase.h +++ b/src/Client/ClientBase.h @@ -83,6 +83,8 @@ class TerminalKeystrokeInterceptor; class WriteBufferFromFileDescriptor; struct Settings; struct MergeTreeSettings; +class ThreadGroup; +using ThreadGroupPtr = std::shared_ptr; struct StreamingQueryContext { @@ -90,6 +92,7 @@ struct StreamingQueryContext ASTPtr parsed_query; void * streaming_result = nullptr; bool is_streaming_query = true; + ThreadGroupPtr thread_group = nullptr; StreamingQueryContext() = default; }; diff --git a/src/Client/LocalConnection.cpp b/src/Client/LocalConnection.cpp index c104cb94583..ed9e47212fd 100644 --- a/src/Client/LocalConnection.cpp +++ b/src/Client/LocalConnection.cpp @@ -147,6 +147,10 @@ void LocalConnection::sendQuery( else query_context = session->makeQueryContext(); query_context->setCurrentQueryId(query_id); +#if USE_PYTHON + query_context->setJSONSupport(session->isJSONSupported()); + query_context->setPythonTableCache(session->getPythonTableCache()); +#endif if (send_progress) { diff --git a/src/Client/LocalConnection.h b/src/Client/LocalConnection.h index 8a24138cbef..0b404979112 100644 --- a/src/Client/LocalConnection.h +++ b/src/Client/LocalConnection.h @@ -166,6 +166,7 @@ class LocalConnection : public IServerConnection, WithContext const Progress & getCHDBProgress() const { return chdb_progress; } #if USE_PYTHON void resetQueryContext(); + Session & getSession() const { return *session; } #endif private: diff --git a/src/IO/ReadBuffer.cpp b/src/IO/ReadBuffer.cpp index be355e26697..40c0c1e0935 100644 --- a/src/IO/ReadBuffer.cpp +++ b/src/IO/ReadBuffer.cpp @@ -17,6 +17,7 @@ namespace ErrorCodes { extern const int ATTEMPT_TO_READ_AFTER_EOF; extern const int CANNOT_READ_ALL_DATA; + extern const int CANNOT_READ_FROM_FILE_DESCRIPTOR; } namespace @@ -95,8 +96,10 @@ bool ReadBuffer::next() { res = nextImpl(); } - catch (...) + catch (const Exception & e) { + if (e.code() == ErrorCodes::CANNOT_READ_FROM_FILE_DESCRIPTOR) + return false; cancel(); throw; } diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index c04af7bab4b..407ceabd523 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -53,6 +53,13 @@ namespace Coordination struct OvercommitTracker; +#if USE_PYTHON +namespace CHDB +{ +class PythonTableCache; +} +#endif + namespace DB { @@ -345,6 +352,10 @@ class ContextData mutable bool need_recalculate_access = true; String current_database; std::unique_ptr settings{}; /// Setting for query execution. +#if USE_PYTHON + bool is_json_supported = true; + std::shared_ptr py_table_cache; +#endif using ProgressCallback = std::function; ProgressCallback progress_callback; /// Callback for tracking progress of query execution. @@ -738,6 +749,12 @@ class Context: public ContextData, public std::enable_shared_from_this /// Global application configuration settings. void setConfig(const ConfigurationPtr & config); const Poco::Util::AbstractConfiguration & getConfigRef() const; +#if USE_PYTHON + bool isJSONSupported() const { return is_json_supported; } + void setJSONSupport(bool support) { is_json_supported = support; } + CHDB::PythonTableCache * getPythonTableCache() const { return py_table_cache.get(); } + void setPythonTableCache(const std::shared_ptr & cache) { this->py_table_cache = cache; } +#endif AccessControl & getAccessControl(); const AccessControl & getAccessControl() const; diff --git a/src/Interpreters/Session.cpp b/src/Interpreters/Session.cpp index 9ee5390aa53..97c67b21192 100644 --- a/src/Interpreters/Session.cpp +++ b/src/Interpreters/Session.cpp @@ -19,12 +19,11 @@ #include #include -#include - #include #include #include #include +#include #include diff --git a/src/Interpreters/Session.h b/src/Interpreters/Session.h index 277670f818f..b6cbddc55ec 100644 --- a/src/Interpreters/Session.h +++ b/src/Interpreters/Session.h @@ -14,6 +14,13 @@ namespace Poco::Net { class SocketAddress; } +#if USE_PYTHON +namespace CHDB +{ +class PythonTableCache; +} +#endif + namespace DB { class Credentials; @@ -103,6 +110,14 @@ class Session /// Closes and removes session void closeSession(const String & session_id); + +#if USE_PYTHON + bool isJSONSupported() const { return is_json_supported; } + void setJSONSupport(bool support) { is_json_supported = support; } + std::shared_ptr & getPythonTableCache() { return py_table_cache; } + void setPythonTableCache(std::shared_ptr py_table_cache_) { py_table_cache = py_table_cache_; } +#endif + private: std::shared_ptr getSessionLog() const; ContextMutablePtr makeQueryContextImpl(const ClientInfo * client_info_to_copy, ClientInfo * client_info_to_move) const; @@ -133,6 +148,11 @@ class Session SettingsChanges settings_from_auth_server; LoggerPtr log = nullptr; + +#if USE_PYTHON + bool is_json_supported = true; + std::shared_ptr py_table_cache; +#endif }; } diff --git a/src/Interpreters/registerInterpreters.cpp b/src/Interpreters/registerInterpreters.cpp index f716ee39f25..2a663d77dee 100644 --- a/src/Interpreters/registerInterpreters.cpp +++ b/src/Interpreters/registerInterpreters.cpp @@ -1,8 +1,11 @@ +#include #include namespace DB { +std::once_flag global_register_once_flag; + void registerInterpreterSelectQuery(InterpreterFactory & factory); void registerInterpreterSelectQueryAnalyzer(InterpreterFactory & factory); void registerInterpreterSelectWithUnionQuery(InterpreterFactory & factory); diff --git a/src/Interpreters/registerInterpreters.h b/src/Interpreters/registerInterpreters.h index 9f0c3bbec22..19659ab596d 100644 --- a/src/Interpreters/registerInterpreters.h +++ b/src/Interpreters/registerInterpreters.h @@ -1,6 +1,10 @@ #pragma once +#include + namespace DB { void registerInterpreters(); + +extern std::once_flag global_register_once_flag; } diff --git a/tests/test_session_concurrency.py b/tests/test_session_concurrency.py new file mode 100644 index 00000000000..c4541737fdf --- /dev/null +++ b/tests/test_session_concurrency.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +import unittest +import shutil +import os +import threading +import platform +import subprocess +from chdb import session + + +test_concurrent_dir = ".tmp_test_session_concurrency" + + +def is_musl_linux(): + if platform.system() != "Linux": + return False + try: + result = subprocess.run(['ldd', '--version'], capture_output=True, text=True) + print(f"stdout: {result.stdout.lower()}") + print(f"stderr: {result.stderr.lower()}") + # Check both stdout and stderr for musl + output_text = (result.stdout + result.stderr).lower() + return 'musl' in output_text + except Exception as e: + print(f"Exception in is_musl_linux: {e}") + return False + + +class TestSessionConcurrency(unittest.TestCase): + def setUp(self) -> None: + shutil.rmtree(test_concurrent_dir, ignore_errors=True) + return super().setUp() + + def tearDown(self) -> None: + shutil.rmtree(test_concurrent_dir, ignore_errors=True) + return super().tearDown() + + def test_multiple_sessions_same_path(self): + sess1 = session.Session(test_concurrent_dir) + sess1.query("CREATE DATABASE IF NOT EXISTS test_db") + sess1.query("CREATE TABLE IF NOT EXISTS test_db.data (id Int32, value String) ENGINE = MergeTree() ORDER BY id") + sess1.query("INSERT INTO test_db.data VALUES (1, 'first')") + sess2 = session.Session(test_concurrent_dir) + result1 = sess1.query("SELECT * FROM test_db.data ORDER BY id", "CSV") + self.assertIn("1", str(result1)) + self.assertIn("first", str(result1)) + result2 = sess2.query("SELECT * FROM test_db.data ORDER BY id", "CSV") + self.assertIn("1", str(result2)) + self.assertIn("first", str(result2)) + sess2.query("INSERT INTO test_db.data VALUES (2, 'second')") + result1 = sess1.query("SELECT * FROM test_db.data ORDER BY id", "CSV") + self.assertIn("1", str(result1)) + self.assertIn("2", str(result1)) + sess1.close() + sess2.close() + + def test_sessions_are_thread_safe(self): + sess = session.Session(test_concurrent_dir) + sess.query("CREATE DATABASE IF NOT EXISTS test_db") + sess.query("CREATE TABLE IF NOT EXISTS test_db.shared_counter (id Int32, thread_id Int32) ENGINE = MergeTree() ORDER BY id") + + errors = [] + + def shared_session_worker(thread_id): + try: + for i in range(3): + sess.query(f"INSERT INTO test_db.shared_counter VALUES ({i}, {thread_id})") + except Exception as e: + errors.append((thread_id, str(e))) + + threads = [] + for i in range(3): + t = threading.Thread(target=shared_session_worker, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + self.assertEqual(len(errors), 0, f"Unexpected errors when sharing session across threads: {errors}") + + result = sess.query("SELECT COUNT(*) FROM test_db.shared_counter") + self.assertIn("9", str(result)) + + sess.close() + + def test_correct_multi_threaded_access(self): + setup_sess = session.Session(test_concurrent_dir) + setup_sess.query("CREATE DATABASE IF NOT EXISTS test_db") + setup_sess.query("CREATE TABLE IF NOT EXISTS test_db.thread_data (thread_id Int32, value Int32) ENGINE = MergeTree() ORDER BY (thread_id, value)") + setup_sess.close() + + results = [] + errors = [] + + def worker(thread_id): + try: + thread_sess = session.Session(test_concurrent_dir) + for i in range(5): + thread_sess.query(f"INSERT INTO test_db.thread_data VALUES ({thread_id}, {i})") + + result = thread_sess.query(f"SELECT COUNT(*) FROM test_db.thread_data WHERE thread_id = {thread_id}") + results.append((thread_id, result)) + + thread_sess.close() + except Exception as e: + errors.append((thread_id, e)) + + threads = [] + for i in range(3): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + self.assertEqual(len(errors), 0, f"Errors occurred: {errors}") + self.assertEqual(len(results), 3) + + verify_sess = session.Session(test_concurrent_dir) + final_count = verify_sess.query("SELECT COUNT(*) FROM test_db.thread_data") + self.assertIn("15", str(final_count)) # 3 threads * 5 inserts each + verify_sess.close() + + def test_session_reopen_after_close(self): + sess1 = session.Session(test_concurrent_dir) + sess1.query("CREATE TABLE IF NOT EXISTS test (id Int32) ENGINE = MergeTree() ORDER BY id") + sess1.query("INSERT INTO test VALUES (1)") + sess1.close() + + sess2 = session.Session(test_concurrent_dir) + result = sess2.query("SELECT * FROM test") + self.assertIn("1", str(result)) + sess2.close() + sess3 = session.Session(test_concurrent_dir) + result = sess3.query("SELECT * FROM test") + self.assertIn("1", str(result)) + sess3.close() + + @unittest.skipIf(is_musl_linux(), "Skip test on musl systems") + def test_session_path_consistency(self): + sess1 = session.Session(test_concurrent_dir) + sess1.query("SELECT 1") + + # Attempting to create a session with a different path will fail + try: + sess2 = session.Session(test_concurrent_dir + "_different") + sess2.close() + self.fail("Should have raised an exception for different path") + except RuntimeError as e: + self.assertIn("already initialized", str(e).lower()) + self.assertIn("path", str(e).lower()) + + sess1.close() + + def test_session_usage_after_close(self): + sess = session.Session(test_concurrent_dir) + sess.query("SELECT 1") + sess.close() + try: + sess.query("SELECT 1") + self.fail("Should raise error when using closed session") + except Exception as e: + error_msg = str(e) + self.assertIsNotNone(error_msg) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_signal_handler.py b/tests/test_signal_handler.py index e62230ac339..38c8e7f4481 100644 --- a/tests/test_signal_handler.py +++ b/tests/test_signal_handler.py @@ -51,13 +51,17 @@ def test_signal_response(self): def test_data_integrity_after_interrupt(self): def data_writer(): + writer_sess = session.Session(test_signal_handler_dir) global insert_counter i = 500000 - while not exit_event1.is_set(): - self.sess.query(f"INSERT INTO signal_handler_table VALUES ({i})") - insert_counter += 1 - i += 1 - exit_event2.set() + try: + while not exit_event1.is_set(): + writer_sess.query(f"INSERT INTO test.signal_handler_table VALUES ({i})") + insert_counter += 1 + i += 1 + finally: + writer_sess.close() + exit_event2.set() self.sess.query("CREATE DATABASE IF NOT EXISTS test") self.sess.query("USE test") @@ -80,7 +84,7 @@ def data_writer(): start_time = time.time() try: while time.time() - start_time < 60: - self.sess.query("SELECT * FROM signal_handler_table") + self.sess.query("SELECT * FROM test.signal_handler_table") except KeyboardInterrupt: print("receive signal") exit_event1.set() diff --git a/tests/test_stateful.py b/tests/test_stateful.py index 6ea670ca5f2..5d84d9bcc96 100644 --- a/tests/test_stateful.py +++ b/tests/test_stateful.py @@ -96,9 +96,9 @@ def test_two_sessions(self): sess2 = None try: sess1 = session.Session() - with self.assertWarns(Warning): - sess2 = session.Session() - self.assertIsNone(sess1._conn) + sess2 = session.Session() + self.assertIsNotNone(sess1._conn) + self.assertIsNotNone(sess2._conn) finally: if sess1: sess1.cleanup() diff --git a/tests/test_streaming_query.py b/tests/test_streaming_query.py index b0b84dd6aa8..43468185900 100644 --- a/tests/test_streaming_query.py +++ b/tests/test_streaming_query.py @@ -99,87 +99,117 @@ def test_cancel_streaming_query(self): def streaming_worker1(self): global result_counter_streaming1 - query_count = 0 - while not stop_event.is_set(): - query_count += 1 - if query_count % 10 == 0: - with self.assertRaises(Exception): - ret = self.sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") - - if query_count % 10 == 1: - ret = self.sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") - ret.cancel() - - stream = self.sess.send_query("SELECT * FROM streaming_test2", "CSVWITHNAMES") - for chunk in stream: - result_counter_streaming1 += chunk.rows_read() + # Create a separate session for this thread + worker_sess = session.Session(test_streaming_query_dir) + worker_sess.query("USE test") + try: + query_count = 0 + while not stop_event.is_set(): + query_count += 1 + if query_count % 10 == 0: + with self.assertRaises(Exception): + ret = worker_sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") + + if query_count % 10 == 1: + ret = worker_sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") + ret.cancel() + + stream = worker_sess.send_query("SELECT * FROM streaming_test2", "CSVWITHNAMES") + for chunk in stream: + result_counter_streaming1 += chunk.rows_read() + finally: + worker_sess.close() def streaming_worker2(self): global result_counter_streaming2 - query_count = 0 - while not stop_event.is_set(): - query_count += 1 - if query_count % 10 == 0: - with self.assertRaises(Exception): - ret = self.sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") - - if query_count % 10 == 2: - ret = self.sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") - ret.cancel() - - stream = self.sess.send_query("SELECT * FROM streaming_test2", "CSVWITHNAMES") - for chunk in stream: - result_counter_streaming2 += chunk.rows_read() + # Create a separate session for this thread + worker_sess = session.Session(test_streaming_query_dir) + worker_sess.query("USE test") + try: + query_count = 0 + while not stop_event.is_set(): + query_count += 1 + if query_count % 10 == 0: + with self.assertRaises(Exception): + ret = worker_sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") + + if query_count % 10 == 2: + ret = worker_sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") + ret.cancel() + + stream = worker_sess.send_query("SELECT * FROM streaming_test2", "CSVWITHNAMES") + for chunk in stream: + result_counter_streaming2 += chunk.rows_read() + finally: + worker_sess.close() def normal_query_worker1(self): global result_counter_normal1 - query_count = 0 - while not stop_event.is_set(): - query_count += 1 - if query_count % 10 == 0: - with self.assertRaises(Exception): - ret = self.sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") - - if query_count % 10 == 3: - ret = self.sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") - ret.cancel() - - result = self.sess.query("SELECT * FROM streaming_test2", "CSVWITHNAMES") - result_counter_normal1 += result.rows_read() + # Create a separate session for this thread + worker_sess = session.Session(test_streaming_query_dir) + worker_sess.query("USE test") + try: + query_count = 0 + while not stop_event.is_set(): + query_count += 1 + if query_count % 10 == 0: + with self.assertRaises(Exception): + ret = worker_sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") + + if query_count % 10 == 3: + ret = worker_sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") + ret.cancel() + + result = worker_sess.query("SELECT * FROM streaming_test2", "CSVWITHNAMES") + result_counter_normal1 += result.rows_read() + finally: + worker_sess.close() def normal_query_worker2(self): global result_counter_normal2 - query_count = 0 - while not stop_event.is_set(): - query_count += 1 - if query_count % 10 == 0: - with self.assertRaises(Exception): - ret = self.sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") - - if query_count % 10 == 4: - ret = self.sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") - ret.cancel() - - result = self.sess.query("SELECT * FROM streaming_test2", "CSVWITHNAMES") - result_counter_normal2 += result.rows_read() + # Create a separate session for this thread + worker_sess = session.Session(test_streaming_query_dir) + worker_sess.query("USE test") + try: + query_count = 0 + while not stop_event.is_set(): + query_count += 1 + if query_count % 10 == 0: + with self.assertRaises(Exception): + ret = worker_sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") + + if query_count % 10 == 4: + ret = worker_sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") + ret.cancel() + + result = worker_sess.query("SELECT * FROM streaming_test2", "CSVWITHNAMES") + result_counter_normal2 += result.rows_read() + finally: + worker_sess.close() def normal_insert_worker(self): global insert_counter - query_count = 0 - i = 500000 - while not stop_event.is_set(): - query_count += 1 - if query_count % 10 == 0: - with self.assertRaises(Exception): - ret = self.sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") - - if query_count % 10 == 5: - ret = self.sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") - ret.cancel() - - self.sess.query(f"INSERT INTO streaming_test2 VALUES ({i})") - insert_counter += 1 - i += 1 + # Create a separate session for this thread + worker_sess = session.Session(test_streaming_query_dir) + worker_sess.query("USE test") + try: + query_count = 0 + i = 500000 + while not stop_event.is_set(): + query_count += 1 + if query_count % 10 == 0: + with self.assertRaises(Exception): + ret = worker_sess.send_query("SELECT * FROM streaming_test2;SELECT * FROM streaming_test2", "CSVWITHNAMES") + + if query_count % 10 == 5: + ret = worker_sess.send_query("SELECT * FROM streaming_test2;", "CSVWITHNAMES") + ret.cancel() + + worker_sess.query(f"INSERT INTO streaming_test2 VALUES ({i})") + insert_counter += 1 + i += 1 + finally: + worker_sess.close() def test_multi_thread_streaming_query(self): self.sess.query("CREATE DATABASE IF NOT EXISTS test")