Skip to main content

music21_rs/
midi.rs

1use std::collections::BTreeMap;
2
3use crate::{
4    defaults::{FloatType, IntegerType},
5    duration::Duration,
6    error::{Error, Result},
7    note::Note,
8    pitch::Pitch,
9    stream::{Stream, StreamElement},
10};
11
12/// Default MIDI pulses per quarter note used by the byte import/export helpers.
13pub const DEFAULT_TICKS_PER_QUARTER: u16 = 480;
14
15/// A note event in quarter-length time.
16#[derive(Clone, Copy, Debug, PartialEq)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct MidiNote {
19    /// MIDI key number, from 0 to 127.
20    pub pitch: u8,
21    /// MIDI note-on velocity, from 0 to 127.
22    pub velocity: u8,
23    /// MIDI channel, from 0 to 15.
24    pub channel: u8,
25    /// Start offset measured in quarter lengths.
26    pub start: FloatType,
27    /// Duration measured in quarter lengths.
28    pub duration: FloatType,
29}
30
31impl MidiNote {
32    /// Creates a MIDI note event.
33    pub fn new(pitch: u8, start: FloatType, duration: FloatType, velocity: u8) -> Result<Self> {
34        Self::with_channel(pitch, start, duration, velocity, 0)
35    }
36
37    /// Creates a MIDI note event with an explicit channel.
38    pub fn with_channel(
39        pitch: u8,
40        start: FloatType,
41        duration: FloatType,
42        velocity: u8,
43        channel: u8,
44    ) -> Result<Self> {
45        if pitch > 127 {
46            return Err(Error::Midi(format!("MIDI pitch out of range: {pitch}")));
47        }
48        if velocity > 127 {
49            return Err(Error::Midi(format!(
50                "MIDI velocity out of range: {velocity}"
51            )));
52        }
53        if channel > 15 {
54            return Err(Error::Midi(format!("MIDI channel out of range: {channel}")));
55        }
56        if !start.is_finite() || start < 0.0 {
57            return Err(Error::Midi(format!("invalid MIDI note start: {start}")));
58        }
59        if !duration.is_finite() || duration < 0.0 {
60            return Err(Error::Midi(format!(
61                "invalid MIDI note duration: {duration}"
62            )));
63        }
64
65        Ok(Self {
66            pitch,
67            velocity,
68            channel,
69            start,
70            duration,
71        })
72    }
73}
74
75/// Extracts MIDI note events from a stream.
76pub fn midi_notes_from_stream(stream: &Stream) -> Result<Vec<MidiNote>> {
77    let mut notes = Vec::new();
78    for event in stream.events() {
79        let start = event.offset();
80        let duration = event.element().quarter_length();
81        match event.element() {
82            StreamElement::Note(note) => {
83                notes.push(note_to_midi_note(note, start, duration)?);
84            }
85            StreamElement::Chord(chord) => {
86                for note in chord.notes() {
87                    notes.push(note_to_midi_note(note, start, duration)?);
88                }
89            }
90            StreamElement::Rest(_) => {}
91        }
92    }
93    Ok(notes)
94}
95
96/// Builds a stream from MIDI note events.
97pub fn stream_from_midi_notes(notes: &[MidiNote]) -> Result<Stream> {
98    let mut stream = Stream::new();
99    for midi_note in notes {
100        let note = Note::from_pitch(Pitch::from_midi(midi_note.pitch as IntegerType)?)?
101            .with_duration(Duration::new(midi_note.duration)?);
102        stream.insert(midi_note.start, note);
103    }
104    Ok(stream)
105}
106
107/// Writes a minimal format-0 Standard MIDI File.
108pub fn write_midi_bytes(notes: &[MidiNote], tempo_bpm: FloatType) -> Result<Vec<u8>> {
109    if !tempo_bpm.is_finite() || tempo_bpm <= 0.0 {
110        return Err(Error::Midi(format!("invalid tempo: {tempo_bpm}")));
111    }
112
113    let mut events = Vec::new();
114    for note in notes {
115        validate_note(*note)?;
116        let start_tick = quarter_to_tick(note.start)?;
117        let end_tick = quarter_to_tick(note.start + note.duration)?;
118        events.push((
119            start_tick,
120            1_u8,
121            [0x90 | note.channel, note.pitch, note.velocity],
122        ));
123        events.push((end_tick, 0_u8, [0x80 | note.channel, note.pitch, 0]));
124    }
125    events.sort_by_key(|event| (event.0, event.1));
126
127    let mut track = Vec::new();
128    write_vlq(0, &mut track);
129    track.extend([0xFF, 0x51, 0x03]);
130    let micros_per_quarter = (60_000_000.0 / tempo_bpm).round() as u32;
131    track.extend([
132        ((micros_per_quarter >> 16) & 0xFF) as u8,
133        ((micros_per_quarter >> 8) & 0xFF) as u8,
134        (micros_per_quarter & 0xFF) as u8,
135    ]);
136
137    let mut last_tick = 0_u32;
138    for (tick, _, bytes) in events {
139        write_vlq(tick.saturating_sub(last_tick), &mut track);
140        track.extend(bytes);
141        last_tick = tick;
142    }
143    write_vlq(0, &mut track);
144    track.extend([0xFF, 0x2F, 0x00]);
145
146    let mut out = Vec::new();
147    out.extend(b"MThd");
148    out.extend(6_u32.to_be_bytes());
149    out.extend(0_u16.to_be_bytes());
150    out.extend(1_u16.to_be_bytes());
151    out.extend(DEFAULT_TICKS_PER_QUARTER.to_be_bytes());
152    out.extend(b"MTrk");
153    out.extend((track.len() as u32).to_be_bytes());
154    out.extend(track);
155    Ok(out)
156}
157
158/// Reads note events from a Standard MIDI File.
159pub fn read_midi_bytes(bytes: &[u8]) -> Result<Vec<MidiNote>> {
160    read_midi_bytes_with_tempo(bytes).map(|(notes, _tempo)| notes)
161}
162
163/// Reads note events and the first tempo marking from a Standard MIDI File.
164pub fn read_midi_bytes_with_tempo(bytes: &[u8]) -> Result<(Vec<MidiNote>, Option<FloatType>)> {
165    let mut pos = 0;
166    expect(bytes, &mut pos, b"MThd")?;
167    let header_len = read_u32(bytes, &mut pos)?;
168    if header_len < 6 {
169        return Err(Error::Midi("MIDI header is too short".to_string()));
170    }
171    let _format = read_u16(bytes, &mut pos)?;
172    let tracks = read_u16(bytes, &mut pos)?;
173    let division = read_u16(bytes, &mut pos)?;
174    pos += (header_len - 6) as usize;
175
176    if division & 0x8000 != 0 {
177        return Err(Error::Midi(
178            "SMPTE MIDI time division is not supported".to_string(),
179        ));
180    }
181
182    let mut all_notes = Vec::new();
183    let mut first_tempo = None;
184    for _ in 0..tracks {
185        expect(bytes, &mut pos, b"MTrk")?;
186        let len = read_u32(bytes, &mut pos)? as usize;
187        let end = pos
188            .checked_add(len)
189            .ok_or_else(|| Error::Midi("MIDI track length overflow".to_string()))?;
190        if end > bytes.len() {
191            return Err(Error::Midi("MIDI track exceeds file length".to_string()));
192        }
193        let (mut notes, tempo) = read_track(&bytes[pos..end], division)?;
194        if first_tempo.is_none() {
195            first_tempo = tempo;
196        }
197        all_notes.append(&mut notes);
198        pos = end;
199    }
200    all_notes.sort_by(|left, right| {
201        left.start
202            .partial_cmp(&right.start)
203            .unwrap_or(std::cmp::Ordering::Equal)
204    });
205    Ok((all_notes, first_tempo))
206}
207
208fn note_to_midi_note(note: &Note, start: FloatType, duration: FloatType) -> Result<MidiNote> {
209    let pitch = note.pitch().ps().round() as IntegerType;
210    if !(0..=127).contains(&pitch) {
211        return Err(Error::Midi(format!("pitch {pitch} is outside MIDI range")));
212    }
213    MidiNote::new(pitch as u8, start, duration, 64)
214}
215
216fn validate_note(note: MidiNote) -> Result<()> {
217    MidiNote::with_channel(
218        note.pitch,
219        note.start,
220        note.duration,
221        note.velocity,
222        note.channel,
223    )
224    .map(|_| ())
225}
226
227fn quarter_to_tick(value: FloatType) -> Result<u32> {
228    if !value.is_finite() || value < 0.0 {
229        return Err(Error::Midi(format!("invalid quarter offset: {value}")));
230    }
231    Ok((value * DEFAULT_TICKS_PER_QUARTER as FloatType).round() as u32)
232}
233
234fn tick_to_quarter(value: u32, division: u16) -> FloatType {
235    value as FloatType / division as FloatType
236}
237
238fn write_vlq(mut value: u32, out: &mut Vec<u8>) {
239    let mut buffer = [0_u8; 5];
240    let mut idx = buffer.len() - 1;
241    buffer[idx] = (value & 0x7F) as u8;
242    value >>= 7;
243    while value > 0 {
244        idx -= 1;
245        buffer[idx] = ((value & 0x7F) as u8) | 0x80;
246        value >>= 7;
247    }
248    out.extend(&buffer[idx..]);
249}
250
251fn read_vlq(bytes: &[u8], pos: &mut usize) -> Result<u32> {
252    let mut value = 0_u32;
253    for _ in 0..4 {
254        let byte = *bytes
255            .get(*pos)
256            .ok_or_else(|| Error::Midi("unexpected end of VLQ".to_string()))?;
257        *pos += 1;
258        value = (value << 7) | (byte & 0x7F) as u32;
259        if byte & 0x80 == 0 {
260            return Ok(value);
261        }
262    }
263    Err(Error::Midi("VLQ is too long".to_string()))
264}
265
266fn read_track(track: &[u8], division: u16) -> Result<(Vec<MidiNote>, Option<FloatType>)> {
267    let mut pos = 0;
268    let mut tick = 0_u32;
269    let mut running_status = None;
270    let mut active: BTreeMap<(u8, u8), Vec<(u32, u8)>> = BTreeMap::new();
271    let mut notes = Vec::new();
272    let mut tempo = None;
273
274    while pos < track.len() {
275        tick = tick.saturating_add(read_vlq(track, &mut pos)?);
276        let byte = *track
277            .get(pos)
278            .ok_or_else(|| Error::Midi("unexpected end of MIDI event".to_string()))?;
279        let status = if byte & 0x80 != 0 {
280            pos += 1;
281            running_status = Some(byte);
282            byte
283        } else {
284            running_status
285                .ok_or_else(|| Error::Midi("running status without status byte".to_string()))?
286        };
287
288        match status {
289            0xFF => {
290                let meta_type = read_byte(track, &mut pos)?;
291                let len = read_vlq(track, &mut pos)? as usize;
292                if pos + len > track.len() {
293                    return Err(Error::Midi("meta event exceeds track length".to_string()));
294                }
295                if meta_type == 0x51 && len == 3 {
296                    let micros = ((track[pos] as u32) << 16)
297                        | ((track[pos + 1] as u32) << 8)
298                        | track[pos + 2] as u32;
299                    tempo = Some(60_000_000.0 / micros as FloatType);
300                } else if meta_type == 0x2F {
301                    break;
302                }
303                pos += len;
304            }
305            0xF0 | 0xF7 => {
306                let len = read_vlq(track, &mut pos)? as usize;
307                pos = pos
308                    .checked_add(len)
309                    .ok_or_else(|| Error::Midi("sysex length overflow".to_string()))?;
310                if pos > track.len() {
311                    return Err(Error::Midi("sysex event exceeds track length".to_string()));
312                }
313            }
314            _ => {
315                let event_type = status & 0xF0;
316                let channel = status & 0x0F;
317                let data_len = match event_type {
318                    0xC0 | 0xD0 => 1,
319                    0x80 | 0x90 | 0xA0 | 0xB0 | 0xE0 => 2,
320                    _ => return Err(Error::Midi(format!("unsupported MIDI status {status:#X}"))),
321                };
322                let data1 = read_byte(track, &mut pos)?;
323                let data2 = if data_len == 2 {
324                    read_byte(track, &mut pos)?
325                } else {
326                    0
327                };
328
329                if event_type == 0x90 && data2 > 0 {
330                    active
331                        .entry((channel, data1))
332                        .or_default()
333                        .push((tick, data2));
334                } else if (event_type == 0x80 || event_type == 0x90)
335                    && let Some(stack) = active.get_mut(&(channel, data1))
336                    && let Some((start_tick, velocity)) = stack.pop()
337                {
338                    notes.push(MidiNote::with_channel(
339                        data1,
340                        tick_to_quarter(start_tick, division),
341                        tick_to_quarter(tick.saturating_sub(start_tick), division),
342                        velocity,
343                        channel,
344                    )?);
345                }
346            }
347        }
348    }
349
350    Ok((notes, tempo))
351}
352
353fn expect(bytes: &[u8], pos: &mut usize, expected: &[u8]) -> Result<()> {
354    if bytes.get(*pos..(*pos).saturating_add(expected.len())) == Some(expected) {
355        *pos += expected.len();
356        Ok(())
357    } else {
358        Err(Error::Midi(format!(
359            "expected MIDI chunk {:?}",
360            String::from_utf8_lossy(expected)
361        )))
362    }
363}
364
365fn read_byte(bytes: &[u8], pos: &mut usize) -> Result<u8> {
366    let byte = *bytes
367        .get(*pos)
368        .ok_or_else(|| Error::Midi("unexpected end of MIDI data".to_string()))?;
369    *pos += 1;
370    Ok(byte)
371}
372
373fn read_u16(bytes: &[u8], pos: &mut usize) -> Result<u16> {
374    let start = *pos;
375    *pos += 2;
376    let data = bytes
377        .get(start..*pos)
378        .ok_or_else(|| Error::Midi("unexpected end of MIDI u16".to_string()))?;
379    Ok(u16::from_be_bytes([data[0], data[1]]))
380}
381
382fn read_u32(bytes: &[u8], pos: &mut usize) -> Result<u32> {
383    let start = *pos;
384    *pos += 4;
385    let data = bytes
386        .get(start..*pos)
387        .ok_or_else(|| Error::Midi("unexpected end of MIDI u32".to_string()))?;
388    Ok(u32::from_be_bytes([data[0], data[1], data[2], data[3]]))
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn midi_roundtrip_bytes() {
397        let notes = vec![MidiNote::new(60, 0.0, 1.0, 90).unwrap()];
398        let bytes = write_midi_bytes(&notes, 120.0).unwrap();
399        assert!(bytes.starts_with(b"MThd"));
400        let (roundtrip, tempo) = read_midi_bytes_with_tempo(&bytes).unwrap();
401        assert_eq!(tempo, Some(120.0));
402        assert_eq!(roundtrip, notes);
403    }
404
405    #[test]
406    fn stream_converts_to_midi_notes() {
407        let mut stream = Stream::new();
408        stream.push(
409            Note::from_name("C4")
410                .unwrap()
411                .with_duration(Duration::half()),
412        );
413        let notes = midi_notes_from_stream(&stream).unwrap();
414        assert_eq!(notes[0].pitch, 60);
415        assert_eq!(notes[0].duration, 2.0);
416    }
417
418    #[test]
419    fn midi_note_validation_rejects_invalid_values() {
420        assert!(MidiNote::with_channel(128, 0.0, 1.0, 64, 0).is_err());
421        assert!(MidiNote::with_channel(60, 0.0, 1.0, 128, 0).is_err());
422        assert!(MidiNote::with_channel(60, 0.0, 1.0, 64, 16).is_err());
423        assert!(MidiNote::with_channel(60, -0.25, 1.0, 64, 0).is_err());
424        assert!(MidiNote::with_channel(60, 0.0, FloatType::INFINITY, 64, 0).is_err());
425        assert!(write_midi_bytes(&[], 0.0).is_err());
426    }
427
428    #[test]
429    fn midi_reader_rejects_invalid_files() {
430        assert!(read_midi_bytes(b"not midi").is_err());
431
432        let mut short_header = Vec::new();
433        short_header.extend(b"MThd");
434        short_header.extend(4_u32.to_be_bytes());
435        short_header.extend([0, 0, 0, 1]);
436        assert!(read_midi_bytes(&short_header).is_err());
437
438        let mut smpte = Vec::new();
439        smpte.extend(b"MThd");
440        smpte.extend(6_u32.to_be_bytes());
441        smpte.extend(0_u16.to_be_bytes());
442        smpte.extend(0_u16.to_be_bytes());
443        smpte.extend(0x8000_u16.to_be_bytes());
444        assert!(read_midi_bytes(&smpte).is_err());
445    }
446}