Skip to content

Commit 3b6cdfd

Browse files
author
Dan Dees
committed
test.cpp - command line accept path to a.dll
1 parent 24a7a6d commit 3b6cdfd

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

test/test.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "../MemoryModule/stdafx.h"
22
#include "../MemoryModule/LoadDllMemoryApi.h"
33
#include <cstdio>
4+
#include <string>
45
#pragma comment(lib,"ntdll.lib")
56

67
static void DisplayStatus() {
@@ -28,11 +29,11 @@ LdrpHashTable = %p\n\n\
2829
);
2930
}
3031

31-
static PVOID ReadDllFile(LPCSTR FileName) {
32+
static PVOID ReadDllFile(const std::string& FilePath) {
3233
LPVOID buffer;
3334
size_t size;
3435
FILE* f;
35-
fopen_s(&f, FileName, "rb");
36+
fopen_s(&f, FilePath.c_str(), "rb");
3637
if (!f)return 0;
3738
_fseeki64(f, 0, SEEK_END);
3839
if (!(size = _ftelli64(f))) {
@@ -48,23 +49,24 @@ static PVOID ReadDllFile(LPCSTR FileName) {
4849
return buffer;
4950
}
5051

51-
PVOID ReadDllFile2(LPCSTR FileName) {
52+
static std::string ResolveWithModulePath(const std::string& dll_path)
53+
{
54+
std::string rc(dll_path) ;
55+
56+
// expect default dll in module directory
5257
CHAR path[MAX_PATH + 4];
53-
DWORD len = GetModuleFileNameA(nullptr, path, sizeof(path));
58+
const auto len = GetModuleFileNameA(nullptr, path, sizeof(path));
5459

5560
if (len) {
56-
while (len && path[len] != '\\') --len;
57-
58-
if (len) {
59-
strcpy_s(&path[len + 1], sizeof(path) - len - 1, FileName);
60-
return ReadDllFile(path);
61-
}
61+
const std::string mod_path(path);
62+
const auto last_slash = mod_path.find_last_of("/\\") ;
63+
rc = mod_path.substr(0, last_slash + 1) + dll_path;
6264
}
6365

64-
return nullptr;
66+
return rc ;
6567
}
6668

67-
int test() {
69+
int test(const std::string& dll_path) {
6870

6971
HMODULE hModule = nullptr;
7072
FARPROC pfn = nullptr;
@@ -76,11 +78,12 @@ int test() {
7678
HGLOBAL gRes;
7779
char str[10];
7880

79-
LPVOID buffer = ReadDllFile2("a.dll");
81+
LPVOID buffer = ReadDllFile(dll_path);
8082
if ( !buffer ) {
81-
printf("failed to find a.dll.\n");
83+
printf("failed to find %s.\n", dll_path.c_str());
8284
goto end;
8385
}
86+
printf("%s read into memory.\n", dll_path.c_str());
8487

8588
if (!NT_SUCCESS(LdrLoadDllMemoryExW(&hModule, nullptr, 0, buffer, 0, L"kernel64", nullptr))) {
8689
printf("LdrLoadDllMemoryExW failed.\n");
@@ -136,11 +139,14 @@ int test() {
136139
return 0;
137140
}
138141

139-
int main() {
142+
int main(int argc, char* argv[]) {
140143

141144
DisplayStatus();
142145

143-
test();
146+
std::string dll_path("a.dll"); // default
147+
dll_path = argc > 1 ? argv[1] : ResolveWithModulePath(dll_path);
148+
149+
test(dll_path);
144150

145151
return 0;
146152
}

0 commit comments

Comments
 (0)