Skip to content
Snippets Groups Projects
Commit 20b2bb4b authored by Cassandra Grzonkowski's avatar Cassandra Grzonkowski
Browse files

device

parent 6ba8b072
1 merge request!1Instant spectograms
...@@ -117,14 +117,12 @@ def train_one_epoch(danceformer, train_loader, epoch, model_params_path, device, ...@@ -117,14 +117,12 @@ def train_one_epoch(danceformer, train_loader, epoch, model_params_path, device,
output_v_2_flat = torch.cat((output_flat[:, :empty_token], output_flat[:, empty_token + 1:]), dim=1).to(device) output_v_2_flat = torch.cat((output_flat[:, :empty_token], output_flat[:, empty_token + 1:]), dim=1).to(device)
output_var_2_softmax = m(output_v_2_flat).to(device) output_var_2_softmax = m(output_v_2_flat).to(device)
# insert empty token with 0 prob as placeholer to have right index # insert empty token with 0 prob as placeholer to have right index
print(output_var_2_softmax[:, :empty_token]) tmp = torch.cat((output_var_2_softmax[:, :empty_token], torch.zeros((output_var_2_softmax.shape[0]), 1, device=device)), dim=1).to(device)
print(torch.zeros((output_var_2_softmax.shape[0])))
tmp = torch.cat((output_var_2_softmax[:, :empty_token], torch.zeros((output_var_2_softmax.shape[0]), 1)), dim=1).to(device)
output_var_2_softmax = torch.cat((tmp[:, :empty_token+1], output_var_2_softmax[:, empty_token:]), dim=1).to(device) output_var_2_softmax = torch.cat((tmp[:, :empty_token+1], output_var_2_softmax[:, empty_token:]), dim=1).to(device)
# richtiges token siehte target/chart an entsprechender stelle # richtiges token siehte target/chart an entsprechender stelle
right_token_probs = output_var_2_softmax[torch.arange(output_var_2_softmax.size(0)), charts_flat].unsqueeze(1).to(device) right_token_probs = output_var_2_softmax[torch.arange(output_var_2_softmax.size(0)), charts_flat].unsqueeze(1).to(device)
# take entries where non empty chart values # take entries where non empty chart values
non_empty_ind = torch.where(charts_flat != empty_token).to(device) non_empty_ind = torch.where(charts_flat != empty_token)
# only right tokens prob ohne empty token # only right tokens prob ohne empty token
right_token_probs = right_token_probs[non_empty_ind] right_token_probs = right_token_probs[non_empty_ind]
if len(right_token_probs) != 0: if len(right_token_probs) != 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment