From 20b2bb4b74a69896002a64af6f331e1a8a554c8b Mon Sep 17 00:00:00 2001
From: Cassandra Grzonkowski <c_grzonkow18@cs.uni-kl.de>
Date: Tue, 27 Feb 2024 06:18:02 +0100
Subject: [PATCH] device

---
 main.py | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/main.py b/main.py
index 2b3dc3e..c98ee4c 100644
--- a/main.py
+++ b/main.py
@@ -117,14 +117,12 @@ def train_one_epoch(danceformer, train_loader, epoch, model_params_path, device,
         output_v_2_flat = torch.cat((output_flat[:, :empty_token], output_flat[:, empty_token + 1:]), dim=1).to(device)
         output_var_2_softmax = m(output_v_2_flat).to(device)
         # insert empty token with 0 prob as placeholer to have right index
-        print(output_var_2_softmax[:, :empty_token])
-        print(torch.zeros((output_var_2_softmax.shape[0])))
-        tmp = torch.cat((output_var_2_softmax[:, :empty_token], torch.zeros((output_var_2_softmax.shape[0]), 1)), dim=1).to(device)
+        tmp = torch.cat((output_var_2_softmax[:, :empty_token], torch.zeros((output_var_2_softmax.shape[0]), 1, device=device)), dim=1).to(device)
         output_var_2_softmax = torch.cat((tmp[:, :empty_token+1], output_var_2_softmax[:, empty_token:]), dim=1).to(device)
         # richtiges token siehte target/chart an entsprechender stelle
         right_token_probs = output_var_2_softmax[torch.arange(output_var_2_softmax.size(0)), charts_flat].unsqueeze(1).to(device)
         # take entries where non empty chart values
-        non_empty_ind = torch.where(charts_flat != empty_token).to(device)
+        non_empty_ind = torch.where(charts_flat != empty_token)
         # only right tokens prob ohne empty token
         right_token_probs = right_token_probs[non_empty_ind]
         if len(right_token_probs) != 0:
-- 
GitLab