66
77from einops import rearrange , repeat
88
9- from causal_conv1d import causal_conv1d_fn
10- import causal_conv1d_cuda
9+ try :
10+ from causal_conv1d import causal_conv1d_fn
11+ import causal_conv1d_cuda
12+ except ImportError :
13+ causal_conv1d_fn = None
14+ causal_conv1d_cuda = None
15+
1116import selective_scan_cuda
1217
1318
@@ -163,6 +168,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
163168 """
164169 xz: (batch, dim, seqlen)
165170 """
171+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
166172 assert checkpoint_lvl in [0 , 1 ]
167173 L = xz .shape [- 1 ]
168174 delta_rank = delta_proj_weight .shape [1 ]
@@ -178,7 +184,9 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
178184 conv1d_weight = rearrange (conv1d_weight , "d 1 w -> d w" )
179185 x , z = xz .chunk (2 , dim = 1 )
180186 conv1d_bias = conv1d_bias .contiguous () if conv1d_bias is not None else None
181- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
187+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (
188+ x , conv1d_weight , conv1d_bias , None , None , None , True
189+ )
182190 # We're being very careful here about the layout, to avoid extra transposes.
183191 # We want delta to have d as the slowest moving dimension
184192 # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
@@ -231,6 +239,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
231239 @custom_bwd
232240 def backward (ctx , dout ):
233241 # dout: (batch, seqlen, dim)
242+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
234243 (xz , conv1d_weight , conv1d_bias , x_dbl , x_proj_weight , delta_proj_weight , out_proj_weight ,
235244 conv1d_out , delta , A , B , C , D , delta_bias , scan_intermediates , out ) = ctx .saved_tensors
236245 L = xz .shape [- 1 ]
@@ -240,7 +249,9 @@ def backward(ctx, dout):
240249 if dout .stride (- 1 ) != 1 :
241250 dout = dout .contiguous ()
242251 if ctx .checkpoint_lvl == 1 :
243- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
252+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (
253+ x , conv1d_weight , conv1d_bias , None , None , None , True
254+ )
244255 delta = rearrange (delta_proj_weight @ x_dbl [:, :delta_rank ].t (),
245256 "d (b l) -> b d l" , l = L )
246257 # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
@@ -285,8 +296,8 @@ def backward(ctx, dout):
285296 dconv1d_out = rearrange (dconv1d_out , "d (b l) -> b d l" , b = x .shape [0 ], l = x .shape [- 1 ])
286297 # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
287298 # backward of conv1d with the backward of chunk).
288- dx , dconv1d_weight , dconv1d_bias = causal_conv1d_cuda .causal_conv1d_bwd (
289- x , conv1d_weight , conv1d_bias , dconv1d_out , None , dx , True
299+ dx , dconv1d_weight , dconv1d_bias , * _ = causal_conv1d_cuda .causal_conv1d_bwd (
300+ x , conv1d_weight , conv1d_bias , dconv1d_out , None , None , None , dx , False , True
290301 )
291302 dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
292303 dconv1d_weight = rearrange (dconv1d_weight , "d w -> d 1 w" )
@@ -314,11 +325,12 @@ def mamba_inner_ref(
314325 A , B = None , C = None , D = None , delta_bias = None , B_proj_bias = None ,
315326 C_proj_bias = None , delta_softplus = True
316327):
328+ assert causal_conv1d_fn is not None , "causal_conv1d_fn is not available. Please install causal-conv1d."
317329 L = xz .shape [- 1 ]
318330 delta_rank = delta_proj_weight .shape [1 ]
319331 d_state = A .shape [- 1 ] * (1 if not A .is_complex () else 2 )
320332 x , z = xz .chunk (2 , dim = 1 )
321- x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , None , "silu" )
333+ x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , activation = "silu" )
322334 # We're being very careful here about the layout, to avoid extra transposes.
323335 # We want delta to have d as the slowest moving dimension
324336 # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
0 commit comments