--- datasets: - amaai-lab/MidiCaps - projectlosangeles/Los-Angeles-MIDI-Dataset base_model: - meta-llama/Llama-3.2-1B-Instruct --- ### Write music scores with llama ### Try the model online: https://huggingface.co/spaces/dx2102/llama-midi This model is finetuned from the `Llama-3.2-1B` language model. It learns to write MIDI music scores with a text representation. Optionally, the score title can also be used as a text prompt. To use this model, you can simply take existing code and replace `meta-llama/Llama-3.2-1B` with `dx2102/llama-midi`. ```python import torch from transformers import pipeline pipe = pipeline( "text-generation", model="dx2102/llama-midi", torch_dtype=torch.bfloat16, device="cuda", # cuda/mps/cpu ) txt = pipe( ''' Bach pitch duration wait velocity instrument '''.strip(), max_new_tokens=10, temperature=1.0, top_p=1.0, )[0]['generated_text'] print(txt) ``` To convert the text representation back to a midi file, try this: ```bash # install this midi library pip install symusic ``` [symusic](https://github.com/Yikai-Liao/symusic) is a fast C++/Python library for efficient MIDI manipulation. ```python import symusic # For example txt = '''pitch duration wait velocity instrument 71 1310 0 20 0 48 330 350 20 0 55 330 350 20 0 64 1310 690 20 0 74 660 690 20 0 69 1310 0 20 0 48 330 350 20 0 57 330 350 20 0 66 1310 690 20 0 67 330 350 20 0 69 330 350 20 0 71 1310 0 20 0 48 330 350 20 0 55 330 350 20 0 64 1310 690 20 0 74 660 690 20 0 69 1970 0 20 0 48 330 350 20 0 ''' def postprocess(txt, path): # assert txt.startswith(prompt) txt = txt.split('\n\n')[-1] tracks = {} now = 0 # we need to ignore the invalid output by the model try: for line in txt.split('\n'): pitch, duration, wait, velocity, instrument = line.split() pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]] if instrument not in tracks: tracks[instrument] = symusic.core.TrackSecond() if instrument != 'drum': tracks[instrument].program = int(instrument) else: tracks[instrument].is_drum = True # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Second') tracks[instrument].notes.append(symusic.core.NoteSecond( time=now/1000, duration=duration/1000, pitch=int(pitch), velocity=int(velocity * 4), )) now += wait except Exception as e: print('Postprocess: Ignored error:', e) print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes') try: score = symusic.Score(ttype='Second') score.tracks.extend(tracks.values()) score.dump_midi(path) except Exception as e: print('Postprocess: Ignored postprocessing error:', e) postprocess(txt, './result.mid') ``` Similarly, to convert a midi file to the text representation: ```python def preprocess(path): # turn the midi into a custom format and write it to ./example/output.txt # midi files may be broken try: score = symusic.Score(path, ttype='Second') except Exception as e: print('Ignored midi loading error:', e) return '' # prolong notes to the end of the current pedal score = score.copy() for track in score.tracks: notes = track.notes pedals = track.pedals track.pedals = [] j = 0 for i, note in enumerate(notes): while j < len(pedals) and pedals[j].time + pedals[j].duration < note.time: j += 1 if j < len(pedals) and pedals[j].time <= note.time <= pedals[j].time + pedals[j].duration: # adjust the duration note.duration = max( note.duration, pedals[j].time + pedals[j].duration - note.time, ) notes = [] for track in score.tracks: instrument = str(track.program) # program id. `instrument` is always a string. if track.is_drum: instrument = 'drum' for note in track.notes: notes.append((note.time, note.duration, note.pitch, note.velocity, instrument)) # dedup notes = list({ (time, duration, pitch): (time, duration, pitch, velocity, instrument) for time, duration, pitch, velocity, instrument in notes }.values()) # merge channels, sort by start time. If notes start at the same time, the higher pitch comes first. notes.sort(key=lambda x: (x[0], -x[2])) # Translate start time to the delta time format: # ie. 'pitch duration wait', in milliseconds. notes1 = [] txt = [] for note in notes1: txt.append(' '.join(map(str, note))) txt = '\n'.join(txt) return txt txt = preprocess('./test.mid') ```