Skip to content
Snippets Groups Projects
Commit 2b51d1ac authored by Cassandra Grzonkowski's avatar Cassandra Grzonkowski
Browse files

save dataset, remove audio if there is no chart

parent dfa263fc
1 merge request!1Instant spectograms
......@@ -280,6 +280,8 @@ def custom_collate_fn(batch, max_len_specs, vocabulary):
def setup_parser():
out = argparse.ArgumentParser()
out.add_argument('--vocabulary', default=None, type=str, help="Path to vocabulary")
#out.add_argument('--dataset', default="C:/Users/cassi/OneDrive/Desktop/Master_Thesis/train_dataset.pkl", type=str, help="Path to dataset")
out.add_argument('--dataset', default=None, type=str, help="Path to dataset")
return out
......@@ -323,23 +325,14 @@ if __name__ == '__main__':
print("Number parameters:")
print(sum(p.numel() for p in danceformer.parameters() if p.requires_grad))
if args.dataset is None:
start_time_first_step = time.time()
all_paths, difficulties, charts, max_len_charts = get_paths_diff_charts(songs, indexed_vocabulary)
elapsed_first_step = time.time() - start_time_first_step
print('-' * 89)
print(f'Get paths diff and charts done, time: {elapsed_first_step:5.2f}s')
print('-' * 89)
# print(f"\nDifficulties: {difficulties}")
# difficulties = np.array(difficulties)
#print(f"\nNumber songs: {len(charts)}")
#print(f"\nNumber difficulties: {len(difficulties)}")
#print(f"\nDifficulties: {difficulties}")
#print(f"First entry length of charts: {len(charts[0])}")
#print(f"First chart first 20 entries: {charts[0][:21]}")
# Data loading.
train_dataset = SongPacksDataset(
......@@ -352,6 +345,25 @@ if __name__ == '__main__':
# shuffle = True,
)
#with open(r'C:/Users/cassi/OneDrive/Desktop/Master_Thesis/train_dataset.pkl', 'wb') as output:
with open(r'/scratch/grzonkow/train_dataset.pkl', 'wb') as output:
pickle.dump(train_dataset, output)
else:
with open(args.dataset, 'rb') as data:
train_dataset = pickle.load(data)
# print(f"\nDifficulties: {difficulties}")
# difficulties = np.array(difficulties)
#print(f"\nNumber songs: {len(charts)}")
#print(f"\nNumber difficulties: {len(difficulties)}")
#print(f"\nDifficulties: {difficulties}")
#print(f"First entry length of charts: {len(charts[0])}")
#print(f"First chart first 20 entries: {charts[0][:21]}")
# max length = self defined maximal length of token sequence in seconds
# for each element/song: path of .og file, difficulties, charts
# check batch, output, maybe change to class / transform !
......
......@@ -7,8 +7,8 @@
#SBATCH --cpus-per-task=1 # use 1 thread per taks
#SBATCH -N 1 # request slots on 1 node
#SBATCH --partition=informatik-mind # run on one of our DGX servers
#SBATCH --output=/scratch/grzonkow/model_prints_new.txt # capture output
#SBATCH --error=/scratch/grzonkow/err_model_prints_new.txt # and error streams
#SBATCH --output=/scratch/grzonkow/model_latest.txt # capture output
#SBATCH --error=/scratch/grzonkow/err_model_latest.txt # and error streams
module load anaconda3/latest
. $ANACONDA_HOME/etc/profile.d/conda.sh
......
......@@ -164,6 +164,10 @@ def get_paths_diff_charts(song_dirs, indexed_vocabulary):
for number in range(number_of_charts_diffs-1):
all_paths.append(all_paths[-1])
# if no chart found but audio, remove audio, can be at most one more audio/path according check before
# -> remove last filepath
if number_of_charts_diffs == 0 and len(all_paths) > all_paths_old_len:
all_paths = all_paths[:-1]
# verify number items is equal
#print(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment