diff --git a/source/main.cpp b/source/main.cpp index c714b31..e549ef4 100644 --- a/source/main.cpp +++ b/source/main.cpp @@ -68,6 +68,9 @@ GMOD_MODULE_OPEN( ) LUA->PushCFunction( redis_subscriber::Create ); LUA->SetField( -2, "CreateSubscriber" ); + LUA->PushCFunction( redis_client::IsError ); + LUA->SetField( -2, "IsError" ); + LUA->SetField( GarrysMod::Lua::INDEX_GLOBAL, redis::table_name ); return 0; diff --git a/source/redis_client.cpp b/source/redis_client.cpp index 76c1dd4..7104897 100644 --- a/source/redis_client.cpp +++ b/source/redis_client.cpp @@ -8,6 +8,8 @@ #include "redis_client.hpp" #include "main.hpp" +using namespace GarrysMod::Lua; + namespace redis_client { @@ -50,6 +52,8 @@ class Container static const char metaname[] = "redis_client"; static int32_t metatype = GarrysMod::Lua::Type::NONE; +static const char errormetaname[] = "redis_clienterror"; +static int32_t errormetatype = GarrysMod::Lua::Type::NONE; static const char invalid_error[] = "invalid redis_client"; static const char table_name[] = "redis_clients"; @@ -84,6 +88,12 @@ LUA_FUNCTION( Create ) return 1; } +LUA_FUNCTION( IsError ) +{ + LUA->PushBool( LUA->GetMetaTable(-1) && LUA->PushMetaTable(errormetatype) && LUA->RawEqual(-1, -2) ); + return 1; +} + inline void CheckType( GarrysMod::Lua::ILuaBase *LUA, int32_t index ) { if( !LUA->IsType( index, metatype ) ) @@ -114,6 +124,12 @@ LUA_FUNCTION_STATIC( tostring ) return 1; } +LUA_FUNCTION_STATIC( errortostring ) +{ + LUA->GetField(1, "error"); + return 1; +} + LUA_FUNCTION_STATIC( eq ) { LUA->PushBool( Get( LUA, 1 ) == Get( LUA, 2 ) ); @@ -172,7 +188,7 @@ LUA_FUNCTION_STATIC( IsValid ) LUA_FUNCTION_STATIC( IsConnected ) { cpp_redis::client *client = Get( LUA, 1 ); - LUA->PushBool( client->is_connected( ) ); + LUA->PushBool( client != nullptr && client->is_connected( ) ); return 1; } @@ -211,40 +227,49 @@ LUA_FUNCTION_STATIC( Disconnect ) return 0; } +void BuildReply(GarrysMod::Lua::ILuaBase *LUA, const cpp_redis::reply &reply); + inline void BuildTable( GarrysMod::Lua::ILuaBase *LUA, const std::vector &replies ) { LUA->CreateTable( ); for( size_t k = 0; k < replies.size( ); ++k ) { - LUA->PushNumber( static_cast( k ) ); - - const cpp_redis::reply &reply = replies[k]; - switch( reply.get_type( ) ) - { - case cpp_redis::reply::type::error: - case cpp_redis::reply::type::bulk_string: - case cpp_redis::reply::type::simple_string: - LUA->PushString( reply.as_string( ).c_str( ) ); - break; - - case cpp_redis::reply::type::integer: - LUA->PushNumber( static_cast( reply.as_integer( ) ) ); - break; - - case cpp_redis::reply::type::array: - BuildTable( LUA, reply.as_array( ) ); - break; - - case cpp_redis::reply::type::null: - LUA->PushNil( ); - break; - } - + LUA->PushNumber( static_cast( k + 1 ) ); + BuildReply(LUA, replies[k]); LUA->SetTable( -3 ); } } +inline void BuildReply(GarrysMod::Lua::ILuaBase *LUA, const cpp_redis::reply &reply) { + switch( reply.get_type( ) ) + { + case cpp_redis::reply::type::error: + LUA->CreateTable(); + LUA->PushString( reply.as_string( ).c_str( ) ); + LUA->SetField(-2, "error"); + LUA->PushMetaTable( errormetatype ); + LUA->SetMetaTable( -2 ); + break; + case cpp_redis::reply::type::bulk_string: + case cpp_redis::reply::type::simple_string: + LUA->PushString( reply.as_string( ).c_str( ) ); + break; + + case cpp_redis::reply::type::integer: + LUA->PushNumber( static_cast( reply.as_integer( ) ) ); + break; + + case cpp_redis::reply::type::array: + BuildTable( LUA, reply.as_array( ) ); + break; + + case cpp_redis::reply::type::null: + LUA->PushBool( false ); + break; + } +} + LUA_FUNCTION_STATIC( Poll ) { Container *container = nullptr; @@ -283,26 +308,7 @@ LUA_FUNCTION_STATIC( Poll ) LUA->ReferencePush( response.reference ); LUA->Push( 1 ); - switch( response.reply.get_type( ) ) - { - case cpp_redis::reply::type::error: - case cpp_redis::reply::type::bulk_string: - case cpp_redis::reply::type::simple_string: - LUA->PushString( response.reply.as_string( ).c_str( ) ); - break; - - case cpp_redis::reply::type::integer: - LUA->PushNumber( static_cast( response.reply.as_integer( ) ) ); - break; - - case cpp_redis::reply::type::array: - BuildTable( LUA, response.reply.as_array( ) ); - break; - - case cpp_redis::reply::type::null: - LUA->PushNil( ); - break; - } + BuildReply(LUA, response.reply); if( LUA->PCall( 2, 0, -4 ) != 0 ) { @@ -492,12 +498,22 @@ void Initialize( GarrysMod::Lua::ILuaBase *LUA ) LUA->SetField( -2, "Commit" ); LUA->Pop( 1 ); + + errormetatype = LUA->CreateMetaTable( errormetaname ); + + LUA->PushCFunction( errortostring ); + LUA->SetField( -2, "__tostring" ); + + LUA->Pop( 1 ); } void Deinitialize( GarrysMod::Lua::ILuaBase *LUA ) { LUA->PushNil( ); LUA->SetField( GarrysMod::Lua::INDEX_REGISTRY, metaname ); + + LUA->PushNil( ); + LUA->SetField( GarrysMod::Lua::INDEX_REGISTRY, errormetaname ); } } diff --git a/source/redis_client.hpp b/source/redis_client.hpp index a8e351f..9d2c2ed 100644 --- a/source/redis_client.hpp +++ b/source/redis_client.hpp @@ -16,5 +16,5 @@ namespace redis_client void Initialize( GarrysMod::Lua::ILuaBase *LUA ); void Deinitialize( GarrysMod::Lua::ILuaBase *LUAe ); LUA_FUNCTION_DECLARE( Create ); - +LUA_FUNCTION_DECLARE( IsError ); } diff --git a/source/redis_subscriber.cpp b/source/redis_subscriber.cpp index 7ba8555..123553f 100644 --- a/source/redis_subscriber.cpp +++ b/source/redis_subscriber.cpp @@ -8,13 +8,16 @@ #include "redis_subscriber.hpp" #include "main.hpp" +using namespace GarrysMod::Lua; + namespace redis_subscriber { enum class Action { Disconnection, - Message + Message, + AuthFail }; struct Response @@ -210,6 +213,36 @@ LUA_FUNCTION_STATIC( Disconnect ) return 0; } + + +LUA_FUNCTION_STATIC( Auth ) +{ + Container *container = nullptr; + cpp_redis::subscriber *subscriber = Get( LUA, 1, &container ); + std::string password = LUA->CheckString( 2 ); + + try + { + subscriber->auth( password, [container]( cpp_redis::reply &reply ) + { + if (!reply.ok()) { + container->EnqueueResponse( { Action::AuthFail, "", "" } ); + } + } ); + } + catch( const cpp_redis::redis_error &e ) + { + LUA->PushNil( ); + LUA->PushString( e.what( ) ); + return 2; + } + + LUA->PushBool( true ); + return 1; +} + + + LUA_FUNCTION_STATIC( Poll ) { Container *container = nullptr; @@ -224,6 +257,18 @@ LUA_FUNCTION_STATIC( Poll ) { switch( response.type ) { + case Action::AuthFail: + { + const char* err = LUA->GetString(-1); + LUA->PushSpecial(GarrysMod::Lua::SPECIAL_GLOB); + LUA->GetField(-1, "ErrorNoHalt"); + LUA->PushString("\n\n[redis subscriber auth failed] \n\n\n"); + LUA->Call(1, 0); + LUA->Pop(2); + } + + break; + case Action::Disconnection: if( !redis::GetMetaField( LUA, 1, "OnDisconnected" ) ) break; @@ -413,6 +458,9 @@ void Initialize( GarrysMod::Lua::ILuaBase *LUA ) LUA->PushCFunction( Disconnect ); LUA->SetField( -2, "Disconnect" ); + LUA->PushCFunction( Auth ); + LUA->SetField( -2, "Auth" ); + LUA->PushCFunction( Poll ); LUA->SetField( -2, "Poll" );