From cbbd9f83096b378f76c4dd4d76eade92405b2712 Mon Sep 17 00:00:00 2001 From: Cassandra Grzonkowski <c_grzonkow18@cs.uni-kl.de> Date: Tue, 27 Feb 2024 05:18:23 +0100 Subject: [PATCH] device --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index ad97ede..65eb1fb 100644 --- a/main.py +++ b/main.py @@ -109,7 +109,7 @@ def train_one_epoch(danceformer, train_loader, epoch, model_params_path, device, # check how probability per token develops #outputs.append(output_flat) #outputs_softmax.append(torch.nn.Softmax(output_flat)) - m = torch.nn.Softmax(dim=1) + m = torch.nn.Softmax(dim=1).to(device) # 1. variante, verhältnis nicht leere zu leere tokens output_softmax = m(output_flat) # 2.variante, verhältnis richtiges token zu nicht leere tokens -- GitLab