diff --git a/tinker_cookbook/renderers.py b/tinker_cookbook/renderers.py index 349c17c..45b6ecb 100644 --- a/tinker_cookbook/renderers.py +++ b/tinker_cookbook/renderers.py @@ -42,6 +42,7 @@ class TrainOnWhat(StrEnum): ALL_MESSAGES = "all_messages" ALL_TOKENS = "all_tokens" ALL_USER_AND_SYSTEM_MESSAGES = "all_user_and_system_messages" + ALL_BUT_FIRST_USER_AND_SYSTEM_MESSAGES = "all_but_first_user_and_system_messages" class Renderer: @@ -109,7 +110,10 @@ def build_supervised_example( - weights: a tensor of weights """ tokens_weights = [(token, 0) for token in start_tokens] + first_user_turn_ended = False for idx, message in enumerate(messages[:-1]): + if message["role"] == "assistant": + first_user_turn_ended = True ob_part, action_part, action_tail = render_message(idx, message) if train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE: tokens_weights.extend([(token, 0) for token in ob_part + action_part]) @@ -128,6 +132,10 @@ def build_supervised_example( tokens_weights += [(token, 0) for token in ob_part] is_user_or_system = message["role"] in ["user", "system"] tokens_weights += [(token, int(is_user_or_system)) for token in action_part] + elif train_on_what == TrainOnWhat.ALL_BUT_FIRST_USER_AND_SYSTEM_MESSAGES: + tokens_weights += [(token, 0) for token in ob_part] + action_weights = int((message["role"] in ["user", "system"]) and first_user_turn_ended) + tokens_weights += [(token, action_weights) for token in action_part] else: raise ValueError(f"Unknown train_on_what: {train_on_what}") ob_part, action_part, action_tail = render_message(len(messages) - 1, messages[-1])