|
107 | 107 | { |
108 | 108 | "tok_embeddings": RowwiseParallel( |
109 | 109 | input_layouts=Replicate(), |
| 110 | + output_layouts=Shard(1), |
110 | 111 | ), |
| 112 | + "norm": SequenceParallel(), |
111 | 113 | "output": ColwiseParallel( |
112 | 114 | input_layouts=Shard(1), |
113 | 115 | output_layouts=Replicate() |
114 | 116 | ), |
115 | | - "norm": SequenceParallel(), |
116 | | - "layers.0": PrepareModuleInput( |
117 | | - input_layouts=(Replicate(), None), |
118 | | - desired_input_layouts=(Shard(1), None), |
119 | | - use_local_output=True, |
120 | | - ), |
121 | 117 | } |
122 | 118 | ) |
123 | 119 |
|
124 | 120 | for layer_id, transformer_block in enumerate(model.layers): |
125 | 121 | layer_tp_plan = { |
| 122 | + "attention_norm": SequenceParallel(), |
126 | 123 | "attention": PrepareModuleInput( |
127 | 124 | input_layouts=(Shard(1), None), |
128 | 125 | desired_input_layouts=(Replicate(), None), |
|
131 | 128 | "attention.wk": ColwiseParallel(), |
132 | 129 | "attention.wv": ColwiseParallel(), |
133 | 130 | "attention.wo": RowwiseParallel(output_layouts=Shard(1)), |
134 | | - "attention_norm": SequenceParallel(), |
| 131 | + "ffn_norm": SequenceParallel(), |
135 | 132 | "feed_forward": PrepareModuleInput( |
136 | 133 | input_layouts=(Shard(1),), |
137 | 134 | desired_input_layouts=(Replicate(),), |
138 | 135 | ), |
139 | 136 | "feed_forward.w1": ColwiseParallel(), |
140 | 137 | "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), |
141 | 138 | "feed_forward.w3": ColwiseParallel(), |
142 | | - "ffn_norm": SequenceParallel(), |
143 | 139 | } |
144 | 140 |
|
145 | 141 | # Adjust attention module to use the local number of heads |
|
0 commit comments