From 2ab9277c025683ba936d238c5e26aec9f0516efa Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Thu, 15 May 2025 15:17:19 +0000 Subject: [PATCH 1/2] Added misinfo verification web-retrieval agent eval. --- README.md | 10 ++ misinfo_data_eval/entrypoint.py | 51 +++++- misinfo_data_eval/generation_utils.py | 122 ++++++++++++- misinfo_data_eval/tasks/keyword_analysis.py | 42 +++-- .../tasks/temporal_correlation.py | 1 - misinfo_data_eval/tasks/web_search.py | 169 ++++++++++++++++++ pyproject.toml | 3 + 7 files changed, 378 insertions(+), 20 deletions(-) create mode 100644 misinfo_data_eval/tasks/web_search.py diff --git a/README.md b/README.md index 0076004..bdc571a 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,16 @@ Evaluating Feasibility: 100%|████████████████| 2 "feasible, no search required": 4 } ``` +## Run LLM Web Retrieval Accuracy Analysis + +```bash +source .env && \ +uv run -m misinfo_data_eval.entrypoint \ +--source_dataset_path hf://ComplexDataLab/Misinfo_Datasets@ce06269:liar_new:test \ +--evaluate_factuality \ +--max_concurrency 32 \ +--limit 72 +``` ## Run Evaluation on a dataset where tweet_id is available diff --git a/misinfo_data_eval/entrypoint.py b/misinfo_data_eval/entrypoint.py index 93162b5..92a284e 100644 --- a/misinfo_data_eval/entrypoint.py +++ b/misinfo_data_eval/entrypoint.py @@ -7,9 +7,11 @@ parser = argparse.ArgumentParser() parser.add_argument("--evaluator_model_name") +parser.add_argument("--web_search_model_name") parser.add_argument("--source_dataset_path", required=True, help=DATA_INSTRUCTIONS) parser.add_argument("--max_concurrency", type=int, default=1) parser.add_argument("--evaluate_feasibility", action="store_true", default=False) +parser.add_argument("--evaluate_factuality", action="store_true", default=False) parser.add_argument( "--evaluate_temporal_correlation", action="store_true", default=False ) @@ -21,7 +23,7 @@ async def main(): args = parser.parse_args() - dataset = load_data(args.source_dataset_path) + dataset = list(load_data(args.source_dataset_path)) print("len(dataset):", len(dataset)) # Feasibility Evaluation @@ -55,6 +57,53 @@ async def main(): # Cache previous generations if interrupted. cache.write() + if args.evaluate_factuality: + from .generation_utils import AsyncElasticsearchCache + from .tasks.web_search import batch_evaluate + + async_semaphore = asyncio.Semaphore(args.max_concurrency) + es_cache = await AsyncElasticsearchCache.maybe_from_env_var( + "cache_misinfo_eval_web_retrieval" + ) + assert es_cache is not None + + data_rows = [ + _row + for _row in dataset[: args.limit] + if _row["veracity"] in ("true", "false") + ] + eval_output = await batch_evaluate( + data_rows, + cache=es_cache, + async_semaphore=async_semaphore, + total=len(data_rows), + ) + await es_cache.close() + + num_total_valid = 0 + num_total_correct = 0 + + for output in eval_output: + if output.is_correct is not None: + num_total_valid += 1 + num_total_correct += output.is_correct + + print( + json.dumps( + { + "num_total": len(eval_output), + "num_valid": num_total_valid, + "num_corect": num_total_correct, + "accuracy": ( + num_total_correct / num_total_valid + if num_total_valid > 0 + else None + ), + }, + indent=2, + ) + ) + # Temporal Correlation Evaluation, if "tweet_id" data is available if args.evaluate_temporal_correlation: from .tasks.temporal_correlation import evaluate_temporal_correlations diff --git a/misinfo_data_eval/generation_utils.py b/misinfo_data_eval/generation_utils.py index 9358bd2..6bb9926 100644 --- a/misinfo_data_eval/generation_utils.py +++ b/misinfo_data_eval/generation_utils.py @@ -4,14 +4,17 @@ import asyncio import gzip +import hashlib +import logging import os from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, TypeVar +from typing import Any, Callable, Coroutine, Optional, TypeVar import backoff import openai import pydantic +from elasticsearch import AsyncElasticsearch from tqdm.asyncio import tqdm client = openai.AsyncOpenAI() @@ -60,9 +63,124 @@ def _add_cache_callback(response: str): yield None, _add_cache_callback +class AsyncElasticsearchCache: + def __init__(self, es: AsyncElasticsearch, index_name: str) -> None: + self.es = es + self.index_name = index_name + logging.info(f"Elastic Index Name: {self.index_name}") + + self.is_refresh_required = bool(os.environ.get("FORCE_CACHE_REFRESH")) + if self.is_refresh_required: + logging.warning("FORCE_CACHE_REFRESH is enabled. All queries will miss.") + + @staticmethod + async def maybe_from_env_var(index_name: str) -> "AsyncElasticsearchCache | None": + """Initialize from env var. Returns None if any of the required env vars are missing.""" + required_env_keys = ["ELASTIC_SEARCH_HOST", "ELASTIC_SEARCH_API_KEY"] + if not all((_key in os.environ) for _key in required_env_keys): + logging.warning( + "All of these are required to enable ElasticsearchCache: " + f"{required_env_keys}." + " Not enabling ElasticsearchCache since some keys are not set." + ) + return None + + es = AsyncElasticsearch( + os.environ["ELASTIC_SEARCH_HOST"], + api_key=os.environ["ELASTIC_SEARCH_API_KEY"], + request_timeout=None, + ) + + # Ensure the index exists at startup (parse the name if '/' is present) + index_name = index_name.lower().replace("/", "_") + if not await es.indices.exists(index=index_name): + await es.indices.create(index=index_name) + return AsyncElasticsearchCache(es=es, index_name=index_name) + + async def get(self, query: str, nonce: Optional[str] = None) -> str | None: + """Try reading response from cache. + + Args: + query (str): The query to fetch from cache. + nonce (str, optional): An optional nonce to differentiate cache entries. + + Returns: + str | None: Cached result if available. + """ + if self.is_refresh_required: + return None + + # Cache lookup + query_hash = self._get_query_hash(query=query, nonce=nonce) + try: + response = await self.es.get(index=self.index_name, id=query_hash) + if response.get("found"): + logging.info(f"Cache hit: {query_hash}") + return response["_source"]["result"] + except Exception: + logging.debug(f"Cache miss: {query_hash}") + + return None # Cache miss or index doesn't exist + + async def set(self, query: str, value: str, nonce: Optional[str] = None) -> None: + """Set/Update cache. + + Args: + query (str): The query whose result is to be cached. + value (str): The value to store in cache. + nonce (str, optional): An optional nonce to differentiate cache entries. + """ + query_hash = self._get_query_hash(query=query, nonce=nonce) + doc = {"query": query, "result": value, "nonce": nonce} + await self.es.index(index=self.index_name, id=query_hash, document=doc) + + async def close(self) -> None: + """Close Elasticsearch connection.""" + await self.es.close() + + @staticmethod + def _get_query_hash(query: str, nonce: Optional[str] = None) -> str: + query_key = query + if nonce is not None: + query_key += f"\n{nonce}" + return hashlib.sha256(query_key.encode()).hexdigest() + + Data = TypeVar("Data") +V = TypeVar("V") +Serializer = TypeVar("Serializer", bound=pydantic.BaseModel) + + +async def cached( + _fn: Callable[[], Coroutine[None, None, Serializer]], + _key: str, + output_serializer_class: type[Serializer], + cache: AsyncElasticsearchCache, +) -> Serializer: + """Run _fn only if cache is missed.""" + cached_data = await cache.get(_key) + if (cached_data is not None) and not bool(os.getenv("IGNORE_CACHE")): + # Cache hit + return output_serializer_class.model_validate_json(cached_data) + + # Cache miss + output = await _fn() + cached_data = output.model_dump_json() + await cache.set(_key, value=cached_data) + + return output + + +async def rate_limited( + _fn: Callable[[], Coroutine[None, None, V]], semaphore: asyncio.Semaphore +) -> V: + """Run _fn with semaphore rate limit.""" + async with semaphore: + return await _fn() + + @backoff.on_exception(backoff.expo, (openai.RateLimitError,)) async def generate( prompt: str, @@ -156,6 +274,6 @@ async def evalute_on_template( for task in tqdm(asyncio.as_completed(coros), ncols=75, total=len(coros)): _text_output, _data = await task _index = _data["_index"] - output[_index] = extract_answer_fn(_text_output) + output[_index] = extract_answer_fn(_text_output) # type: ignore return output diff --git a/misinfo_data_eval/tasks/keyword_analysis.py b/misinfo_data_eval/tasks/keyword_analysis.py index 65360e9..62e705d 100644 --- a/misinfo_data_eval/tasks/keyword_analysis.py +++ b/misinfo_data_eval/tasks/keyword_analysis.py @@ -1,30 +1,32 @@ -import pandas as pd -import numpy as np import string from collections import Counter -from sklearn.model_selection import train_test_split -from sklearn.ensemble import RandomForestClassifier -from sklearn.metrics import confusion_matrix, f1_score -from nltk.corpus import stopwords -from nltk.tokenize import word_tokenize # Ensure NLTK stopwords are downloaded import nltk +import numpy as np +import pandas as pd +from nltk.corpus import stopwords +from nltk.tokenize import word_tokenize +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import confusion_matrix, f1_score +from sklearn.model_selection import train_test_split + nltk.download("punkt") -nltk.download('punkt_tab') +nltk.download("punkt_tab") nltk.download("stopwords") from typing import Any LABEL_MAP = {"true": 1, "false": 0, "unknown": None} + def keyword_analysis(dataset: list[dict[str, Any]]) -> dict[str, float]: """ Processes a dataset to analyze veracity distribution, top keywords, and train a classifier. - + Parameters: data (pd.DataFrame or np.ndarray): A 2D dataset with "dataset", "claim", and "veracity" columns. dataset_name (str): The name of the dataset to process. - + Returns: dict: A dictionary containing veracity counts, top keywords, confusion matrix, and F1 scores. """ @@ -32,8 +34,7 @@ def keyword_analysis(dataset: list[dict[str, Any]]) -> dict[str, float]: # Convert to DataFrame if input is NumPy array temp_df = pd.DataFrame(dataset) - - # exclude all but T/F (TODO: Add if statement to include 'mixed' values) + # exclude all but T/F (TODO: Add if statement to include 'mixed' values) temp_df = temp_df[temp_df.veracity != 3] temp_df["veracity"] = temp_df["veracity"].apply(LABEL_MAP.get) temp_df = temp_df[temp_df.veracity.notna()] @@ -49,6 +50,7 @@ def keyword_analysis(dataset: list[dict[str, Any]]) -> dict[str, float]: # Text Preprocessing Function stop_words = set(stopwords.words("english")) + def preprocess_text(text): text = text.lower() text = text.translate(str.maketrans("", "", string.punctuation)) @@ -65,6 +67,7 @@ def preprocess_text(text): # Convert Claims to Features (Bag-of-Words) word_features = list(top_keywords.keys()) + def claim_to_vector(words): return [1 if word in words else 0 for word in word_features] @@ -72,7 +75,9 @@ def claim_to_vector(words): y = df["veracity"].values # Train-Test Split - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.25, random_state=42 + ) # Train Random Forest Model rf_model = RandomForestClassifier(n_estimators=100, max_depth=20, random_state=42) @@ -80,11 +85,15 @@ def claim_to_vector(words): # Predict & Evaluate y_pred = rf_model.predict(X_test) - conf_matrix = confusion_matrix(y_test, y_pred).tolist() # Convert to list for JSON compatibility + conf_matrix = confusion_matrix( + y_test, y_pred + ).tolist() # Convert to list for JSON compatibility macro_f1 = f1_score(y_test, y_pred, average="macro") # Baseline Random Predictions - random_preds = np.random.choice(np.unique(y_train), size=len(y_test), p=np.bincount(y_train) / len(y_train)) + random_preds = np.random.choice( + np.unique(y_train), size=len(y_test), p=np.bincount(y_train) / len(y_train) + ) baseline_f1 = f1_score(y_test, random_preds, average="macro") # Return results as a dictionary @@ -94,9 +103,10 @@ def claim_to_vector(words): "top_keywords": top_keywords, "confusion_matrix": conf_matrix, "macro_f1_random_forest": macro_f1, - "macro_f1_random_baseline": baseline_f1 + "macro_f1_random_baseline": baseline_f1, } + # Example Usage: # data = pd.read_csv("dat_claims.csv") # Load dataset # results = analyze_dataset(data, "checkcovid") diff --git a/misinfo_data_eval/tasks/temporal_correlation.py b/misinfo_data_eval/tasks/temporal_correlation.py index 78cf2f7..fdc0386 100644 --- a/misinfo_data_eval/tasks/temporal_correlation.py +++ b/misinfo_data_eval/tasks/temporal_correlation.py @@ -5,7 +5,6 @@ from sklearn.metrics import classification_report from sklearn.model_selection import train_test_split - LABEL_MAP = {"true": 1, "false": 0, "unknown": None} diff --git a/misinfo_data_eval/tasks/web_search.py b/misinfo_data_eval/tasks/web_search.py new file mode 100644 index 0000000..b396725 --- /dev/null +++ b/misinfo_data_eval/tasks/web_search.py @@ -0,0 +1,169 @@ +import asyncio +import re +from typing import Any, Coroutine, Iterable + +from agents import Agent, ItemHelpers, Runner, RunResult, set_trace_processors +from agents.tool import WebSearchTool +from opik.integrations.openai.agents import OpikTracingProcessor +from pydantic import BaseModel +from tqdm.asyncio import tqdm + +from ..generation_utils import AsyncElasticsearchCache, cached, rate_limited + +set_trace_processors(processors=[OpikTracingProcessor("20250514-CDL-misinfo-data-1a")]) + + +class DatasetRow(BaseModel): + claim: str + veracity: str + + +class Output(BaseModel): + """Represents feasibility evaluation output.""" + + statement: str + factuality_prediction: int | None + output_text: str + is_output_invalid: bool + trace: list[str] + data_source: dict[str, Any] | None = None + is_correct: bool | None = None + + +def serialize_run_result(run_result: RunResult) -> list[str]: + """Serialize Agent trace.""" + return [str(_result) for _result in run_result.new_items] + + +INSTRUCTIONS = """\ +Your task is to analyze the factuality of the given statement. + +You may invoke the "search" tool as many times as needed to retrieve \ +up-to-date information. However, before invoking the tool, you must \ +explain your rationale for doing so. + +After providing all your analysis steps, summarize your analysis \ +and state "True statement; Factuality: 1" if you think the statement \ +is factual, or "False statement; Factuality: 0" otherwise. +""" + +RETRIEVAL_INSTRUCTIONS = """\ +Answer the question by summarizing the search results. +""" + + +search_agent = Agent( + name="SearchAgent", + instructions=( + "You are a search agent. You receive a single search query as input. " + "Use the WebSearchTool to perform a web search, then produce a concise " + "'search summary' of the key findings. Do NOT return raw search results." + ), + tools=[WebSearchTool(search_context_size="low")], + # a faster, smaller model for quick searches + model="gpt-4o-mini", +) + +main_agent = Agent( + name="MainAgent", + instructions=INSTRUCTIONS, + tools=[ + search_agent.as_tool( + tool_name="search", + tool_description="Perform a web search for a query and return a concise summary.", + ) + ], + # a larger, more capable model for reasoning over summaries + model="gpt-4o", +) + + +async def evaluate(statement: str, data_source: dict[str, Any] | None) -> Output: + """Evaluate on a single statement.""" + result = await Runner.run(main_agent, statement) + prediction_match = re.search(r"Factuality:\s*(\d)", result.final_output) + + prediction = None + try: + if prediction_match is not None: + prediction_str = prediction_match.group(1) + prediction = int(prediction_str) + except: + pass + + is_correct: bool | None = None + if data_source is not None: + label = data_source.get("veracity") + if label == "true": + is_correct = prediction == 1 + if label == "false": + is_correct = prediction == 0 + + return Output( + statement=statement, + factuality_prediction=prediction, + output_text=result.final_output, + is_output_invalid=prediction is None, + trace=serialize_run_result(result), + data_source=data_source, + is_correct=is_correct, + ) + + +async def batch_evaluate( + dataset_rows: Iterable[dict[str, Any]], + cache: AsyncElasticsearchCache, + async_semaphore: asyncio.Semaphore, + total: int | None = None, +) -> list[Output]: + """Evaluate veracity, reusing cache whenever possible.""" + coros: list[Coroutine[None, None, Output]] = [ + rate_limited( + lambda: cached( + _fn=lambda: evaluate(DatasetRow(**_row).claim, data_source=_row), + _key=DatasetRow(**_row).claim, + output_serializer_class=Output, + cache=cache, + ), + semaphore=async_semaphore, + ) + for _row in dataset_rows + ] + + outputs: list[Output] = [] + for coro in tqdm(asyncio.as_completed(coros), ncols=75, total=total): + outputs.append(await coro) + + return outputs + + +async def main(): + result = Runner.run_streamed( + main_agent, + " Comparing the price of oil and gas in June 2008 to March 2022 shows that oil companies are price gouging.", + ) + + async for event in result.stream_events(): + # We'll ignore the raw responses event deltas + if event.type == "raw_response_event": + continue + # When the agent updates, print that + elif event.type == "agent_updated_stream_event": + print(f"Agent updated: {event.new_agent.name}") + continue + # When items are generated, print them + elif event.type == "run_item_stream_event": + if event.item.type == "tool_call_item": + print("-- Tool was called") + elif event.item.type == "tool_call_output_item": + print(f"-- Tool output: {event.item.output}\n") + elif event.item.type == "message_output_item": + print( + f"-- Message output:\n {ItemHelpers.text_message_output(event.item)}\n" + ) + else: + pass # Ignore other event types + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index d01db49..b7eaf3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,4 +13,7 @@ dependencies = [ "isort", "black", "nltk", + "openai-agents>=0.0.14", + "opik>=1.7.20", + "elasticsearch[async]==8.17.2", ] From 3304e6e3502ea83812f107c33cfb179138fe1ee8 Mon Sep 17 00:00:00 2001 From: jacobthebanana Date: Mon, 26 May 2025 19:51:16 +0000 Subject: [PATCH 2/2] Added web search and feasibility eval logic. --- data/.gitignore | 3 +- misinfo_data_eval/entrypoint.py | 66 ++----- misinfo_data_eval/entrypoint_langfuse.py | 178 +++++++++++++++++ misinfo_data_eval/generation_utils.py | 149 +++++++------- .../feasibility_accuracy_correlation.py | 117 +++++++++++ misinfo_data_eval/tasks/feasibility_eval.py | 12 +- misinfo_data_eval/tasks/web_search.py | 185 +++++++++++++++--- misinfo_data_eval/tracing_utils/__init__.py | 0 .../tracing_utils/langfuse_otlp.py | 63 ++++++ pyproject.toml | 4 + scripts.md | 146 ++++++++++++++ 11 files changed, 760 insertions(+), 163 deletions(-) create mode 100644 misinfo_data_eval/entrypoint_langfuse.py create mode 100644 misinfo_data_eval/metrics/feasibility_accuracy_correlation.py create mode 100644 misinfo_data_eval/tracing_utils/__init__.py create mode 100644 misinfo_data_eval/tracing_utils/langfuse_otlp.py create mode 100644 scripts.md diff --git a/data/.gitignore b/data/.gitignore index b49a22a..b718fca 100644 --- a/data/.gitignore +++ b/data/.gitignore @@ -1 +1,2 @@ -/cache \ No newline at end of file +/cache +/* \ No newline at end of file diff --git a/misinfo_data_eval/entrypoint.py b/misinfo_data_eval/entrypoint.py index 92a284e..cadb7cd 100644 --- a/misinfo_data_eval/entrypoint.py +++ b/misinfo_data_eval/entrypoint.py @@ -28,7 +28,7 @@ async def main(): # Feasibility Evaluation if args.evaluate_feasibility: - from .generation_utils import AsyncLLMEvaluator, Cache + from .generation_utils import AsyncElasticsearchCache, AsyncLLMEvaluator from .tasks.feasibility_eval import evaluate_feasibility if args.evaluator_model_name is None: @@ -37,17 +37,25 @@ async def main(): makedirs("data/cache", exist_ok=True) async_semaphore = asyncio.Semaphore(args.max_concurrency) - cache = Cache(f"data/cache/{args.evaluator_model_name}.jsonl.gz") + es_cache = await AsyncElasticsearchCache.maybe_from_env_var( + f"cache_misinfo_eval_feasibility" + ) + if es_cache is None: + print( + "Warning: es_cache is not available. " + "See AsyncElasticsearchCache.maybe_from_env_var on how to enable." + ) + llm_evaluator = AsyncLLMEvaluator( model_name=args.evaluator_model_name, - cache=cache, + cache=es_cache, async_semaphore=async_semaphore, assert_cached=args.assert_cached, max_completion_tokens=args.max_generation_tokens, ) try: - feasibility_metrics = await evaluate_feasibility( + feasibility_metrics, _ = await evaluate_feasibility( statements=[_row["claim"] for _row in dataset][: args.limit], llm_evaluator=llm_evaluator, ) @@ -55,54 +63,8 @@ async def main(): finally: # Cache previous generations if interrupted. - cache.write() - - if args.evaluate_factuality: - from .generation_utils import AsyncElasticsearchCache - from .tasks.web_search import batch_evaluate - - async_semaphore = asyncio.Semaphore(args.max_concurrency) - es_cache = await AsyncElasticsearchCache.maybe_from_env_var( - "cache_misinfo_eval_web_retrieval" - ) - assert es_cache is not None - - data_rows = [ - _row - for _row in dataset[: args.limit] - if _row["veracity"] in ("true", "false") - ] - eval_output = await batch_evaluate( - data_rows, - cache=es_cache, - async_semaphore=async_semaphore, - total=len(data_rows), - ) - await es_cache.close() - - num_total_valid = 0 - num_total_correct = 0 - - for output in eval_output: - if output.is_correct is not None: - num_total_valid += 1 - num_total_correct += output.is_correct - - print( - json.dumps( - { - "num_total": len(eval_output), - "num_valid": num_total_valid, - "num_corect": num_total_correct, - "accuracy": ( - num_total_correct / num_total_valid - if num_total_valid > 0 - else None - ), - }, - indent=2, - ) - ) + if es_cache: + await es_cache.close() # Temporal Correlation Evaluation, if "tweet_id" data is available if args.evaluate_temporal_correlation: diff --git a/misinfo_data_eval/entrypoint_langfuse.py b/misinfo_data_eval/entrypoint_langfuse.py new file mode 100644 index 0000000..c72decb --- /dev/null +++ b/misinfo_data_eval/entrypoint_langfuse.py @@ -0,0 +1,178 @@ +""" +Evaluate web retrieval pipeline using LangFuse + +- Run Agent SDK to produce traces +- Fetch and score traces +- Update traces to include eval scores. +""" + +import argparse +import asyncio +import hashlib + +from langfuse import Langfuse +from tqdm.auto import tqdm + +from .data_loading_utils import DATA_INSTRUCTIONS, load_data +from .generation_utils import AsyncElasticsearchCache, AsyncLLMEvaluator +from .tasks.feasibility_eval import FEASIBILITY_LEVEL_MAPS, evaluate_feasibility +from .tasks.web_search import DatasetRow +from .tasks.web_search import batch_evaluate as evaluate_veracity + +parser = argparse.ArgumentParser() +parser.add_argument("--evaluator_model_name") +parser.add_argument("--max_generation_tokens", type=int, default=4096) +parser.add_argument("--web_search_variant", choices=("es", "oai"), required=True) +parser.add_argument("--langfuse_dataset_name", default="cdl-misinfo") +parser.add_argument("--langfuse_populate_dataset", action="store_true", default=False) +parser.add_argument("--langfuse_always_link", action="store_true", default=False) +parser.add_argument( + "--source_dataset_paths", required=True, nargs="+", help=DATA_INSTRUCTIONS +) +parser.add_argument("--max_concurrency", type=int, default=1) +parser.add_argument("--evaluate_feasibility", action="store_true", default=False) +parser.add_argument("--evaluate_factuality", action="store_true", default=False) +parser.add_argument( + "--evaluate_temporal_correlation", action="store_true", default=False +) +parser.add_argument("--keyword_analysis", action="store_true", default=False) +parser.add_argument("--assert_cached", action="store_true", default=False) +parser.add_argument("--limit", type=int, default=-1, help="Per-source limit") + + +async def main(): + + args = parser.parse_args() + run_name = f"openai-agent-{args.web_search_variant}-search" + + async_semaphore = asyncio.Semaphore(args.max_concurrency) + langfuse = Langfuse() + es_cache = await AsyncElasticsearchCache.maybe_from_env_var( + f"cache_misinfo_eval_{run_name}" + ) + llm_evaluator = AsyncLLMEvaluator( + model_name=args.evaluator_model_name, + cache=es_cache, + async_semaphore=async_semaphore, + assert_cached=args.assert_cached, + max_completion_tokens=args.max_generation_tokens, + ) + + langfuse_dataset_name = args.langfuse_dataset_name + print(f"langfuse_dataset_name: {langfuse_dataset_name}") + dataset = langfuse.create_dataset( + langfuse_dataset_name, metadata={"type": "benchmark"} + ) + + assert es_cache is not None + + if args.langfuse_populate_dataset: + for source_dataset_path in tqdm(args.source_dataset_paths, ncols=75): + dataset = load_data(source_dataset_path) + source_path_hash = hashlib.sha256(source_dataset_path.encode()).hexdigest() + dataset_name_hash = hashlib.sha256( + langfuse_dataset_name.encode() + ).hexdigest() + + for index, row in tqdm( + enumerate(dataset), + ncols=75, + total=args.limit, + desc="Populating LangFuse dataset", + ): + # Skip "unknown" rows. + if (row["veracity"] not in ["true", "false"]) or ( + len(row["claim"]) < 10 + ): + continue + + if (args.limit > 0) and (index >= args.limit): + break + + langfuse.create_dataset_item( + id=f"{source_path_hash[:6]}-{dataset_name_hash}-{index:05}", + dataset_name=langfuse_dataset_name, + input=row["claim"], + expected_output=row["veracity"], + metadata={"source": source_dataset_path, **row}, + ) + + # Evaluate in batch, and then pair up dataset (input) and output traces. + dataset_langfuse = langfuse.get_dataset(langfuse_dataset_name) + input_items = [_item.metadata for _item in dataset_langfuse.items] + input_items = [_row for _row in input_items if _row is not None] + if len(input_items) == 0: + raise FileNotFoundError( + f"Langfuse Dataset {args.langfuse_dataset_name} is empty. " + "Maybe re-run with the following? --langfuse_populate_dataset" + ) + + _, feasibility_outputs = await evaluate_feasibility( + [DatasetRow(**_item).claim for _item in input_items], llm_evaluator + ) + eval_outputs = await evaluate_veracity( + input_items, # type: ignore[dict] + cache=es_cache, + async_semaphore=async_semaphore, + variant=args.web_search_variant, + total=len(input_items), + ) + for index, (item, eval_output, feasibility_output) in enumerate( + zip( + dataset_langfuse.items, + tqdm(eval_outputs, desc="Scoring", ncols=75), + feasibility_outputs, + ) + ): + # Link only on the first run. + if (not eval_output.cache_hit) or (args.langfuse_always_link): + item.link( + None, + run_name=run_name, + trace_id=eval_output.langfuse_trace_id, + ) + + langfuse_trace = langfuse.trace(id=eval_output.langfuse_trace_id) + + is_correct = eval_output.is_correct + if is_correct is not None: + langfuse_trace.score( + id=f"correct_{eval_output.langfuse_trace_id}", + name="Correct", + value=is_correct, + data_type="BOOLEAN", + ) + + langfuse_trace.score( + id=f"valid_{eval_output.langfuse_trace_id}", + name="Valid", + value=not eval_output.is_output_invalid, + data_type="BOOLEAN", + ) + langfuse_trace.update(tags=[langfuse_dataset_name]) + + if feasibility_output is not None: + feasibility_level = FEASIBILITY_LEVEL_MAPS[feasibility_output] + langfuse_trace.score( + id=f"feasibility_search_{eval_output.langfuse_trace_id}", + name="Feasibility with Search", + value=feasibility_level >= 1, + data_type="BOOLEAN", + ) + langfuse_trace.score( + id=f"feasibility_no_search_{eval_output.langfuse_trace_id}", + name="Feasibility, no Search", + value=feasibility_level == 2, + data_type="BOOLEAN", + ) + + if (index - 1) % 10 == 0: + langfuse.flush() + + print("Waiting for langfuse.flush()") + langfuse.flush() + await es_cache.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/misinfo_data_eval/generation_utils.py b/misinfo_data_eval/generation_utils.py index 6bb9926..be81761 100644 --- a/misinfo_data_eval/generation_utils.py +++ b/misinfo_data_eval/generation_utils.py @@ -9,12 +9,13 @@ import os from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Coroutine, Optional, TypeVar +from typing import Awaitable, Callable, Coroutine, Optional, TypeVar import backoff import openai import pydantic from elasticsearch import AsyncElasticsearch +from opentelemetry import context from tqdm.asyncio import tqdm client = openai.AsyncOpenAI() @@ -27,42 +28,6 @@ class CacheEntry(pydantic.BaseModel): response: str -class Cache: - def __init__(self, cache_path: str | Path): - self.cache: dict[str, str] = {} - self.new_entries: dict[str, str] = {} - self.cache_path = cache_path - if not os.path.exists(cache_path): - gzip.open(cache_path, "wt") - - with gzip.open(cache_path, "rt") as cache_file: - for row in cache_file.readlines(): - if len(row.strip()) > 0: - entry = CacheEntry.model_validate_json(row.strip()) - self.cache[entry.prompt] = entry.response - - def write(self): - with gzip.open(self.cache_path, "at") as cache_file: - cache_file.write("\n") - for key, value in self.cache.items(): - line = CacheEntry(prompt=key, response=value).model_dump_json() + "\n" - cache_file.write(line) - - @contextmanager - def cache_response(self, prompt: str): - """Cache response.""" - - def _add_cache_callback(response: str): - self.new_entries[prompt] = response - self.cache[prompt] = response - - if prompt in self.cache: - yield self.cache[prompt], _add_cache_callback - return - - yield None, _add_cache_callback - - class AsyncElasticsearchCache: def __init__(self, es: AsyncElasticsearch, index_name: str) -> None: self.es = es @@ -153,28 +118,59 @@ def _get_query_hash(query: str, nonce: Optional[str] = None) -> str: Serializer = TypeVar("Serializer", bound=pydantic.BaseModel) +@contextmanager +def disable_auto_tracing(): + """ + Context manager that sets the `suppress_instrumentation` flag + so that any auto-instrumentation which checks it will bail out. + """ + # Create a new Context with suppression enabled + new_ctx = context.set_value("suppress_instrumentation", True) + # Attach it, saving the token for later + token = context.attach(new_ctx) + try: + yield + finally: + # Detach to restore the previous context + context.detach(token) + + async def cached( - _fn: Callable[[], Coroutine[None, None, Serializer]], + _fn: Callable[[], Awaitable[Serializer]], _key: str, output_serializer_class: type[Serializer], - cache: AsyncElasticsearchCache, + cache: AsyncElasticsearchCache | None, + assert_cached: bool = False, ) -> Serializer: - """Run _fn only if cache is missed.""" - cached_data = await cache.get(_key) + """Run _fn only if cache is missed or if Cache is None (not provided).""" + with disable_auto_tracing(): + cached_data = (await cache.get(_key)) if (cache is not None) else None + if (cached_data is not None) and not bool(os.getenv("IGNORE_CACHE")): # Cache hit - return output_serializer_class.model_validate_json(cached_data) + output = output_serializer_class.model_validate_json(cached_data) + return output_serializer_class(**{**output.model_dump(), "cache_hit": True}) + + if assert_cached: + raise RuntimeError(f"assert_cached is set and cache missed on key: {_key}") # Cache miss output = await _fn() cached_data = output.model_dump_json() - await cache.set(_key, value=cached_data) + if cache is not None: + with disable_auto_tracing(): + await cache.set(_key, value=cached_data) return output +async def indexed(index: int, coro: Coroutine[None, None, V]) -> tuple[int, V]: + """Returns (index, await coro).""" + return index, (await coro) + + async def rate_limited( - _fn: Callable[[], Coroutine[None, None, V]], semaphore: asyncio.Semaphore + _fn: Callable[[], Awaitable[V]], semaphore: asyncio.Semaphore ) -> V: """Run _fn with semaphore rate limit.""" async with semaphore: @@ -184,49 +180,35 @@ async def rate_limited( @backoff.on_exception(backoff.expo, (openai.RateLimitError,)) async def generate( prompt: str, - data: Data, model_name: str, - async_semaphore: asyncio.Semaphore, async_client: openai.AsyncOpenAI, - cache: Cache, - assert_cached: bool = False, max_completion_tokens: int = 4096, -) -> tuple[str, Data]: +) -> CacheEntry: """Generate using ChatCompletion generation. Params: prompt: str data: to be returned verbatim model_name: str - async_semaphore: to limit number of concurrent requests. async_client: async OpenAI client. - cache: Cache """ - async with async_semaphore: - with cache.cache_response(prompt) as (cached_output, _callback): - if cached_output is not None: - return cached_output, data - - assert not assert_cached, "Cache miss. Maybe run without --assert_cached?" - - response = await async_client.chat.completions.create( - model=model_name, - messages=[{"role": "user", "content": prompt}], - max_completion_tokens=max_completion_tokens, - ) + response = await async_client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": prompt}], + max_completion_tokens=max_completion_tokens, + ) - output = response.choices[0].message.content - assert output is not None - _callback(output) + output = response.choices[0].message.content + assert output is not None - return output, data + return CacheEntry(prompt=prompt, response=output) class AsyncLLMEvaluator: def __init__( self, model_name: str, - cache: Cache, + cache: AsyncElasticsearchCache | None, async_semaphore: asyncio.Semaphore, assert_cached: bool = False, max_completion_tokens: int = 4096, @@ -256,15 +238,23 @@ async def evalute_on_template( list of extracted answers, same length as data. """ coros = [ - generate( - prompt=apply_template_fn(row), - data={**row, "_index": index}, - model_name=self.model_name, - async_semaphore=self.async_semaphore, - async_client=self.async_client, - cache=self.cache, - assert_cached=self.assert_cached, - max_completion_tokens=self.max_completion_tokens, + indexed( + index, + coro=rate_limited( + lambda row=row: cached( + lambda row=row: generate( + prompt=apply_template_fn(row), + model_name=self.model_name, + async_client=self.async_client, + max_completion_tokens=self.max_completion_tokens, + ), + _key=apply_template_fn(row), + output_serializer_class=CacheEntry, + cache=self.cache, + assert_cached=self.assert_cached, + ), + semaphore=self.async_semaphore, + ), ) for index, row in enumerate(rows) ] @@ -272,8 +262,7 @@ async def evalute_on_template( output: list[str | None] = [None for _ in range(len(coros))] for task in tqdm(asyncio.as_completed(coros), ncols=75, total=len(coros)): - _text_output, _data = await task - _index = _data["_index"] - output[_index] = extract_answer_fn(_text_output) # type: ignore + _index, _cache_entry = await task + output[_index] = extract_answer_fn(_cache_entry.response) return output diff --git a/misinfo_data_eval/metrics/feasibility_accuracy_correlation.py b/misinfo_data_eval/metrics/feasibility_accuracy_correlation.py new file mode 100644 index 0000000..7119b50 --- /dev/null +++ b/misinfo_data_eval/metrics/feasibility_accuracy_correlation.py @@ -0,0 +1,117 @@ +"""Evaluate correlation between feasibility and veracity prediction accuracy.""" + +import argparse +import json +from typing import Any, Dict + +import numpy as np +import pandas as pd +from sklearn.metrics import ( + accuracy_score, + confusion_matrix, + f1_score, + mutual_info_score, + precision_score, + recall_score, +) + +keys = ["Correct", "Feasibility with Search"] + + +def quantify_predictive_power(y_true: pd.Series, y_pred: pd.Series) -> Dict[str, Any]: + """ + Compute a suite of binary-classification metrics between two 0/1 series. + + :param y_true: ground-truth booleans (or 0/1 ints) + :param y_pred: predicted booleans (or 0/1 ints) + :returns: dict with + - accuracy, precision, recall, f1, mutual_info, phi_coefficient, + - tn, fp, fn, tp counts + """ + y_true_i = y_true.astype(int) + y_pred_i = y_pred.astype(int) + + # basic scores + acc = accuracy_score(y_true_i, y_pred_i) + prec = precision_score(y_true_i, y_pred_i, zero_division=0) + rec = recall_score(y_true_i, y_pred_i, zero_division=0) + f1 = f1_score(y_true_i, y_pred_i, zero_division=0) + mi = mutual_info_score(y_true_i, y_pred_i) + + # force 2×2 CM even if one class is missing + cm = confusion_matrix(y_true_i, y_pred_i, labels=[0, 1]) + tn, fp, fn, tp = cm.ravel() + + # φ-coefficient (Pearson ρ for binary) + denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + phi = (tp * tn - fp * fn) / np.sqrt(denom) if denom > 0 else 0.0 + + return { + "accuracy": acc, + "precision": prec, + "recall": rec, + "f1": f1, + "mutual_info": mi, + "phi_coefficient": phi, + "tn": tn, + "fp": fp, + "fn": fn, + "tp": tp, + "total": len(y_true), + f"mean of {keys[0]}": y_true_i.mean(), + f"mean of {keys[1]}": y_pred_i.mean(), + } + + +def summarize_by_source(df: pd.DataFrame) -> pd.DataFrame: + """ + Group df by 'data_source', compute predictive metrics of feature_2 → feature_1 + for each source, and add an 'ALL' row for the aggregate. + + :param df: DataFrame with columns ['data_source','feature_1','feature_2'] + :returns: DataFrame indexed by data_source, columns are the metrics. + """ + records: list[Dict[str, Any]] = [] + + # per‐source + for source, group in df.groupby("data_source"): + metrics = quantify_predictive_power(group[keys[0]], group[keys[1]]) + metrics["data_source"] = source + records.append(metrics) + + # aggregate + total = quantify_predictive_power(df[keys[0]], df[keys[1]]) + print(json.dumps(total, indent=2, default=str)) + total["data_source"] = "ALL" + records.append(total) + + result = pd.DataFrame.from_records(records).set_index("data_source") + return result + + +parser = argparse.ArgumentParser() +parser.add_argument("jsonl_path") + +if __name__ == "__main__": + args = parser.parse_args() + rows = [] + with open(args.jsonl_path) as jsonl_file: + for _line in jsonl_file: + _row = json.loads(_line) + + if not all(_row.get(key) for key in keys): + continue + + _labels = {k: _row[k][0] == "True" for k in keys} + _attribution = {"data_source": _row["metadata"]["data_source"]["dataset"]} + rows.append({**_labels, **_attribution}) + + df = pd.DataFrame(rows) + + # 2. Quantify predictive power + summary = summarize_by_source(df) + + # 3. Display + print(summary.to_latex()) + print(summary.to_markdown()) + print(summary.to_csv()) diff --git a/misinfo_data_eval/tasks/feasibility_eval.py b/misinfo_data_eval/tasks/feasibility_eval.py index 191d915..4353025 100644 --- a/misinfo_data_eval/tasks/feasibility_eval.py +++ b/misinfo_data_eval/tasks/feasibility_eval.py @@ -23,10 +23,16 @@ ("1", "1"): "feasible, no search required", } +FEASIBILITY_LEVEL_MAPS = { + "not feasible even with search": 0, + "feasible, requires search": 1, + "feasible, no search required": 2, +} + async def evaluate_feasibility( statements: list[str], llm_evaluator: "AsyncLLMEvaluator" -) -> dict[str | None, int]: +) -> tuple[dict[str | None, int], list[str | None]]: """Evaluate feasibility of the given data. Params: @@ -39,6 +45,8 @@ async def evaluate_feasibility( - "feasible, requires search": int - "not feasible even with search": int - None: int + + list of raw predictions. """ per_template_predictions: list[list[str | None]] = [] @@ -59,4 +67,4 @@ async def evaluate_feasibility( for _paired_prediction in zip(*per_template_predictions): projected_predictions.append(PROJECTIONS.get(_paired_prediction)) - return Counter(projected_predictions) + return Counter(projected_predictions), projected_predictions diff --git a/misinfo_data_eval/tasks/web_search.py b/misinfo_data_eval/tasks/web_search.py index b396725..aac5983 100644 --- a/misinfo_data_eval/tasks/web_search.py +++ b/misinfo_data_eval/tasks/web_search.py @@ -1,21 +1,32 @@ import asyncio import re -from typing import Any, Coroutine, Iterable - -from agents import Agent, ItemHelpers, Runner, RunResult, set_trace_processors +from os import getenv +from typing import TYPE_CHECKING, Any, Coroutine, Iterable, Literal + +from agents import ( + Agent, + ItemHelpers, + OpenAIChatCompletionsModel, + Runner, + RunResult, + function_tool, +) from agents.tool import WebSearchTool -from opik.integrations.openai.agents import OpikTracingProcessor +from elasticsearch import AsyncElasticsearch +from openai import AsyncOpenAI from pydantic import BaseModel from tqdm.asyncio import tqdm -from ..generation_utils import AsyncElasticsearchCache, cached, rate_limited +from ..generation_utils import AsyncElasticsearchCache, cached, indexed, rate_limited +from ..tracing_utils.langfuse_otlp import get_langfuse_trace_id, langfuse, tracer -set_trace_processors(processors=[OpikTracingProcessor("20250514-CDL-misinfo-data-1a")]) +SearchVariant = Literal["es", "oai"] class DatasetRow(BaseModel): claim: str veracity: str + dataset: str class Output(BaseModel): @@ -26,9 +37,12 @@ class Output(BaseModel): output_text: str is_output_invalid: bool trace: list[str] - data_source: dict[str, Any] | None = None + langfuse_trace_id: str + data_source: DatasetRow | None = None is_correct: bool | None = None + cache_hit: bool = False + def serialize_run_result(run_result: RunResult) -> list[str]: """Serialize Agent trace.""" @@ -51,36 +65,141 @@ def serialize_run_result(run_result: RunResult) -> list[str]: Answer the question by summarizing the search results. """ +openai_client = AsyncOpenAI() +es_client = AsyncElasticsearch(getenv("KB_ES_HOST"), api_key=getenv("KB_ES_API_KEY")) + + +async def news_search(keyword: str) -> list[dict]: + """Search News Database. + + Returns: + a list of results. Empty list of no match is found. + """ + + title_match = { + "match": { + "title.fuzzy": { + "query": keyword, + "fuzziness": "AUTO", + "operator": "and", + } + } + } + text_match = { + "match": { + "text": { + "query": keyword, + "fuzziness": "AUTO", + "operator": "and", + } + } + } + + # Return first match. + for requirements_option in [[text_match]]: + response = await es_client.search( + index="ccnews", + body={ + "query": {"bool": {"must": requirements_option}}, + "highlight": { + "fields": { + "text": { + # snippet length in characters + "fragment_size": 1000, + "number_of_fragments": 5, + } + }, + }, + "size": 5, + # Do not return full document source + "_source": False, + }, + ) + hits = response["hits"]["hits"] + if len(hits) > 0: + return hits + + return [] -search_agent = Agent( + +search_agent_oai = Agent( name="SearchAgent", instructions=( "You are a search agent. You receive a single search query as input. " "Use the WebSearchTool to perform a web search, then produce a concise " "'search summary' of the key findings. Do NOT return raw search results." ), - tools=[WebSearchTool(search_context_size="low")], + tools=[ + WebSearchTool(search_context_size="low"), + ], # a faster, smaller model for quick searches model="gpt-4o-mini", ) -main_agent = Agent( +search_agent_es = Agent( + name="SearchAgent", + instructions=( + "You are a search agent. You receive a single search query as input. " + "Use the WebSearchTool to perform a web search, then produce a concise " + "'search summary' of the key findings. Do NOT return raw search results." + ), + tools=[ + function_tool(news_search), + ], + # a faster, smaller model for quick searches + model="gpt-4o-mini", +) + +main_agent_oai = Agent( + name="MainAgent", + instructions=INSTRUCTIONS, + tools=[ + search_agent_oai.as_tool( + tool_name="search", + tool_description="Perform a web search for a query and return a concise summary.", + ) + ], + # a larger, more capable model for reasoning over summaries + model=OpenAIChatCompletionsModel(model="gpt-4o", openai_client=openai_client), +) + +main_agent_es = Agent( name="MainAgent", instructions=INSTRUCTIONS, tools=[ - search_agent.as_tool( + search_agent_es.as_tool( tool_name="search", tool_description="Perform a web search for a query and return a concise summary.", ) ], # a larger, more capable model for reasoning over summaries - model="gpt-4o", + model=OpenAIChatCompletionsModel(model="gpt-4o", openai_client=openai_client), ) -async def evaluate(statement: str, data_source: dict[str, Any] | None) -> Output: +async def evaluate( + statement: str, data_source: DatasetRow | None, variant: SearchVariant +) -> Output: """Evaluate on a single statement.""" - result = await Runner.run(main_agent, statement) + if variant == "es": + main_agent = main_agent_es + else: + main_agent = main_agent_oai + + with tracer.start_as_current_span(f"Agent-SDK-{variant}-search"): + result = await Runner.run(main_agent, statement) + langfuse_trace_id = get_langfuse_trace_id() + langfuse.trace( + id=langfuse_trace_id, + input=statement, + output=result.final_output, + metadata={ + "data_source": ( + data_source.model_dump() if data_source is not None else None + ) + }, + ) + prediction_match = re.search(r"Factuality:\s*(\d)", result.final_output) prediction = None @@ -93,7 +212,7 @@ async def evaluate(statement: str, data_source: dict[str, Any] | None) -> Output is_correct: bool | None = None if data_source is not None: - label = data_source.get("veracity") + label = data_source.veracity if label == "true": is_correct = prediction == 1 if label == "false": @@ -105,6 +224,7 @@ async def evaluate(statement: str, data_source: dict[str, Any] | None) -> Output output_text=result.final_output, is_output_invalid=prediction is None, trace=serialize_run_result(result), + langfuse_trace_id=langfuse_trace_id, data_source=data_source, is_correct=is_correct, ) @@ -114,32 +234,41 @@ async def batch_evaluate( dataset_rows: Iterable[dict[str, Any]], cache: AsyncElasticsearchCache, async_semaphore: asyncio.Semaphore, + variant: SearchVariant, total: int | None = None, ) -> list[Output]: """Evaluate veracity, reusing cache whenever possible.""" - coros: list[Coroutine[None, None, Output]] = [ - rate_limited( - lambda: cached( - _fn=lambda: evaluate(DatasetRow(**_row).claim, data_source=_row), - _key=DatasetRow(**_row).claim, - output_serializer_class=Output, - cache=cache, + rows = [DatasetRow(**_row) for _row in dataset_rows] + coros: list[Coroutine[None, None, tuple[int, Output]]] = [ + indexed( + index=_index, + coro=rate_limited( + _fn=lambda _row=_row: cached( + _fn=lambda _row=_row: evaluate( + _row.claim, data_source=_row, variant=variant + ), + _key=_row.claim, + output_serializer_class=Output, + cache=cache, + ), + semaphore=async_semaphore, ), - semaphore=async_semaphore, ) - for _row in dataset_rows + for _index, _row in enumerate(rows) ] - outputs: list[Output] = [] + # Map original position to output + outputs: dict[int, Output] = {} for coro in tqdm(asyncio.as_completed(coros), ncols=75, total=total): - outputs.append(await coro) + _index, output = await coro + outputs[_index] = output - return outputs + return [outputs[_index] for _index in range(len(coros))] async def main(): result = Runner.run_streamed( - main_agent, + main_agent_es, " Comparing the price of oil and gas in June 2008 to March 2022 shows that oil companies are price gouging.", ) diff --git a/misinfo_data_eval/tracing_utils/__init__.py b/misinfo_data_eval/tracing_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/misinfo_data_eval/tracing_utils/langfuse_otlp.py b/misinfo_data_eval/tracing_utils/langfuse_otlp.py new file mode 100644 index 0000000..5dc8a90 --- /dev/null +++ b/misinfo_data_eval/tracing_utils/langfuse_otlp.py @@ -0,0 +1,63 @@ +import base64 +import os + +import logfire +import nest_asyncio +from langfuse import Langfuse +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.trace import format_trace_id + +# Build Basic Auth header. +LANGFUSE_AUTH = base64.b64encode( + f"{os.environ.get('LANGFUSE_PUBLIC_KEY')}:{os.environ.get('LANGFUSE_SECRET_KEY')}".encode() +).decode() + +# Configure OpenTelemetry endpoint & headers +os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = ( + os.environ.get("LANGFUSE_HOST", "") + "/api/public/otel" +) +os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"Authorization=Basic {LANGFUSE_AUTH}" + + +nest_asyncio.apply() + + +# Configure logfire instrumentation. +logfire.configure( + service_name="openai_agent_sdk", + send_to_logfire=False, + scrubbing=False, +) +# This method automatically patches the OpenAI Agents SDK to send logs via OTLP to Langfuse. +logfire.instrument_openai_agents() + + +# Create a TracerProvider for OpenTelemetry +trace_provider = TracerProvider() + +# Add a SimpleSpanProcessor with the OTLPSpanExporter to send traces +trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter())) + +# Set the global default tracer provider +trace.set_tracer_provider(trace_provider) +tracer = trace.get_tracer(__name__) + +langfuse = Langfuse() + + +def get_langfuse_trace_id() -> str: + """Get langfuse trace from context. + + Run this function within the tracer context. + + Allows the current run to be linked to e.g., a LangFuse dataset item. + """ + current_span = trace.get_current_span() + span_context = current_span.get_span_context() + trace_id = span_context.trace_id + formatted_trace_id = format_trace_id(trace_id) + + return formatted_trace_id diff --git a/pyproject.toml b/pyproject.toml index b7eaf3e..d61f2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,4 +16,8 @@ dependencies = [ "openai-agents>=0.0.14", "opik>=1.7.20", "elasticsearch[async]==8.17.2", + "langfuse>=2.60.5", + "nest-asyncio>=1.6.0", + "pydantic-ai[logfire]>=0.2.4", + "tabulate>=0.9.0", ] diff --git a/scripts.md b/scripts.md new file mode 100644 index 0000000..3d91641 --- /dev/null +++ b/scripts.md @@ -0,0 +1,146 @@ +## 20250519- LangFuse Example + +```bash +REPO=hf://ComplexDataLab/Misinfo_Datasets@ce06269 + +source .env && \ +IGNORE_CACHE=1 uv run -m misinfo_data_eval.entrypoint_langfuse \ +--langfuse_dataset_name cdl-misinfo-ce06269-dry-run-2a \ +--langfuse_populate_dataset \ +--evaluator_model_name gpt-4o-mini-2024-07-18 \ +--web_search_variant oai \ +--source_dataset_paths \ +"$REPO:IFND:test" \ +"$REPO:antivax:test" \ +"$REPO:checkcovid:test" \ +"$REPO:claimskg:test" \ +"$REPO:climate_fever:test" \ +"$REPO:cmu_miscov19:test" \ +"$REPO:coaid:test" \ +"$REPO:covid-19-disinformation:test" \ +"$REPO:covid_19_rumor:test" \ +"$REPO:covid_vaccine_misinfo_mic:test" \ +"$REPO:covidfact:test" \ +"$REPO:defakts:test" \ +"$REPO:esoc:test" \ +"$REPO:fakecovid:test" \ +"$REPO:faviq:test" \ +"$REPO:fever:test" \ +"$REPO:feverous:test" \ +"$REPO:fibvid:test" \ +"$REPO:hover:test" \ +"$REPO:liar:test" \ +"$REPO:liar_new:test" \ +"$REPO:mediaeval:test" \ +"$REPO:mm-covid:test" \ +"$REPO:multiclaim:test" \ +"$REPO:nlp4if:test" \ +"$REPO:pheme:test" \ +"$REPO:pubhealthtab:test" \ +"$REPO:rumors:test" \ +"$REPO:snopes:test" \ +"$REPO:truthseeker2023:test" \ +"$REPO:twitter15:test" \ +"$REPO:twitter16:test" \ +"$REPO:verite:test" \ +"$REPO:wico:test" \ +"$REPO:x_fact:test" \ +--max_concurrency 32 \ +--limit 9 +``` + +## 20250527- Feasibility Evaluation + +```bash +REPO=hf://ComplexDataLab/Misinfo_Datasets@ce06269 + +source .env && \ +IGNORE_CACHE=0 uv run -m misinfo_data_eval.entrypoint_langfuse \ +--langfuse_dataset_name cdl-misinfo-ce06269-1a \ +--evaluator_model_name gpt-4o-mini-2024-07-18 \ +--web_search_variant oai \ +--source_dataset_paths \ +"$REPO:IFND:test" \ +"$REPO:antivax:test" \ +"$REPO:checkcovid:test" \ +"$REPO:claimskg:test" \ +"$REPO:climate_fever:test" \ +"$REPO:cmu_miscov19:test" \ +"$REPO:coaid:test" \ +"$REPO:covid-19-disinformation:test" \ +"$REPO:covid_19_rumor:test" \ +"$REPO:covid_vaccine_misinfo_mic:test" \ +"$REPO:covidfact:test" \ +"$REPO:defakts:test" \ +"$REPO:esoc:test" \ +"$REPO:fakecovid:test" \ +"$REPO:faviq:test" \ +"$REPO:fever:test" \ +"$REPO:feverous:test" \ +"$REPO:fibvid:test" \ +"$REPO:hover:test" \ +"$REPO:liar:test" \ +"$REPO:liar_new:test" \ +"$REPO:mediaeval:test" \ +"$REPO:mm-covid:test" \ +"$REPO:multiclaim:test" \ +"$REPO:nlp4if:test" \ +"$REPO:pheme:test" \ +"$REPO:pubhealthtab:test" \ +"$REPO:rumors:test" \ +"$REPO:snopes:test" \ +"$REPO:truthseeker2023:test" \ +"$REPO:twitter15:test" \ +"$REPO:twitter16:test" \ +"$REPO:verite:test" \ +"$REPO:wico:test" \ +"$REPO:x_fact:test" \ +--max_concurrency 32 \ +--limit 324 +``` + +```bash +# wget -O data/export.jsonl \ +# "https://langfuse-mtrl5-s3.j7n.me/langfuse/exports/1748265146124-lf-traces-export-cmavfjk0b0006p3064utgd9ax.jsonl" +# wget -O data/export.jsonl \ +# "https://langfuse-mtrl5-s3.j7n.me/langfuse/exports/1748267827463-lf-traces-export-cmavfjk0b0006p3064utgd9ax.jsonl" +wget -O data/export.jsonl \ +"https://langfuse-mtrl5-s3.j7n.me/langfuse/exports/1748287226769-lf-traces-export-cmavfjk0b0006p3064utgd9ax.jsonl" + +uv run -m misinfo_data_eval.metrics.feasibility_accuracy_correlation \ +data/export.jsonl +``` + +```javascript + accuracy precision recall f1 mutual_info phi_coefficient tn fp fn tp total +data_source +IFND 0.592476 0.709402 0.728070 0.718615 0.000193 -0.019592 23 68 62 166 319 +checkcovid 0.754545 0.825641 0.889503 0.856383 0.000221 0.021308 5 34 20 161 220 +claimskg 0.602410 0.771242 0.648352 0.704478 0.006497 0.114772 32 35 64 118 249 +climate_fever 0.622378 0.700000 0.785714 0.740385 0.001638 0.057735 12 33 21 77 143 +coaid 0.947368 1.000000 0.944444 0.971429 0.133229 0.687184 1 0 1 17 19 +covid_19_rumor 0.698565 0.815029 0.819767 0.817391 0.001092 -0.045587 5 32 31 141 209 +covidfact 0.582555 0.657025 0.757143 0.703540 0.000054 0.010374 28 83 51 159 321 +defakts 0.619497 0.704698 0.576923 0.634441 0.032041 0.251226 92 44 77 105 318 +esoc 0.511111 0.570681 0.602210 0.586022 0.000049 -0.009847 52 82 72 109 315 +fakecovid 0.715385 0.936842 0.741667 0.827907 0.003362 0.085106 4 6 31 89 130 +faviq 0.799383 0.817891 0.969697 0.887348 0.000808 0.042248 3 57 8 256 324 +fever 0.845833 0.889868 0.943925 0.916100 0.000320 -0.024185 1 25 12 202 240 +feverous 0.738983 0.758007 0.959459 0.846918 0.001489 0.056734 5 68 9 213 295 +fibvid 0.618056 0.797872 0.675676 0.731707 0.003813 0.088209 14 19 36 75 144 +hover 0.642405 0.710744 0.800000 0.752735 0.006726 0.117749 31 70 43 172 316 +liar 0.576324 0.610169 0.765957 0.679245 0.003408 0.082872 41 92 44 144 321 +liar_new 0.638365 0.858407 0.700361 0.771372 0.001838 -0.059219 9 32 83 194 318 +mm-covid 0.661392 0.870536 0.714286 0.784708 0.000444 0.030089 14 29 78 195 316 +multiclaim 0.761364 0.904412 0.809211 0.854167 0.021007 0.219090 11 13 29 123 176 +nlp4if 0.548223 0.520833 0.789474 0.627615 0.008180 0.127331 33 69 20 75 197 +pubhealthtab 0.636000 0.669065 0.673913 0.671480 0.034996 0.263429 66 46 45 93 250 +rumors 0.817143 0.952055 0.847561 0.896774 0.007742 0.137858 4 7 25 139 175 +snopes 0.779167 0.978610 0.788793 0.873508 0.006489 0.124973 4 4 49 183 240 +truthseeker2023 0.800000 1.000000 0.800000 0.888889 0.000000 0.000000 0 0 1 4 5 +twitter15 0.580882 0.738095 0.639175 0.685083 0.002419 0.069865 17 22 35 62 136 +twitter16 0.500000 0.510638 0.571429 0.539326 0.000007 -0.003609 17 23 18 24 82 +verite 0.545946 0.661765 0.703125 0.681818 0.006147 -0.108707 11 46 38 90 185 +x_fact 0.540000 0.619355 0.548571 0.581818 0.002855 0.075542 66 59 79 96 300 +ALL 0.651924 0.760262 0.762927 0.761592 0.006617 0.117013 601 1098 1082 3482 6263 +``` \ No newline at end of file