@@ -1203,39 +1203,102 @@ void SharedUtil::RandomizeRandomSeed()
12031203}
12041204
12051205#ifdef WIN32
1206+ static LONG SafeNtQueryInformationThread (HANDLE ThreadHandle, INT ThreadInformationClass, PVOID ThreadInformation, ULONG ThreadInformationLength,
1207+ PULONG ReturnLength)
1208+ {
1209+ using FunctionPointer = LONG (__stdcall*)(HANDLE, INT /* = THREADINFOCLASS*/ , PVOID, ULONG, PULONG);
1210+
1211+ struct FunctionLookup
1212+ {
1213+ FunctionPointer function;
1214+ bool once;
1215+ };
1216+
1217+ static FunctionLookup lookup = {};
1218+
1219+ if (!lookup.once )
1220+ {
1221+ lookup.once = true ;
1222+
1223+ HMODULE ntdll = LoadLibraryA (" ntdll.dll" );
1224+
1225+ if (ntdll)
1226+ lookup.function = (FunctionPointer)GetProcAddress (ntdll, " NtQueryInformationThread" );
1227+ else
1228+ return 0xC0000135L ; // STATUS_DLL_NOT_FOUND
1229+ }
1230+
1231+ if (lookup.function )
1232+ return lookup.function (ThreadHandle, ThreadInformationClass, ThreadInformation, ThreadInformationLength, ReturnLength);
1233+ else
1234+ return 0xC00000BBL ; // STATUS_NOT_SUPPORTED
1235+ }
1236+
1237+ bool SharedUtil::QueryThreadEntryPointAddress (void * thread, DWORD* entryPointAddress)
1238+ {
1239+ return SafeNtQueryInformationThread (thread, 9 /* ThreadQuerySetWin32StartAddress*/ , entryPointAddress, sizeof (DWORD), nullptr ) == 0 ;
1240+ }
1241+
12061242DWORD SharedUtil::GetMainThreadId ()
12071243{
12081244 static DWORD dwMainThreadID = 0 ;
1245+
12091246 if (dwMainThreadID == 0 )
12101247 {
1248+ // Get the module information for the currently running process
1249+ MODULEINFO moduleInfo = {};
1250+ GetModuleInformation (GetCurrentProcess (), GetModuleHandle (nullptr ), &moduleInfo, sizeof (MODULEINFO));
1251+
1252+ DWORD processEntryPointAddress = reinterpret_cast <DWORD>(moduleInfo.EntryPoint );
1253+
12111254 // Find oldest thread in the current process ( http://www.codeproject.com/Questions/78801/How-to-get-the-main-thread-ID-of-a-process-known-b )
12121255 HANDLE hThreadSnap = CreateToolhelp32Snapshot (TH32CS_SNAPTHREAD, 0 );
1256+
12131257 if (hThreadSnap != INVALID_HANDLE_VALUE)
12141258 {
1215- ULONGLONG ullMinCreateTime = ULLONG_MAX;
1216- THREADENTRY32 th32;
1259+ ULONGLONG ullMinCreateTime = ULLONG_MAX;
1260+
1261+ THREADENTRY32 th32 = {};
12171262 th32.dwSize = sizeof (THREADENTRY32);
1263+
12181264 for (BOOL bOK = Thread32First (hThreadSnap, &th32); bOK; bOK = Thread32Next (hThreadSnap, &th32))
12191265 {
12201266 if (th32.th32OwnerProcessID == GetCurrentProcessId ())
12211267 {
12221268 HANDLE hThread = OpenThread (THREAD_QUERY_INFORMATION, TRUE , th32.th32ThreadID );
1269+
12231270 if (hThread)
12241271 {
1225- FILETIME afTimes[4 ] = {0 };
1272+ // Check the thread by entry point first
1273+ DWORD entryPointAddress = 0 ;
1274+
1275+ if (QueryThreadEntryPointAddress (hThread, &entryPointAddress) && entryPointAddress == processEntryPointAddress)
1276+ {
1277+ dwMainThreadID = th32.th32ThreadID ;
1278+ CloseHandle (hThread);
1279+ CloseHandle (hThreadSnap);
1280+ return dwMainThreadID;
1281+ }
1282+
1283+ // If entry point check failed, find the oldest thread in the system
1284+ FILETIME afTimes[4 ] = {};
1285+
12261286 if (GetThreadTimes (hThread, &afTimes[0 ], &afTimes[1 ], &afTimes[2 ], &afTimes[3 ]))
12271287 {
12281288 ULONGLONG ullTest = (ULONGLONG (afTimes[0 ].dwHighDateTime ) << 32 ) + afTimes[0 ].dwLowDateTime ;
1289+
12291290 if (ullTest && ullTest < ullMinCreateTime)
12301291 {
12311292 ullMinCreateTime = ullTest;
12321293 dwMainThreadID = th32.th32ThreadID ;
12331294 }
12341295 }
1296+
12351297 CloseHandle (hThread);
12361298 }
12371299 }
12381300 }
1301+
12391302 CloseHandle (hThreadSnap);
12401303 }
12411304
@@ -1245,6 +1308,7 @@ DWORD SharedUtil::GetMainThreadId()
12451308 dwMainThreadID = GetCurrentThreadId ();
12461309 }
12471310 }
1311+
12481312 return dwMainThreadID;
12491313}
12501314#endif
0 commit comments