88 *
99 *****************************************************************************/
1010
11+ #pragma once
12+ #include < windows.h>
13+ #include < cassert>
14+ #include " SharedUtil.MemAccess.h"
15+ #include " SharedUtil.Logging.h"
16+ #include " SString.h"
17+
1118namespace SharedUtil
1219{
1320 // Returns true if matching memset would have no affect
@@ -74,13 +81,12 @@ namespace SharedUtil
7481 // Temporarily unprotect slow mem area
7582 SMemWrite OpenMemWrite (const void * pAddr, uint uiAmount)
7683 {
77- SMemWrite hMem;
84+ SMemWrite hMem{} ;
7885
7986 // Check for incorrect use of function
8087 if (!IsSlowMem (pAddr, uiAmount))
8188 {
8289 dassert (0 && " Should use Mem*Fast function" );
83- hMem.dwFirstPage = 0 ;
8490 return hMem;
8591 }
8692
@@ -91,10 +97,18 @@ namespace SharedUtil
9197 hMem.dwFirstPage = ((DWORD)pAddr) & ~0xFFF ;
9298 DWORD dwLastPage = (((DWORD)pAddr) + uiAmount - 1 ) & ~0xFFF ;
9399 hMem.dwSize = dwLastPage - hMem.dwFirstPage + 0x1000 ;
94- VirtualProtect ((LPVOID)hMem.dwFirstPage , 0x1000 , PAGE_EXECUTE_READWRITE, &hMem.oldProt );
100+
101+ if (!VirtualProtect ((LPVOID)hMem.dwFirstPage , 0x1000 , PAGE_EXECUTE_READWRITE, &hMem.oldProt ))
102+ {
103+ DWORD error = GetLastError ();
104+ OutputDebugLine (SString (" MemAccess::OpenMemWrite: VirtualProtect failed at %08x, error: %d" , hMem.dwFirstPage , error));
105+ hMem = {};
106+ assert (!" Failed to unprotect memory" );
107+ return hMem;
108+ }
95109
96110 // Make sure not using this slow function too much
97- OutputDebugLine (SString (" [Mem ] OpenMemWrite at %08x for %d bytes (oldProt:%04x)" , pAddr, uiAmount, hMem.oldProt ));
111+ OutputDebugLine (SString (" [MemAccess ] OpenMemWrite at %08x for %d bytes (oldProt:%04x)" , pAddr, uiAmount, hMem.oldProt ));
98112
99113#ifdef MTA_DEBUG
100114#if 0 // Annoying
@@ -103,14 +117,25 @@ namespace SharedUtil
103117 assert( hMem.oldProt == PAGE_EXECUTE_READ || hMem.oldProt == PAGE_READONLY );
104118 else
105119 assert( hMem.oldProt == PAGE_EXECUTE_READWRITE || hMem.oldProt == PAGE_EXECUTE_WRITECOPY );
106- #endif
120+ #endif
107121#endif
108122
109123 // Extra if more than one page
110124 for (uint i = 0x1000 ; i < hMem.dwSize ; i += 0x1000 )
111125 {
112126 DWORD oldProtCheck;
113- VirtualProtect ((LPVOID)(hMem.dwFirstPage + i), 0x1000 , PAGE_EXECUTE_READWRITE, &oldProtCheck);
127+ if (!VirtualProtect ((LPVOID)(hMem.dwFirstPage + i), 0x1000 , PAGE_EXECUTE_READWRITE, &oldProtCheck))
128+ {
129+ // Try to rollback
130+ DWORD temp;
131+ VirtualProtect ((LPVOID)hMem.dwFirstPage , i, hMem.oldProt , &temp);
132+
133+ DWORD error = GetLastError ();
134+ OutputDebugLine (SString (" [MemAccess] OpenMemWrite VirtualProtect failed at %08x, error: %d" , hMem.dwFirstPage + i, error));
135+ hMem = {};
136+ assert (!" Failed to unprotect multi-page memory region" );
137+ return hMem;
138+ }
114139 dassert (hMem.oldProt == oldProtCheck);
115140 }
116141
@@ -122,9 +147,22 @@ namespace SharedUtil
122147 {
123148 if (hMem.dwFirstPage == 0 )
124149 return ;
150+
125151 DWORD oldProt;
126- VirtualProtect ((LPVOID)hMem.dwFirstPage , hMem.dwSize , hMem.oldProt , &oldProt);
127- dassert (oldProt == PAGE_EXECUTE_READWRITE);
152+ BOOL result = VirtualProtect ((LPVOID)hMem.dwFirstPage , hMem.dwSize , hMem.oldProt , &oldProt);
153+
154+ if (!result)
155+ {
156+ DWORD error = GetLastError ();
157+ OutputDebugLine (SString (" MemAccess::CloseMemWrite: VirtualProtect failed at %08x, size %08x, error: %d" , hMem.dwFirstPage , hMem.dwSize , error));
158+ assert (!" Failed to restore memory protection - critical" );
159+ }
160+ else
161+ {
162+ dassert (oldProt == PAGE_EXECUTE_READWRITE);
163+ }
164+
165+ hMem.dwFirstPage = 0 ; // Invalidate handle
128166 }
129167
130168} // namespace SharedUtil
0 commit comments