Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions snn-dt/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
import warnings
warnings.filterwarnings('ignore')

from src.models.cql import CQL
from src.models.dt import DecisionTransformer
from src.models.dsformer import DsFormer
from src.models.iql import IQL
from src.models.snn_dt import SnnDt
#
# MODIFICATION: Removed direct model imports to prevent silent crashes from C-extension conflicts.
# The get_model utility lazily imports the required model, which is a safer pattern.
#
# from src.models.cql import CQL
# from src.models.dt import DecisionTransformer
# from src.models.dsformer import DsFormer
# from src.models.iql import IQL
# from src.models.snn_dt import SnnDt
from src.utils.config import AttrDict
from src.utils.models import get_model
from src.utils.seed import seed_everything
Expand Down Expand Up @@ -205,7 +209,9 @@ def train(cfg, logger):
logger.info("--- Checkpoint: Starting main training loop ---")
for epoch in range(cfg.training.epochs):
start_time = time.time()
epoch_losses = []
# MODIFIED: Accumulate loss in a tensor to avoid CPU-GPU sync for non-IQL/CQL models
total_loss = torch.tensor(0.0, device=cfg.training.device)
epoch_losses = [] # Kept for IQL/CQL models for simplicity

if hasattr(model, "reset_spike_counts"):
model.reset_spike_counts()
Expand Down Expand Up @@ -249,9 +255,11 @@ def train(cfg, logger):
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
epoch_losses.append(loss.item())
# MODIFIED: Accumulate loss as a tensor.
total_loss += loss.detach()
if i % 10 == 0:
pbar.set_postfix(loss=f"{np.mean(epoch_losses):.4f}")
# MODIFIED: Call .item() only periodically for logging.
pbar.set_postfix(loss=f"{(total_loss / (i + 1)).item():.4f}")

# Evaluation, Checkpointing, and Logging
if (epoch + 1) % cfg.training.eval_every == 0:
Expand All @@ -260,7 +268,11 @@ def train(cfg, logger):
env = gym.make(cfg.env)
eval_results = evaluate_policy(model, env, cfg, episodes=10)
epoch_time = time.time() - start_time
avg_loss = np.mean(epoch_losses)
# MODIFIED: Calculate avg_loss from tensor or list based on model type.
if cfg.model.name in ['iql', 'cql']:
avg_loss = np.mean(epoch_losses) if epoch_losses else 0.0
else:
avg_loss = (total_loss / cfg.training.batches_per_epoch).item()

log_str = f"Epoch {epoch+1}/{cfg.training.epochs} | Time: {epoch_time:.2f}s | Loss: {avg_loss:.4f}"

Expand All @@ -282,7 +294,8 @@ def train(cfg, logger):
logger.info(f"New best eval return: {best_eval_return:.2f}. Saved best model.")

# Apply plasticity for SNN models
if isinstance(model, SnnDt) and model.use_plasticity:
# MODIFICATION: Use class name to avoid direct import of SnnDt.
if model.__class__.__name__ == 'SnnDt' and hasattr(model, 'use_plasticity') and model.use_plasticity:
model.apply_plasticity(eval_results["return_mean"])

# Periodic checkpointing
Expand Down