@@ -206,54 +206,69 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
206206 weights : torch.Tensor of shape (n_timepoints,), optional
207207 Only included if weights are not `None`.
208208 """
209- group_id = self ._group_ids [index ]
210-
211- if self ._group :
212- mask = self ._groups [group_id ]
209+ time = self .time
210+ feature_cols = self .feature_cols
211+ _target = self ._target
212+ _known = self ._known
213+ _static = self ._static
214+ _group = self ._group
215+ _groups = self ._groups
216+ _group_ids = self ._group_ids
217+ weight = self .weight
218+ data_future = self .data_future
219+
220+ group_id = _group_ids [index ]
221+
222+ if _group :
223+ mask = _groups [group_id ]
213224 data = self .data .loc [mask ]
214225 else :
215226 data = self .data
216227
217- cutoff_time = data [self .time ].max ()
228+ cutoff_time = data [time ].max ()
229+
230+ data_vals = data [time ].values
231+ data_tgt_vals = data [_target ].values
232+ data_feat_vals = data [feature_cols ].values
218233
219234 result = {
220- "t" : data [ self . time ]. values ,
221- "y" : torch .tensor (data [ self . _target ]. values ),
222- "x" : torch .tensor (data [ self . feature_cols ]. values ),
235+ "t" : data_vals ,
236+ "y" : torch .tensor (data_tgt_vals ),
237+ "x" : torch .tensor (data_feat_vals ),
223238 "group" : torch .tensor ([hash (str (group_id ))]),
224- "st" : torch .tensor (data [self . _static ].iloc [0 ].values if self . _static else []),
239+ "st" : torch .tensor (data [_static ].iloc [0 ].values if _static else []),
225240 "cutoff_time" : cutoff_time ,
226241 }
227242
228- if self . data_future is not None :
229- if self . _group :
230- future_mask = self .data_future .groupby (self . _group ).groups [group_id ]
243+ if data_future is not None :
244+ if _group :
245+ future_mask = self .data_future .groupby (_group ).groups [group_id ]
231246 future_data = self .data_future .loc [future_mask ]
232247 else :
233248 future_data = self .data_future
234249
235- combined_times = np . concatenate (
236- [ data [ self . time ]. values , future_data [ self . time ]. values ]
237- )
250+ data_fut_vals = future_data [ time ]. values
251+
252+ combined_times = np . concatenate ([ data_vals , data_fut_vals ] )
238253 combined_times = np .unique (combined_times )
239254 combined_times .sort ()
240255
241256 num_timepoints = len (combined_times )
242- x_merged = np .full ((num_timepoints , len (self . feature_cols )), np .nan )
243- y_merged = np .full ((num_timepoints , len (self . _target )), np .nan )
257+ x_merged = np .full ((num_timepoints , len (feature_cols )), np .nan )
258+ y_merged = np .full ((num_timepoints , len (_target )), np .nan )
244259
245260 current_time_indices = {t : i for i , t in enumerate (combined_times )}
246- for i , t in enumerate (data [ self . time ]. values ):
261+ for i , t in enumerate (data_vals ):
247262 idx = current_time_indices [t ]
248- x_merged [idx ] = data [ self . feature_cols ]. values [i ]
249- y_merged [idx ] = data [ self . _target ]. values [i ]
263+ x_merged [idx ] = data_feat_vals [i ]
264+ y_merged [idx ] = data_tgt_vals [i ]
250265
251- for i , t in enumerate (future_data [ self . time ]. values ):
266+ for i , t in enumerate (data_fut_vals ):
252267 if t in current_time_indices :
253268 idx = current_time_indices [t ]
254- for j , col in enumerate (self . _known ):
255- if col in self . feature_cols :
256- feature_idx = self . feature_cols .index (col )
269+ for j , col in enumerate (_known ):
270+ if col in feature_cols :
271+ feature_idx = feature_cols .index (col )
257272 x_merged [idx , feature_idx ] = future_data [col ].values [i ]
258273
259274 result .update (
@@ -264,17 +279,17 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
264279 }
265280 )
266281
267- if self . weight :
282+ if weight :
268283 if self .data_future is not None and self .weight in self .data_future .columns :
269284 weights_merged = np .full (num_timepoints , np .nan )
270- for i , t in enumerate (data [ self . time ]. values ):
285+ for i , t in enumerate (data_vals ):
271286 idx = current_time_indices [t ]
272- weights_merged [idx ] = data [self . weight ].values [i ]
287+ weights_merged [idx ] = data [weight ].values [i ]
273288
274- for i , t in enumerate (future_data [ self . time ]. values ):
289+ for i , t in enumerate (data_fut_vals ):
275290 if t in current_time_indices and self .weight in future_data .columns :
276291 idx = current_time_indices [t ]
277- weights_merged [idx ] = future_data [self . weight ].values [i ]
292+ weights_merged [idx ] = future_data [weight ].values [i ]
278293
279294 result ["weights" ] = torch .tensor (weights_merged , dtype = torch .float32 )
280295 else :
0 commit comments