From 20b2bb4b74a69896002a64af6f331e1a8a554c8b Mon Sep 17 00:00:00 2001 From: Cassandra Grzonkowski <c_grzonkow18@cs.uni-kl.de> Date: Tue, 27 Feb 2024 06:18:02 +0100 Subject: [PATCH] device --- main.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 2b3dc3e..c98ee4c 100644 --- a/main.py +++ b/main.py @@ -117,14 +117,12 @@ def train_one_epoch(danceformer, train_loader, epoch, model_params_path, device, output_v_2_flat = torch.cat((output_flat[:, :empty_token], output_flat[:, empty_token + 1:]), dim=1).to(device) output_var_2_softmax = m(output_v_2_flat).to(device) # insert empty token with 0 prob as placeholer to have right index - print(output_var_2_softmax[:, :empty_token]) - print(torch.zeros((output_var_2_softmax.shape[0]))) - tmp = torch.cat((output_var_2_softmax[:, :empty_token], torch.zeros((output_var_2_softmax.shape[0]), 1)), dim=1).to(device) + tmp = torch.cat((output_var_2_softmax[:, :empty_token], torch.zeros((output_var_2_softmax.shape[0]), 1, device=device)), dim=1).to(device) output_var_2_softmax = torch.cat((tmp[:, :empty_token+1], output_var_2_softmax[:, empty_token:]), dim=1).to(device) # richtiges token siehte target/chart an entsprechender stelle right_token_probs = output_var_2_softmax[torch.arange(output_var_2_softmax.size(0)), charts_flat].unsqueeze(1).to(device) # take entries where non empty chart values - non_empty_ind = torch.where(charts_flat != empty_token).to(device) + non_empty_ind = torch.where(charts_flat != empty_token) # only right tokens prob ohne empty token right_token_probs = right_token_probs[non_empty_ind] if len(right_token_probs) != 0: -- GitLab