Skip to content
Snippets Groups Projects
Select Git revision
  • 51b4268b79ac5b9575c96a0e5b792655167a0870
  • main default protected
  • fix_prediction_combine_tokens
  • test_difficulty
  • fix_prediction_learn_constant
  • window_spectrograms
6 results

preliminary_experiment.py

Blame
  • preliminary_experiment.py 10.07 KiB
    import simfile
    from simfile.notes import Note, NoteType, NoteData
    from simfile.notes.group import OrphanedNotes, group_notes
    from simfile.notes.count import *
    from simfile.timing import Beat, TimingData
    from simfile.timing.engine import TimingEngine
    from simfile.notes.timed import time_notes
    from simfile.timing.displaybpm import displaybpm
    from simfile.notes.group import OrphanedNotes, group_notes
    from typing import Iterator, Optional, Sequence
    from simfile.types import Chart
    import io
    import os
    import tarfile
    import tempfile
    import numpy as np
    
    import torch
    import torchaudio
    import matplotlib.pyplot as plt
    
    #print(torch.__version__)
    #print(torchaudio.__version__)
    
    
    # Timing & note data from https://simfile.readthedocs.io/en/latest/timing-note-data.html#reading-note-data
    def read_file_meta_data(path):
        # open and load file
        opened_file = simfile.open(path)
        with open(path, 'r', encoding="utf8") as infile:
            test_file = simfile.load(infile)
    
        # some information about file
        #print(f"Info: \n {list(test_file.keys())}")
        # print(f"Info: \n {list(test_file.keys())[:12]}")
        print(f"{list(test_file.keys())[12:24]}")
        print(f"{list(test_file.keys())[24:]}")
        print(f"Artist: {test_file.artist}")
        print(f"Title: {test_file.title}")
        print(f"BPMs: {test_file.bpms}")
        timing_data = TimingData(opened_file)
        #chart = test_file.charts[0] #['Medium'].data
        #chart = get_hardest_chart(test_file.charts)
        #chart = get_lowest_chart(test_file.charts)
        chart = get_middle_chart(test_file.charts)
        print(f"Stepstype: {chart.stepstype}")
        print(f"Meter: {chart.meter}")
        # print(chart)
        #print(f"BPMs timed: {timing_data.bpms} \n")
    
        # Reading note data
        note_data = NoteData(chart)
        # print(note_data.columns)
        #print("Reading note data: ")
        #for note in note_data:
            #if note.beat > Beat(18): break
            #print(f"Beat: {note.beat}, Note Type: {note.note_type}, Column: {note.column}, Player: {note.player}")
        #print("\n")
    
        # Counting notes
        print(f"Number steps: {count_steps(note_data)}")
        print(f"Number jumps: {count_jumps(note_data)}")
    
        # Handling holds, rolls, and jumps
        #chart_new = next(filter(lambda chart_new: chart_new.meter == '1', opened_file.charts))
        #note_data_new = NoteData(chart_new)
        #group_iterator = group_notes(
        #    note_data_new,
         #   include_note_types={NoteType.HOLD_HEAD, NoteType.TAIL},
        #    join_heads_to_tails=True,
          #  orphaned_tail=OrphanedNotes.DROP_ORPHAN,
        #)
        #longest_hold = 0
        #for grouped_notes in group_iterator:
        #    note = grouped_notes[0]
        ##    longest_hold = max(longest_hold, note.tail_beat - note.beat)
        #print(f"Longest hold: {longest_hold}")
    
        # Changing & writing notes
        cols = 4
        notes = [
            Note(beat=Beat(i, 2), column=i % cols, note_type=NoteType.TAP)
            for i in range(8)
        ]
        note_data_to_write = NoteData.from_notes(notes, cols)
        #print(f"Example new note data: \n {str(note_data_to_write)}")
        # insert back
        cols_number = note_data.columns
    
        def mirror(note, cols_number):
            return Note(
                beat=note.beat,
                column=cols_number - note.column - 1,
                note_type=note.note_type,
            )
        # print(chart.notes[0])
        mirrored_notes = (mirror(note, cols) for note in note_data)
        mirrored_note_data = NoteData.from_notes(mirrored_notes, cols_number)
        # print(str(mirrored_note_data))
        # chart.notes = str(mirrored_note_data)
    
        # Reading timing data
        split_timing = TimingData(opened_file, chart)
        print(f"BPMs splitted: {split_timing.bpms}")
        print(f"Offset: {test_file.offset} \n")
    
        # Getting the displayed BPM
        #disp = displaybpm(opened_file)
        #if disp.value:
        #    print(f"Static value: {disp.value} \n")
        #elif disp.range:
        #    print(f"Range of values: {disp.value[0]}-{disp.value[1]} \n")
        #else:
        #    print(f"* (obfuscated BPM) \n")
    
        # Converting song time to beats
        engine = TimingEngine(timing_data)
        #print(f"Time at beat: {engine.time_at(Beat(32))}")
        #print(f"Beat at time: {engine.beat_at(19.222)} \n")
    
        # Combining notes and time
        #print("Timed notes: ")
        count = 0
        for timed_note in time_notes(note_data, timing_data):
            count += 1
            #if timed_note.note.note_type == NoteType.TAIL:
            #    print(timed_note.note)
            #if 60 < timed_note.time < 61:
                #print(timed_note)
                #print(timed_note.note)
        #print(f"Number timed notes: {count}")
    
        note_dist, time_diffs = find_note_distances(chart, opened_file)
        #print(note_dist)
        #print(time_diffs)
        print(f"Number time differences between notes: {len(note_dist)}")
    
        most_dist = plot_histogram(time_diffs, path)
        print(timing_data.bpms.__getitem__(0)[1])
        dist = calculate_bpm_dist(most_dist, timing_data.bpms.__getitem__(0)[1])
        print(f"Distance between peak value and bpm is: {dist}")
    
        return dist
    
        # steps = chart.charts['Medium'].steps
        #
        # for step in steps:
        #     arrow = step[0]  # Arrow type (0, 1, 2, 3, M, L, F)
        #     beat = step[1]  # Beat position
        #     print(f"Arrow: {arrow}, Beat: {beat}")
    
    
    # Get charts for one game mode: https://simfile.readthedocs.io/en/latest/examples.html
    def charts_for_stepstype(charts, stepstype='dance-single') -> Iterator[Chart]:
        for chart in charts:
            if chart.stepstype == stepstype:
                yield chart
    
    
    # Get the hardest chart: https://simfile.readthedocs.io/en/latest/examples.html
    def get_hardest_chart(charts) -> Optional[Chart]:
        hardest_chart: Optional[Chart] = None
        hardest_meter: Optional[int] = None
    
        for chart in charts:
            # Remember to convert `meter` to an integer for comparisons
            meter = int(chart.meter or "1")
            if hardest_meter is None or meter > hardest_meter:
                hardest_chart = chart
                hardest_meter = meter
    
        return hardest_chart
    
    
    def get_lowest_chart(charts) -> Optional[Chart]:
        lowest_chart: Optional[Chart] = None
        lowest_meter: Optional[int] = None
    
        for chart in charts:
            # Remember to convert `meter` to an integer for comparisons
            meter = int(chart.meter) #or "10")
            if lowest_meter is None or meter < lowest_meter:
                lowest_chart = chart
                lowest_meter = meter
    
        return lowest_chart
    
    
    def get_middle_chart(charts) -> Optional[Chart]:
        middle_chart: Optional[Chart] = None
    
        for chart in charts:
            # Remember to convert `meter` to an integer for comparisons
            meter = int(chart.meter) #or "10")
            if meter == 5 or meter == 6 or meter == 4:
                middle_chart = chart
                return middle_chart
    
    
        return get_lowest_chart(charts)
    
    
    def find_note_distances(chart, opened_file):
        note_data = NoteData(chart)
        timing_data = TimingData(opened_file)
    
        # count distance between notes
        prev_note = None
        note_dist = []
        time_diffs = []
        for timed_note in time_notes(note_data, timing_data):
            if prev_note is None:
                prev_note = timed_note
            else:
                time_diff = timed_note.time - prev_note.time
                # temp_note_dist = [{}]
                # temp_note_dist['Note_prev'] = {prev_note}
                # temp_note_dist['Note_foll'] = {timed_note}
                # temp_note_dist['Time_diff'] = {time_diff}
                # note_dist.append(temp_note_dist)
                # note.info
                note_dist.append([(prev_note.note, timed_note.note, time_diff)])
                time_diffs.append(time_diff)
                prev_note = timed_note
    
        return note_dist, time_diffs
    
    
    def plot_histogram(time_diffs, path):
        y, x, _  = plt.hist(time_diffs, bins=77)  # bins=10, Adjust the number of bins as needed
    
        most_dist = x[np.where(y == y.max())]
    
        plt.xlabel('Time Difference')
        plt.ylabel('Frequency')
        plt.title('Time Difference of consecutive notes')
        plt.savefig(path.rsplit('/', 1)[0] + "/histogram.png")
        plt.close()
        return most_dist[0]
    
    
    def plot_end_histogram(time_diffs, path):
        y, x, _  = plt.hist(time_diffs, bins=77)  # bins=10, Adjust the number of bins as needed
    
        most_dist = x[np.where(y == y.max())]
    
        plt.xlabel('Ratio')
        plt.ylabel('Frequency')
        plt.title('Ratio between given and found seconds per beat')
        plt.savefig(path.rsplit('/', 1)[0] + "/histogram.png")
        plt.close()
        return most_dist[0]
    
    
    # method from https://pytorch.org/audio/stable/tutorials/audio_io_tutorial.html
    def plot_waveform(waveform, sample_rate):
        waveform = waveform.numpy()
    
        num_channels, num_frames = waveform.shape
        time_axis = torch.arange(0, num_frames) / sample_rate
    
        figure, axes = plt.subplots(num_channels, 1)
        if num_channels == 1:
            axes = [axes]
        for c in range(num_channels):
            axes[c].plot(time_axis, waveform[c], linewidth=1)
            axes[c].grid(True)
            if num_channels > 1:
                axes[c].set_ylabel(f"Channel {c+1}")
        figure.suptitle("waveform")
        #plt.show(block=False)
        plt.savefig("./outputs/wave.png")
    
    
    # method from https://pytorch.org/audio/stable/tutorials/audio_io_tutorial.html
    def plot_spectogram(waveform, sample_rate, title="Spectrogram"):
        waveform = waveform.numpy()
    
        num_channels, num_frames = waveform.shape
    
        figure, axes = plt.subplots(num_channels, 1)
        if num_channels == 1:
            axes = [axes]
        for c in range(num_channels):
            axes[c].specgram(waveform[c], Fs=sample_rate)
            if num_channels > 1:
                axes[c].set_ylabel(f"Channel {c+1}")
        figure.suptitle(title)
        #plt.show(block=False)
        plt.savefig("./outputs/spectogram.png")
    
    
    def read_audio(SAMPLE_WAV):
        metadata = torchaudio.info(SAMPLE_WAV)
        print(metadata)
        waveform, sample_rate = torchaudio.load(SAMPLE_WAV)
        plot_waveform(waveform, sample_rate)
        plot_spectogram(waveform, sample_rate)
    
    
    def calculate_bpm_dist(most_dist, bpm):
        # calculate seconds per beat
        if bpm != 0:
            new_spb = 60.0/float(bpm)
        else:
            print("Case 1")
            new_spb = bpm
        print(f"Found seconds per beat: {most_dist}")
        print(f"Given seconds per beat: {new_spb}")
        if most_dist != 0:
            return new_spb / float(most_dist)
        else:
            print("Case 2")
            return most_dist