|
12 | 12 |
|
13 | 13 | import dataclasses |
14 | 14 | from dataclasses import dataclass |
15 | | -from typing import Callable |
| 15 | +from typing import Callable, Optional |
16 | 16 | from domino.timer import Timers |
17 | 17 | from megatron.tokenizer import build_tokenizer |
18 | 18 |
|
@@ -206,6 +206,8 @@ def parse_args(): |
206 | 206 | help='Report loss and timing interval.') |
207 | 207 | parser.add_argument('--save-interval', type=int, default=None, |
208 | 208 | help='Number of iterations between checkpoint saves.') |
| 209 | + parser.add_argument('--fused-linear-loss', action='store_true', |
| 210 | + help='whether to use LigerFusedLinearCrossEntropyFunction()') |
209 | 211 |
|
210 | 212 | args = parser.parse_args() |
211 | 213 |
|
@@ -359,6 +361,8 @@ class TransformerConfig(): |
359 | 361 | no_sync_func: Callable = None |
360 | 362 | # grad_sync_func: Callable = None |
361 | 363 | # param_sync_func: Callable = None |
| 364 | + |
| 365 | + fused_linear_loss: bool = False |
362 | 366 |
|
363 | 367 | def __post_init__(self): |
364 | 368 | """ Python dataclass method that is used to modify attributes after initialization. |
@@ -400,5 +404,6 @@ def core_transformer_config_from_args(args): |
400 | 404 | kw_args['init_method'] = args.init_method |
401 | 405 | kw_args['output_layer_init_method'] = args.init_method |
402 | 406 | kw_args['params_dtype'] = args.params_dtype |
| 407 | + kw_args['fused_linear_loss'] = args.fused_linear_loss |
403 | 408 |
|
404 | 409 | return TransformerConfig(**kw_args) |
0 commit comments