|
6 | 6 | use std::fmt::{self, Display, Formatter}; |
7 | 7 | use std::str::FromStr; |
8 | 8 |
|
| 9 | +use crate::expand::typetree::TypeTree; |
9 | 10 | use crate::expand::{Decodable, Encodable, HashStable_Generic}; |
10 | 11 | use crate::ptr::P; |
11 | 12 | use crate::{Ty, TyKind}; |
@@ -85,6 +86,9 @@ pub struct AutoDiffItem { |
85 | 86 | /// The name of the function being generated |
86 | 87 | pub target: String, |
87 | 88 | pub attrs: AutoDiffAttrs, |
| 89 | + // Type Tree support |
| 90 | + pub inputs: Vec<TypeTree>, |
| 91 | + pub output: TypeTree, |
88 | 92 | } |
89 | 93 |
|
90 | 94 | #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] |
@@ -276,14 +280,23 @@ impl AutoDiffAttrs { |
276 | 280 | !matches!(self.mode, DiffMode::Error | DiffMode::Source) |
277 | 281 | } |
278 | 282 |
|
279 | | - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { |
280 | | - AutoDiffItem { source, target, attrs: self } |
| 283 | + pub fn into_item( |
| 284 | + self, |
| 285 | + source: String, |
| 286 | + target: String, |
| 287 | + inputs: Vec<TypeTree>, |
| 288 | + output: TypeTree, |
| 289 | + ) -> AutoDiffItem { |
| 290 | + AutoDiffItem { source, target, inputs, output, attrs: self } |
281 | 291 | } |
282 | 292 | } |
283 | 293 |
|
284 | 294 | impl fmt::Display for AutoDiffItem { |
285 | 295 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
286 | 296 | write!(f, "Differentiating {} -> {}", self.source, self.target)?; |
287 | | - write!(f, " with attributes: {:?}", self.attrs) |
| 297 | + write!(f, " with attributes: {:?}", self.attrs)?; |
| 298 | + write!(f, " with attributes: {:?}", self.attrs)?; |
| 299 | + write!(f, " with inputs: {:?}", self.inputs)?; |
| 300 | + write!(f, " with output: {:?}", self.output) |
288 | 301 | } |
289 | 302 | } |
0 commit comments