File tree Expand file tree Collapse file tree 1 file changed +5
-12
lines changed Expand file tree Collapse file tree 1 file changed +5
-12
lines changed Original file line number Diff line number Diff line change @@ -387,24 +387,17 @@ def dist_params(self, node) -> Sequence[Variable]:
387387 return node .inputs [2 :]
388388
389389 def perform (self , node , inputs , outputs ):
390- rng_var_out , smpl_out = outputs
391-
392390 rng , size , * args = inputs
393391
394392 # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
395393 if not self .inplace :
396394 rng = copy (rng )
397395
398- rng_var_out [0 ] = rng
399-
400- if size is not None :
401- size = tuple (size )
402- smpl_val = self .rng_fn (rng , * ([* args , size ]))
403-
404- if not isinstance (smpl_val , np .ndarray ) or str (smpl_val .dtype ) != self .dtype :
405- smpl_val = np .asarray (smpl_val , dtype = self .dtype )
406-
407- smpl_out [0 ] = smpl_val
396+ outputs [0 ][0 ] = rng
397+ outputs [1 ][0 ] = np .asarray (
398+ self .rng_fn (rng , * args , None if size is None else tuple (size )),
399+ dtype = self .dtype ,
400+ )
408401
409402 def grad (self , inputs , outputs ):
410403 return [
You can’t perform that action at this time.
0 commit comments