@@ -310,8 +310,6 @@ def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool =
310310 Batch of flattened nonzero matrix elements for triangular matrix.
311311 upper : bool
312312 Return upper triangular matrix if True, else lower triangular matrix. Default is False.
313- positive_diag : bool
314- Whether to apply a softplus operation to diagonal elements. Default is False.
315313
316314 Returns
317315 -------
@@ -327,47 +325,70 @@ def fill_triangular_matrix(x: Tensor, upper: bool = False, positive_diag: bool =
327325 batch_shape = x .shape [:- 1 ]
328326 m = x .shape [- 1 ]
329327
330- if m == 1 :
331- y = keras .ops .reshape (x , (- 1 , 1 , 1 ))
332- if positive_diag :
333- y = keras .activations .softplus (y )
334- return y
335-
336- # Calculate matrix shape
337- n = (0.25 + 2 * m ) ** 0.5 - 0.5
338- if not np .isclose (np .floor (n ), n ):
339- raise ValueError (f"Input right-most shape ({ m } ) does not correspond to a triangular matrix." )
340- else :
341- n = int (n )
342-
343- # Trick: Create triangular matrix by concatenating with a flipped version of its tail, then reshape.
344- x_tail = keras .ops .take (x , indices = list (range ((m - (n ** 2 - m )), x .shape [- 1 ])), axis = - 1 )
345- if not upper :
346- y = keras .ops .concatenate ([x_tail , keras .ops .flip (x , axis = - 1 )], axis = len (batch_shape ))
347- y = keras .ops .reshape (y , (- 1 , n , n ))
348- y = keras .ops .tril (y )
349-
350- if positive_diag :
351- y_offdiag = keras .ops .tril (y , k = - 1 )
352- # carve out diagonal, by setting upper and lower offdiagonals to zero
353- y_diag = keras .ops .tril (
354- keras .ops .triu (keras .activations .softplus (y )), # apply softplus to enforce positivity
328+ if m > 1 : # Matrix is larger than than 1x1
329+ # Calculate matrix shape
330+ n = (0.25 + 2 * m ) ** 0.5 - 0.5
331+ if not np .isclose (np .floor (n ), n ):
332+ raise ValueError (f"Input right-most shape ({ m } ) does not correspond to a triangular matrix." )
333+ else :
334+ n = int (n )
335+
336+ # Trick: Create triangular matrix by concatenating with a flipped version of itself, then reshape.
337+ if not upper :
338+ x_list = [x , keras .ops .flip (x [..., n :], axis = - 1 )]
339+
340+ y = keras .ops .concatenate (x_list , axis = len (batch_shape ))
341+ y = keras .ops .reshape (y , (- 1 , n , n ))
342+ y = keras .ops .tril (y )
343+
344+ else :
345+ x_list = [x [..., n :], keras .ops .flip (x , axis = - 1 )]
346+
347+ y = keras .ops .concatenate (x_list , axis = len (batch_shape ))
348+ y = keras .ops .reshape (y , (- 1 , n , n ))
349+ y = keras .ops .triu (
350+ y ,
355351 )
356- y = y_diag + y_offdiag
357352
358- else :
359- y = keras .ops .concatenate ([x , keras .ops .flip (x_tail , axis = - 1 )], axis = len (batch_shape ))
360- y = keras .ops .reshape (y , (- 1 , n , n ))
361- y = keras .ops .triu (
362- y ,
363- )
364-
365- if positive_diag :
366- y_offdiag = keras .ops .triu (y , k = 1 )
367- # carve out diagonal, by setting upper and lower offdiagonals to zero
368- y_diag = keras .ops .tril (
369- keras .ops .triu (keras .activations .softplus (y )), # apply softplus to enforce positivity
370- )
371- y = y_diag + y_offdiag
353+ else : # Matrix is 1x1
354+ y = keras .ops .reshape (x , (- 1 , 1 , 1 ))
372355
373356 return y
357+
358+
359+ def positive_diag (x : Tensor , method = "default" ) -> Tensor :
360+ """
361+ Ensures that matrix elements on diagonal are positive.
362+
363+ Parameters
364+ ----------
365+ x : Tensor of shape (batch_size, n, n)
366+ Batch of matrices.
367+ method : str, optional
368+ Method by which to ensure positivity of diagonal entries. Choose from
369+ - "shifted_softplus": softplus(x + 0.5413)
370+ - "exp": exp(x)
371+ Both methods map a matrix filled with zeros to the unit matrix.
372+ Default is "shifted_softplus".
373+
374+ Returns
375+ -------
376+ Tensor of shape (batch_size, n, n)
377+ """
378+ # ensure positivity
379+ match method :
380+ case "default" | "shifted_softplus" :
381+ x_positive = keras .activations .softplus (x + 0.5413 )
382+ case "exp" :
383+ x_positive = keras .ops .exp (x )
384+
385+ # zero all offdiagonals
386+ x_diag_positive = keras .ops .tril (keras .ops .triu (x_positive ))
387+
388+ # zero diagonal entries
389+ x_offdiag = keras .ops .triu (x , k = 1 ) + keras .ops .tril (x , k = - 1 )
390+
391+ # sum to get full matrices with softplus applied only to diagonal entries
392+ x = x_diag_positive + x_offdiag
393+
394+ return x
0 commit comments