File tree Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change 1616from typing import List
1717
1818import numpy as np
19+ from torch .utils .data import BatchSampler , DataLoader , Sampler
20+
1921from opacus .optimizers import DPOptimizer
2022from opacus .utils .uniform_sampler import (
2123 DistributedUniformWithReplacementSampler ,
2224 UniformWithReplacementSampler ,
2325)
24- from torch .utils .data import BatchSampler , DataLoader , Sampler
2526
2627
2728class BatchSplittingSampler (Sampler [List [int ]]):
@@ -71,13 +72,17 @@ def __iter__(self):
7172 def __len__ (self ):
7273 if isinstance (self .sampler , BatchSampler ):
7374 return int (
74- len (self .sampler ) * (self .sampler .batch_size / self .max_batch_size )
75+ np .ceil (
76+ len (self .sampler ) * (self .sampler .batch_size / self .max_batch_size )
77+ )
7578 )
7679 elif isinstance (self .sampler , UniformWithReplacementSampler ) or isinstance (
7780 self .sampler , DistributedUniformWithReplacementSampler
7881 ):
7982 expected_batch_size = self .sampler .sample_rate * self .sampler .num_samples
80- return int (len (self .sampler ) * (expected_batch_size / self .max_batch_size ))
83+ return int (
84+ np .ceil (len (self .sampler ) * (expected_batch_size / self .max_batch_size ))
85+ )
8186
8287 return len (self .sampler )
8388
You can’t perform that action at this time.
0 commit comments