diff --git a/crates/deno_task_shell/src/grammar.pest b/crates/deno_task_shell/src/grammar.pest index 094e058..f137f8e 100644 --- a/crates/deno_task_shell/src/grammar.pest +++ b/crates/deno_task_shell/src/grammar.pest @@ -191,9 +191,9 @@ pipeline = !{ Bang? ~ pipe_sequence } pipe_sequence = !{ command ~ ((StdoutStderr | Stdout) ~ linebreak ~ pipe_sequence)? } command = !{ + function_definition | compound_command ~ redirect_list? | - simple_command | - function_definition + simple_command } compound_command = { @@ -378,10 +378,14 @@ binary_posix_conditional_op = !{ while_clause = !{ While ~ conditional_expression ~ do_group } until_clause = !{ Until ~ conditional_expression ~ do_group } -function_definition = !{ fname ~ "(" ~ ")" ~ linebreak ~ function_body } -function_body = !{ compound_command ~ redirect_list? } +function_definition = { + fname ~ "(" ~ ")" ~ linebreak ~ function_body | + "function" ~ fname ~ ("(" ~ ")")? ~ linebreak ~ function_body +} + +function_body = !{ Lbrace ~ compound_list ~ Rbrace } -fname = @{ RESERVED_WORD | NAME | ASSIGNMENT_WORD | UNQUOTED_PENDING_WORD } +fname = @{ NAME } name = @{ NAME } brace_group = !{ Lbrace ~ compound_list ~ Rbrace } diff --git a/crates/deno_task_shell/src/parser.rs b/crates/deno_task_shell/src/parser.rs index 6c41ab5..e4890b5 100644 --- a/crates/deno_task_shell/src/parser.rs +++ b/crates/deno_task_shell/src/parser.rs @@ -149,6 +149,15 @@ pub struct Command { pub redirect: Option, } +#[cfg_attr(feature = "serialization", derive(serde::Serialize))] +#[cfg_attr(feature = "serialization", serde(rename_all = "camelCase"))] +#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[error("Invalid function")] +pub struct Function { + pub name: String, + pub body: SequentialList, +} + #[cfg_attr(feature = "serialization", derive(serde::Serialize))] #[cfg_attr( feature = "serialization", @@ -170,6 +179,8 @@ pub enum CommandInner { Case(CaseClause), #[error("Invalid arithmetic expression")] ArithmeticExpression(Arithmetic), + #[error("Invalid function definition")] + FunctionType(Function), } impl From for Sequence { @@ -910,13 +921,70 @@ fn parse_command(pair: Pair) -> Result { match inner.as_rule() { Rule::simple_command => parse_simple_command(inner), Rule::compound_command => parse_compound_command(inner), - Rule::function_definition => { - Err(miette!("Function definitions are not supported yet")) - } + Rule::function_definition => parse_function_definition(inner), _ => Err(miette!("Unexpected rule in command: {:?}", inner.as_rule())), } } +fn parse_function_definition(pair: Pair) -> Result { + let mut inner = pair.into_inner(); + + // Handle both styles: + // 1. name() { body } + // 2. function name { body } or function name() { body } + let (name, body_pair) = if inner.peek().unwrap().as_rule() == Rule::fname { + // Style 1: name() { body } + let name = inner.next().unwrap().as_str().to_string(); + // Skip the () part + if inner.peek().is_some() { + let next = inner.peek().unwrap(); + if next.as_str() == "(" || next.as_str() == ")" { + inner.next(); // skip ( + inner.next(); // skip ) + } + } + (name, inner.next().unwrap()) + } else { + // Style 2: function name [()] { body } + // Skip "function" keyword + inner.next(); + let name = inner.next().unwrap().as_str().to_string(); + // Skip optional () + if inner.peek().is_some() { + let next = inner.peek().unwrap(); + if next.as_str() == "(" || next.as_str() == ")" { + inner.next(); // skip ( + inner.next(); // skip ) + } + } + (name, inner.next().unwrap()) + }; + + // Parse the function body + let mut body_inner = body_pair.into_inner(); + // Skip Lbrace + if let Some(lbrace) = body_inner.next() { + if lbrace.as_str() != "{" { + return Err(miette!("Expected Lbrace to start function body")); + } + } + // Parse the actual compound_list + let compound_list = body_inner + .next() + .ok_or_else(|| miette!("Expected compound list in function body"))?; + let mut body_items = Vec::new(); + + parse_compound_list(compound_list, &mut body_items)?; + + Ok(Command { + inner: CommandInner::FunctionType(Function { + name, + body: SequentialList { items: body_items }, + }), + redirect: None, + }) +} + fn parse_simple_command(pair: Pair) -> Result { let mut env_vars = Vec::new(); let mut args = Vec::new(); diff --git a/crates/deno_task_shell/src/shell/command.rs b/crates/deno_task_shell/src/shell/command.rs index bac6a99..5261013 100644 --- a/crates/deno_task_shell/src/shell/command.rs +++ b/crates/deno_task_shell/src/shell/command.rs @@ -205,6 +205,7 @@ async fn parse_shebang_args( CommandInner::While(_) => return err_unsupported(text), CommandInner::ArithmeticExpression(_) => return err_unsupported(text), CommandInner::Case(_) => return err_unsupported(text), + CommandInner::FunctionType(_) => return err_unsupported(text), }; if !cmd.env_vars.is_empty() { return err_unsupported(text); diff --git a/crates/deno_task_shell/src/shell/execute.rs b/crates/deno_task_shell/src/shell/execute.rs index ed790a0..da98201 100644 --- a/crates/deno_task_shell/src/shell/execute.rs +++ b/crates/deno_task_shell/src/shell/execute.rs @@ -667,6 +667,13 @@ async fn execute_command( } } } + CommandInner::FunctionType(function) => { + changes.push(EnvChange::AddFunction( + function.name.clone(), + std::sync::Arc::new(function.clone()), + )); + ExecuteResult::Continue(0, changes, Vec::new()) + } } } @@ -1521,6 +1528,47 @@ async fn execute_simple_command( } }; + if !args.is_empty() { + let command_name = &args[0]; + if let Some(body) = state.get_function(command_name).cloned() { + // Set $0 to function name and $1, $2, etc. to arguments + let mut function_changes = vec![EnvChange::SetShellVar( + "0".to_string(), + command_name.clone().to_string(), + )]; + for (i, arg) in args.iter().skip(1).enumerate() { + function_changes.push(EnvChange::SetShellVar( + (i + 1).to_string(), + arg.clone().to_string(), + )); + } + + state.apply_changes(&function_changes); + changes.extend(function_changes); + + let result = execute_sequential_list( + body.body.clone(), + state.clone(), + stdin, + stdout, + stderr, + AsyncCommandBehavior::Yield, + ) + .await; + + match result { + ExecuteResult::Exit(code, env_changes, handles) => { + changes.extend(env_changes); + return ExecuteResult::Exit(code, changes, handles); + } + ExecuteResult::Continue(code, env_changes, handles) => { + changes.extend(env_changes); + return ExecuteResult::Continue(code, changes, handles); + } + } + } + } + let mut state = state.clone(); for env_var in command.env_vars { let word_result = evaluate_word( diff --git a/crates/deno_task_shell/src/shell/types.rs b/crates/deno_task_shell/src/shell/types.rs index aa4eb95..cfa90fb 100644 --- a/crates/deno_task_shell/src/shell/types.rs +++ b/crates/deno_task_shell/src/shell/types.rs @@ -13,6 +13,7 @@ use std::path::Path; use std::path::PathBuf; use std::rc::Rc; use std::str::FromStr; +use std::sync::Arc; use futures::future::LocalBoxFuture; use miette::Error; @@ -25,6 +26,7 @@ use crate::shell::fs_util; use super::commands::builtin_commands; use super::commands::ShellCommand; +use crate::parser::Function; #[derive(Clone)] pub struct ShellState { @@ -52,6 +54,7 @@ pub struct ShellState { last_command_exit_code: i32, // Exit code of the last command // The shell options to be modified using `set` command shell_options: HashMap, + pub functions: HashMap, } #[allow(clippy::print_stdout)] @@ -92,6 +95,7 @@ impl ShellState { map.insert(ShellOptions::ExitOnError, true); map }, + functions: HashMap::new(), }; // ensure the data is normalized for (name, value) in env_vars { @@ -293,6 +297,9 @@ impl ShellState { EnvChange::SetShellOptions(option, value) => { self.set_shell_option(*option, *value); } + EnvChange::AddFunction(name, func) => { + self.add_function(name.clone(), (**func).clone()); + } } } @@ -353,9 +360,17 @@ impl ShellState { pub fn reset_cancellation_token(&mut self) { self.token = CancellationToken::default(); } + + pub fn add_function(&mut self, name: String, func: Function) { + self.functions.insert(name, func); + } + + pub fn get_function(&self, name: &str) -> Option<&Function> { + return self.functions.get(name); + } } -#[derive(Debug, PartialEq, Eq, Clone, PartialOrd)] +#[derive(Debug, Clone)] pub enum EnvChange { /// `export ENV_VAR=VALUE` SetEnvVar(String, String), @@ -371,6 +386,77 @@ pub enum EnvChange { Cd(PathBuf), /// `set -ex` SetShellOptions(ShellOptions, bool), + /// Add a user-defined function + AddFunction(String, Arc), +} + +// Manual implementations for PartialEq and PartialOrd to handle Arc +impl PartialEq for EnvChange { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (EnvChange::SetEnvVar(a1, b1), EnvChange::SetEnvVar(a2, b2)) => { + a1 == a2 && b1 == b2 + } + ( + EnvChange::SetShellVar(a1, b1), + EnvChange::SetShellVar(a2, b2), + ) => a1 == a2 && b1 == b2, + ( + EnvChange::AliasCommand(a1, b1), + EnvChange::AliasCommand(a2, b2), + ) => a1 == a2 && b1 == b2, + (EnvChange::UnAliasCommand(a1), EnvChange::UnAliasCommand(a2)) => { + a1 == a2 + } + (EnvChange::UnsetVar(a1), EnvChange::UnsetVar(a2)) => a1 == a2, + (EnvChange::Cd(a1), EnvChange::Cd(a2)) => a1 == a2, + ( + EnvChange::SetShellOptions(a1, b1), + EnvChange::SetShellOptions(a2, b2), + ) => a1 == a2 && b1 == b2, + ( + EnvChange::AddFunction(a1, b1), + EnvChange::AddFunction(a2, b2), + ) => a1 == a2 && **b1 == **b2, + _ => false, + } + } +} + +impl Eq for EnvChange {} + +impl PartialOrd for EnvChange { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for EnvChange { + fn cmp(&self, other: &Self) -> Ordering { + // Simple ordering based on variant - using discriminant comparison + use EnvChange::*; + let self_idx = match self { + SetEnvVar(..) => 0, + SetShellVar(..) => 1, + AliasCommand(..) => 2, + UnAliasCommand(..) => 3, + UnsetVar(..) => 4, + Cd(..) => 5, + SetShellOptions(..) => 6, + AddFunction(..) => 7, + }; + let other_idx = match other { + SetEnvVar(..) => 0, + SetShellVar(..) => 1, + AliasCommand(..) => 2, + UnAliasCommand(..) => 3, + UnsetVar(..) => 4, + Cd(..) => 5, + SetShellOptions(..) => 6, + AddFunction(..) => 7, + }; + self_idx.cmp(&other_idx) + } } #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug, PartialOrd)] diff --git a/crates/shell/src/commands/which.rs b/crates/shell/src/commands/which.rs index 6ddb114..310d9ab 100644 --- a/crates/shell/src/commands/which.rs +++ b/crates/shell/src/commands/which.rs @@ -28,6 +28,11 @@ fn execute_which(context: &mut ShellCommandContext) -> Result<(), i32> { return Ok(()); } + if context.state.get_function(arg).is_some() { + context.stdout.write_line("").ok(); + return Ok(()); + } + if context.state.resolve_custom_command(arg).is_some() { context.stdout.write_line("").ok(); return Ok(()); diff --git a/crates/tests/src/lib.rs b/crates/tests/src/lib.rs index 3123f38..3f388e9 100644 --- a/crates/tests/src/lib.rs +++ b/crates/tests/src/lib.rs @@ -1463,6 +1463,116 @@ async fn test_set() { .await; } +#[tokio::test] +async fn functions() { + // Basic function definition and call + TestBuilder::new() + .command( + r#" +greet() { + echo "Hello, World!" +} +greet +"#, + ) + .assert_stdout("Hello, World!\n") + .run() + .await; + + // Function with parameters + TestBuilder::new() + .command( + r#" +show_params() { + echo "First: $1" + echo "Second: $2" +} +show_params "foo" "bar" +"#, + ) + .assert_stdout("First: foo\nSecond: bar\n") + .run() + .await; + + // Function with 'function' keyword + TestBuilder::new() + .command( + r#" +function my_function { + echo "Using function keyword" +} +my_function +"#, + ) + .assert_stdout("Using function keyword\n") + .run() + .await; + + // Function with 'function' keyword and parentheses + TestBuilder::new() + .command( + r#" +function another_function() { + echo "Function with parens" +} +another_function +"#, + ) + .assert_stdout("Function with parens\n") + .run() + .await; + + // Multiple functions in sequence + TestBuilder::new() + .command( + r#" +myfunc1() { + echo "First function" +} +myfunc2() { + echo "Second function" +} +myfunc1 +myfunc2 +"#, + ) + .assert_stdout("First function\nSecond function\n") + .run() + .await; + + // Function overriding + TestBuilder::new() + .command( + r#" +test_override() { + echo "First version" +} +test_override +test_override() { + echo "Second version" +} +test_override +"#, + ) + .assert_stdout("First version\nSecond version\n") + .run() + .await; + + // Test 'which' command with functions + TestBuilder::new() + .command( + r#" +myfunc() { + echo "test" +} +which myfunc +"#, + ) + .assert_stdout("\n") + .run() + .await; +} + #[tokio::test] async fn test_reserved_substring() { // Test that there is no panic (prefix-dev/shell#256) diff --git a/crates/tests/test-data/functions.sh b/crates/tests/test-data/functions.sh new file mode 100644 index 0000000..61d0d70 --- /dev/null +++ b/crates/tests/test-data/functions.sh @@ -0,0 +1,164 @@ +# Test basic function definition and call +> greet() { +> echo "Hello, World!" +> } +> greet +Hello, World! + +# Test function with parameters +> show_params() { +> echo "First: $1" +> echo "Second: $2" +> } +> show_params "foo" "bar" +First: foo +Second: bar + +# Test function with multiple parameters +> add_values() { +> echo "A=$1, B=$2, C=$3" +> } +> add_values "10" "20" "30" +A=10, B=20, C=30 + +# Test function keyword syntax +> function my_function { +> echo "Using function keyword" +> } +> my_function +Using function keyword + +# Test function keyword with parentheses +> function another_function() { +> echo "Function with parens" +> } +> another_function +Function with parens + +# Test multiple functions in sequence +> myfunc1() { +> echo "First function" +> } +> myfunc2() { +> echo "Second function" +> } +> myfunc1 +> myfunc2 +First function +Second function + +# Test function overriding +> test_override() { +> echo "First version" +> } +> test_override +> test_override() { +> echo "Second version" +> } +> test_override +First version +Second version + +# Test function with multiple commands +> multi_cmd() { +> echo "Line 1" +> echo "Line 2" +> echo "Line 3" +> } +> multi_cmd +Line 1 +Line 2 +Line 3 + +# Test function with variable expansion +> greet_name() { +> echo "Hello, $1!" +> echo "Welcome, $1" +> } +> greet_name "Alice" +Hello, Alice! +Welcome, Alice + +# Test function with empty parameters +> test_empty() { +> echo "A=$1" +> echo "B=$2" +> } +> test_empty +A= +B= + +# Test 'which' command with functions +> myfunc() { +> echo "test" +> } +> which myfunc + + +# Test function with exported variable +> export_test() { +> export MY_VAR="exported" +> echo "Set MY_VAR" +> } +> export_test +> echo $MY_VAR +Set MY_VAR +exported + +# Test function with local variable scope (basic) +> set_local() { +> FOO="local value" +> echo "Inside: $FOO" +> } +> FOO="global" +> set_local +> echo "Outside: $FOO" +Inside: local value +Outside: local value + +# Test function calling built-in commands +> use_builtins() { +> pwd > /dev/null +> echo "Working with builtins" +> } +> use_builtins +Working with builtins + +# Test function with command substitution in parameters +> echo_twice() { +> echo "$1 $1" +> } +> echo_twice "$(echo 'hello')" +hello hello + +# Real-world example: Simple logger function +> log() { +> echo "[LOG] $1" +> } +> log "Application started" +> log "Processing data" +[LOG] Application started +[LOG] Processing data + +# Real-world example: Error handling wrapper +> run_with_msg() { +> echo "Running: $1" +> echo "Status: $2" +> } +> run_with_msg "backup.sh" "success" +Running: backup.sh +Status: success + +# Real-world example: Path manipulation +> make_path() { +> echo "$1/$2" +> } +> make_path "/home/user" "documents" +/home/user/documents + +# Test function name that's not a reserved word +> my_custom_func() { +> echo "Custom function works" +> } +> my_custom_func +Custom function works