|
|
--- |
|
|
datasets: |
|
|
- amaai-lab/MidiCaps |
|
|
- projectlosangeles/Los-Angeles-MIDI-Dataset |
|
|
base_model: |
|
|
- meta-llama/Llama-3.2-1B-Instruct |
|
|
--- |
|
|
|
|
|
### Write music scores with llama |
|
|
|
|
|
### Try the model online: https://huggingface.co/spaces/dx2102/llama-midi |
|
|
|
|
|
This model is finetuned from the `Llama-3.2-1B` language model. |
|
|
|
|
|
It learns to write MIDI music scores with a text representation. |
|
|
|
|
|
Optionally, the score title can also be used as a text prompt. |
|
|
|
|
|
To use this model, you can simply take existing code and replace `meta-llama/Llama-3.2-1B` with `dx2102/llama-midi`. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import pipeline |
|
|
|
|
|
pipe = pipeline( |
|
|
"text-generation", |
|
|
model="dx2102/llama-midi", |
|
|
torch_dtype=torch.bfloat16, |
|
|
device="cuda", # cuda/mps/cpu |
|
|
) |
|
|
|
|
|
txt = pipe( |
|
|
''' |
|
|
Bach |
|
|
pitch duration wait velocity instrument |
|
|
'''.strip(), |
|
|
max_new_tokens=10, |
|
|
temperature=1.0, |
|
|
top_p=1.0, |
|
|
)[0]['generated_text'] |
|
|
print(txt) |
|
|
``` |
|
|
|
|
|
|
|
|
To convert the text representation back to a midi file, try this: |
|
|
|
|
|
```bash |
|
|
# install this midi library |
|
|
pip install symusic |
|
|
``` |
|
|
|
|
|
[symusic](https://github.com/Yikai-Liao/symusic) is a fast C++/Python library for efficient MIDI manipulation. |
|
|
|
|
|
```python |
|
|
import symusic |
|
|
|
|
|
# For example |
|
|
txt = '''pitch duration wait velocity instrument |
|
|
|
|
|
71 1310 0 20 0 |
|
|
48 330 350 20 0 |
|
|
55 330 350 20 0 |
|
|
64 1310 690 20 0 |
|
|
74 660 690 20 0 |
|
|
69 1310 0 20 0 |
|
|
48 330 350 20 0 |
|
|
57 330 350 20 0 |
|
|
66 1310 690 20 0 |
|
|
67 330 350 20 0 |
|
|
69 330 350 20 0 |
|
|
71 1310 0 20 0 |
|
|
48 330 350 20 0 |
|
|
55 330 350 20 0 |
|
|
64 1310 690 20 0 |
|
|
74 660 690 20 0 |
|
|
69 1970 0 20 0 |
|
|
48 330 350 20 0 |
|
|
''' |
|
|
|
|
|
def postprocess(txt, path): |
|
|
# assert txt.startswith(prompt) |
|
|
txt = txt.split('\n\n')[-1] |
|
|
|
|
|
tracks = {} |
|
|
|
|
|
now = 0 |
|
|
# we need to ignore the invalid output by the model |
|
|
try: |
|
|
for line in txt.split('\n'): |
|
|
pitch, duration, wait, velocity, instrument = line.split() |
|
|
pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]] |
|
|
if instrument not in tracks: |
|
|
tracks[instrument] = symusic.core.TrackSecond() |
|
|
if instrument != 'drum': |
|
|
tracks[instrument].program = int(instrument) |
|
|
else: |
|
|
tracks[instrument].is_drum = True |
|
|
# Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Second') |
|
|
tracks[instrument].notes.append(symusic.core.NoteSecond( |
|
|
time=now/1000, |
|
|
duration=duration/1000, |
|
|
pitch=int(pitch), |
|
|
velocity=int(velocity * 4), |
|
|
)) |
|
|
now += wait |
|
|
except Exception as e: |
|
|
print('Postprocess: Ignored error:', e) |
|
|
|
|
|
print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes') |
|
|
|
|
|
try: |
|
|
score = symusic.Score(ttype='Second') |
|
|
score.tracks.extend(tracks.values()) |
|
|
score.dump_midi(path) |
|
|
except Exception as e: |
|
|
print('Postprocess: Ignored postprocessing error:', e) |
|
|
|
|
|
postprocess(txt, './result.mid') |
|
|
``` |
|
|
|
|
|
|
|
|
|
|
|
Similarly, to convert a midi file to the text representation: |
|
|
|
|
|
```python |
|
|
def preprocess(path): |
|
|
# turn the midi into a custom format and write it to ./example/output.txt |
|
|
# midi files may be broken |
|
|
try: |
|
|
score = symusic.Score(path, ttype='Second') |
|
|
except Exception as e: |
|
|
print('Ignored midi loading error:', e) |
|
|
return '' |
|
|
|
|
|
# prolong notes to the end of the current pedal |
|
|
score = score.copy() |
|
|
for track in score.tracks: |
|
|
notes = track.notes |
|
|
pedals = track.pedals |
|
|
track.pedals = [] |
|
|
j = 0 |
|
|
for i, note in enumerate(notes): |
|
|
while j < len(pedals) and pedals[j].time + pedals[j].duration < note.time: |
|
|
j += 1 |
|
|
if j < len(pedals) and pedals[j].time <= note.time <= pedals[j].time + pedals[j].duration: |
|
|
# adjust the duration |
|
|
note.duration = max( |
|
|
note.duration, |
|
|
pedals[j].time + pedals[j].duration - note.time, |
|
|
) |
|
|
|
|
|
notes = [] |
|
|
for track in score.tracks: |
|
|
instrument = str(track.program) # program id. `instrument` is always a string. |
|
|
if track.is_drum: |
|
|
instrument = 'drum' |
|
|
for note in track.notes: |
|
|
notes.append((note.time, note.duration, note.pitch, note.velocity, instrument)) |
|
|
# dedup |
|
|
notes = list({ |
|
|
(time, duration, pitch): (time, duration, pitch, velocity, instrument) |
|
|
for time, duration, pitch, velocity, instrument in notes |
|
|
}.values()) |
|
|
# merge channels, sort by start time. If notes start at the same time, the higher pitch comes first. |
|
|
notes.sort(key=lambda x: (x[0], -x[2])) |
|
|
# Translate start time to the delta time format: |
|
|
# ie. 'pitch duration wait', in milliseconds. |
|
|
notes1 = [] |
|
|
|
|
|
txt = [] |
|
|
for note in notes1: |
|
|
txt.append(' '.join(map(str, note))) |
|
|
txt = '\n'.join(txt) |
|
|
return txt |
|
|
|
|
|
txt = preprocess('./test.mid') |
|
|
``` |
|
|
|
|
|
|
|
|
|