Skip to content

Commit e095daf

Browse files
committed
fix: handle catastrophic cancellation
1 parent c9f12c0 commit e095daf

File tree

1 file changed

+36
-14
lines changed

1 file changed

+36
-14
lines changed

pandas/_libs/window/aggregations.pyx

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,12 @@ cdef void add_var(
359359
float64_t *compensation,
360360
int64_t *num_consecutive_same_value,
361361
float64_t *prev_value,
362+
bint *numerically_unstable,
362363
) noexcept nogil:
363364
""" add a value from the var calc """
364365
cdef:
365366
float64_t delta, prev_mean, y, t
367+
float64_t prev_m2 = ssqdm_x[0]
366368

367369
# GH#21813, if msvc 2017 bug is resolved, we should be OK with != instead of `isnan`
368370
if val != val:
@@ -392,17 +394,23 @@ cdef void add_var(
392394
mean_x[0] = 0
393395
ssqdm_x[0] = ssqdm_x[0] + (val - prev_mean) * (val - mean_x[0])
394396

397+
if prev_m2 * InvCondTol > ssqdm_x[0]:
398+
# possible catastrophic cancellation
399+
numerically_unstable[0] = True
400+
395401

396402
cdef void remove_var(
397403
float64_t val,
398404
float64_t *nobs,
399405
float64_t *mean_x,
400406
float64_t *ssqdm_x,
401-
float64_t *compensation
407+
float64_t *compensation,
408+
bint *numerically_unstable,
402409
) noexcept nogil:
403410
""" remove a value from the var calc """
404411
cdef:
405412
float64_t delta, prev_mean, y, t
413+
float64_t prev_m2 = ssqdm_x[0]
406414
if val == val:
407415
nobs[0] = nobs[0] - 1
408416
if nobs[0]:
@@ -416,9 +424,14 @@ cdef void remove_var(
416424
delta = t
417425
mean_x[0] = mean_x[0] - delta / nobs[0]
418426
ssqdm_x[0] = ssqdm_x[0] - (val - prev_mean) * (val - mean_x[0])
427+
428+
if prev_m2 * InvCondTol > ssqdm_x[0]:
429+
# possible catastrophic cancellation
430+
numerically_unstable[0] = True
419431
else:
420432
mean_x[0] = 0
421433
ssqdm_x[0] = 0
434+
numerically_unstable[0] = False
422435

423436

424437
def roll_var(const float64_t[:] values, ndarray[int64_t] start,
@@ -433,6 +446,7 @@ def roll_var(const float64_t[:] values, ndarray[int64_t] start,
433446
Py_ssize_t i, j, N = len(start)
434447
ndarray[float64_t] output
435448
bint is_monotonic_increasing_bounds
449+
bint requires_recompute, numerically_unstable
436450

437451
minp = max(minp, 1)
438452
is_monotonic_increasing_bounds = is_monotonic_increasing_start_end_bounds(
@@ -449,30 +463,38 @@ def roll_var(const float64_t[:] values, ndarray[int64_t] start,
449463

450464
# Over the first window, observations can only be added
451465
# never removed
452-
if i == 0 or not is_monotonic_increasing_bounds or s < end[i]:
453-
454-
prev_value = values[s]
455-
num_consecutive_same_value = 0
456-
457-
mean_x = ssqdm_x = nobs = compensation_add = compensation_remove = 0
458-
for j in range(s, e):
459-
add_var(values[j], &nobs, &mean_x, &ssqdm_x, &compensation_add,
460-
&num_consecutive_same_value, &prev_value)
461-
462-
else:
466+
requires_recompute = (
467+
i == 0
468+
or not is_monotonic_increasing_bounds
469+
or s >= end[i - 1]
470+
)
463471

472+
if not requires_recompute:
464473
# After the first window, observations can both be added
465474
# and removed
466475

467476
# calculate deletes
468477
for j in range(start[i - 1], s):
469478
remove_var(values[j], &nobs, &mean_x, &ssqdm_x,
470-
&compensation_remove)
479+
&compensation_remove, &numerically_unstable)
471480

472481
# calculate adds
473482
for j in range(end[i - 1], e):
474483
add_var(values[j], &nobs, &mean_x, &ssqdm_x, &compensation_add,
475-
&num_consecutive_same_value, &prev_value)
484+
&num_consecutive_same_value, &prev_value,
485+
&numerically_unstable)
486+
487+
if requires_recompute or numerically_unstable:
488+
489+
prev_value = values[s]
490+
num_consecutive_same_value = 0
491+
492+
mean_x = ssqdm_x = nobs = compensation_add = compensation_remove = 0
493+
for j in range(s, e):
494+
add_var(values[j], &nobs, &mean_x, &ssqdm_x, &compensation_add,
495+
&num_consecutive_same_value, &prev_value,
496+
&numerically_unstable)
497+
numerically_unstable = False
476498

477499
output[i] = calc_var(minp, ddof, nobs, ssqdm_x, num_consecutive_same_value)
478500

0 commit comments

Comments
 (0)