@@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs):
946946 b_coeff_flipped = b_coeffs .flip (1 ).contiguous ()
947947 padded_waveform = F .pad (waveform , (n_order - 1 , 0 ))
948948 output = F .conv1d (padded_waveform , b_coeff_flipped .unsqueeze (1 ), groups = n_channel )
949- ctx .save_for_backward (waveform , b_coeffs , output )
949+ if not torch .jit .is_scripting ():
950+ ctx .save_for_backward (waveform , b_coeffs , output )
950951 return output
951952
952953 @staticmethod
@@ -955,21 +956,28 @@ def backward(ctx, dy):
955956 n_batch = x .size (0 )
956957 n_channel = x .size (1 )
957958 n_order = b_coeffs .size (1 )
958- db = (
959- F .conv1d (
960- F .pad (x , (n_order - 1 , 0 )).view (1 , n_batch * n_channel , - 1 ),
961- dy .view (n_batch * n_channel , 1 , - 1 ),
962- groups = n_batch * n_channel ,
963- )
964- . view ( n_batch , n_channel , - 1 )
965- .sum (0 )
966- . flip ( 1 )
967- if b_coeffs . requires_grad
968- else None
969- )
970- dx = F . conv1d ( F . pad ( dy , ( 0 , n_order - 1 )), b_coeffs . unsqueeze ( 1 ), groups = n_channel ) if x .requires_grad else None
959+
960+ db = F .conv1d (
961+ F .pad (x , (n_order - 1 , 0 )).view (1 , n_batch * n_channel , - 1 ),
962+ dy .view (n_batch * n_channel , 1 , - 1 ),
963+ groups = n_batch * n_channel
964+ ). view (
965+ n_batch , n_channel , - 1
966+ ) .sum (0 ). flip ( 1 ) if b_coeffs . requires_grad else None
967+ dx = F . conv1d (
968+ F . pad ( dy , ( 0 , n_order - 1 )),
969+ b_coeffs . unsqueeze ( 1 ),
970+ groups = n_channel
971+ ) if x .requires_grad else None
971972 return (dx , db )
972973
974+ @staticmethod
975+ def ts_apply (waveform , b_coeffs ):
976+ if torch .jit .is_scripting ():
977+ return DifferentiableFIR .forward (torch .empty (0 ), waveform , b_coeffs )
978+ else :
979+ return DifferentiableFIR .apply (waveform , b_coeffs )
980+
973981
974982class DifferentiableIIR (torch .autograd .Function ):
975983 @staticmethod
@@ -984,7 +992,8 @@ def forward(ctx, waveform, a_coeffs_normalized):
984992 )
985993 _lfilter_core_loop (waveform , a_coeff_flipped , padded_output_waveform )
986994 output = padded_output_waveform [:, :, n_order - 1 :]
987- ctx .save_for_backward (waveform , a_coeffs_normalized , output )
995+ if not torch .jit .is_scripting ():
996+ ctx .save_for_backward (waveform , a_coeffs_normalized , output )
988997 return output
989998
990999 @staticmethod
@@ -1006,10 +1015,17 @@ def backward(ctx, dy):
10061015 )
10071016 return (dx , da )
10081017
1018+ @staticmethod
1019+ def ts_apply (waveform , a_coeffs_normalized ):
1020+ if torch .jit .is_scripting ():
1021+ return DifferentiableIIR .forward (torch .empty (0 ), waveform , a_coeffs_normalized )
1022+ else :
1023+ return DifferentiableIIR .apply (waveform , a_coeffs_normalized )
1024+
10091025
10101026def _lfilter (waveform , a_coeffs , b_coeffs ):
1011- filtered_waveform = DifferentiableFIR .apply (waveform , b_coeffs / a_coeffs [:, 0 :1 ])
1012- return DifferentiableIIR .apply (filtered_waveform , a_coeffs / a_coeffs [:, 0 :1 ])
1027+ filtered_waveform = DifferentiableFIR .ts_apply (waveform , b_coeffs / a_coeffs [:, 0 :1 ])
1028+ return DifferentiableIIR .ts_apply (filtered_waveform , a_coeffs / a_coeffs [:, 0 :1 ])
10131029
10141030
10151031def lfilter (waveform : Tensor , a_coeffs : Tensor , b_coeffs : Tensor , clamp : bool = True , batching : bool = True ) -> Tensor :
0 commit comments