Skip to content
Snippets Groups Projects
Select Git revision
  • 1665de6fdd94185b3e3472f5f5ef3aac8fb30e7d
  • main default protected
2 results

emah.fingui

Blame
  • load_vocab.py 4.15 KiB
    import pickle
    import re
    import simfile
    import os
    from io import StringIO
    from simfile.notes import NoteData
    from simfile.notes.timed import time_notes
    from simfile.timing import TimingData
    
    
    def filter_whitespaces(notes):
        # regular expression pattern matching multiple spaces between \n
        pattern = re.compile(r'\n\s*\n')
    
        filtered_notes = pattern.sub('\n', notes)
    
        return filtered_notes
    
    
    def index_tokens(token_sequence):
        # Create an index for the tokens
        token_index = {token: index for index, token in enumerate(token_sequence)}
    
        return token_index
    
    
    def built_chart_vocabulary(song_dirs):
        vocabulary_set = set()
        cols_init = 4
        # once add zero string
        note_string = ["0"] * cols_init
        notedata_new = StringIO()
        notedata_new.write("".join(note_string))
        # print(notedata_new.getvalue())
        vocabulary_set.update([notedata_new.getvalue()])
        # special token/unknown token
        for num, song_dir in enumerate(song_dirs):
            if num % 1000 == 0:
                print(f"Song: {num}")
            for item in os.listdir(song_dir):
                if item.endswith('.sm'):
                    # specs.append(f"{song_dir}/{item}")
                    with open(f"{song_dir}/{item}", 'r', encoding="ISO-8859-1") as infile:
                        sm_file = simfile.load(infile)
    
                    # get chart of middle difficulty
                    # chart = get_middle_chart(sm_file.charts)
                    charts = sm_file.charts
                    # print("Built vocabulary")
                    for chart in charts:
                        # notes is string
                        note_data = NoteData(chart)
                        timing_data = TimingData(sm_file)
                        cols = note_data.columns
    
                        # dictionary
                        tmp = dict()
                        for timed_note in time_notes(note_data, timing_data):
                            if timed_note.time in tmp.keys():
                                tmp[timed_note.time] = tmp[timed_note.time][:timed_note.note.column] + str(timed_note.note) + \
                                                       tmp[timed_note.time][timed_note.note.column + 1:]
                            else:
                                note_string = ["0"] * cols
                                note_string[timed_note.note.column] = str(timed_note.note)
                                note_string_new = StringIO()
                                note_string_new.write("".join(note_string))
                                tmp[timed_note.time] = note_string_new.getvalue()
    
                        vocabulary_set.update(tmp.values())
    
                        # alternative dictionary
                        #tmp = dict()
                        #for i, note in enumerate(notes):
                        #    tmp[note.beat].append(note)
    
    
        print("----------------------------------")
        print("Built Vocabulary:")
        print(vocabulary_set)
        indexed_vocabulary = index_tokens(vocabulary_set)
        print("Indexing:")
        print(indexed_vocabulary)
        print("Length vocabulary:")
        print(len(indexed_vocabulary))
        return indexed_vocabulary
    
    
    if __name__ == '__main__':
        # prepare pack/folder of all songs
        # folder of packages, each including folders for songs
        folder = r'/work/MLShare/StepMania/data/cleaned/allowed_meter_difference_of_2/'
        #folder = r'C:/Users/cassi/OneDrive/Desktop/Master_Thesis/test_pack_2/'
        pkgs = os.listdir(folder)
    
        songs = []
        for pkg in pkgs:
            song_folder_names = os.listdir(f"{folder}{pkg}")
            for song in song_folder_names:
                songs.append(f"{folder}{pkg}/{song}")
    
        # song_dirs = []
        # for song in songs:
        #    song_dirs.append(f"{folder}{song}")
        print(f"Songs in given folder: {songs}")
        print(f"Number songs: {len(songs)}")
        # print(f"Songs paths: {song_dirs}")
    
        # preprocessing
        indexed_vocabulary = built_chart_vocabulary(songs)
        # save vocabulary for reusing
        with open(r'/scratch/grzonkow/vocabulary.pkl', 'wb') as fp:
            pickle.dump(indexed_vocabulary, fp)
            print('dictionary saved successfully to file')
    
        with open(r'/scratch/grzonkow/vocabulary.txt', 'w') as f:
            f.write(f"Length vocabulary: {len(indexed_vocabulary)}")
            f.write(f"Vocabulary: {indexed_vocabulary}")