File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed
bayesflow/approximators/backend_approximators Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -56,6 +56,20 @@ def stateless_compute_metrics(
5656 variables and returns both the loss and auxiliary information for
5757 further updates.
5858
59+ Things we do for specifically jax:
60+
61+ 1. Accept trainable variables as the first argument
62+ (can be at any position as indicated by the argnum parameter
63+ in autograd, but needs to be an explicit arg)
64+ 2. Accept, potentially modify, and return other state variables
65+ 3. Return just the loss tensor as the first value
66+ 4. Return all other values in a tuple as the second value
67+
68+ This ensures:
69+
70+ 1. The function is stateless
71+ 2. The function can be differentiated with jax autograd
72+
5973 Parameters
6074 ----------
6175 trainable_variables : Any
You can’t perform that action at this time.
0 commit comments