Skip to content

Commit 9030803

Browse files
committed
reintroduce comment in jax approximator [no ci]
1 parent 5659773 commit 9030803

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ def stateless_compute_metrics(
5656
variables and returns both the loss and auxiliary information for
5757
further updates.
5858
59+
Things we do for specifically jax:
60+
61+
1. Accept trainable variables as the first argument
62+
(can be at any position as indicated by the argnum parameter
63+
in autograd, but needs to be an explicit arg)
64+
2. Accept, potentially modify, and return other state variables
65+
3. Return just the loss tensor as the first value
66+
4. Return all other values in a tuple as the second value
67+
68+
This ensures:
69+
70+
1. The function is stateless
71+
2. The function can be differentiated with jax autograd
72+
5973
Parameters
6074
----------
6175
trainable_variables : Any

0 commit comments

Comments
 (0)