66// option. This file may not be copied, modified, or distributed
77// except according to those terms.
88
9- use core:: cmp:: { max , min} ;
9+ use core:: cmp:: min;
1010
1111use num_traits:: Zero ;
1212
13- use crate :: { dimension:: is_layout_f, Array , ArrayBase , Axis , Data , Dimension , IntoDimension , Zip } ;
13+ use crate :: {
14+ dimension:: { is_layout_c, is_layout_f} ,
15+ Array ,
16+ ArrayBase ,
17+ Axis ,
18+ Data ,
19+ Dimension ,
20+ Zip ,
21+ } ;
1422
1523impl < S , A , D > ArrayBase < S , D >
1624where
1725 S : Data < Elem = A > ,
1826 D : Dimension ,
1927 A : Clone + Zero ,
20- D :: Smaller : Copy ,
2128{
2229 /// Upper triangular of an array.
2330 ///
@@ -30,38 +37,56 @@ where
3037 /// ```
3138 /// use ndarray::array;
3239 ///
33- /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
34- /// let res = arr.triu(0);
35- /// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
40+ /// let arr = array![
41+ /// [1, 2, 3],
42+ /// [4, 5, 6],
43+ /// [7, 8, 9]
44+ /// ];
45+ /// assert_eq!(
46+ /// arr.triu(0),
47+ /// array![
48+ /// [1, 2, 3],
49+ /// [0, 5, 6],
50+ /// [0, 0, 9]
51+ /// ]
52+ /// );
3653 /// ```
3754 pub fn triu ( & self , k : isize ) -> Array < A , D >
3855 {
3956 if self . ndim ( ) <= 1 {
4057 return self . to_owned ( ) ;
4158 }
42- match is_layout_f ( & self . dim , & self . strides ) {
43- true => {
44- let n = self . ndim ( ) ;
45- let mut x = self . view ( ) ;
46- x. swap_axes ( n - 2 , n - 1 ) ;
47- let mut tril = x. tril ( -k) ;
48- tril. swap_axes ( n - 2 , n - 1 ) ;
49-
50- tril
51- }
52- false => {
53- let mut res = Array :: zeros ( self . raw_dim ( ) ) ;
54- Zip :: indexed ( self . rows ( ) )
55- . and ( res. rows_mut ( ) )
56- . for_each ( |i, src, mut dst| {
57- let row_num = i. into_dimension ( ) . last_elem ( ) ;
58- let lower = max ( row_num as isize + k, 0 ) ;
59- dst. slice_mut ( s ! [ lower..] ) . assign ( & src. slice ( s ! [ lower..] ) ) ;
60- } ) ;
61-
62- res
63- }
59+
60+ // Performance optimization for F-order arrays.
61+ // C-order array check prevents infinite recursion in edge cases like [[1]].
62+ // k-size check prevents underflow when k == isize::MIN
63+ let n = self . ndim ( ) ;
64+ if is_layout_f ( & self . dim , & self . strides ) && !is_layout_c ( & self . dim , & self . strides ) && k > isize:: MIN {
65+ let mut x = self . view ( ) ;
66+ x. swap_axes ( n - 2 , n - 1 ) ;
67+ let mut tril = x. tril ( -k) ;
68+ tril. swap_axes ( n - 2 , n - 1 ) ;
69+
70+ return tril;
6471 }
72+
73+ let mut res = Array :: zeros ( self . raw_dim ( ) ) ;
74+ let ncols = self . len_of ( Axis ( n - 1 ) ) ;
75+ let nrows = self . len_of ( Axis ( n - 2 ) ) ;
76+ let indices = Array :: from_iter ( 0 ..nrows) ;
77+ Zip :: from ( self . rows ( ) )
78+ . and ( res. rows_mut ( ) )
79+ . and_broadcast ( & indices)
80+ . for_each ( |src, mut dst, row_num| {
81+ let mut lower = match k >= 0 {
82+ true => row_num. saturating_add ( k as usize ) , // Avoid overflow
83+ false => row_num. saturating_sub ( k. unsigned_abs ( ) ) , // Avoid underflow, go to 0
84+ } ;
85+ lower = min ( lower, ncols) ;
86+ dst. slice_mut ( s ! [ lower..] ) . assign ( & src. slice ( s ! [ lower..] ) ) ;
87+ } ) ;
88+
89+ res
6590 }
6691
6792 /// Lower triangular of an array.
@@ -75,45 +100,65 @@ where
75100 /// ```
76101 /// use ndarray::array;
77102 ///
78- /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
79- /// let res = arr.tril(0);
80- /// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
103+ /// let arr = array![
104+ /// [1, 2, 3],
105+ /// [4, 5, 6],
106+ /// [7, 8, 9]
107+ /// ];
108+ /// assert_eq!(
109+ /// arr.tril(0),
110+ /// array![
111+ /// [1, 0, 0],
112+ /// [4, 5, 0],
113+ /// [7, 8, 9]
114+ /// ]
115+ /// );
81116 /// ```
82117 pub fn tril ( & self , k : isize ) -> Array < A , D >
83118 {
84119 if self . ndim ( ) <= 1 {
85120 return self . to_owned ( ) ;
86121 }
87- match is_layout_f ( & self . dim , & self . strides ) {
88- true => {
89- let n = self . ndim ( ) ;
90- let mut x = self . view ( ) ;
91- x. swap_axes ( n - 2 , n - 1 ) ;
92- let mut tril = x. triu ( -k) ;
93- tril. swap_axes ( n - 2 , n - 1 ) ;
94-
95- tril
96- }
97- false => {
98- let mut res = Array :: zeros ( self . raw_dim ( ) ) ;
99- let ncols = self . len_of ( Axis ( self . ndim ( ) - 1 ) ) as isize ;
100- Zip :: indexed ( self . rows ( ) )
101- . and ( res. rows_mut ( ) )
102- . for_each ( |i, src, mut dst| {
103- let row_num = i. into_dimension ( ) . last_elem ( ) ;
104- let upper = min ( row_num as isize + k, ncols) + 1 ;
105- dst. slice_mut ( s ! [ ..upper] ) . assign ( & src. slice ( s ! [ ..upper] ) ) ;
106- } ) ;
107-
108- res
109- }
122+
123+ // Performance optimization for F-order arrays.
124+ // C-order array check prevents infinite recursion in edge cases like [[1]].
125+ // k-size check prevents underflow when k == isize::MIN
126+ let n = self . ndim ( ) ;
127+ if is_layout_f ( & self . dim , & self . strides ) && !is_layout_c ( & self . dim , & self . strides ) && k > isize:: MIN {
128+ let mut x = self . view ( ) ;
129+ x. swap_axes ( n - 2 , n - 1 ) ;
130+ let mut tril = x. triu ( -k) ;
131+ tril. swap_axes ( n - 2 , n - 1 ) ;
132+
133+ return tril;
110134 }
135+
136+ let mut res = Array :: zeros ( self . raw_dim ( ) ) ;
137+ let ncols = self . len_of ( Axis ( n - 1 ) ) ;
138+ let nrows = self . len_of ( Axis ( n - 2 ) ) ;
139+ let indices = Array :: from_iter ( 0 ..nrows) ;
140+ Zip :: from ( self . rows ( ) )
141+ . and ( res. rows_mut ( ) )
142+ . and_broadcast ( & indices)
143+ . for_each ( |src, mut dst, row_num| {
144+ // let row_num = i.into_dimension().last_elem();
145+ let mut upper = match k >= 0 {
146+ true => row_num. saturating_add ( k as usize ) . saturating_add ( 1 ) , // Avoid overflow
147+ false => row_num. saturating_sub ( ( k + 1 ) . unsigned_abs ( ) ) , // Avoid underflow
148+ } ;
149+ upper = min ( upper, ncols) ;
150+ dst. slice_mut ( s ! [ ..upper] ) . assign ( & src. slice ( s ! [ ..upper] ) ) ;
151+ } ) ;
152+
153+ res
111154 }
112155}
113156
114157#[ cfg( test) ]
115158mod tests
116159{
160+ use core:: isize;
161+
117162 use crate :: { array, dimension, Array0 , Array1 , Array2 , Array3 , ShapeBuilder } ;
118163 use alloc:: vec;
119164
@@ -188,6 +233,19 @@ mod tests
188233 assert_eq ! ( res, array![ [ 1 , 0 , 0 ] , [ 4 , 5 , 0 ] , [ 7 , 8 , 9 ] ] ) ;
189234 }
190235
236+ #[ test]
237+ fn test_2d_single ( )
238+ {
239+ let x = array ! [ [ 1 ] ] ;
240+
241+ assert_eq ! ( x. triu( 0 ) , array![ [ 1 ] ] ) ;
242+ assert_eq ! ( x. tril( 0 ) , array![ [ 1 ] ] ) ;
243+ assert_eq ! ( x. triu( 1 ) , array![ [ 0 ] ] ) ;
244+ assert_eq ! ( x. tril( 1 ) , array![ [ 1 ] ] ) ;
245+ assert_eq ! ( x. triu( -1 ) , array![ [ 1 ] ] ) ;
246+ assert_eq ! ( x. tril( -1 ) , array![ [ 0 ] ] ) ;
247+ }
248+
191249 #[ test]
192250 fn test_3d ( )
193251 {
@@ -285,8 +343,25 @@ mod tests
285343 let res = x. triu ( 0 ) ;
286344 assert_eq ! ( res, array![ [ 1 , 2 , 3 ] , [ 0 , 5 , 6 ] ] ) ;
287345
346+ let res = x. tril ( 0 ) ;
347+ assert_eq ! ( res, array![ [ 1 , 0 , 0 ] , [ 4 , 5 , 0 ] ] ) ;
348+
288349 let x = array ! [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ;
289350 let res = x. triu ( 0 ) ;
290351 assert_eq ! ( res, array![ [ 1 , 2 ] , [ 0 , 4 ] , [ 0 , 0 ] ] ) ;
352+
353+ let res = x. tril ( 0 ) ;
354+ assert_eq ! ( res, array![ [ 1 , 0 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ) ;
355+ }
356+
357+ #[ test]
358+ fn test_odd_k ( )
359+ {
360+ let x = array ! [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] , [ 7 , 8 , 9 ] ] ;
361+ let z = Array2 :: zeros ( [ 3 , 3 ] ) ;
362+ assert_eq ! ( x. triu( isize :: MIN ) , x) ;
363+ assert_eq ! ( x. tril( isize :: MIN ) , z) ;
364+ assert_eq ! ( x. triu( isize :: MAX ) , z) ;
365+ assert_eq ! ( x. tril( isize :: MAX ) , x) ;
291366 }
292367}
0 commit comments