|
7 | 7 | from typing import Any, Optional |
8 | 8 |
|
9 | 9 | import httpx |
| 10 | +import nest_asyncio |
10 | 11 | from dotenv import load_dotenv |
11 | 12 | from playwright.async_api import ( |
12 | 13 | BrowserContext, |
|
16 | 17 | from playwright.async_api import Page as PlaywrightPage |
17 | 18 |
|
18 | 19 | from .agent import Agent |
19 | | -from .api import _create_session, _execute |
| 20 | +from .api import _create_session, _execute, _get_replay_metrics |
20 | 21 | from .browser import ( |
21 | 22 | cleanup_browser_resources, |
22 | 23 | connect_browserbase_browser, |
@@ -206,7 +207,7 @@ def __init__( |
206 | 207 | ) |
207 | 208 |
|
208 | 209 | # Initialize metrics tracking |
209 | | - self.metrics = StagehandMetrics() |
| 210 | + self._local_metrics = StagehandMetrics() # Internal storage for local metrics |
210 | 211 | self._inference_start_time = 0 # To track inference time |
211 | 212 |
|
212 | 213 | # Validate env |
@@ -372,26 +373,26 @@ def update_metrics( |
372 | 373 | inference_time_ms: Time taken for inference in milliseconds |
373 | 374 | """ |
374 | 375 | if function_name == StagehandFunctionName.ACT: |
375 | | - self.metrics.act_prompt_tokens += prompt_tokens |
376 | | - self.metrics.act_completion_tokens += completion_tokens |
377 | | - self.metrics.act_inference_time_ms += inference_time_ms |
| 376 | + self._local_metrics.act_prompt_tokens += prompt_tokens |
| 377 | + self._local_metrics.act_completion_tokens += completion_tokens |
| 378 | + self._local_metrics.act_inference_time_ms += inference_time_ms |
378 | 379 | elif function_name == StagehandFunctionName.EXTRACT: |
379 | | - self.metrics.extract_prompt_tokens += prompt_tokens |
380 | | - self.metrics.extract_completion_tokens += completion_tokens |
381 | | - self.metrics.extract_inference_time_ms += inference_time_ms |
| 380 | + self._local_metrics.extract_prompt_tokens += prompt_tokens |
| 381 | + self._local_metrics.extract_completion_tokens += completion_tokens |
| 382 | + self._local_metrics.extract_inference_time_ms += inference_time_ms |
382 | 383 | elif function_name == StagehandFunctionName.OBSERVE: |
383 | | - self.metrics.observe_prompt_tokens += prompt_tokens |
384 | | - self.metrics.observe_completion_tokens += completion_tokens |
385 | | - self.metrics.observe_inference_time_ms += inference_time_ms |
| 384 | + self._local_metrics.observe_prompt_tokens += prompt_tokens |
| 385 | + self._local_metrics.observe_completion_tokens += completion_tokens |
| 386 | + self._local_metrics.observe_inference_time_ms += inference_time_ms |
386 | 387 | elif function_name == StagehandFunctionName.AGENT: |
387 | | - self.metrics.agent_prompt_tokens += prompt_tokens |
388 | | - self.metrics.agent_completion_tokens += completion_tokens |
389 | | - self.metrics.agent_inference_time_ms += inference_time_ms |
| 388 | + self._local_metrics.agent_prompt_tokens += prompt_tokens |
| 389 | + self._local_metrics.agent_completion_tokens += completion_tokens |
| 390 | + self._local_metrics.agent_inference_time_ms += inference_time_ms |
390 | 391 |
|
391 | 392 | # Always update totals |
392 | | - self.metrics.total_prompt_tokens += prompt_tokens |
393 | | - self.metrics.total_completion_tokens += completion_tokens |
394 | | - self.metrics.total_inference_time_ms += inference_time_ms |
| 393 | + self._local_metrics.total_prompt_tokens += prompt_tokens |
| 394 | + self._local_metrics.total_completion_tokens += completion_tokens |
| 395 | + self._local_metrics.total_inference_time_ms += inference_time_ms |
395 | 396 |
|
396 | 397 | def update_metrics_from_response( |
397 | 398 | self, |
@@ -426,9 +427,9 @@ def update_metrics_from_response( |
426 | 427 | f"{completion_tokens} completion tokens, {time_ms}ms" |
427 | 428 | ) |
428 | 429 | self.logger.debug( |
429 | | - f"Total metrics: {self.metrics.total_prompt_tokens} prompt tokens, " |
430 | | - f"{self.metrics.total_completion_tokens} completion tokens, " |
431 | | - f"{self.metrics.total_inference_time_ms}ms" |
| 430 | + f"Total metrics: {self._local_metrics.total_prompt_tokens} prompt tokens, " |
| 431 | + f"{self._local_metrics.total_completion_tokens} completion tokens, " |
| 432 | + f"{self._local_metrics.total_inference_time_ms}ms" |
432 | 433 | ) |
433 | 434 | else: |
434 | 435 | # Try to extract from _hidden_params or other locations |
@@ -736,7 +737,50 @@ def page(self) -> Optional[StagehandPage]: |
736 | 737 |
|
737 | 738 | return self._live_page_proxy |
738 | 739 |
|
| 740 | + def __getattribute__(self, name): |
| 741 | + """ |
| 742 | + Intercept access to 'metrics' to fetch from API when use_api=True. |
| 743 | + """ |
| 744 | + if name == "metrics": |
| 745 | + use_api = ( |
| 746 | + object.__getattribute__(self, "use_api") |
| 747 | + if hasattr(self, "use_api") |
| 748 | + else False |
| 749 | + ) |
| 750 | + |
| 751 | + if use_api: |
| 752 | + # Need to fetch from API |
| 753 | + try: |
| 754 | + # Get the _get_replay_metrics method |
| 755 | + get_replay_metrics = object.__getattribute__( |
| 756 | + self, "_get_replay_metrics" |
| 757 | + ) |
| 758 | + |
| 759 | + # Try to get current event loop |
| 760 | + try: |
| 761 | + asyncio.get_running_loop() |
| 762 | + # We're in an async context, need to handle this carefully |
| 763 | + # Create a new task and wait for it |
| 764 | + nest_asyncio.apply() |
| 765 | + return asyncio.run(get_replay_metrics()) |
| 766 | + except RuntimeError: |
| 767 | + # No event loop running, we can use asyncio.run directly |
| 768 | + return asyncio.run(get_replay_metrics()) |
| 769 | + except Exception as e: |
| 770 | + # Log error and return empty metrics |
| 771 | + logger = object.__getattribute__(self, "logger") |
| 772 | + if logger: |
| 773 | + logger.error(f"Failed to fetch metrics from API: {str(e)}") |
| 774 | + return StagehandMetrics() |
| 775 | + else: |
| 776 | + # Return local metrics |
| 777 | + return object.__getattribute__(self, "_local_metrics") |
| 778 | + |
| 779 | + # For all other attributes, use normal behavior |
| 780 | + return object.__getattribute__(self, name) |
| 781 | + |
739 | 782 |
|
740 | 783 | # Bind the imported API methods to the Stagehand class |
741 | 784 | Stagehand._create_session = _create_session |
742 | 785 | Stagehand._execute = _execute |
| 786 | +Stagehand._get_replay_metrics = _get_replay_metrics |
0 commit comments