1use std::collections::BTreeMap;
2
3use crate::{
4 defaults::{FloatType, IntegerType},
5 duration::Duration,
6 error::{Error, Result},
7 note::Note,
8 pitch::Pitch,
9 stream::{Stream, StreamElement},
10};
11
12pub const DEFAULT_TICKS_PER_QUARTER: u16 = 480;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18pub struct MidiNote {
19 pub pitch: u8,
21 pub velocity: u8,
23 pub channel: u8,
25 pub start: FloatType,
27 pub duration: FloatType,
29}
30
31impl MidiNote {
32 pub fn new(pitch: u8, start: FloatType, duration: FloatType, velocity: u8) -> Result<Self> {
34 Self::with_channel(pitch, start, duration, velocity, 0)
35 }
36
37 pub fn with_channel(
39 pitch: u8,
40 start: FloatType,
41 duration: FloatType,
42 velocity: u8,
43 channel: u8,
44 ) -> Result<Self> {
45 if pitch > 127 {
46 return Err(Error::Midi(format!("MIDI pitch out of range: {pitch}")));
47 }
48 if velocity > 127 {
49 return Err(Error::Midi(format!(
50 "MIDI velocity out of range: {velocity}"
51 )));
52 }
53 if channel > 15 {
54 return Err(Error::Midi(format!("MIDI channel out of range: {channel}")));
55 }
56 if !start.is_finite() || start < 0.0 {
57 return Err(Error::Midi(format!("invalid MIDI note start: {start}")));
58 }
59 if !duration.is_finite() || duration < 0.0 {
60 return Err(Error::Midi(format!(
61 "invalid MIDI note duration: {duration}"
62 )));
63 }
64
65 Ok(Self {
66 pitch,
67 velocity,
68 channel,
69 start,
70 duration,
71 })
72 }
73}
74
75pub fn midi_notes_from_stream(stream: &Stream) -> Result<Vec<MidiNote>> {
77 let mut notes = Vec::new();
78 for event in stream.events() {
79 let start = event.offset();
80 let duration = event.element().quarter_length();
81 match event.element() {
82 StreamElement::Note(note) => {
83 notes.push(note_to_midi_note(note, start, duration)?);
84 }
85 StreamElement::Chord(chord) => {
86 for note in chord.notes() {
87 notes.push(note_to_midi_note(note, start, duration)?);
88 }
89 }
90 StreamElement::Rest(_) => {}
91 }
92 }
93 Ok(notes)
94}
95
96pub fn stream_from_midi_notes(notes: &[MidiNote]) -> Result<Stream> {
98 let mut stream = Stream::new();
99 for midi_note in notes {
100 let note = Note::from_pitch(Pitch::from_midi(midi_note.pitch as IntegerType)?)?
101 .with_duration(Duration::new(midi_note.duration)?);
102 stream.insert(midi_note.start, note);
103 }
104 Ok(stream)
105}
106
107pub fn write_midi_bytes(notes: &[MidiNote], tempo_bpm: FloatType) -> Result<Vec<u8>> {
109 if !tempo_bpm.is_finite() || tempo_bpm <= 0.0 {
110 return Err(Error::Midi(format!("invalid tempo: {tempo_bpm}")));
111 }
112
113 let mut events = Vec::new();
114 for note in notes {
115 validate_note(*note)?;
116 let start_tick = quarter_to_tick(note.start)?;
117 let end_tick = quarter_to_tick(note.start + note.duration)?;
118 events.push((
119 start_tick,
120 1_u8,
121 [0x90 | note.channel, note.pitch, note.velocity],
122 ));
123 events.push((end_tick, 0_u8, [0x80 | note.channel, note.pitch, 0]));
124 }
125 events.sort_by_key(|event| (event.0, event.1));
126
127 let mut track = Vec::new();
128 write_vlq(0, &mut track);
129 track.extend([0xFF, 0x51, 0x03]);
130 let micros_per_quarter = (60_000_000.0 / tempo_bpm).round() as u32;
131 track.extend([
132 ((micros_per_quarter >> 16) & 0xFF) as u8,
133 ((micros_per_quarter >> 8) & 0xFF) as u8,
134 (micros_per_quarter & 0xFF) as u8,
135 ]);
136
137 let mut last_tick = 0_u32;
138 for (tick, _, bytes) in events {
139 write_vlq(tick.saturating_sub(last_tick), &mut track);
140 track.extend(bytes);
141 last_tick = tick;
142 }
143 write_vlq(0, &mut track);
144 track.extend([0xFF, 0x2F, 0x00]);
145
146 let mut out = Vec::new();
147 out.extend(b"MThd");
148 out.extend(6_u32.to_be_bytes());
149 out.extend(0_u16.to_be_bytes());
150 out.extend(1_u16.to_be_bytes());
151 out.extend(DEFAULT_TICKS_PER_QUARTER.to_be_bytes());
152 out.extend(b"MTrk");
153 out.extend((track.len() as u32).to_be_bytes());
154 out.extend(track);
155 Ok(out)
156}
157
158pub fn read_midi_bytes(bytes: &[u8]) -> Result<Vec<MidiNote>> {
160 read_midi_bytes_with_tempo(bytes).map(|(notes, _tempo)| notes)
161}
162
163pub fn read_midi_bytes_with_tempo(bytes: &[u8]) -> Result<(Vec<MidiNote>, Option<FloatType>)> {
165 let mut pos = 0;
166 expect(bytes, &mut pos, b"MThd")?;
167 let header_len = read_u32(bytes, &mut pos)?;
168 if header_len < 6 {
169 return Err(Error::Midi("MIDI header is too short".to_string()));
170 }
171 let _format = read_u16(bytes, &mut pos)?;
172 let tracks = read_u16(bytes, &mut pos)?;
173 let division = read_u16(bytes, &mut pos)?;
174 pos += (header_len - 6) as usize;
175
176 if division & 0x8000 != 0 {
177 return Err(Error::Midi(
178 "SMPTE MIDI time division is not supported".to_string(),
179 ));
180 }
181
182 let mut all_notes = Vec::new();
183 let mut first_tempo = None;
184 for _ in 0..tracks {
185 expect(bytes, &mut pos, b"MTrk")?;
186 let len = read_u32(bytes, &mut pos)? as usize;
187 let end = pos
188 .checked_add(len)
189 .ok_or_else(|| Error::Midi("MIDI track length overflow".to_string()))?;
190 if end > bytes.len() {
191 return Err(Error::Midi("MIDI track exceeds file length".to_string()));
192 }
193 let (mut notes, tempo) = read_track(&bytes[pos..end], division)?;
194 if first_tempo.is_none() {
195 first_tempo = tempo;
196 }
197 all_notes.append(&mut notes);
198 pos = end;
199 }
200 all_notes.sort_by(|left, right| {
201 left.start
202 .partial_cmp(&right.start)
203 .unwrap_or(std::cmp::Ordering::Equal)
204 });
205 Ok((all_notes, first_tempo))
206}
207
208fn note_to_midi_note(note: &Note, start: FloatType, duration: FloatType) -> Result<MidiNote> {
209 let pitch = note.pitch().ps().round() as IntegerType;
210 if !(0..=127).contains(&pitch) {
211 return Err(Error::Midi(format!("pitch {pitch} is outside MIDI range")));
212 }
213 MidiNote::new(pitch as u8, start, duration, 64)
214}
215
216fn validate_note(note: MidiNote) -> Result<()> {
217 MidiNote::with_channel(
218 note.pitch,
219 note.start,
220 note.duration,
221 note.velocity,
222 note.channel,
223 )
224 .map(|_| ())
225}
226
227fn quarter_to_tick(value: FloatType) -> Result<u32> {
228 if !value.is_finite() || value < 0.0 {
229 return Err(Error::Midi(format!("invalid quarter offset: {value}")));
230 }
231 Ok((value * DEFAULT_TICKS_PER_QUARTER as FloatType).round() as u32)
232}
233
234fn tick_to_quarter(value: u32, division: u16) -> FloatType {
235 value as FloatType / division as FloatType
236}
237
238fn write_vlq(mut value: u32, out: &mut Vec<u8>) {
239 let mut buffer = [0_u8; 5];
240 let mut idx = buffer.len() - 1;
241 buffer[idx] = (value & 0x7F) as u8;
242 value >>= 7;
243 while value > 0 {
244 idx -= 1;
245 buffer[idx] = ((value & 0x7F) as u8) | 0x80;
246 value >>= 7;
247 }
248 out.extend(&buffer[idx..]);
249}
250
251fn read_vlq(bytes: &[u8], pos: &mut usize) -> Result<u32> {
252 let mut value = 0_u32;
253 for _ in 0..4 {
254 let byte = *bytes
255 .get(*pos)
256 .ok_or_else(|| Error::Midi("unexpected end of VLQ".to_string()))?;
257 *pos += 1;
258 value = (value << 7) | (byte & 0x7F) as u32;
259 if byte & 0x80 == 0 {
260 return Ok(value);
261 }
262 }
263 Err(Error::Midi("VLQ is too long".to_string()))
264}
265
266fn read_track(track: &[u8], division: u16) -> Result<(Vec<MidiNote>, Option<FloatType>)> {
267 let mut pos = 0;
268 let mut tick = 0_u32;
269 let mut running_status = None;
270 let mut active: BTreeMap<(u8, u8), Vec<(u32, u8)>> = BTreeMap::new();
271 let mut notes = Vec::new();
272 let mut tempo = None;
273
274 while pos < track.len() {
275 tick = tick.saturating_add(read_vlq(track, &mut pos)?);
276 let byte = *track
277 .get(pos)
278 .ok_or_else(|| Error::Midi("unexpected end of MIDI event".to_string()))?;
279 let status = if byte & 0x80 != 0 {
280 pos += 1;
281 running_status = Some(byte);
282 byte
283 } else {
284 running_status
285 .ok_or_else(|| Error::Midi("running status without status byte".to_string()))?
286 };
287
288 match status {
289 0xFF => {
290 let meta_type = read_byte(track, &mut pos)?;
291 let len = read_vlq(track, &mut pos)? as usize;
292 if pos + len > track.len() {
293 return Err(Error::Midi("meta event exceeds track length".to_string()));
294 }
295 if meta_type == 0x51 && len == 3 {
296 let micros = ((track[pos] as u32) << 16)
297 | ((track[pos + 1] as u32) << 8)
298 | track[pos + 2] as u32;
299 tempo = Some(60_000_000.0 / micros as FloatType);
300 } else if meta_type == 0x2F {
301 break;
302 }
303 pos += len;
304 }
305 0xF0 | 0xF7 => {
306 let len = read_vlq(track, &mut pos)? as usize;
307 pos = pos
308 .checked_add(len)
309 .ok_or_else(|| Error::Midi("sysex length overflow".to_string()))?;
310 if pos > track.len() {
311 return Err(Error::Midi("sysex event exceeds track length".to_string()));
312 }
313 }
314 _ => {
315 let event_type = status & 0xF0;
316 let channel = status & 0x0F;
317 let data_len = match event_type {
318 0xC0 | 0xD0 => 1,
319 0x80 | 0x90 | 0xA0 | 0xB0 | 0xE0 => 2,
320 _ => return Err(Error::Midi(format!("unsupported MIDI status {status:#X}"))),
321 };
322 let data1 = read_byte(track, &mut pos)?;
323 let data2 = if data_len == 2 {
324 read_byte(track, &mut pos)?
325 } else {
326 0
327 };
328
329 if event_type == 0x90 && data2 > 0 {
330 active
331 .entry((channel, data1))
332 .or_default()
333 .push((tick, data2));
334 } else if (event_type == 0x80 || event_type == 0x90)
335 && let Some(stack) = active.get_mut(&(channel, data1))
336 && let Some((start_tick, velocity)) = stack.pop()
337 {
338 notes.push(MidiNote::with_channel(
339 data1,
340 tick_to_quarter(start_tick, division),
341 tick_to_quarter(tick.saturating_sub(start_tick), division),
342 velocity,
343 channel,
344 )?);
345 }
346 }
347 }
348 }
349
350 Ok((notes, tempo))
351}
352
353fn expect(bytes: &[u8], pos: &mut usize, expected: &[u8]) -> Result<()> {
354 if bytes.get(*pos..(*pos).saturating_add(expected.len())) == Some(expected) {
355 *pos += expected.len();
356 Ok(())
357 } else {
358 Err(Error::Midi(format!(
359 "expected MIDI chunk {:?}",
360 String::from_utf8_lossy(expected)
361 )))
362 }
363}
364
365fn read_byte(bytes: &[u8], pos: &mut usize) -> Result<u8> {
366 let byte = *bytes
367 .get(*pos)
368 .ok_or_else(|| Error::Midi("unexpected end of MIDI data".to_string()))?;
369 *pos += 1;
370 Ok(byte)
371}
372
373fn read_u16(bytes: &[u8], pos: &mut usize) -> Result<u16> {
374 let start = *pos;
375 *pos += 2;
376 let data = bytes
377 .get(start..*pos)
378 .ok_or_else(|| Error::Midi("unexpected end of MIDI u16".to_string()))?;
379 Ok(u16::from_be_bytes([data[0], data[1]]))
380}
381
382fn read_u32(bytes: &[u8], pos: &mut usize) -> Result<u32> {
383 let start = *pos;
384 *pos += 4;
385 let data = bytes
386 .get(start..*pos)
387 .ok_or_else(|| Error::Midi("unexpected end of MIDI u32".to_string()))?;
388 Ok(u32::from_be_bytes([data[0], data[1], data[2], data[3]]))
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn midi_roundtrip_bytes() {
397 let notes = vec![MidiNote::new(60, 0.0, 1.0, 90).unwrap()];
398 let bytes = write_midi_bytes(¬es, 120.0).unwrap();
399 assert!(bytes.starts_with(b"MThd"));
400 let (roundtrip, tempo) = read_midi_bytes_with_tempo(&bytes).unwrap();
401 assert_eq!(tempo, Some(120.0));
402 assert_eq!(roundtrip, notes);
403 }
404
405 #[test]
406 fn stream_converts_to_midi_notes() {
407 let mut stream = Stream::new();
408 stream.push(
409 Note::from_name("C4")
410 .unwrap()
411 .with_duration(Duration::half()),
412 );
413 let notes = midi_notes_from_stream(&stream).unwrap();
414 assert_eq!(notes[0].pitch, 60);
415 assert_eq!(notes[0].duration, 2.0);
416 }
417
418 #[test]
419 fn midi_note_validation_rejects_invalid_values() {
420 assert!(MidiNote::with_channel(128, 0.0, 1.0, 64, 0).is_err());
421 assert!(MidiNote::with_channel(60, 0.0, 1.0, 128, 0).is_err());
422 assert!(MidiNote::with_channel(60, 0.0, 1.0, 64, 16).is_err());
423 assert!(MidiNote::with_channel(60, -0.25, 1.0, 64, 0).is_err());
424 assert!(MidiNote::with_channel(60, 0.0, FloatType::INFINITY, 64, 0).is_err());
425 assert!(write_midi_bytes(&[], 0.0).is_err());
426 }
427
428 #[test]
429 fn midi_reader_rejects_invalid_files() {
430 assert!(read_midi_bytes(b"not midi").is_err());
431
432 let mut short_header = Vec::new();
433 short_header.extend(b"MThd");
434 short_header.extend(4_u32.to_be_bytes());
435 short_header.extend([0, 0, 0, 1]);
436 assert!(read_midi_bytes(&short_header).is_err());
437
438 let mut smpte = Vec::new();
439 smpte.extend(b"MThd");
440 smpte.extend(6_u32.to_be_bytes());
441 smpte.extend(0_u16.to_be_bytes());
442 smpte.extend(0_u16.to_be_bytes());
443 smpte.extend(0x8000_u16.to_be_bytes());
444 assert!(read_midi_bytes(&smpte).is_err());
445 }
446}