Skip to content

Commit ea58fce

Browse files
authored
[kto] fix kto apo_zero_unpaired (#6601)
1 parent 82a2b22 commit ea58fce

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

swift/llm/template/base.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,16 @@ def get_base_model(model):
357357
else:
358358
return model
359359

360-
def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
360+
def _rlhf_encode(self, inputs: TemplateInputs, check_rejected=True) -> Dict[str, Any]:
361361
chosen = inputs.chosen
362362
margin = chosen.margin
363363
chosen_encoded = self._encode_truncated(chosen)
364-
rejected_encoded = self._encode_truncated(inputs.rejected)
364+
if inputs.rejected is None:
365+
if check_rejected:
366+
raise ValueError('inputs.rejected is None')
367+
rejected_encoded = {}
368+
else:
369+
rejected_encoded = self._encode_truncated(inputs.rejected)
365370

366371
encoded = {}
367372
for prefix in ['chosen', 'rejected']:
@@ -373,7 +378,7 @@ def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
373378
return encoded
374379

375380
def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
376-
encoded = self._rlhf_encode(inputs)
381+
encoded = self._rlhf_encode(inputs, check_rejected=False)
377382
encoded['label'] = bool(inputs.chosen.label)
378383
return encoded
379384

@@ -1485,7 +1490,10 @@ def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optiona
14851490
kl_batch = self._fetch_inputs_startswith(batch, 'rejected_')
14861491

14871492
res = self._data_collator(new_batch, padding_to=padding_to)
1488-
kl_res = self._data_collator(kl_batch, padding_to=padding_to)
1493+
if any(kl_batch):
1494+
kl_res = self._data_collator(kl_batch, padding_to=padding_to)
1495+
else:
1496+
kl_res = {}
14891497
res = {
14901498
**{f'completion_{k}': v
14911499
for k, v in res.items()},

swift/llm/train/kto.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,17 @@ def _get_kl_dataset(dataset: Optional[HfDataset],
4141

4242

4343
def prepare_kto_dataset(args, train_dataset, val_dataset):
44-
world_size = get_dist_setting()[2]
45-
if hasattr(args, 'global_batch_size') and args.global_batch_size is not None:
46-
total_batch_size = args.global_batch_size
47-
else:
48-
total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps)
49-
if total_batch_size <= 1:
50-
raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term '
51-
'will be equivalent to the implied reward.')
52-
train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed)
53-
val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed)
44+
if args.loss_type != 'apo_zero_unpaired':
45+
world_size = get_dist_setting()[2]
46+
if hasattr(args, 'global_batch_size') and args.global_batch_size is not None:
47+
total_batch_size = args.global_batch_size
48+
else:
49+
total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps)
50+
if total_batch_size <= 1:
51+
raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term '
52+
'will be equivalent to the implied reward.')
53+
train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed)
54+
val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed)
5455

5556
label = train_dataset['label']
5657
num_desirable = max(sum(label), 1)

swift/megatron/trainers/kto_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def loss_func(self, output_tensor, *, data, kl_data, label):
7676
loss = loss.mean()
7777
mean_metric = {
7878
'loss': loss.detach().clone(),
79-
'kl': kl.detach(),
79+
'kl': kl.squeeze().detach(),
8080
}
8181
metric = self._all_reduce_metric(mean_metric)
8282
sum_metric = {
@@ -159,7 +159,11 @@ def _prepare_batch(self, data, vp_stage):
159159
num_samples = data.pop('num_samples')
160160
for key in ['completion_', 'KL_completion_']:
161161
_data = {k[len(key):]: v for k, v in data.items() if k.startswith(key)}
162-
res.append(super()._prepare_batch(_data, vp_stage, num_samples))
162+
if not self.args.calculate_KL and key == 'KL_completion_':
163+
_data = {}
164+
else:
165+
_data = super()._prepare_batch(_data, vp_stage, num_samples)
166+
res.append(_data)
163167
res[0]['label'] = data['label']
164168
return res
165169

0 commit comments

Comments
 (0)