@@ -77,6 +77,17 @@ pub struct AutoDiffAttrs {
7777 /// e.g. in the [JAX
7878 /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
7979 pub mode : DiffMode ,
80+ /// A user-provided, batching width. If not given, we will default to 1 (no batching).
81+ /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
82+ /// - Calling the function 50 times with a batch size of 2
83+ /// - Calling the function 25 times with a batch size of 4,
84+ /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
85+ /// cache locality, better re-usal of primal values, and other optimizations.
86+ /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
87+ /// times, so this massively increases code size. As such, values like 1024 are unlikely to
88+ /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
89+ /// experiments for now and focus on documenting the implications of a large width.
90+ pub width : u32 ,
8091 pub ret_activity : DiffActivity ,
8192 pub input_activity : Vec < DiffActivity > ,
8293}
@@ -222,13 +233,15 @@ impl AutoDiffAttrs {
222233 pub const fn error ( ) -> Self {
223234 AutoDiffAttrs {
224235 mode : DiffMode :: Error ,
236+ width : 0 ,
225237 ret_activity : DiffActivity :: None ,
226238 input_activity : Vec :: new ( ) ,
227239 }
228240 }
229241 pub fn source ( ) -> Self {
230242 AutoDiffAttrs {
231243 mode : DiffMode :: Source ,
244+ width : 0 ,
232245 ret_activity : DiffActivity :: None ,
233246 input_activity : Vec :: new ( ) ,
234247 }
0 commit comments