11import numpy as np
22from numpy .linalg import norm
33from numba import njit
4- from numba import float64
4+ from numba import float64 , int64 , bool_
55
66from skglm .datafits .base import BaseDatafit
77from skglm .utils .sparse_ops import spectral_norm
@@ -547,90 +547,100 @@ def intercept_update_self(self, y, Xw):
547547
548548
549549class Cox (BaseDatafit ):
550- r"""Cox datafit for survival analysis with Breslow estimate .
550+ r"""Cox datafit for survival analysis.
551551
552- The datafit reads [1]
553-
554- .. math::
555-
556- 1 / n_"samples" \sum_(i=1)^(n_"samples") -s_i \langle x_i, w \rangle
557- + \log (\sum_(j | y_j \geq y_i) e^{\langle x_i, w \rangle})
558-
559- where :math:`s_i` indicates the sample censorship and :math:`tm`
560- is the vector recording the time of event occurrences.
561-
562- Defining the matrix :math:`B` with
563- :math:`B_{i,j} = 1` if :math:`tm_j \geq tm_i` and :math:`0` otherwise,
564- the datafit can be rewritten in the following compact form
565-
566- .. math::
567-
568- 1 / n_"samples" \langle s, Xw \rangle
569- + 1 / n_"samples" \langle s, \log B e^{Xw} \rangle
552+ Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>` for details.
570553
554+ Parameters
555+ ----------
556+ use_efron : bool, default=False
557+ If ``True`` uses Efron estimate to handle tied observations.
571558
572559 Attributes
573560 ----------
574561 B : array-like, shape (n_samples, n_samples)
575562 Matrix where every ``(i, j)`` entry (row, column) equals ``1``
576- if ``tm[j] >= tm[i]`` and `0 ` otherwise. This matrix is initialized
563+ if ``tm[j] >= tm[i]`` and ``0` ` otherwise. This matrix is initialized
577564 using the ``.initialize`` method.
578565
579- References
580- ----------
581- .. [1] DY Lin. On the Breslow estimator.
582- Lifetime data analysis, 13:471–480, 2007.
566+ H_indices : array-like, shape (n_samples,)
567+ Indices of observations with the same occurrence times stacked horizontally
568+ as ``[group_1, group_2, ...]``. This array is initialized
569+ when calling ``.initialize`` method when ``use_efron=True``.
570+
571+ H_indptr : array-like, (np.unique(tm) + 1,)
572+ Array where two consecutive elements delimits a group of observations
573+ having the same occurrence times.
583574 """
584575
585- def __init__ (self ):
586- pass
576+ def __init__ (self , use_efron = False ):
577+ self . use_efron = use_efron
587578
588579 def get_spec (self ):
589580 return (
581+ ('use_efron' , bool_ ),
590582 ('B' , float64 [:, ::1 ]),
583+ ('H_indptr' , int64 [:]),
584+ ('H_indices' , int64 [:]),
591585 )
592586
593587 def params_to_dict (self ):
594- return dict ()
588+ return dict (use_efron = self . use_efron )
595589
596590 def value (self , y , w , Xw ):
597591 """Compute the value of the datafit."""
598592 tm , s = y
599593 n_samples = Xw .shape [0 ]
600594
601- out = - (s @ Xw ) + s @ np .log (self .B @ np .exp (Xw ))
595+ # compute inside log term
596+ exp_Xw = np .exp (Xw )
597+ B_exp_Xw = self .B @ exp_Xw
598+ if self .use_efron :
599+ B_exp_Xw -= self ._A_dot_vec (exp_Xw )
600+
601+ out = - (s @ Xw ) + s @ np .log (B_exp_Xw )
602602 return out / n_samples
603603
604604 def raw_grad (self , y , Xw ):
605605 r"""Compute gradient of datafit w.r.t. ``Xw``.
606606
607- The raw gradient reads
608-
609- (-s + exp_Xw * (B.T @ (s / B @ exp_Xw)) / n_samples
607+ Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
608+ equation 4 for details.
610609 """
611610 tm , s = y
612611 n_samples = Xw .shape [0 ]
613612
614613 exp_Xw = np .exp (Xw )
615614 B_exp_Xw = self .B @ exp_Xw
615+ if self .use_efron :
616+ B_exp_Xw -= self ._A_dot_vec (exp_Xw )
617+
618+ s_over_B_exp_Xw = s / B_exp_Xw
619+ out = - s + exp_Xw * (self .B .T @ (s_over_B_exp_Xw ))
620+ if self .use_efron :
621+ out -= exp_Xw * self ._AT_dot_vec (s_over_B_exp_Xw )
616622
617- out = - s + exp_Xw * (self .B .T @ (s / B_exp_Xw ))
618623 return out / n_samples
619624
620625 def raw_hessian (self , y , Xw ):
621626 """Compute a diagonal upper bound of the datafit's Hessian w.r.t. ``Xw``.
622627
623- The diagonal upper bound reads
624-
625- exp_Xw * (B.T @ s / B_exp_Xw) / n_samples
628+ Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
629+ equation 6 for details.
626630 """
627631 tm , s = y
628632 n_samples = Xw .shape [0 ]
629633
630634 exp_Xw = np .exp (Xw )
631635 B_exp_Xw = self .B @ exp_Xw
636+ if self .use_efron :
637+ B_exp_Xw -= self ._A_dot_vec (exp_Xw )
638+
639+ s_over_B_exp_Xw = s / B_exp_Xw
640+ out = exp_Xw * (self .B .T @ (s_over_B_exp_Xw ))
641+ if self .use_efron :
642+ out -= exp_Xw * self ._AT_dot_vec (s_over_B_exp_Xw )
632643
633- out = exp_Xw * (self .B .T @ (s / B_exp_Xw ))
634644 return out / n_samples
635645
636646 def initialize (self , X , y ):
@@ -640,9 +650,58 @@ def initialize(self, X, y):
640650 tm_as_col = tm .reshape ((- 1 , 1 ))
641651 self .B = (tm >= tm_as_col ).astype (X .dtype )
642652
653+ if self .use_efron :
654+ H_indices = np .argsort (tm )
655+ # filter out censored data
656+ H_indices = H_indices [s [H_indices ] != 0 ]
657+ n_uncensored_samples = H_indices .shape [0 ]
658+
659+ # build H_indptr
660+ H_indptr = [0 ]
661+ count = 1
662+ for i in range (1 , n_uncensored_samples ):
663+ if tm [H_indices [i - 1 ]] == tm [H_indices [i ]]:
664+ count += 1
665+ else :
666+ H_indptr .append (count + H_indptr [- 1 ])
667+ count = 1
668+ H_indptr .append (n_uncensored_samples )
669+ H_indptr = np .asarray (H_indptr , dtype = np .int64 )
670+
671+ # save in instance
672+ self .H_indptr = H_indptr
673+ self .H_indices = H_indices
674+
643675 def initialize_sparse (self , X_data , X_indptr , X_indices , y ):
644676 """Initialize the datafit attributes in sparse dataset case."""
645- tm , s = y
677+ # initialize_sparse and initialize have the same implementation
678+ # small hack to avoid repetitive code: pass in X_data as only its dtype is used
679+ self .initialize (X_data , y )
646680
647- tm_as_col = tm .reshape ((- 1 , 1 ))
648- self .B = (tm >= tm_as_col ).astype (X_data .dtype )
681+ def _A_dot_vec (self , vec ):
682+ out = np .zeros_like (vec )
683+ n_H = self .H_indptr .shape [0 ] - 1
684+
685+ for idx in range (n_H ):
686+ current_H_idx = self .H_indices [self .H_indptr [idx ]: self .H_indptr [idx + 1 ]]
687+ size_current_H = current_H_idx .shape [0 ]
688+ frac_range = np .arange (size_current_H , dtype = vec .dtype ) / size_current_H
689+
690+ sum_vec_H = np .sum (vec [current_H_idx ])
691+ out [current_H_idx ] = sum_vec_H * frac_range
692+
693+ return out
694+
695+ def _AT_dot_vec (self , vec ):
696+ out = np .zeros_like (vec )
697+ n_H = self .H_indptr .shape [0 ] - 1
698+
699+ for idx in range (n_H ):
700+ current_H_idx = self .H_indices [self .H_indptr [idx ]: self .H_indptr [idx + 1 ]]
701+ size_current_H = current_H_idx .shape [0 ]
702+ frac_range = np .arange (size_current_H , dtype = vec .dtype ) / size_current_H
703+
704+ weighted_sum_vec_H = vec [current_H_idx ] @ frac_range
705+ out [current_H_idx ] = weighted_sum_vec_H * np .ones (size_current_H )
706+
707+ return out
0 commit comments