Skip to content

Commit 47c4739

Browse files
committed
add protected commands and protected help menu items
1 parent 39b4e3f commit 47c4739

File tree

5 files changed

+166
-95
lines changed

5 files changed

+166
-95
lines changed

src/ban.rs

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,43 +79,41 @@ pub(crate) fn start_unban_thread(cx: Context) {
7979
///
8080
/// Requires the ban members permission
8181
pub(crate) fn temp_ban(args: Args) -> Result<()> {
82-
if api::is_mod(&args)? {
83-
let user_id = parse_username(
84-
&args
85-
.params
86-
.get("user")
87-
.ok_or("unable to retrieve user param")?,
88-
)
89-
.ok_or("unable to retrieve user id")?;
82+
let user_id = parse_username(
83+
&args
84+
.params
85+
.get("user")
86+
.ok_or("unable to retrieve user param")?,
87+
)
88+
.ok_or("unable to retrieve user id")?;
9089

91-
use std::str::FromStr;
90+
use std::str::FromStr;
9291

93-
let hours = u64::from_str(
94-
args.params
95-
.get("hours")
96-
.ok_or("unable to retrieve hours param")?,
97-
)?;
92+
let hours = u64::from_str(
93+
args.params
94+
.get("hours")
95+
.ok_or("unable to retrieve hours param")?,
96+
)?;
9897

99-
let reason = args
100-
.params
101-
.get("reason")
102-
.ok_or("unable to retrieve reason param")?;
98+
let reason = args
99+
.params
100+
.get("reason")
101+
.ok_or("unable to retrieve reason param")?;
103102

104-
if let Some(guild) = args.msg.guild(&args.cx) {
105-
info!("Banning user from guild");
106-
let user = UserId::from(user_id);
103+
if let Some(guild) = args.msg.guild(&args.cx) {
104+
info!("Banning user from guild");
105+
let user = UserId::from(user_id);
107106

108-
user.create_dm_channel(args.cx)?
109-
.say(args.cx, ban_message(reason, hours))?;
107+
user.create_dm_channel(args.cx)?
108+
.say(args.cx, ban_message(reason, hours))?;
110109

111-
guild.read().ban(args.cx, &user, &"all")?;
110+
guild.read().ban(args.cx, &user, &"all")?;
112111

113-
save_ban(
114-
format!("{}", user_id),
115-
format!("{}", guild.read().id),
116-
hours,
117-
)?;
118-
}
112+
save_ban(
113+
format!("{}", user_id),
114+
format!("{}", guild.read().id),
115+
hours,
116+
)?;
119117
}
120118
Ok(())
121119
}

src/commands.rs

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,22 @@ use std::{collections::HashMap, sync::Arc};
55

66
const PREFIX: &'static str = "?";
77
pub(crate) type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
8-
pub(crate) type CmdPtr = Arc<dyn for<'m> Fn(Args<'m>) -> Result<()> + Send + Sync>;
8+
pub(crate) type GuardFn = fn(&Args) -> Result<bool>;
9+
10+
struct Command {
11+
guard: GuardFn,
12+
ptr: Box<dyn for<'m> Fn(Args<'m>) -> Result<()> + Send + Sync>,
13+
}
14+
15+
impl Command {
16+
fn authorize(&self, args: &Args) -> Result<bool> {
17+
(self.guard)(&args)
18+
}
19+
20+
fn call(&self, args: Args) -> Result<()> {
21+
(self.ptr)(args)
22+
}
23+
}
924

1025
pub struct Args<'m> {
1126
pub http: &'m HttpClient,
@@ -15,24 +30,33 @@ pub struct Args<'m> {
1530
}
1631

1732
pub(crate) struct Commands {
18-
state_machine: StateMachine,
33+
state_machine: StateMachine<Arc<Command>>,
1934
client: HttpClient,
20-
menu: HashMap<&'static str, &'static str>,
35+
menu: Option<HashMap<&'static str, (&'static str, GuardFn)>>,
2136
}
2237

2338
impl Commands {
2439
pub(crate) fn new() -> Self {
2540
Self {
2641
state_machine: StateMachine::new(),
2742
client: HttpClient::new(),
28-
menu: HashMap::new(),
43+
menu: Some(HashMap::new()),
2944
}
3045
}
3146

3247
pub(crate) fn add(
3348
&mut self,
3449
command: &'static str,
3550
handler: impl Fn(Args) -> Result<()> + Send + Sync + 'static,
51+
) {
52+
self.add_protected(command, handler, |_| Ok(true));
53+
}
54+
55+
pub(crate) fn add_protected(
56+
&mut self,
57+
command: &'static str,
58+
handler: impl Fn(Args) -> Result<()> + Send + Sync + 'static,
59+
guard: GuardFn,
3660
) {
3761
info!("Adding command {}", &command);
3862
let mut state = 0;
@@ -89,7 +113,10 @@ impl Commands {
89113
}
90114
});
91115

92-
let handler = Arc::new(handler);
116+
let handler = Arc::new(Command {
117+
guard,
118+
ptr: Box::new(handler),
119+
});
93120

94121
if opt_lambda_state.is_some() {
95122
opt_final_states.iter().for_each(|state| {
@@ -108,34 +135,60 @@ impl Commands {
108135
desc: &'static str,
109136
handler: impl Fn(Args) -> Result<()> + Send + Sync + 'static,
110137
) {
111-
let base_cmd = &cmd[1..];
138+
self.help_protected(cmd, desc, handler, |_| Ok(true));
139+
}
140+
141+
pub(crate) fn help_protected(
142+
&mut self,
143+
cmd: &'static str,
144+
desc: &'static str,
145+
handler: impl Fn(Args) -> Result<()> + Send + Sync + 'static,
146+
guard: GuardFn,
147+
) {
112148
info!("Adding command ?help {}", &base_cmd);
149+
let base_cmd = &cmd[1..];
113150
let mut state = 0;
114151

115-
self.menu.insert(cmd, desc);
116-
state = add_help_menu(&mut self.state_machine, base_cmd, state);
152+
self.menu.as_mut().map(|menu| {
153+
menu.insert(cmd, (desc, guard));
154+
menu
155+
});
117156

157+
state = add_help_menu(&mut self.state_machine, base_cmd, state);
118158
self.state_machine.set_final_state(state);
119-
self.state_machine.set_handler(state, Arc::new(handler));
159+
self.state_machine.set_handler(
160+
state,
161+
Arc::new(Command {
162+
guard,
163+
ptr: Box::new(handler),
164+
}),
165+
);
120166
}
121167

122-
pub(crate) fn menu(&mut self) -> &HashMap<&'static str, &'static str> {
123-
&self.menu
168+
pub(crate) fn menu(&mut self) -> Option<HashMap<&'static str, (&'static str, GuardFn)>> {
169+
self.menu.take()
124170
}
125171

126172
pub(crate) fn execute<'m>(&'m self, cx: Context, msg: Message) {
127173
let message = &msg.content;
128174
if !msg.is_own(&cx) && message.starts_with(PREFIX) {
129175
self.state_machine.process(message).map(|matched| {
130-
info!("Executing command {}", message);
131176
let args = Args {
132177
http: &self.client,
133178
cx: &cx,
134179
msg: &msg,
135180
params: matched.params,
136181
};
137-
if let Err(e) = (matched.handler)(args) {
138-
println!("{}", e);
182+
info!("Checking permissions");
183+
match matched.handler.authorize(&args) {
184+
Ok(true) => {
185+
info!("Executing command {}", message);
186+
if let Err(e) = matched.handler.call(args) {
187+
println!("{}", e);
188+
}
189+
}
190+
Ok(false) => {}
191+
Err(e) => error!("{}", e),
139192
}
140193
});
141194
}
@@ -156,7 +209,7 @@ fn key_value_pair(s: &'static str) -> Option<&'static str> {
156209
.flatten()
157210
}
158211

159-
fn add_space(state_machine: &mut StateMachine, mut state: usize, i: usize) -> usize {
212+
fn add_space<T>(state_machine: &mut StateMachine<T>, mut state: usize, i: usize) -> usize {
160213
if i > 0 {
161214
let mut char_set = CharacterSet::from_char(' ');
162215
char_set.insert('\n');
@@ -167,8 +220,8 @@ fn add_space(state_machine: &mut StateMachine, mut state: usize, i: usize) -> us
167220
state
168221
}
169222

170-
fn add_help_menu(
171-
mut state_machine: &mut StateMachine,
223+
fn add_help_menu<T>(
224+
mut state_machine: &mut StateMachine<T>,
172225
cmd: &'static str,
173226
mut state: usize,
174227
) -> usize {
@@ -183,8 +236,8 @@ fn add_help_menu(
183236
state
184237
}
185238

186-
fn add_dynamic_segment(
187-
state_machine: &mut StateMachine,
239+
fn add_dynamic_segment<T>(
240+
state_machine: &mut StateMachine<T>,
188241
name: &'static str,
189242
mut state: usize,
190243
) -> usize {
@@ -198,8 +251,8 @@ fn add_dynamic_segment(
198251
state
199252
}
200253

201-
fn add_remaining_segment(
202-
state_machine: &mut StateMachine,
254+
fn add_remaining_segment<T>(
255+
state_machine: &mut StateMachine<T>,
203256
name: &'static str,
204257
mut state: usize,
205258
) -> usize {
@@ -212,8 +265,8 @@ fn add_remaining_segment(
212265
state
213266
}
214267

215-
fn add_code_segment_multi_line(
216-
state_machine: &mut StateMachine,
268+
fn add_code_segment_multi_line<T>(
269+
state_machine: &mut StateMachine<T>,
217270
name: &'static str,
218271
mut state: usize,
219272
) -> usize {
@@ -246,8 +299,8 @@ fn add_code_segment_multi_line(
246299
state
247300
}
248301

249-
fn add_code_segment_single_line(
250-
state_machine: &mut StateMachine,
302+
fn add_code_segment_single_line<T>(
303+
state_machine: &mut StateMachine<T>,
251304
name: &'static str,
252305
mut state: usize,
253306
n_backticks: usize,
@@ -266,7 +319,11 @@ fn add_code_segment_single_line(
266319
state
267320
}
268321

269-
fn add_key_value(state_machine: &mut StateMachine, name: &'static str, mut state: usize) -> usize {
322+
fn add_key_value<T>(
323+
state_machine: &mut StateMachine<T>,
324+
name: &'static str,
325+
mut state: usize,
326+
) -> usize {
270327
name.chars().for_each(|c| {
271328
state = state_machine.add(state, CharacterSet::from_char(c));
272329
});

0 commit comments

Comments
 (0)