@@ -87,6 +87,19 @@ describeWithFlags('concat1d', ALL_ENVS, () => {
8787 const expected = [ 3 , 5 ] ;
8888 expectArraysClose ( await result . data ( ) , expected ) ;
8989 } ) ;
90+
91+ it ( 'concat complex input' , async ( ) => {
92+ // [1+1j, 2+2j]
93+ const c1 = tf . complex ( [ 1 , 2 ] , [ 1 , 2 ] ) ;
94+ // [3+3j, 4+4j]
95+ const c2 = tf . complex ( [ 3 , 4 ] , [ 3 , 4 ] ) ;
96+
97+ const axis = 0 ;
98+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
99+ const expected = [ 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 ] ;
100+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
101+ expectArraysClose ( await result . data ( ) , expected ) ;
102+ } ) ;
90103} ) ;
91104
92105describeWithFlags ( 'concat2d' , ALL_ENVS , ( ) => {
@@ -220,6 +233,32 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
220233 expect ( res2 . shape ) . toEqual ( [ 0 , 15 ] ) ;
221234 expectArraysEqual ( await res2 . data ( ) , [ ] ) ;
222235 } ) ;
236+
237+ it ( 'concat complex input axis=0' , async ( ) => {
238+ // [[1+1j, 2+2j], [3+3j, 4+4j]]
239+ const c1 = tf . complex ( [ [ 1 , 2 ] , [ 3 , 4 ] ] , [ [ 1 , 2 ] , [ 3 , 4 ] ] ) ;
240+ // [[5+5j, 6+6j], [7+7j, 8+8j]]
241+ const c2 = tf . complex ( [ [ 5 , 6 ] , [ 7 , 8 ] ] , [ [ 5 , 6 ] , [ 7 , 8 ] ] ) ;
242+
243+ const axis = 0 ;
244+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
245+ const expected = [ 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 8 , 8 ] ;
246+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
247+ expectArraysClose ( await result . data ( ) , expected ) ;
248+ } ) ;
249+
250+ it ( 'concat complex input axis=1' , async ( ) => {
251+ // [[1+1j, 2+2j], [3+3j, 4+4j]]
252+ const c1 = tf . complex ( [ [ 1 , 2 ] , [ 3 , 4 ] ] , [ [ 1 , 2 ] , [ 3 , 4 ] ] ) ;
253+ // [[5+5j, 6+6j], [7+7j, 8+8j]]
254+ const c2 = tf . complex ( [ [ 5 , 6 ] , [ 7 , 8 ] ] , [ [ 5 , 6 ] , [ 7 , 8 ] ] ) ;
255+
256+ const axis = 1 ;
257+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
258+ const expected = [ 1 , 1 , 2 , 2 , 5 , 5 , 6 , 6 , 3 , 3 , 4 , 4 , 7 , 7 , 8 , 8 ] ;
259+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
260+ expectArraysClose ( await result . data ( ) , expected ) ;
261+ } ) ;
223262} ) ;
224263
225264describeWithFlags ( 'concat3d' , ALL_ENVS , ( ) => {
@@ -460,6 +499,54 @@ describeWithFlags('concat3d', ALL_ENVS, () => {
460499 expect ( values . shape ) . toEqual ( [ 2 , 3 , 1 ] ) ;
461500 expectArraysClose ( await values . data ( ) , [ 1 , 2 , 3 , 4 , 5 , 6 ] ) ;
462501 } ) ;
502+
503+ it ( 'concat complex input axis=0' , async ( ) => {
504+ // [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
505+ const c1 = tf . complex (
506+ [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] , [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] ) ;
507+ // [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
508+ const c2 = tf . complex (
509+ [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] , [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] ) ;
510+
511+ const axis = 0 ;
512+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
513+ const expected = [ 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 ,
514+ 7 , 7 , 8 , 8 , 9 , 9 , 10 , 10 , 11 , 11 , 12 , 12 ] ;
515+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
516+ expectArraysClose ( await result . data ( ) , expected ) ;
517+ } ) ;
518+
519+ it ( 'concat complex input axis=1' , async ( ) => {
520+ // [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
521+ const c1 = tf . complex (
522+ [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] , [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] ) ;
523+ // [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
524+ const c2 = tf . complex (
525+ [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] , [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] ) ;
526+
527+ const axis = 1 ;
528+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
529+ const expected = [ 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 ,
530+ 7 , 7 , 8 , 8 , 9 , 9 , 10 , 10 , 11 , 11 , 12 , 12 ] ;
531+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
532+ expectArraysClose ( await result . data ( ) , expected ) ;
533+ } ) ;
534+
535+ it ( 'concat complex input axis=1' , async ( ) => {
536+ // [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
537+ const c1 = tf . complex (
538+ [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] , [ [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ] ) ;
539+ // [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
540+ const c2 = tf . complex (
541+ [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] , [ [ [ 7 , 8 ] , [ 9 , 10 ] , [ 11 , 12 ] ] ] ) ;
542+
543+ const axis = 2 ;
544+ const result = tf . concat ( [ c1 , c2 ] , axis ) ;
545+ const expected = [ 1 , 1 , 2 , 2 , 7 , 7 , 8 , 8 , 3 , 3 , 4 , 4 ,
546+ 9 , 9 , 10 , 10 , 5 , 5 , 6 , 6 , 11 , 11 , 12 , 12 ] ;
547+ expect ( result . dtype ) . toEqual ( 'complex64' ) ;
548+ expectArraysClose ( await result . data ( ) , expected ) ;
549+ } ) ;
463550} ) ;
464551
465552describeWithFlags ( 'concat throws for non-tensors' , ALL_ENVS , ( ) => {
0 commit comments