1- import os
2- import tempfile
31from unittest import TestCase
42
5- import numpy as np
63import pytest
74
8- from datasets import Dataset , Features , load_dataset
9- from datasets .features import Midi , Value
5+ from datasets import Dataset , Features
6+ from datasets .features import Midi
107
118
129class TestMidiFeature (TestCase ):
@@ -17,15 +14,15 @@ def test_audio_feature_type(self):
1714
1815 def test_audio_feature_encode_example (self ):
1916 midi = Midi ()
20-
17+
2118 # Test with path
2219 encoded = midi .encode_example ("path/to/midi.mid" )
2320 assert encoded == {"bytes" : None , "path" : "path/to/midi.mid" }
24-
21+
2522 # Test with bytes
2623 encoded = midi .encode_example (b"fake_midi_bytes" )
2724 assert encoded == {"bytes" : b"fake_midi_bytes" , "path" : None }
28-
25+
2926 # Test with dict containing notes
3027 notes_data = {
3128 "notes" : [[60 , 64 , 0.0 , 1.0 ], [62 , 64 , 1.0 , 2.0 ]],
@@ -38,7 +35,7 @@ def test_audio_feature_encode_example(self):
3835
3936 def test_audio_feature_decode_example (self ):
4037 midi = Midi ()
41-
38+
4239 # Test decode with bytes
4340 fake_midi_bytes = b'MThd\x00 \x00 \x00 \x06 \x00 \x01 \x00 \x02 \x00 \xdc MTrk\x00 \x00 \x00 \x13 \x00 \xff Q\x03 \x07 \xa1 \x00 \xff X\x04 \x04 \x02 \x18 \x08 \x01 \xff /\x00 MTrk\x00 \x00 \x00 \x16 \x00 \xc0 \x00 \x00 \x90 <@\x83 8<\x00 \x00 >@\x83 8>\x00 \x01 \xff /\x00 '
4441 decoded = midi .decode_example ({"bytes" : fake_midi_bytes , "path" : None })
@@ -50,11 +47,10 @@ def test_audio_feature_decode_example(self):
5047 def test_audio_feature_with_dataset (self ):
5148 features = Features ({"midi" : Midi ()})
5249 data = {"midi" : ["fake_path1.mid" , "fake_path2.mid" ]}
53-
54- with tempfile .TemporaryDirectory () as tmp_dir :
55- dataset = Dataset .from_dict (data , features = features )
56- assert "midi" in dataset .column_names
57- assert dataset .features ["midi" ].dtype == "dict"
50+
51+ dataset = Dataset .from_dict (data , features = features )
52+ assert "midi" in dataset .column_names
53+ assert dataset .features ["midi" ].dtype == "dict"
5854
5955 def test_audio_feature_decode_false (self ):
6056 midi = Midi (decode = False )
@@ -68,10 +64,10 @@ def test_audio_feature_resolution(self):
6864 def test_audio_feature_flatten (self ):
6965 midi = Midi (decode = False )
7066 flattened = midi .flatten ()
71- assert "bytes" in flattened
72- assert "path" in flattened
67+ assert "bytes" in flattened # type: ignore
68+ assert "path" in flattened # type: ignore
7369
7470 def test_audio_feature_decode_error (self ):
7571 midi = Midi (decode = False )
7672 with pytest .raises (RuntimeError ):
77- midi .decode_example ({"bytes" : b"fake" , "path" : None })
73+ midi .decode_example ({"bytes" : b"fake" , "path" : None })
0 commit comments