Skip to content

Commit bbcffd9

Browse files
committed
feat: implement new functions and update API
1 parent ae84bc1 commit bbcffd9

File tree

6 files changed

+62
-23
lines changed

6 files changed

+62
-23
lines changed

examples/pi.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use auto_rust::auto_implement;
22

33
#[auto_implement]
4-
/// Algoritmo basdo en trignometria para calcular el valor de pi
54
fn calculate_pi_with_n_iterations(n: u64) -> f64 {
65
todo!()
76
}

examples/quine.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use auto_rust::auto_implement;
2+
3+
#[auto_implement]
4+
/// Only return the string literal of the quine
5+
/// Remember to use `to_string()` to convert it to a string
6+
fn calculate_rust_quine() -> String {
7+
todo!()
8+
}
9+
10+
fn main() {
11+
let quine = calculate_rust_quine().to_string();
12+
13+
print!("{}", quine)
14+
}

examples/with_context.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
use auto_rust::auto_implement;
2+
3+
#[auto_implement(
4+
context = "This algorithm works by calculating the area of a circle using the trigonometric formula for the area of a triangle. Don't
5+
use powi, use powf instead. Don't use the sqrt function, use the hypot function instead. Don't use the sin function, use the sin_cos function instead. Don't use the cos function, use the sin_cos function instead"
6+
)]
7+
fn calculate_pi_with_n_iterations(n: u64) -> f64 {
8+
todo!()
9+
}
10+
11+
fn main() {
12+
let result = calculate_pi_with_n_iterations(100_000);
13+
println!("pi: {}", result);
14+
// assert_eq!(result, 3.0418396189294032);
15+
}

src/api.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ pub fn open_ai_chat_completions(
88
system_message: String,
99
user_message: String,
1010
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error>> {
11-
let open_ai_key = env::var("OPENAI_API_KEY").unwrap_or("".to_string());
11+
let open_ai_key = env::var("OPENAI_API_KEY").unwrap();
12+
let model_name = env::var("OPENAI_MODEL_NAME").unwrap_or("gpt-3.5-turbo".to_string());
1213

1314
let mut headers = header::HeaderMap::new();
1415

@@ -23,23 +24,20 @@ pub fn open_ai_chat_completions(
2324
.build()
2425
.unwrap();
2526

27+
let body = json!({
28+
"model": model_name,
29+
"messages": [
30+
{"role": "system", "content": system_message},
31+
{"role": "user", "content": user_message}
32+
]
33+
});
34+
2635
let res = client
2736
.post("https://api.openai.com/v1/chat/completions")
2837
.headers(headers)
29-
.body(
30-
json!({
31-
"model": "gpt-3.5-turbo",
32-
"messages": [
33-
{"role": "system", "content": system_message},
34-
{"role": "user", "content": user_message}
35-
]
36-
})
37-
.to_string(),
38-
)
38+
.body(body.to_string())
3939
.send()?
4040
.json::<ChatCompletionResponse>()?;
4141

42-
// println!("{:?}", res);
43-
4442
Ok(res)
4543
}

src/generator.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use std::error::Error;
22

33
use crate::api::open_ai_chat_completions;
44

5-
pub fn generate_body_function_from_head(head: String) -> Result<String, Box<dyn Error>> {
5+
pub fn generate_body_function_from_head(
6+
head: String,
7+
extra_context: Option<String>,
8+
) -> Result<String, Box<dyn Error>> {
69
let system_message = "You are an AI code assistant trained on the GPT-4 architecture. Your task is to generate Rust function body implementations based only on the provided function signatures. When the user provides a function signature using the command '/complete', your response must be the plain text function body, without any explanations, formatting, or code blocks. Do not include the function signature, function name, or any other information in your response. Triple backticks (```) and function signatures are strictly prohibited in your response. Responding with any prohibited content will result in a penalty.
710
example 1:
811
INPUT: /implement fn my_ip() -> String
@@ -16,13 +19,20 @@ pub fn generate_body_function_from_head(head: String) -> Result<String, Box<dyn
1619
ip_addr.to_string()
1720
example 2:
1821
INPUT: /implement fn hello_world() -> String
19-
OUTPUT: \"Hello World\".to_string()
22+
OUTPUT:
23+
\"Hello World\".to_string()
2024
example 3:
2125
INPUT: /implement fn hello_world(name: String) -> String
22-
OUTPUT: format!(\"Hello {}!\", name)
26+
OUTPUT:
27+
format!(\"Hello {}!\", name)
2328
".to_string();
2429

25-
let user_message = format!("/implement {}", head);
30+
let user_message = extra_context
31+
.map(|c| format!("Extra context: {}\n", c))
32+
.unwrap_or("".to_string())
33+
+ &format!("/implement {}", head);
34+
35+
// println!("User message: {}", user_message);
2636

2737
let res = open_ai_chat_completions(system_message, user_message).unwrap();
2838

src/lib.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,20 @@ pub fn implement(_item: TokenStream) -> TokenStream {
1414
// TODO: Evaluate the use of dotenv in this crate
1515
dotenv().ok();
1616

17-
let implemented_fn = generate_body_function_from_head(_item.to_string()).unwrap();
17+
let implemented_fn = generate_body_function_from_head(_item.to_string(), None).unwrap();
1818

19-
println!("{}", implemented_fn);
19+
// println!("{}", implemented_fn);
2020

2121
implemented_fn.parse().unwrap()
2222
}
2323

2424
#[proc_macro_attribute]
25-
pub fn auto_implement(_args: TokenStream, input: TokenStream) -> TokenStream {
25+
pub fn auto_implement(args: TokenStream, input: TokenStream) -> TokenStream {
2626
let ast: ItemFn = syn::parse(input).expect("Failed to parse input as a function");
2727

28-
// Search for the information within the attributes.
28+
let context = args.to_string();
29+
30+
// println!("Context: {}", context);
2931

3032
let mut prompt_input = String::new();
3133

@@ -46,11 +48,12 @@ pub fn auto_implement(_args: TokenStream, input: TokenStream) -> TokenStream {
4648

4749
dotenv().ok();
4850

49-
let implemented_fn = generate_body_function_from_head(prompt_input).unwrap();
51+
let implemented_fn = generate_body_function_from_head(prompt_input, Some(context)).unwrap();
5052

5153
// println!("\n{}\n", implemented_fn);
5254

5355
// #[allow(long_running_const_eval)]
56+
5457
// loop {
5558
// let mut line = String::new();
5659
// let b1 = std::io::stdin().read_line(&mut line).unwrap();

0 commit comments

Comments
 (0)