diff --git a/libs/api/ebpf_api.cpp b/libs/api/ebpf_api.cpp index c56f24200e..45f4feb3e5 100644 --- a/libs/api/ebpf_api.cpp +++ b/libs/api/ebpf_api.cpp @@ -38,6 +38,7 @@ typedef unsigned char boolean; #include #include +#include #include #include #include @@ -3755,9 +3756,9 @@ _Must_inspect_result_ ebpf_result_t _ebpf_object_load_native( _In_z_ const char* file_name, _Out_ ebpf_handle_t* native_module_handle, - _Out_ size_t* count_of_maps, + _Inout_ size_t* count_of_maps, _Outptr_result_buffer_all_maybenull_(*count_of_maps) ebpf_handle_t** map_handles, - _Out_ size_t* count_of_programs, + _Inout_ size_t* count_of_programs, _Outptr_result_buffer_all_maybenull_(*count_of_programs) ebpf_handle_t** program_handles) NO_EXCEPT_TRY { EBPF_LOG_ENTRY(); @@ -3779,11 +3780,11 @@ _ebpf_object_load_native( std::wstring service_path(SERVICE_PATH_PREFIX); std::wstring parameters_path(PARAMETERS_PATH_PREFIX); ebpf_protocol_buffer_t request_buffer; + size_t real_count_of_maps = 0; + size_t real_count_of_programs = 0; *native_module_handle = ebpf_handle_invalid; - *count_of_maps = 0; *map_handles = nullptr; - *count_of_programs = 0; *program_handles = nullptr; if (UuidCreate(&service_name_guid) != RPC_S_OK) { @@ -3852,7 +3853,7 @@ _ebpf_object_load_native( service_path = service_path + service_name.c_str(); result = _load_native_module( - service_path, &provider_module_id, native_module_handle, count_of_maps, count_of_programs); + service_path, &provider_module_id, native_module_handle, &real_count_of_maps, &real_count_of_programs); if (result != EBPF_SUCCESS) { EBPF_LOG_MESSAGE_WSTRING( EBPF_TRACELOG_LEVEL_ERROR, @@ -3862,8 +3863,19 @@ _ebpf_object_load_native( goto Done; } + if (*count_of_maps < real_count_of_maps || *count_of_programs < real_count_of_programs) { + result = EBPF_NO_MEMORY; + } + + *count_of_maps = real_count_of_maps; + *count_of_programs = real_count_of_programs; + + if (result != EBPF_SUCCESS) { + goto Done; + } + result = _load_native_programs( - &provider_module_id, *count_of_maps, map_handles, *count_of_programs, program_handles); + &provider_module_id, real_count_of_maps, map_handles, real_count_of_programs, program_handles); if (result != EBPF_SUCCESS) { EBPF_LOG_MESSAGE_STRING( EBPF_TRACELOG_LEVEL_ERROR, @@ -3882,14 +3894,15 @@ _ebpf_object_load_native( Done: if (result != EBPF_SUCCESS) { - _ebpf_free_handles(*count_of_maps, *map_handles); - _ebpf_free_handles(*count_of_programs, *program_handles); + _ebpf_free_handles(real_count_of_maps, *map_handles); + _ebpf_free_handles(real_count_of_programs, *program_handles); if (*native_module_handle != ebpf_handle_invalid) { Platform::CloseHandle(*native_module_handle); } - Platform::_stop_service(service_handle); + Platform::_stop_and_delete_service(service_handle, service_name.c_str()); + EBPF_RETURN_RESULT(result); } // Workaround: Querying service status hydrates service reference count in SCM. @@ -3919,36 +3932,24 @@ ebpf_object_load_native_by_fds( EBPF_LOG_ENTRY(); ebpf_assert(count_of_maps); - ebpf_assert(*count_of_maps > 0 && map_fds); + ebpf_assert(*count_of_maps == 0 || map_fds); ebpf_assert(count_of_programs); - ebpf_assert(*count_of_programs > 0 && program_fds); + ebpf_assert(*count_of_programs == 0 || program_fds); ebpf_handle_t native_module_handle; ebpf_handle_t* map_handles = nullptr; ebpf_handle_t* program_handles = nullptr; - size_t real_count_of_maps = 0; - size_t real_count_of_programs = 0; ebpf_result_t result = _ebpf_object_load_native( - file_name, &native_module_handle, &real_count_of_maps, &map_handles, &real_count_of_programs, &program_handles); + file_name, &native_module_handle, count_of_maps, &map_handles, count_of_programs, &program_handles); if (result != EBPF_SUCCESS) { EBPF_RETURN_RESULT(result); } Platform::CloseHandle(native_module_handle); - if (*count_of_maps < real_count_of_maps || *count_of_programs < real_count_of_programs) { - *count_of_maps = real_count_of_maps; - *count_of_programs = real_count_of_programs; - _ebpf_free_handles(real_count_of_maps, map_handles); - _ebpf_free_handles(real_count_of_programs, program_handles); - EBPF_RETURN_RESULT(EBPF_NO_MEMORY); - } - - *count_of_maps = real_count_of_maps; - *count_of_programs = real_count_of_programs; - - for (int i = 0; i < real_count_of_maps; i++) { + __analysis_assume(*count_of_maps == 0 || map_handles != nullptr); + for (int i = 0; i < *count_of_maps; i++) { map_fds[i] = _create_file_descriptor_for_handle(map_handles[i]); if (map_fds[i] == ebpf_fd_invalid) { result = EBPF_NO_MEMORY; @@ -3957,7 +3958,8 @@ ebpf_object_load_native_by_fds( } } - for (int i = 0; i < real_count_of_programs; i++) { + __analysis_assume(*count_of_programs == 0 || program_handles != nullptr); + for (int i = 0; i < *count_of_programs; i++) { program_fds[i] = _create_file_descriptor_for_handle(program_handles[i]); if (program_fds[i] == ebpf_fd_invalid) { result = EBPF_NO_MEMORY; @@ -3968,7 +3970,7 @@ ebpf_object_load_native_by_fds( if (result != EBPF_SUCCESS) { if (map_fds != nullptr) { - for (int i = 0; i < real_count_of_maps; i++) { + for (int i = 0; i < *count_of_maps; i++) { if (map_fds[i] != ebpf_fd_invalid) { Platform::_close(map_fds[i]); map_fds[i] = ebpf_fd_invalid; @@ -3977,7 +3979,7 @@ ebpf_object_load_native_by_fds( } if (program_fds != nullptr) { - for (int i = 0; i < real_count_of_programs; i++) { + for (int i = 0; i < *count_of_programs; i++) { if (program_fds[i] != ebpf_fd_invalid) { Platform::_close(program_fds[i]); program_fds[i] = ebpf_fd_invalid; @@ -3986,8 +3988,8 @@ ebpf_object_load_native_by_fds( } } - _ebpf_free_handles(real_count_of_maps, map_handles); - _ebpf_free_handles(real_count_of_programs, program_handles); + _ebpf_free_handles(*count_of_maps, map_handles); + _ebpf_free_handles(*count_of_programs, program_handles); EBPF_RETURN_RESULT(result); } @@ -4006,8 +4008,8 @@ _ebpf_program_load_native( ebpf_handle_t native_module_handle = ebpf_handle_invalid; ebpf_handle_t* map_handles = nullptr; ebpf_handle_t* program_handles = nullptr; - size_t count_of_maps = 0; - size_t count_of_programs = 0; + size_t count_of_maps = SIZE_MAX; + size_t count_of_programs = SIZE_MAX; try { result = _ebpf_object_load_native( diff --git a/libs/thunk/mock/mock.cpp b/libs/thunk/mock/mock.cpp index e5220e3154..89f8d6043c 100644 --- a/libs/thunk/mock/mock.cpp +++ b/libs/thunk/mock/mock.cpp @@ -17,8 +17,8 @@ std::function duplicate_handle_handler; std::function device_io_control_handler; std::function get_osfhandle_handler; std::function open_osfhandle_handler; -std::function create_service_handler; -std::function delete_service_handler; +std::function create_service_handler; +std::function delete_service_handler; namespace Platform { bool @@ -180,13 +180,10 @@ _query_service_status(SC_HANDLE service_handle, _Inout_ SERVICE_STATUS* status) } uint32_t -_stop_service(SC_HANDLE service_handle) +_stop_and_delete_service(SC_HANDLE service_handle, const wchar_t* service_name) { - // TODO: (Issue# 852) Just a stub currently in order to compile. - // Will be replaced by a proper mock. - - UNREFERENCED_PARAMETER(service_handle); - return ERROR_SUCCESS; + UNREFERENCED_PARAMETER(service_name); + return _delete_service(service_handle); } } // namespace Platform diff --git a/libs/thunk/mock/mock.h b/libs/thunk/mock/mock.h index 7c389424b1..7799ac9de0 100644 --- a/libs/thunk/mock/mock.h +++ b/libs/thunk/mock/mock.h @@ -6,12 +6,16 @@ #include #include +namespace Platform { + uint32_t _create_service(_In_z_ const wchar_t* service_name, _In_z_ const wchar_t* file_path, _Out_ SC_HANDLE* service_handle); uint32_t _delete_service(SC_HANDLE service_handle); +} // namespace Platform + extern std::function close_handler; extern std::function dup_handler; extern std::function cancel_io_ex_handler; @@ -21,5 +25,5 @@ extern std::function duplicate_handle_handler; extern std::function device_io_control_handler; extern std::function get_osfhandle_handler; extern std::function open_osfhandle_handler; -extern std::function create_service_handler; -extern std::function delete_service_handler; +extern std::function create_service_handler; +extern std::function delete_service_handler; diff --git a/libs/thunk/platform.h b/libs/thunk/platform.h index 17c3dbf3b0..ddbfae9a0f 100644 --- a/libs/thunk/platform.h +++ b/libs/thunk/platform.h @@ -79,7 +79,7 @@ uint32_t _delete_service(SC_HANDLE service_handle); uint32_t -_stop_service(SC_HANDLE service_handle); +_stop_and_delete_service(SC_HANDLE service_handle, const wchar_t* service_name); bool _query_service_status(SC_HANDLE service_handle, _Inout_ SERVICE_STATUS* status); diff --git a/libs/thunk/windows/platform.cpp b/libs/thunk/windows/platform.cpp index 29dada01f5..06bce7f791 100644 --- a/libs/thunk/windows/platform.cpp +++ b/libs/thunk/windows/platform.cpp @@ -182,33 +182,36 @@ _update_registry_value( return error; } -static bool -_check_service_state(SC_HANDLE service_handle, unsigned long expected_state, _Out_ unsigned long* final_state) +static DWORD +_check_service_deleted(const wchar_t* service_name) { #define MAX_RETRY_COUNT 20 #define WAIT_TIME_IN_MS 500 - int retry_count = 0; - bool status = false; - int error; - SERVICE_STATUS service_status = {0}; + SC_HANDLE scm_handle = OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT); + if (scm_handle == nullptr) { + return GetLastError(); + } - // Query service state. - while (retry_count < MAX_RETRY_COUNT) { - if (!QueryServiceStatus(service_handle, &service_status)) { + DWORD error = ERROR_SERVICE_REQUEST_TIMEOUT; + for (int retry_count = 0; retry_count < MAX_RETRY_COUNT; retry_count++) { + SC_HANDLE service_handle = OpenService(scm_handle, service_name, SERVICE_QUERY_STATUS); + + if (service_handle == nullptr) { error = GetLastError(); break; - } else if (service_status.dwCurrentState == expected_state) { - status = true; - break; - } else { - Sleep(WAIT_TIME_IN_MS); - retry_count++; } + + CloseServiceHandle(service_handle); + Sleep(WAIT_TIME_IN_MS); + } + + if (error == ERROR_SERVICE_DOES_NOT_EXIST) { + error = ERROR_SUCCESS; } - *final_state = service_status.dwCurrentState; - return status; + CloseServiceHandle(scm_handle); + return error; } uint32_t @@ -285,27 +288,30 @@ _delete_service(SC_HANDLE service_handle) } uint32_t -_stop_service(SC_HANDLE service_handle) +_stop_and_delete_service(SC_HANDLE service_handle, const wchar_t* service_name) { SERVICE_STATUS status; - bool service_stopped = false; - unsigned long service_state; - int error = ERROR_SUCCESS; + DWORD error = ERROR_SUCCESS; if (service_handle == nullptr) { return EBPF_INVALID_ARGUMENT; } if (!ControlService(service_handle, SERVICE_CONTROL_STOP, &status)) { - return GetLastError(); + error = GetLastError(); + if (error != ERROR_SERVICE_NOT_ACTIVE) { + return error; + } + + // Service is already stopped. } - service_stopped = _check_service_state(service_handle, SERVICE_STOPPED, &service_state); - if (!service_stopped) { - error = ERROR_SERVICE_REQUEST_TIMEOUT; + error = _delete_service(service_handle); + if (error != ERROR_SUCCESS) { + return error; } - return error; + return _check_service_deleted(service_name); } } // namespace Platform diff --git a/tests/api_test/api_test.cpp b/tests/api_test/api_test.cpp index affed63c0a..4c5ceab93f 100644 --- a/tests/api_test/api_test.cpp +++ b/tests/api_test/api_test.cpp @@ -1641,6 +1641,40 @@ TEST_CASE("prog_array_map_user_reference-native", "[user_reference]") _test_prog_array_map_user_reference(EBPF_EXECUTION_NATIVE); } +TEST_CASE("native_load_retry_after_insufficient_buffers", "[native_tests]") +{ + native_module_helper_t native_helper; + native_helper.initialize("bindmonitor", EBPF_EXECUTION_NATIVE); + + std::vector map_fds(3, ebpf_fd_invalid); + std::vector program_fds(1, ebpf_fd_invalid); + size_t count_of_maps = 0; + size_t count_of_programs = 0; + + ebpf_result_t result = ebpf_object_load_native_by_fds( + native_helper.get_file_name().c_str(), &count_of_maps, nullptr, &count_of_programs, nullptr); + + REQUIRE(result == EBPF_NO_MEMORY); + REQUIRE(count_of_maps == map_fds.size()); + REQUIRE(count_of_programs == program_fds.size()); + + result = ebpf_object_load_native_by_fds( + native_helper.get_file_name().c_str(), &count_of_maps, map_fds.data(), &count_of_programs, program_fds.data()); + + REQUIRE(result == EBPF_SUCCESS); + REQUIRE(count_of_maps == map_fds.size()); + REQUIRE(count_of_programs == program_fds.size()); + + for (auto fd : map_fds) { + REQUIRE(fd != ebpf_fd_invalid); + _close(fd); + } + for (auto fd : program_fds) { + REQUIRE(fd != ebpf_fd_invalid); + _close(fd); + } +} + TEST_CASE("load_all_sample_programs", "[native_tests]") { struct _ebpf_program_load_test_parameters test_parameters[] = {