@@ -9,102 +9,67 @@ import (
99)
1010
1111func TestGenerateModelsIndexFromFile (t * testing.T ) {
12- testdataPath := paths .New ("testdata" )
13-
14- t .Run ("Valid Model list" , func (t * testing.T ) {
15- modelsIndex , err := GenerateModelsIndexFromFile (testdataPath )
12+ t .Run ("it parses a valid model-list.yaml" , func (t * testing.T ) {
13+ modelsIndex , err := GenerateModelsIndexFromFile (paths .New ("testdata" ))
1614 require .NoError (t , err )
1715 require .NotNil (t , modelsIndex )
1816
1917 models := modelsIndex .GetModels ()
20- assert .Len (t , models , 3 , "Expected 3 models to be parsed" )
21-
22- // Test first model
23- model1 , found := modelsIndex .GetModelByID ("face-detection" )
24- assert .Equal (t , "brick" , model1 .Runner )
25- require .True (t , found , "face-detection should be found" )
26- assert .Equal (t , "face-detection:" , model1 .ID )
27- assert .Equal (t , "Lightweight-Face-Detection" , model1 .Name )
28- assert .Equal (t , "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images." , model1 .ModuleDescription )
29- assert .Equal (t , []string {"arduino:object_detection" , "arduino:video_object_detection" }, model1 .L )
30- assert .Equal (t , []string {"arduino:object_detection" , "arduino:video_object_detection" }, model1 .Bricks )
31- assert .Equal (t , "1.0.0" , model1 .Metadata ["version" ])
32- assert .Equal (t , "Test Author" , model1 .Metadata ["author" ])
33- assert .Equal (t , "1000" , model1 .ModelConfiguration ["max_tokens" ])
34- assert .Equal (t , "0.7" , model1 .ModelConfiguration ["temperature" ])
18+ assert .Len (t , models , 2 , "Expected 2 models to be parsed" )
19+ })
3520
36- // // Test second model
37- // model2, found := modelsIndex.GetModelByID("test_model_2")
38- // // require.True(t, found, "test_model_2 should be found")
39- // // assert.Equal(t, "test_model_2", model2.ID)
40- // // assert.Equal(t, "Test Model 2", model2.Name)
41- // // assert.Equal(t, "Another test AI model", model2.ModuleDescription)
42- // // assert.Equal(t, "another_runner", model2.Runner)
43- // // assert.Equal(t, []string{"brick2", "brick3"}, model2.Bricks)
44- // // assert.Equal(t, "2.0.0", model2.Metadata["version"])
45- // // assert.Equal(t, "MIT", model2.Metadata["license"])
21+ t .Run ("it gets a model by ID" , func (t * testing.T ) {
22+ modelsIndex , err := GenerateModelsIndexFromFile (paths .New ("testdata" ))
23+ require .NoError (t , err )
4624
47- // // Test minimal model
48- // model3, found := modelsIndex.GetModelByID("minimal_model")
49- // require.True(t, found, "minimal_model should be found")
50- // assert.Equal(t, "minimal_model", model3.ID)
51- // assert.Equal(t, "Minimal Model", model3.Name)
52- // assert.Equal(t, "Minimal model with no optional fields", model3.ModuleDescription)
53- // assert.Equal(t, "minimal_runner", model3.Runner)
54- // assert.Empty(t, model3.Bricks)
55- // assert.Empty(t, model3.Metadata)
56- // assert.Empty(t, model3.ModelConfiguration)
25+ model , found := modelsIndex .GetModelByID ("face-detection" )
26+ assert .Equal (t , "brick" , model .Runner )
27+ require .True (t , found , "face-detection should be found" )
28+ assert .Equal (t , "face-detection" , model .ID )
29+ assert .Equal (t , "Lightweight-Face-Detection" , model .Name )
30+ assert .Equal (t , "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images." , model .ModuleDescription )
31+ assert .Equal (t , []string {"face" }, model .ModelLabels )
32+ assert .Equal (t , "/models/ootb/ei/lw-face-det.eim" , model .ModelConfiguration ["EI_OBJ_DETECTION_MODEL" ])
33+ assert .Equal (t , []string {"arduino:object_detection" , "arduino:video_object_detection" }, model .Bricks )
34+ assert .Equal (t , "qualcomm-ai-hub" , model .Metadata ["source" ])
35+ assert .Equal (t , "false" , model .Metadata ["ei-gpu-mode" ])
36+ assert .Equal (t , "face-det-lite" , model .Metadata ["source-model-id" ])
37+ assert .Equal (t , "https://aihub.qualcomm.com/models/face_det_lite" , model .Metadata ["source-model-url" ])
5738 })
5839
59- // Test file not found error
60- t .Run ("FileNotFound" , func (t * testing.T ) {
61- nonExistentPath := paths .New ("nonexistent" )
40+ t .Run ("it fails if model-list.yaml does not exist" , func (t * testing.T ) {
41+ nonExistentPath := paths .New ("nonexistent.yaml" )
6242 modelsIndex , err := GenerateModelsIndexFromFile (nonExistentPath )
6343 assert .Error (t , err )
6444 assert .Nil (t , modelsIndex )
6545 })
6646
67- // Test invalid YAML parsing
68- t .Run ("InvalidYAML" , func (t * testing.T ) {
69- // Create a temporary invalid YAML file
70- invalidPath := testdataPath .Join ("invalid-models.yaml" )
47+ t .Run ("it filters models by a single brick" , func (t * testing.T ) {
48+ modelsIndex , err := GenerateModelsIndexFromFile (paths .New ("testdata" ))
49+ require .NoError (t , err )
50+
51+ brick1Models := modelsIndex .GetModelsByBrick ("arduino:object_detection" )
52+ assert .Len (t , brick1Models , 1 )
53+ assert .Equal (t , "face-detection" , brick1Models [0 ].ID )
7154
72- // We expect this to either fail parsing or handle gracefully
73- // Since the current implementation may be lenient with missing fields
74- modelsIndex , err := GenerateModelsIndexFromFile (testdataPath .Parent ().Join ("testdata-invalid" ))
75- if err != nil {
76- // If it fails, that's expected for invalid files
77- assert .Error (t , err )
78- assert .Nil (t , modelsIndex )
79- }
80- // Note: Some invalid YAML might still parse successfully depending on the YAML library's behavior
81- _ = invalidPath // Avoid unused variable warning
55+ brick1Models = modelsIndex .GetModelsByBrick ("not-existing-brick" )
56+ assert .Nil (t , brick1Models )
8257 })
8358
84- // Test brick filtering functionality
85- t .Run ("BrickFiltering" , func (t * testing.T ) {
86- modelsIndex , err := GenerateModelsIndexFromFile (testdataPath )
59+ t .Run ("it filters models by multiple bricks" , func (t * testing.T ) {
60+ modelsIndex , err := GenerateModelsIndexFromFile (paths .New ("testdata" ))
8761 require .NoError (t , err )
8862
89- // Test GetModelsByBrick
90- brick1Models := modelsIndex .GetModelsByBrick ("brick1" )
91- assert .Len (t , brick1Models , 1 )
92- assert .Equal (t , "test_model_1" , brick1Models [0 ].ID )
93-
94- brick2Models := modelsIndex .GetModelsByBrick ("brick2" )
63+ brick2Models := modelsIndex .GetModelsByBrick ("arduino:video_object_detection" )
9564 assert .Len (t , brick2Models , 2 )
96- modelIDs := []string {brick2Models [0 ].ID , brick2Models [1 ].ID }
97- assert .Contains (t , modelIDs , "test_model_1" )
98- assert .Contains (t , modelIDs , "test_model_2" )
65+ assert .Equal (t , "face-detection" , brick2Models [0 ].ID )
66+ assert .Equal (t , "yolox-object-detection" , brick2Models [1 ].ID )
9967
100- // Test GetModelsByBricks
101- multiModels := modelsIndex .GetModelsByBricks ([]string {"brick1" , "brick3" })
102- assert .Len (t , multiModels , 2 )
103- multiModelIDs := []string {multiModels [0 ].ID , multiModels [1 ].ID }
104- assert .Contains (t , multiModelIDs , "test_model_1" )
105- assert .Contains (t , multiModelIDs , "test_model_2" )
68+ bricks2Models := modelsIndex .GetModelsByBricks ([]string {"arduino:object_detection" , "arduino:video_object_detection" })
69+ assert .Len (t , bricks2Models , 2 )
70+ assert .Equal (t , "face-detection" , bricks2Models [0 ].ID )
71+ assert .Equal (t , "yolox-object-detection" , bricks2Models [1 ].ID )
10672
107- // Test non-existent brick
10873 nonExistentModels := modelsIndex .GetModelsByBrick ("nonexistent_brick" )
10974 assert .Nil (t , nonExistentModels )
11075 })
0 commit comments