|
3 | 3 | import warnings |
4 | 4 | from functools import reduce |
5 | 5 | from types import ModuleType |
6 | | -from typing import Any, Callable, cast, Dict, List, Optional, Tuple |
| 6 | +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type |
7 | 7 |
|
8 | 8 | import torch |
9 | 9 | import torch.nn as nn |
10 | 10 | from captum._utils.models.linear_model.model import LinearModel |
11 | 11 | from torch.utils.data import DataLoader |
12 | 12 |
|
13 | 13 |
|
14 | | -# pyre-fixme[2]: Parameter must be annotated. |
15 | | -def l2_loss(x1, x2, weights=None) -> torch.Tensor: |
| 14 | +def l2_loss( |
| 15 | + x1: torch.Tensor, x2: torch.Tensor, weights: Optional[torch.Tensor] = None |
| 16 | +) -> torch.Tensor: |
16 | 17 | if weights is None: |
17 | | - return torch.mean((x1 - x2) ** 2) / 2.0 |
| 18 | + return torch.mean(torch.pow(x1 - x2, 2)) / 2.0 |
18 | 19 | else: |
19 | | - return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0 |
| 20 | + return torch.sum((weights / weights.norm(p=1)) * torch.pow(x1 - x2, 2)) / 2.0 |
20 | 21 |
|
21 | 22 |
|
22 | 23 | class ConvergenceTracker: |
@@ -60,20 +61,19 @@ def average(self) -> torch.Tensor: |
60 | 61 |
|
61 | 62 |
|
62 | 63 | def _init_linear_model(model: LinearModel, init_scheme: Optional[str] = None) -> None: |
63 | | - assert model.linear is not None |
| 64 | + linear_layer = model.linear |
| 65 | + assert linear_layer is not None |
64 | 66 | if init_scheme is not None: |
65 | 67 | assert init_scheme in ["xavier", "zeros"] |
66 | 68 |
|
67 | 69 | with torch.no_grad(): |
68 | 70 | if init_scheme == "xavier": |
69 | | - # pyre-fixme[16]: `Optional` has no attribute `weight`. |
70 | | - torch.nn.init.xavier_uniform_(model.linear.weight) |
| 71 | + torch.nn.init.xavier_uniform_(linear_layer.weight) |
71 | 72 | else: |
72 | | - model.linear.weight.zero_() |
| 73 | + linear_layer.weight.zero_() |
73 | 74 |
|
74 | | - # pyre-fixme[16]: `Optional` has no attribute `bias`. |
75 | | - if model.linear.bias is not None: |
76 | | - model.linear.bias.zero_() |
| 75 | + if linear_layer.bias is not None: |
| 76 | + linear_layer.bias.zero_() |
77 | 77 |
|
78 | 78 |
|
79 | 79 | def _get_point( |
@@ -103,8 +103,9 @@ def sgd_train_linear_model( |
103 | 103 | reduce_lr: bool = True, |
104 | 104 | initial_lr: float = 0.01, |
105 | 105 | alpha: float = 1.0, |
106 | | - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
107 | | - loss_fn: Callable = l2_loss, |
| 106 | + loss_fn: Callable[ |
| 107 | + [torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor |
| 108 | + ] = l2_loss, |
108 | 109 | reg_term: Optional[int] = 1, |
109 | 110 | patience: int = 10, |
110 | 111 | threshold: float = 1e-4, |
@@ -224,8 +225,8 @@ def sgd_train_linear_model( |
224 | 225 |
|
225 | 226 | loss = loss_fn(y, out, w) |
226 | 227 | if reg_term is not None: |
227 | | - # pyre-fixme[16]: `Optional` has no attribute `weight`. |
228 | | - reg = torch.norm(model.linear.weight, p=reg_term) # type: ignore |
| 228 | + assert model.linear is not None |
| 229 | + reg = torch.norm(model.linear.weight, p=reg_term) |
229 | 230 | loss += reg.sum() * alpha |
230 | 231 |
|
231 | 232 | loss_window.append(loss.clone().detach()) |
@@ -269,18 +270,19 @@ def sgd_train_linear_model( |
269 | 270 |
|
270 | 271 |
|
271 | 272 | class NormLayer(nn.Module): |
272 | | - # pyre-fixme[2]: Parameter must be annotated. |
273 | | - def __init__(self, mean, std, n=None, eps: float = 1e-8) -> None: |
| 273 | + def __init__( |
| 274 | + self, |
| 275 | + mean: torch.Tensor, |
| 276 | + std: torch.Tensor, |
| 277 | + n: Optional[int] = None, |
| 278 | + eps: float = 1e-8, |
| 279 | + ) -> None: |
274 | 280 | super().__init__() |
275 | | - # pyre-fixme[4]: Attribute must be annotated. |
276 | | - self.mean = mean |
277 | | - # pyre-fixme[4]: Attribute must be annotated. |
278 | | - self.std = std |
| 281 | + self.mean: torch.Tensor = mean |
| 282 | + self.std: torch.Tensor = std |
279 | 283 | self.eps = eps |
280 | 284 |
|
281 | | - # pyre-fixme[3]: Return type must be annotated. |
282 | | - # pyre-fixme[2]: Parameter must be annotated. |
283 | | - def forward(self, x): |
| 285 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
284 | 286 | return (x - self.mean) / (self.std + self.eps) |
285 | 287 |
|
286 | 288 |
|
@@ -371,16 +373,23 @@ def sklearn_train_linear_model( |
371 | 373 | else: |
372 | 374 | w = None |
373 | 375 |
|
| 376 | + mean, std = None, None |
374 | 377 | if norm_input: |
375 | 378 | mean, std = x.mean(0), x.std(0) |
376 | 379 | x -= mean |
377 | 380 | x /= std |
378 | 381 |
|
379 | 382 | t1 = time.time() |
380 | | - # pyre-fixme[29]: `str` is not a function. |
381 | | - sklearn_model = reduce( # type: ignore |
382 | | - lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501 |
383 | | - )(**construct_kwargs) |
| 383 | + # Start with the sklearn module and navigate through the attribute path |
| 384 | + sklearn_cls = cast( |
| 385 | + Type[Any], |
| 386 | + reduce( |
| 387 | + lambda obj, attr: getattr(obj, attr), |
| 388 | + sklearn_trainer.split("."), |
| 389 | + sklearn, |
| 390 | + ), |
| 391 | + ) |
| 392 | + sklearn_model = sklearn_cls(**construct_kwargs) |
384 | 393 | try: |
385 | 394 | sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs) |
386 | 395 | except TypeError: |
@@ -417,8 +426,7 @@ def sklearn_train_linear_model( |
417 | 426 | ) |
418 | 427 |
|
419 | 428 | if norm_input: |
420 | | - # pyre-fixme[61]: `mean` is undefined, or not always defined. |
421 | | - # pyre-fixme[61]: `std` is undefined, or not always defined. |
| 429 | + assert mean is not None and std is not None |
422 | 430 | model.norm = NormLayer(mean, std) |
423 | 431 |
|
424 | 432 | return {"train_time": t2 - t1} |
0 commit comments