Skip to content
Snippets Groups Projects
Commit 5676ab57 authored by Cassandra Grzonkowski's avatar Cassandra Grzonkowski
Browse files

device

parent 51526f9d
No related branches found
No related tags found
1 merge request!1Instant spectograms
......@@ -117,7 +117,7 @@ def train_one_epoch(danceformer, train_loader, epoch, model_params_path, device,
output_v_2_flat = torch.cat((output_flat[:, :empty_token], output_flat[:, empty_token + 1:]), dim=1).to(device)
output_var_2_softmax = m(output_v_2_flat).to(device)
# insert empty token with 0 prob as placeholer to have right index
tmp = torch.cat((output_var_2_softmax[:, :empty_token], torch.zeros((output_var_2_softmax.shape[0]), 1)), dim=1).to(device)
tmp = torch.cat((output_var_2_softmax[:, :empty_token].to(device), torch.zeros((output_var_2_softmax.shape[0]).to(device), 1)), dim=1).to(device)
output_var_2_softmax = torch.cat((tmp[:, :empty_token+1], output_var_2_softmax[:, empty_token:]), dim=1).to(device)
# richtiges token siehte target/chart an entsprechender stelle
right_token_probs = output_var_2_softmax[torch.arange(output_var_2_softmax.size(0)), charts_flat].unsqueeze(1).to(device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment