Skip to content

Commit 63087b8

Browse files
committed
#52 Improved stability of MmpTls
1 parent 06968d8 commit 63087b8

File tree

2 files changed

+138
-32
lines changed

2 files changed

+138
-32
lines changed

MemoryModule/MemoryModulePP.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
LIBRARY
22
EXPORTS
3+
4+
MmInitialize
5+
MmCleanup
6+
37
LoadLibraryMemory
48
LoadLibraryMemoryExA
59
LoadLibraryMemoryExW

MemoryModule/MmpTls.cpp

Lines changed: 134 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cassert>
88
#include <algorithm>
99
#include <3rdparty/Detours/detours.h>
10+
#include <set>
1011

1112

1213
PVOID NTAPI MmpQuerySystemInformation(
@@ -402,6 +403,96 @@ BOOL NTAPI PreHookNtSetInformationProcess() {
402403
return success;
403404
}
404405

406+
int MmpSyncThreadTlsData() {
407+
PSYSTEM_PROCESS_INFORMATION pspi = (PSYSTEM_PROCESS_INFORMATION)MmpQuerySystemInformation(SYSTEM_INFORMATION_CLASS::SystemProcessInformation, nullptr);
408+
PSYSTEM_PROCESS_INFORMATION current = pspi;
409+
std::set<HANDLE>threads;
410+
int count = 0;
411+
412+
//
413+
// Build thread id set.
414+
//
415+
416+
PLIST_ENTRY entry = MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer.Flink;
417+
while (entry != &MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer) {
418+
PMMP_TLSP_RECORD j = CONTAINING_RECORD(entry, MMP_TLSP_RECORD, InMmpThreadLocalStoragePointer);
419+
threads.insert(j->UniqueThread);
420+
421+
entry = entry->Flink;
422+
}
423+
424+
while (pspi) {
425+
426+
if (current->UniqueProcessId == NtCurrentTeb()->ClientId.UniqueProcess) {
427+
428+
for (ULONG index = 0; index < current->NumberOfThreads; ++index) {
429+
CLIENT_ID cid = current->Threads[index].ClientId;
430+
431+
if (threads.find(cid.UniqueThread) == threads.end()) {
432+
433+
HANDLE hThread;
434+
OBJECT_ATTRIBUTES oa{};
435+
NTSTATUS status = NtOpenThread(&hThread, THREAD_QUERY_INFORMATION, &oa, &cid);
436+
if (NT_SUCCESS(status)) {
437+
438+
THREAD_BASIC_INFORMATION tbi{};
439+
status = NtQueryInformationThread(hThread, THREADINFOCLASS::ThreadBasicInformation, &tbi, sizeof(tbi), nullptr);
440+
if (NT_SUCCESS(status)) {
441+
442+
PTEB teb = tbi.TebBaseAddress;
443+
if (teb->ThreadLocalStoragePointer) {
444+
445+
//
446+
// Allocate TLS record
447+
//
448+
449+
auto record = PMMP_TLSP_RECORD(RtlAllocateHeap(RtlProcessHeap(), 0, sizeof(MMP_TLSP_RECORD)));
450+
if (record) {
451+
record->TlspLdrBlock = (PVOID*)teb->ThreadLocalStoragePointer;
452+
record->TlspMmpBlock = (PVOID*)MmpAllocateTlsp();
453+
record->UniqueThread = cid.UniqueThread;
454+
if (record->TlspMmpBlock) {
455+
record->TlspMmpBlock = ((PTLS_VECTOR)record->TlspMmpBlock)->ModuleTlsData;
456+
457+
auto size = CONTAINING_RECORD(record->TlspLdrBlock, TLS_VECTOR, ModuleTlsData)->Length;
458+
if ((HANDLE)(ULONG_PTR)size != record->UniqueThread) {
459+
RtlCopyMemory(
460+
record->TlspMmpBlock,
461+
record->TlspLdrBlock,
462+
size * sizeof(PVOID)
463+
);
464+
}
465+
466+
teb->ThreadLocalStoragePointer = record->TlspMmpBlock;
467+
InsertTailList(&MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer, &record->InMmpThreadLocalStoragePointer);
468+
InterlockedIncrement(&MmpGlobalDataPtr->MmpTls->MmpActiveThreadCount);
469+
470+
++count;
471+
}
472+
else {
473+
RtlFreeHeap(RtlProcessHeap(), 0, record);
474+
}
475+
}
476+
}
477+
}
478+
479+
NtClose(hThread);
480+
}
481+
482+
}
483+
}
484+
485+
break;
486+
}
487+
488+
if (!current->NextEntryOffset)break;
489+
current = (PSYSTEM_PROCESS_INFORMATION)((PBYTE)current + current->NextEntryOffset);
490+
}
491+
492+
RtlFreeHeap(RtlProcessHeap(), 0, pspi);
493+
return count;
494+
}
495+
405496
NTSTATUS NTAPI HookNtSetInformationProcess(
406497
_In_opt_ HANDLE ProcessHandle,
407498
_In_ PROCESSINFOCLASS ProcessInformationClass,
@@ -423,6 +514,12 @@ NTSTATUS NTAPI HookNtSetInformationProcess(
423514
PPROCESS_TLS_INFORMATION Tls = nullptr;
424515
NTSTATUS status = STATUS_SUCCESS;
425516

517+
//
518+
// Sync thread data with ntdll!Ldr.
519+
//
520+
521+
MmpSyncThreadTlsData();
522+
426523
do {
427524
if (ProcessTlsInformation->OperationType >= MaxProcessTlsOperation) {
428525
status = STATUS_INVALID_PARAMETER;
@@ -456,7 +553,7 @@ NTSTATUS NTAPI HookNtSetInformationProcess(
456553
break;
457554
}
458555

459-
// reserved 0x50 PVOID for ntdll loader
556+
// reserved 0x80 PVOID for ntdll loader
460557
if (ProcessTlsInformation->TlsVectorLength >= MMP_START_TLS_INDEX) {
461558
status = STATUS_NO_MEMORY;
462559
break;
@@ -496,50 +593,55 @@ NTSTATUS NTAPI HookNtSetInformationProcess(
496593
//
497594
EnterCriticalSection(&MmpGlobalDataPtr->MmpTls->MmpTlspLock);
498595
for (ULONG i = 0; i < Tls->ThreadDataCount; ++i) {
499-
BOOL found = FALSE;
500-
PLIST_ENTRY entry = MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer.Flink;
501596

502-
// Find thread-spec tlsp
503-
while (entry != &MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer) {
597+
if (Tls->ThreadData[i].Flags == 2) {
504598

505-
PMMP_TLSP_RECORD j = CONTAINING_RECORD(entry, MMP_TLSP_RECORD, InMmpThreadLocalStoragePointer);
599+
BOOL found = FALSE;
600+
PLIST_ENTRY entry = MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer.Flink;
506601

507-
if (ProcessTlsInformation->OperationType == ProcessTlsReplaceVector) {
508-
if (j->TlspMmpBlock[ProcessTlsInformation->TlsVectorLength] == ProcessTlsInformation->ThreadData[i].TlsVector[ProcessTlsInformation->TlsVectorLength]) {
509-
found = TRUE;
602+
// Find thread-spec tlsp
603+
while (entry != &MmpGlobalDataPtr->MmpTls->MmpThreadLocalStoragePointer) {
510604

511-
// Copy old data to new pointer
512-
RtlCopyMemory(
513-
ProcessTlsInformation->ThreadData[i].TlsVector,
514-
j->TlspMmpBlock,
515-
sizeof(PVOID) * ProcessTlsInformation->TlsVectorLength
516-
);
605+
PMMP_TLSP_RECORD j = CONTAINING_RECORD(entry, MMP_TLSP_RECORD, InMmpThreadLocalStoragePointer);
517606

518-
// Swap the tlsp
519-
std::swap(
520-
j->TlspLdrBlock,
521-
ProcessTlsInformation->ThreadData[i].TlsVector
522-
);
607+
if (ProcessTlsInformation->OperationType == ProcessTlsReplaceVector) {
608+
if (j->TlspMmpBlock[ProcessTlsInformation->TlsVectorLength] == ProcessTlsInformation->ThreadData[i].TlsVector[ProcessTlsInformation->TlsVectorLength]) {
609+
found = TRUE;
610+
611+
// Copy old data to new pointer
612+
RtlCopyMemory(
613+
ProcessTlsInformation->ThreadData[i].TlsVector,
614+
j->TlspMmpBlock,
615+
sizeof(PVOID) * ProcessTlsInformation->TlsVectorLength
616+
);
617+
618+
// Swap the tlsp
619+
std::swap(
620+
j->TlspLdrBlock,
621+
ProcessTlsInformation->ThreadData[i].TlsVector
622+
);
623+
}
523624
}
524-
}
525-
else {
526-
if (j->TlspMmpBlock[ProcessTlsInformation->TlsIndex] == ProcessTlsInformation->ThreadData[i].TlsModulePointer) {
527-
found = TRUE;
625+
else {
626+
if (j->TlspMmpBlock[ProcessTlsInformation->TlsIndex] == ProcessTlsInformation->ThreadData[i].TlsModulePointer) {
627+
found = TRUE;
628+
629+
if (ProcessHandle) {
630+
j->TlspLdrBlock[ProcessTlsInformation->TlsIndex] = ProcessTlsInformation->ThreadData[i].TlsModulePointer;
631+
}
528632

529-
if (ProcessHandle) {
530-
j->TlspLdrBlock[ProcessTlsInformation->TlsIndex] = ProcessTlsInformation->ThreadData[i].TlsModulePointer;
633+
ProcessTlsInformation->ThreadData[i].TlsModulePointer = Tls->ThreadData[i].TlsModulePointer;
531634
}
532-
533-
ProcessTlsInformation->ThreadData[i].TlsModulePointer = Tls->ThreadData[i].TlsModulePointer;
534635
}
636+
637+
if (found)break;
638+
entry = entry->Flink;
535639
}
536640

537-
if (found)break;
538-
entry = entry->Flink;
641+
ProcessTlsInformation->ThreadData[i].Flags = Tls->ThreadData[i].Flags;
642+
ProcessTlsInformation->ThreadData[i].ThreadId = Tls->ThreadData[i].ThreadId;
539643
}
540644

541-
ProcessTlsInformation->ThreadData[i].Flags = Tls->ThreadData[i].Flags;
542-
ProcessTlsInformation->ThreadData[i].ThreadId = Tls->ThreadData[i].ThreadId;
543645
}
544646
LeaveCriticalSection(&MmpGlobalDataPtr->MmpTls->MmpTlspLock);
545647

0 commit comments

Comments
 (0)