From cbbd9f83096b378f76c4dd4d76eade92405b2712 Mon Sep 17 00:00:00 2001
From: Cassandra Grzonkowski <c_grzonkow18@cs.uni-kl.de>
Date: Tue, 27 Feb 2024 05:18:23 +0100
Subject: [PATCH] device

---
 main.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/main.py b/main.py
index ad97ede..65eb1fb 100644
--- a/main.py
+++ b/main.py
@@ -109,7 +109,7 @@ def train_one_epoch(danceformer, train_loader, epoch, model_params_path, device,
         # check how probability per token develops
         #outputs.append(output_flat)
         #outputs_softmax.append(torch.nn.Softmax(output_flat))
-        m = torch.nn.Softmax(dim=1)
+        m = torch.nn.Softmax(dim=1).to(device)
         # 1. variante, verhältnis nicht leere zu leere tokens
         output_softmax = m(output_flat)
         # 2.variante, verhältnis richtiges token zu nicht leere tokens
-- 
GitLab