diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 77b4f6f8a..22aabe292 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -120,6 +120,27 @@ async def sample( entry.sample_count += 1 sampled_episodes.append(entry.data) + # Calculate and record policy age metrics for sampled episodes + sampled_policy_ages = [ + curr_policy_version - ep.policy_version for ep in sampled_episodes + ] + if sampled_policy_ages: + record_metric( + "buffer/sample/avg_sampled_policy_age", + sum(sampled_policy_ages) / len(sampled_policy_ages), + Reduce.MEAN, + ) + record_metric( + "buffer/sample/max_sampled_policy_age", + max(sampled_policy_ages), + Reduce.MAX, + ) + record_metric( + "buffer/sample/min_sampled_policy_age", + min(sampled_policy_ages), + Reduce.MIN, + ) + # Reshape into (dp_size, bsz, ...) reshaped_episodes = [ sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size]