11from __future__ import annotations
22from pydantic import BaseModel
3+ from typing import Optional
34import torch
45
56
67class StandardSampler (BaseModel , torch .utils .data .Sampler ):
78 proportion : float
89 replacement : bool
910 sampler : torch .utils .data .WeightedRandomSampler
11+ seed : Optional [int ]
12+ generator : Optional [torch .Generator ]
1013
1114 class Config :
1215 arbitrary_types_allowed = True
1316 allow_mutation = False
1417
15- def __init__ (self , length , proportion = 1.0 , replacement = False ):
18+ def __init__ (self , length , proportion = 1.0 , replacement = False , seed = None ):
19+ if seed is not None :
20+ generator = torch .Generator ()
21+ generator .manual_seed (seed )
22+ else :
23+ generator = None
1624 BaseModel .__init__ (
1725 self ,
1826 proportion = proportion ,
@@ -21,13 +29,18 @@ def __init__(self, length, proportion=1.0, replacement=False):
2129 torch .ones (length ).double (),
2230 num_samples = int (max (1 , min (length , length * proportion ))),
2331 replacement = replacement ,
24- )
32+ generator = generator ,
33+ ),
34+ seed = seed ,
35+ generator = generator ,
2536 )
2637
2738 def __len__ (self ):
2839 return len (self .sampler )
2940
3041 def __iter__ (self ):
42+ if self .generator is not None :
43+ self .generator .manual_seed (self .seed )
3144 return iter (self .sampler )
3245
3346 @property
@@ -51,6 +64,7 @@ def sample_proportion(self, proportion):
5164 len (self ),
5265 proportion ,
5366 self .replacement ,
67+ self .seed ,
5468 )
5569 sampler .sampler .weights = self .sampler .weights
5670 return sampler
0 commit comments