File tree Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -119,8 +119,8 @@ def convolve1d(
119119 if mode == "same" :
120120 # We implement "same" as "valid" with padded `in1`.
121121 in1_batch_shape = tuple (in1 .shape )[:- 1 ]
122- zeros_left = in2 .shape [0 ] // 2
123- zeros_right = (in2 .shape [0 ] - 1 ) // 2
122+ zeros_left = in2 .shape [- 1 ] // 2
123+ zeros_right = (in2 .shape [- 1 ] - 1 ) // 2
124124 in1 = join (
125125 - 1 ,
126126 zeros ((* in1_batch_shape , zeros_left ), dtype = in2 .dtype ),
Original file line number Diff line number Diff line change @@ -47,3 +47,16 @@ def test_convolve1d_batch():
4747 res_np = np .convolve (x_test [0 ], y_test [0 ])
4848 np .testing .assert_allclose (res [0 ], res_np , rtol = rtol )
4949 np .testing .assert_allclose (res [1 ], res_np , rtol = rtol )
50+
51+
52+ def test_convolve1d_batch_same ():
53+ x = matrix ("data" )
54+ y = matrix ("kernel" )
55+ out = convolve1d (x , y , mode = "same" )
56+
57+ rng = np .random .default_rng (38 )
58+ x_test = rng .normal (size = (2 , 8 )).astype (x .dtype )
59+ y_test = rng .normal (size = (2 , 8 )).astype (x .dtype )
60+
61+ res = out .eval ({x : x_test , y : y_test })
62+ assert res .shape == (2 , 8 )
You can’t perform that action at this time.
0 commit comments