11import re
22import time
3+ from concurrent .futures import ThreadPoolExecutor
34from io import StringIO
45from typing import Optional
56
@@ -127,7 +128,10 @@ def simple_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
127128 running_output ,
128129 ), running_output
129130
130- qpl ._finish_pbar (pbar )
131+ with ThreadPoolExecutor () as executor :
132+ future = executor .submit (lambda : None )
133+ qpl ._finish_pbar (future , pbar )
134+
131135 finished_output = pbarOutputStream .getvalue ().split ("\r " )[- 1 ]
132136 assert re .match (
133137 r"test task: 100%\|##########\| 100.0/100 \[00:00<00:00, \d+.\d+%/s, status: FINISHED\]" , finished_output
@@ -149,7 +153,10 @@ def simple_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
149153
150154 qpl ._update_pbar (pbar , TaskWithProgress ("test task" , "n/a" , "PENDING" , "" ))
151155 qpl ._update_pbar (pbar , TaskWithProgress ("test task" , "" , "RUNNING" , "root 1/1::leaf" ))
152- qpl ._finish_pbar (pbar )
156+
157+ with ThreadPoolExecutor () as executor :
158+ future = executor .submit (lambda : None )
159+ qpl ._finish_pbar (future , pbar )
153160
154161 assert pbarOutputStream .getvalue ().rstrip () == "" .join (
155162 [
@@ -160,6 +167,34 @@ def simple_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
160167 )
161168
162169
170+ def test_progress_bar_with_failing_query () -> None :
171+ def simple_run_cypher (query : str , database : Optional [str ] = None ) -> DataFrame :
172+ raise NotImplementedError ("Should not be called!" )
173+
174+ def failing_runnable () -> DataFrame :
175+ raise NotImplementedError ("Should not be called!" )
176+
177+ with StringIO () as pbarOutputStream :
178+ qpl = QueryProgressLogger (
179+ simple_run_cypher ,
180+ lambda : ServerVersion (3 , 0 , 0 ),
181+ progress_bar_options = {"file" : pbarOutputStream , "mininterval" : 100 },
182+ )
183+
184+ with ThreadPoolExecutor () as executor :
185+ future = executor .submit (failing_runnable )
186+
187+ pbar = qpl ._init_pbar (TaskWithProgress ("test task" , "n/a" , "PENDING" , "" ))
188+ qpl ._finish_pbar (future , pbar )
189+
190+ assert pbarOutputStream .getvalue ().rstrip () == "" .join (
191+ [
192+ "\r test task [elapsed: 00:00 ]" ,
193+ "\r test task [elapsed: 00:00 , status: FAILED]" ,
194+ ]
195+ )
196+
197+
163198def test_uses_static_store () -> None :
164199 def fake_run_cypher (query : str , database : Optional [str ] = None ) -> DataFrame :
165200 return DataFrame ([{"progress" : "n/a" , "taskName" : "Test task" , "status" : "RUNNING" }])
0 commit comments