From 25a507f7eee8d9b0e19952af1b0a954b823b0c29 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Thu, 16 Oct 2025 12:16:26 -0400 Subject: [PATCH 1/2] moved generation logic into common --- LLMlean/API/Common.lean | 96 +++++++++++++++++ LLMlean/API/ProofGen.lean | 204 +++---------------------------------- LLMlean/API/TacticGen.lean | 196 ++--------------------------------- 3 files changed, 118 insertions(+), 378 deletions(-) diff --git a/LLMlean/API/Common.lean b/LLMlean/API/Common.lean index f8fb496..0d6611a 100644 --- a/LLMlean/API/Common.lean +++ b/LLMlean/API/Common.lean @@ -270,4 +270,100 @@ def getMarkdownLeanCodeBlocks (markdown : String) : List String := Id.run do blocks := blocks ++ [part.headD ""] return blocks + +/-- +Parses a proof out of a response from the LLM. +The proof is expected to be enclosed in `[PROOF]...[/PROOF]` tags. +-/ +def splitProof (text : String) : String := + let text := ((text.splitOn "[PROOF]").tailD [text]).headD text + match (text.splitOn "[/PROOF]").head? with + | some s => s.trim + | none => text.trim + +def generateOpenAI (prompts : List String) +(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do + let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity + for prompt in prompts do + let req : OpenAIGenerationRequest := { + model := api.model, + messages := [ + { + role := "user", + content := prompt + } + ], + n := options.numSamples, + temperature := options.temperature, + max_tokens := options.maxTokens, + stop := options.stopSequences + } + let res : OpenAIResponse ← post req api.baseUrl api.key + for result in res.choices do + results := results.insert result.message.content + + let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) + return finalResults + +def generateAnthropic (prompts : List String) +(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do + let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity + for prompt in prompts do + for i in List.range options.numSamples do + let temperature := if i == 1 then 0.0 else options.temperature + let req : AnthropicGenerationRequest := { + model := api.model, + messages := [ + { + role := "user", + content := prompt + } + ], + temperature := temperature, + max_tokens := options.maxTokens, + stop_sequences := options.stopSequences + } + let res : AnthropicResponse ← post req api.baseUrl api.key + for result in res.content do + results := results.insert result.text + + let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) + return finalResults + +def generateOllama (prompts : List String) +(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do + let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity + for prompt in prompts do + for i in List.range options.numSamples do + let temperature := if i == 1 then 0.0 else options.temperature + let req : OllamaGenerationRequest := { + model := api.model, + prompt := prompt, + stream := false, + options := { + temperature := temperature, + stop := options.stopSequences, + num_predict := options.maxTokens + } + } + let res : OllamaResponse ← post req api.baseUrl api.key + results := results.insert res.response + + return results.toArray.map (fun x => (x, 1.0)) + +/-! +## Main Handler +-/ +def Config.API.generate + (api : API) (prompts : List String) (options : ChatGenerationOptions): CoreM $ Array (String × Float) := do + match api.kind with + | APIKind.Ollama => + generateOllama prompts api options + | APIKind.TogetherAI => + generateOpenAI prompts api options + | APIKind.OpenAI => + generateOpenAI prompts api options + | APIKind.Anthropic => + generateAnthropic prompts api options + end LLMlean diff --git a/LLMlean/API/ProofGen.lean b/LLMlean/API/ProofGen.lean index 808d1a0..778de20 100644 --- a/LLMlean/API/ProofGen.lean +++ b/LLMlean/API/ProofGen.lean @@ -6,108 +6,17 @@ open Lean LLMlean.Config namespace LLMlean -/-- -Parses a proof out of a response from the LLM. -The proof is expected to be enclosed in `[PROOF]...[/PROOF]` tags. --/ -def splitProof (text : String) : String := - let text := ((text.splitOn "[PROOF]").tailD [text]).headD text - match (text.splitOn "[/PROOF]").head? with - | some s => s.trim - | none => text.trim - -/-! -## OpenAI --/ -def parseResponseQedOpenAI (res: OpenAIResponse) : Array String := - (res.choices.map fun x => (splitProof x.message.content)).toArray - -def qedOpenAI (prompts : List String) -(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - let req : OpenAIGenerationRequest := { - model := api.model, - messages := [ - { - role := "user", - content := prompt - } - ], - n := options.numSamples, - temperature := options.temperature, - max_tokens := options.maxTokens, - stop := options.stopSequences - } - let res : OpenAIResponse ← post req api.baseUrl api.key - for result in (parseResponseQedOpenAI res) do - results := results.insert result - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - -/-! -## Anthropic --/ -def parseResponseQedAnthropic (res: AnthropicResponse) : Array String := - (res.content.map fun x => (splitProof x.text)).toArray +open API -def qedAnthropic (prompts : List String) -(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - for i in List.range options.numSamples do - let temperature := if i == 1 then 0.0 else options.temperature - let req : AnthropicGenerationRequest := { - model := api.model, - messages := [ - { - role := "user", - content := prompt - } - ], - temperature := temperature, - max_tokens := options.maxTokens, - stop_sequences := options.stopSequences - } - let res : AnthropicResponse ← post req api.baseUrl api.key - for result in (parseResponseQedAnthropic res) do - results := results.insert result - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - -/-! -## Ollama --/ -def parseResponseQedOllama (res: OllamaResponse) : String := - splitProof res.response - -def qedOllama (prompts : List String) -(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - for i in List.range options.numSamples do - let temperature := if i == 1 then 0.0 else options.temperature - let req : OllamaGenerationRequest := { - model := api.model, - prompt := prompt, - stream := false, - options := { - temperature := temperature, - stop := options.stopSequences, - num_predict := options.maxTokens - } - } - let res : OllamaResponse ← post req api.baseUrl api.key - results := results.insert (parseResponseQedOllama res) - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - -/-! -## Ollama with markdown output (e.g., Kimina Prover) +/-- +Generates proof completions using the LLM API. -/ +def LLMlean.Config.API.proofCompletion + (api : API) (tacticState : String) (context : String) : CoreM $ Array (String × Float) := do + let prompts := makeQedPrompts api.promptKind context tacticState + let options ← getChatGenerationOptions api TacticKind.LLMQed + let responses := (← generate api prompts options).map (fun (x, p) => (splitProof x, p)) + return responses.filter (fun (x, _) => filterGeneration x) /-- Extracts proof from markdown response by finding the last code block @@ -138,80 +47,6 @@ def extractProofFromMarkdownResponse (context : String) (response : String) : Op -- If we can't find the context, return the whole block some lastBlock.trim -def qedOllamaMarkdown (prompts : List String) (context : String) -(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - for i in List.range options.numSamples do - let temperature := if i == 1 then 0.0 else options.temperature - let req : OllamaGenerationRequest := { - model := api.model, - prompt := prompt, - stream := false, - options := { - temperature := temperature, - stop := options.stopSequences, - num_predict := options.maxTokens - } - } - let res : OllamaResponse ← post req api.baseUrl api.key - match extractProofFromMarkdownResponse context res.response with - | some proof => results := results.insert proof - | none => results := results - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - -/-! -## Ollama with tactic output (e.g., BFS-Prover) --/ -def qedOllamaTactic (prompts : List String) -(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - for i in List.range options.numSamples do - let temperature := if i == 1 then 0.0 else options.temperature - let req : OllamaGenerationRequest := { - model := api.model, - prompt := prompt, - stream := false, - options := { - temperature := temperature, - num_predict := options.maxTokens, - stop := options.stopSequences - } - } - let res : OllamaResponse ← post req api.baseUrl api.key - let tactic := res.response - results := results.insert tactic - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - -/-! -## Main Handler --/ - -/-- -Generates proof completions using the LLM API. --/ -def LLMlean.Config.API.proofCompletion - (api : API) (tacticState : String) (context : String) : CoreM $ Array (String × Float) := do - let prompts := makeQedPrompts api.promptKind context tacticState - let options ← getChatGenerationOptionsQed api TacticKind.LLMQed - match api.kind with - | APIKind.Ollama => - match api.responseFormat with - | ResponseFormat.Markdown => - qedOllamaMarkdown prompts context api options - | _ => - qedOllama prompts api options - | APIKind.TogetherAI => - qedOpenAI prompts api options - | APIKind.OpenAI => - qedOpenAI prompts api options - | APIKind.Anthropic => - qedAnthropic prompts api options /-- Generates proof completions with refinement context using the LLM API. @@ -220,19 +55,12 @@ def LLMlean.Config.API.proofCompletionRefinement (api : API) (tacticState : String) (context : String) (previousAttempt : String) (errorMsg : String) : CoreM $ Array (String × Float) := do let prompts := makeQedRefinementPrompts api.promptKind context tacticState previousAttempt errorMsg - let options ← getChatGenerationOptionsQed api TacticKind.LLMQed - match api.kind with - | APIKind.Ollama => - match api.responseFormat with - | ResponseFormat.Markdown => - qedOllamaMarkdown prompts context api options - | _ => - qedOllama prompts api options - | APIKind.TogetherAI => - qedOpenAI prompts api options - | APIKind.OpenAI => - qedOpenAI prompts api options - | APIKind.Anthropic => - qedAnthropic prompts api options + let options ← getChatGenerationOptions api TacticKind.LLMQed + let responses ← generate api prompts options + return Std.HashMap.toArray (responses.foldl (fun results (response, prob) => + match extractProofFromMarkdownResponse context response with + | some proof => results.insert proof prob + | none => results + ) {}) end LLMlean diff --git a/LLMlean/API/TacticGen.lean b/LLMlean/API/TacticGen.lean index 39b0836..3a714f4 100644 --- a/LLMlean/API/TacticGen.lean +++ b/LLMlean/API/TacticGen.lean @@ -16,181 +16,7 @@ def splitTac (text : String) : String := | some s => s.trim | none => text.trim -/-! -## Open AI --/ -def parseTacticResponseOpenAI (res: OpenAIResponse) (pfx : String) : Array String := - (res.choices.map fun x => pfx ++ (splitTac x.message.content)).toArray - -def tacticGenerationOpenAI (pfx : String) (prompts : List String) -(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - let req : OpenAIGenerationRequest := { - model := api.model, - messages := [ - { - role := "user", - content := prompt - } - ], - n := options.numSamples, - temperature := options.temperature, - max_tokens := options.maxTokens, - stop := options.stopSequences - } - let res : OpenAIResponse ← post req api.baseUrl api.key - for result in (parseTacticResponseOpenAI res pfx) do - results := results.insert result - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - - -/-! -## Anthropic --/ -def parseTacticResponseAnthropic (res: AnthropicResponse) (pfx : String) : Array String := - (res.content.map fun x => pfx ++ (splitTac x.text)).toArray - -def tacticGenerationAnthropic (pfx : String) (prompts : List String) -(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - for i in List.range options.numSamples do - let temperature := if i == 1 then 0.0 else options.temperature - let req : AnthropicGenerationRequest := { - model := api.model, - messages := [ - { - role := "user", - content := prompt - } - ], - temperature := temperature, - max_tokens := options.maxTokens, - stop_sequences := options.stopSequences - } - let res : AnthropicResponse ← post req api.baseUrl api.key - for result in (parseTacticResponseAnthropic res pfx) do - results := results.insert result - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - -/-! -## Ollama --/ -def parseResponseOllama (res: OllamaResponse) : String := - splitTac res.response - -def tacticGenerationOllama (pfx : String) (prompts : List String) -(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - for i in List.range options.numSamples do - let temperature := if i == 1 then 0.0 else options.temperature - let req : OllamaGenerationRequest := { - model := api.model, - prompt := prompt, - stream := false, - options := { - temperature := temperature, - num_predict := options.maxTokens, - stop := options.stopSequences - } - } - let res : OllamaResponse ← post req api.baseUrl api.key - results := results.insert (pfx ++ (parseResponseOllama res)) - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - - -/-! -## Ollama with markdown output (e.g., Kimina-Prover) --/ - -/-- -Given a code block and a context, returns the first line of the code block after the context is written out. --/ -def getTacticFromBlockContext (context : String) (block : String) : String := Id.run do - -- Get the trimmed last nonempty nonwhitespace line of the context - let last_context := (((context.splitOn "\n").filter (fun x => x.trim.length > 0)).getLast?.getD "").trim - - -- Trim every line of the block - let block := "\n".intercalate ((block.splitOn "\n").map (fun x => x.trim)) - - let post_context := (block.splitOn last_context)[1]?.getD "" - if post_context.length > 0 then - -- get the first nonempty nonwhitespace line of the post_context - let tactic := ((post_context.splitOn "\n").filter (fun x => x.trim.length > 0)).getLast?.getD "" - return tactic.trim - else - return s!"Did not find context: \n\n{context}\n\n in \n\n{block}\n\n" - -def parseTacticResponseOllamaMarkdown (_context : String) (res: OllamaResponse) : List String := Id.run do - let blocks := getMarkdownLeanCodeBlocks res.response - let mut results : List String := [] - for block in blocks do - for line in (block.splitOn "\n") do - if line.trim.length > 0 then - results := results ++ [line.trim] - return results - -def tacticGenerationOllamaMarkdown (_pfx : String) (context : String) (prompts : List String) -(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - for i in List.range options.numSamples do - let temperature := if i == 1 then 0.0 else options.temperature - let req : OllamaGenerationRequest := { - model := api.model, - prompt := prompt, - stream := false, - options := { - temperature := temperature, - num_predict := options.maxTokens, - stop := options.stopSequences - } - } - let res : OllamaResponse ← post req api.baseUrl api.key - for tactic in (parseTacticResponseOllamaMarkdown context res) do - results := results.insert (tactic) - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - -/-! -## Ollama with tactic output (e.g., BFS-Prover) --/ -def tacticGenerationOllamaTactic (pfx : String) (prompts : List String) -(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do - let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity - for prompt in prompts do - for i in List.range options.numSamples do - let temperature := if i == 1 then 0.0 else options.temperature - let req : OllamaGenerationRequest := { - model := api.model, - prompt := prompt, - stream := false, - options := { - temperature := temperature, - num_predict := options.maxTokens, - stop := options.stopSequences - } - } - let res : OllamaResponse ← post req api.baseUrl api.key - let tactic := res.response - if tactic.startsWith pfx.trim then - results := results.insert tactic - - let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0) - return finalResults - -/-! -## Main Handler --/ +open Config API /-- Generates a list of tactics using the LLM API. @@ -200,20 +26,10 @@ def LLMlean.Config.API.tacticGeneration («prefix» : String) : CoreM $ Array (String × Float) := do let prompts := makePrompts api.promptKind context tacticState «prefix» let options ← getChatGenerationOptions api TacticKind.LLMStep - match api.kind with - | APIKind.Ollama => - match api.responseFormat with - | ResponseFormat.Markdown => - tacticGenerationOllamaMarkdown «prefix» context prompts api options - | ResponseFormat.Tactic => - tacticGenerationOllamaTactic «prefix» prompts api options - | _ => - tacticGenerationOllama «prefix» prompts api options - | APIKind.TogetherAI => - tacticGenerationOpenAI «prefix» prompts api options - | APIKind.OpenAI => - tacticGenerationOpenAI «prefix» prompts api options - | APIKind.Anthropic => - tacticGenerationAnthropic «prefix» prompts api options + let mut results : Std.HashMap String Float := {} + for (tactic, prob) in ← generate api prompts options do + if tactic.startsWith prefix.trim then + results := results.insert («prefix» ++ splitTac tactic) prob + return results.toArray end LLMlean From 73f69290dc75d4cbf02caea2b561c57043ece952 Mon Sep 17 00:00:00 2001 From: FrederickPu Date: Thu, 16 Oct 2025 12:17:54 -0400 Subject: [PATCH 2/2] get rid of qed options --- LLMlean/API/Common.lean | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/LLMlean/API/Common.lean b/LLMlean/API/Common.lean index 0d6611a..b46548f 100644 --- a/LLMlean/API/Common.lean +++ b/LLMlean/API/Common.lean @@ -50,13 +50,6 @@ structure ChatGenerationOptions where numSamples : Nat := defaultSamples deriving ToJson, FromJson -structure ChatGenerationOptionsQed where - temperature : Float := defaultTemperature - maxTokens : Nat := defaultMaxTokens - stopSequences : List String := defaultStopProof - numSamples : Nat := defaultSamples -deriving ToJson, FromJson - structure OpenAIMessage where role : String content : String @@ -240,18 +233,6 @@ def getChatGenerationOptions (api : API) (tacticKind : TacticKind): CoreM ChatGe stopSequences := defaultStopTactic } -def getChatGenerationOptionsQed (api : API) (tacticKind : TacticKind) : CoreM ChatGenerationOptionsQed := do - let numSamples ← getNumSamples api tacticKind - let maxTokens ← getMaxTokens api tacticKind - -- Print configuration in verbose mode - printConfiguration api tacticKind numSamples maxTokens - return { - numSamples := numSamples - temperature := defaultTemperature, - maxTokens := maxTokens, - stopSequences := defaultStopProof - } - /-- Parses a string consisting of Markdown text, and extracts the Lean code blocks. The code blocks are enclosed in triple backticks.