Safetensors
llama
File size: 4,982 Bytes
48db3f3
 
 
 
 
 
 
 
c303c10
 
ad02931
48db3f3
0b0b9bb
48db3f3
 
 
29adb43
 
6469e50
48db3f3
 
 
 
 
 
 
 
 
45a9390
48db3f3
 
ad02931
d5b69bb
48db3f3
 
d5b69bb
317e0de
d5b69bb
 
317e0de
ad02931
 
 
 
6ec00f0
ad02931
 
 
 
 
 
96c4f8e
ad02931
db0cbd6
5de84a6
 
ad02931
 
78cc042
ad02931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763bd6e
ad02931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44e4bf1
48db3f3
8a05cd6
 
 
8eafdae
8a05cd6
8eafdae
8a05cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbfb2e7
 
8a05cd6
cbfb2e7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
---
datasets:
- amaai-lab/MidiCaps
- projectlosangeles/Los-Angeles-MIDI-Dataset
base_model:
- meta-llama/Llama-3.2-1B-Instruct
---

### Write music scores with llama

### Try the model online: https://huggingface.co/spaces/dx2102/llama-midi

This model is finetuned from the `Llama-3.2-1B` language model.

It learns to write MIDI music scores with a text representation.

Optionally, the score title can also be used as a text prompt.

To use this model, you can simply take existing code and replace `meta-llama/Llama-3.2-1B` with `dx2102/llama-midi`.

```python
import torch
from transformers import pipeline

pipe = pipeline(
    "text-generation", 
    model="dx2102/llama-midi", 
    torch_dtype=torch.bfloat16, 
    device="cuda", # cuda/mps/cpu
)

txt = pipe(
'''
Bach
pitch duration wait velocity instrument
'''.strip(),
    max_new_tokens=10,
    temperature=1.0,
    top_p=1.0,
)[0]['generated_text']
print(txt)
```


To convert the text representation back to a midi file, try this:

```bash
# install this midi library
pip install symusic
```

[symusic](https://github.com/Yikai-Liao/symusic) is a fast C++/Python library for efficient MIDI manipulation.

```python
import symusic

# For example
txt = '''pitch duration wait velocity instrument

71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1310 0 20 0
48 330 350 20 0
57 330 350 20 0
66 1310 690 20 0
67 330 350 20 0
69 330 350 20 0
71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1970 0 20 0
48 330 350 20 0
'''

def postprocess(txt, path):
    # assert txt.startswith(prompt)
    txt = txt.split('\n\n')[-1]

    tracks = {}

    now = 0
    # we need to ignore the invalid output by the model
    try:
        for line in txt.split('\n'):
            pitch, duration, wait, velocity, instrument = line.split()
            pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]]
            if instrument not in tracks:
                tracks[instrument] = symusic.core.TrackSecond()
                if instrument != 'drum':
                    tracks[instrument].program = int(instrument)
                else:
                    tracks[instrument].is_drum = True
            # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Second')
            tracks[instrument].notes.append(symusic.core.NoteSecond(
                time=now/1000,
                duration=duration/1000,
                pitch=int(pitch),
                velocity=int(velocity * 4),
            ))
            now += wait
    except Exception as e:
        print('Postprocess: Ignored error:', e)
    
    print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes')

    try:
        score = symusic.Score(ttype='Second')
        score.tracks.extend(tracks.values())
        score.dump_midi(path)
    except Exception as e:
        print('Postprocess: Ignored postprocessing error:', e)

postprocess(txt, './result.mid')
```



Similarly, to convert a midi file to the text representation:

```python
def preprocess(path):
    # turn the midi into a custom format and write it to ./example/output.txt
    # midi files may be broken
    try:
        score = symusic.Score(path, ttype='Second')
    except Exception as e:
        print('Ignored midi loading error:', e)
        return ''
    
    # prolong notes to the end of the current pedal
    score = score.copy()
    for track in score.tracks:
        notes = track.notes
        pedals = track.pedals
        track.pedals = []
        j = 0
        for i, note in enumerate(notes):
            while j < len(pedals) and pedals[j].time + pedals[j].duration < note.time:
                j += 1
            if j < len(pedals) and pedals[j].time <= note.time <= pedals[j].time + pedals[j].duration:
                # adjust the duration
                note.duration = max(
                    note.duration, 
                    pedals[j].time + pedals[j].duration - note.time,
                )
    
    notes = []
    for track in score.tracks:
        instrument = str(track.program)   # program id. `instrument` is always a string.
        if track.is_drum:
            instrument = 'drum'
        for note in track.notes:
            notes.append((note.time, note.duration, note.pitch, note.velocity, instrument))
    # dedup
    notes = list({
        (time, duration, pitch): (time, duration, pitch, velocity, instrument) 
        for time, duration, pitch, velocity, instrument in notes
    }.values())
    # merge channels, sort by start time. If notes start at the same time, the higher pitch comes first.
    notes.sort(key=lambda x: (x[0], -x[2]))
    # Translate start time to the delta time format: 
    # ie. 'pitch duration wait', in milliseconds.
    notes1 = []
    
    txt = []
    for note in notes1:
        txt.append(' '.join(map(str, note)))
    txt = '\n'.join(txt)
    return txt

txt = preprocess('./test.mid')
```