File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ def fast_collate(batch):
3333 if isinstance (batch [0 ][0 ], tuple ):
3434 # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
3535 # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
36+ is_np = isinstance (batch [0 ][0 ], np .ndarray )
3637 inner_tuple_size = len (batch [0 ][0 ])
3738 flattened_batch_size = batch_size * inner_tuple_size
3839 targets = torch .zeros (flattened_batch_size , dtype = torch .int64 )
@@ -41,7 +42,10 @@ def fast_collate(batch):
4142 assert len (batch [i ][0 ]) == inner_tuple_size # all input tensor tuples must be same length
4243 for j in range (inner_tuple_size ):
4344 targets [i + j * batch_size ] = batch [i ][1 ]
44- tensor [i + j * batch_size ] += torch .from_numpy (batch [i ][0 ][j ])
45+ if is_np :
46+ tensor [i + j * batch_size ] += torch .from_numpy (batch [i ][0 ][j ])
47+ else :
48+ tensor [i + j * batch_size ] += batch [i ][0 ][j ]
4549 return tensor , targets
4650 elif isinstance (batch [0 ][0 ], np .ndarray ):
4751 targets = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
You can’t perform that action at this time.
0 commit comments