1+ """
2+ Regression test for extract_aigrant_companies functionality.
3+
4+ This test verifies that data extraction works correctly by extracting
5+ companies that received AI grants along with their batch numbers,
6+ based on the TypeScript extract_aigrant_companies evaluation.
7+ """
8+
9+ import os
10+ import pytest
11+ import pytest_asyncio
12+ from pydantic import BaseModel , Field
13+ from typing import List
14+
15+ from stagehand import Stagehand , StagehandConfig
16+ from stagehand .schemas import ExtractOptions
17+
18+
19+ class Company (BaseModel ):
20+ company : str = Field (..., description = "The name of the company" )
21+ batch : str = Field (..., description = "The batch number of the grant" )
22+
23+
24+ class Companies (BaseModel ):
25+ companies : List [Company ] = Field (..., description = "List of companies that received AI grants" )
26+
27+
28+ class TestExtractAigrantCompanies :
29+ """Regression test for extract_aigrant_companies functionality"""
30+
31+ @pytest .fixture (scope = "class" )
32+ def local_config (self ):
33+ """Configuration for LOCAL mode testing"""
34+ return StagehandConfig (
35+ env = "LOCAL" ,
36+ model_name = "gpt-4o-mini" ,
37+ headless = True ,
38+ verbose = 1 ,
39+ dom_settle_timeout_ms = 2000 ,
40+ model_client_options = {"apiKey" : os .getenv ("MODEL_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )},
41+ )
42+
43+ @pytest .fixture (scope = "class" )
44+ def browserbase_config (self ):
45+ """Configuration for BROWSERBASE mode testing"""
46+ return StagehandConfig (
47+ env = "BROWSERBASE" ,
48+ api_key = os .getenv ("BROWSERBASE_API_KEY" ),
49+ project_id = os .getenv ("BROWSERBASE_PROJECT_ID" ),
50+ model_name = "gpt-4o" ,
51+ headless = False ,
52+ verbose = 2 ,
53+ model_client_options = {"apiKey" : os .getenv ("MODEL_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )},
54+ )
55+
56+ @pytest_asyncio .fixture
57+ async def local_stagehand (self , local_config ):
58+ """Create a Stagehand instance for LOCAL testing"""
59+ stagehand = Stagehand (config = local_config )
60+ await stagehand .init ()
61+ yield stagehand
62+ await stagehand .close ()
63+
64+ @pytest_asyncio .fixture
65+ async def browserbase_stagehand (self , browserbase_config ):
66+ """Create a Stagehand instance for BROWSERBASE testing"""
67+ if not (os .getenv ("BROWSERBASE_API_KEY" ) and os .getenv ("BROWSERBASE_PROJECT_ID" )):
68+ pytest .skip ("Browserbase credentials not available" )
69+
70+ stagehand = Stagehand (config = browserbase_config )
71+ await stagehand .init ()
72+ yield stagehand
73+ await stagehand .close ()
74+
75+ @pytest .mark .asyncio
76+ @pytest .mark .regression
77+ @pytest .mark .local
78+ async def test_extract_aigrant_companies_local (self , local_stagehand ):
79+ """
80+ Regression test: extract_aigrant_companies
81+
82+ Mirrors the TypeScript extract_aigrant_companies evaluation:
83+ - Navigate to AI grant companies test site
84+ - Extract all companies that received AI grants with their batch numbers
85+ - Verify total count is 91
86+ - Verify first company is "Goodfire" in batch "4"
87+ - Verify last company is "Forefront" in batch "1"
88+ """
89+ stagehand = local_stagehand
90+
91+ await stagehand .page .goto ("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/" )
92+
93+ # Extract all companies with their batch numbers
94+ extract_options = ExtractOptions (
95+ instruction = (
96+ "Extract all companies that received the AI grant and group them with their "
97+ "batch numbers as an array of objects. Each object should contain the company "
98+ "name and its corresponding batch number."
99+ ),
100+ schema_definition = Companies
101+ )
102+
103+ result = await stagehand .page .extract (extract_options )
104+
105+ # Both LOCAL and BROWSERBASE modes return the Pydantic model instance directly
106+ companies = result .companies
107+
108+ # Verify total count
109+ expected_length = 91
110+ assert len (companies ) == expected_length , (
111+ f"Expected { expected_length } companies, but got { len (companies )} "
112+ )
113+
114+ # Verify first company
115+ expected_first_item = {
116+ "company" : "Goodfire" ,
117+ "batch" : "4"
118+ }
119+ assert len (companies ) > 0 , "No companies were extracted"
120+ first_company = companies [0 ]
121+ assert first_company .company == expected_first_item ["company" ], (
122+ f"Expected first company to be '{ expected_first_item ['company' ]} ', "
123+ f"but got '{ first_company .company } '"
124+ )
125+ assert first_company .batch == expected_first_item ["batch" ], (
126+ f"Expected first company batch to be '{ expected_first_item ['batch' ]} ', "
127+ f"but got '{ first_company .batch } '"
128+ )
129+
130+ # Verify last company
131+ expected_last_item = {
132+ "company" : "Forefront" ,
133+ "batch" : "1"
134+ }
135+ last_company = companies [- 1 ]
136+ assert last_company .company == expected_last_item ["company" ], (
137+ f"Expected last company to be '{ expected_last_item ['company' ]} ', "
138+ f"but got '{ last_company .company } '"
139+ )
140+ assert last_company .batch == expected_last_item ["batch" ], (
141+ f"Expected last company batch to be '{ expected_last_item ['batch' ]} ', "
142+ f"but got '{ last_company .batch } '"
143+ )
144+
145+ @pytest .mark .asyncio
146+ @pytest .mark .regression
147+ @pytest .mark .api
148+ @pytest .mark .skipif (
149+ not (os .getenv ("BROWSERBASE_API_KEY" ) and os .getenv ("BROWSERBASE_PROJECT_ID" )),
150+ reason = "Browserbase credentials not available"
151+ )
152+ async def test_extract_aigrant_companies_browserbase (self , browserbase_stagehand ):
153+ """
154+ Regression test: extract_aigrant_companies (Browserbase)
155+
156+ Same test as local but running in Browserbase environment.
157+ """
158+ stagehand = browserbase_stagehand
159+
160+ await stagehand .page .goto ("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/" )
161+
162+ # Extract all companies with their batch numbers
163+ extract_options = ExtractOptions (
164+ instruction = (
165+ "Extract all companies that received the AI grant and group them with their "
166+ "batch numbers as an array of objects. Each object should contain the company "
167+ "name and its corresponding batch number."
168+ ),
169+ schema_definition = Companies
170+ )
171+
172+ result = await stagehand .page .extract (extract_options )
173+
174+ # Both LOCAL and BROWSERBASE modes return the Pydantic model instance directly
175+ companies = result .companies
176+
177+ # Verify total count
178+ expected_length = 91
179+ assert len (companies ) == expected_length , (
180+ f"Expected { expected_length } companies, but got { len (companies )} "
181+ )
182+
183+ # Verify first company
184+ expected_first_item = {
185+ "company" : "Goodfire" ,
186+ "batch" : "4"
187+ }
188+ assert len (companies ) > 0 , "No companies were extracted"
189+ first_company = companies [0 ]
190+ assert first_company .company == expected_first_item ["company" ], (
191+ f"Expected first company to be '{ expected_first_item ['company' ]} ', "
192+ f"but got '{ first_company .company } '"
193+ )
194+ assert first_company .batch == expected_first_item ["batch" ], (
195+ f"Expected first company batch to be '{ expected_first_item ['batch' ]} ', "
196+ f"but got '{ first_company .batch } '"
197+ )
198+
199+ # Verify last company
200+ expected_last_item = {
201+ "company" : "Forefront" ,
202+ "batch" : "1"
203+ }
204+ last_company = companies [- 1 ]
205+ assert last_company .company == expected_last_item ["company" ], (
206+ f"Expected last company to be '{ expected_last_item ['company' ]} ', "
207+ f"but got '{ last_company .company } '"
208+ )
209+ assert last_company .batch == expected_last_item ["batch" ], (
210+ f"Expected last company batch to be '{ expected_last_item ['batch' ]} ', "
211+ f"but got '{ last_company .batch } '"
212+ )
0 commit comments