Skip to content

Commit 089460b

Browse files
committed
first version working of SVMSMO + color training example
1 parent ccde933 commit 089460b

File tree

4 files changed

+402
-0
lines changed

4 files changed

+402
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include <EloquentSVMSMO.h>
2+
#include "RGB.h"
3+
4+
#define MAX_TRAINING_SAMPLES 20
5+
#define FEATURES_DIM 3
6+
7+
8+
int numSamples;
9+
RGB rgb(2, 3, 4);
10+
float X_train[MAX_TRAINING_SAMPLES][FEATURES_DIM];
11+
int y_train[MAX_TRAINING_SAMPLES];
12+
Eloquent::TinyML::SVMSMO<FEATURES_DIM> classifier(linearKernel);
13+
14+
15+
void setup() {
16+
Serial.begin(115200);
17+
rgb.begin();
18+
19+
classifier.setC(5);
20+
classifier.setTol(1e-5);
21+
classifier.setMaxIter(10000);
22+
}
23+
24+
void loop() {
25+
if (!Serial.available()) {
26+
delay(100);
27+
return;
28+
}
29+
30+
String command = Serial.readStringUntil('\n');
31+
32+
if (command == "help") {
33+
Serial.println("Available commands:");
34+
Serial.println("\tfit: train the classifier on a new set of samples");
35+
Serial.println("\tpredict: classify a new sample");
36+
Serial.println("\tinspect: print X_train and y_train");
37+
}
38+
else if (command == "fit") {
39+
Serial.print("How many samples will you record? ");
40+
numSamples = readSerialNumber();
41+
42+
Serial.print("You will record ");
43+
Serial.print(numSamples);
44+
Serial.println(" samples");
45+
46+
for (int i = 0; i < numSamples; i++) {
47+
Serial.println("Which class does the sample belongs to, 1 or -1?");
48+
y_train[i] = readSerialNumber() > 0 ? 1 : -1;
49+
getFeatures(X_train[i]);
50+
}
51+
52+
Serial.print("Start training... ");
53+
classifier.fit(X_train, y_train, numSamples);
54+
Serial.println("Done");
55+
}
56+
else if (command == "predict") {
57+
int label;
58+
float x[FEATURES_DIM];
59+
60+
getFeatures(x);
61+
Serial.print("Predicted label is ");
62+
Serial.println(classifier.predict(X_train, x));
63+
}
64+
else if (command == "inspect") {
65+
for (int i = 0; i < numSamples; i++) {
66+
Serial.print("[");
67+
Serial.print(y_train[i]);
68+
Serial.print("] ");
69+
70+
for (int j = 0; j < FEATURES_DIM; j++) {
71+
Serial.print(X_train[i][j]);
72+
Serial.print(", ");
73+
}
74+
75+
Serial.println();
76+
}
77+
}
78+
}
79+
80+
/**
81+
*
82+
* @return
83+
*/
84+
int readSerialNumber() {
85+
while (!Serial.available()) delay(1);
86+
87+
return Serial.readStringUntil('\n').toInt();
88+
}
89+
90+
/**
91+
* Get features for new sample
92+
* @param x
93+
*/
94+
void getFeatures(float x[FEATURES_DIM]) {
95+
rgb.read(x);
96+
97+
for (int i = 0; i < FEATURES_DIM; i++) {
98+
Serial.print(x[i]);
99+
Serial.print(", ");
100+
}
101+
102+
Serial.println();
103+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#pragma once
2+
3+
/**
4+
* Wrapper for RGB color sensor
5+
*/
6+
class RGB {
7+
public:
8+
RGB(uint8_t s2, uint8_t s3, uint8_t out) :
9+
_s2(s2),
10+
_s3(s3),
11+
_out(out) {
12+
13+
}
14+
15+
/**
16+
*
17+
*/
18+
void begin() {
19+
pinMode(_s2, OUTPUT);
20+
pinMode(_s3, OUTPUT);
21+
pinMode(_out, INPUT);
22+
}
23+
24+
/**
25+
*
26+
* @param x
27+
*/
28+
void read(float x[3]) {
29+
x[0] = readComponent(LOW, LOW);
30+
x[1] = readComponent(HIGH, HIGH);
31+
x[2] = readComponent(LOW, HIGH);
32+
}
33+
34+
protected:
35+
uint8_t _s2;
36+
uint8_t _s3;
37+
uint8_t _out;
38+
39+
/**
40+
*
41+
* @param s2
42+
* @param s3
43+
* @return
44+
*/
45+
int readComponent(bool s2, bool s3) {
46+
delay(10);
47+
digitalWrite(_s2, s2);
48+
digitalWrite(_s3, s3);
49+
50+
return pulseIn(_out, LOW);
51+
}
52+
};

src/EloquentSVMSMO.h

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#pragma once
2+
3+
#include "kernels.h"
4+
5+
6+
namespace Eloquent {
7+
namespace TinyML {
8+
9+
/**
10+
*
11+
* @tparam D
12+
*/
13+
template<unsigned int D>
14+
class SVMSMO {
15+
public:
16+
SVMSMO(kernelFunction kernel) :
17+
_kernel(kernel) {
18+
_params = {
19+
.C = 1,
20+
.tol = 1e-4,
21+
.alphaTol = 1e-7,
22+
.maxIter = 10000,
23+
.passes = 10
24+
};
25+
}
26+
27+
/**
28+
*
29+
* @param C
30+
*/
31+
void setC(float C) {
32+
_params.C = C;
33+
}
34+
35+
/**
36+
*
37+
* @param tol
38+
*/
39+
void setTol(float tol) {
40+
_params.tol = tol;
41+
}
42+
43+
/**
44+
*
45+
* @param alphaTol
46+
*/
47+
void setAlphaTol(float alphaTol) {
48+
_params.alphaTol = alphaTol;
49+
}
50+
51+
/**
52+
*
53+
* @param maxIter
54+
*/
55+
void setMaxIter(unsigned int maxIter) {
56+
_params.maxIter = maxIter;
57+
}
58+
59+
/**
60+
*
61+
* @param passes
62+
*/
63+
void setPasses(unsigned int passes) {
64+
_params.passes = passes;
65+
}
66+
67+
/**
68+
*
69+
* @param X
70+
* @param y
71+
* @param N num samples
72+
*/
73+
void fit(float X[][D], int *y, unsigned int N) {
74+
_alphas = (float *) malloc(sizeof(float) * N);
75+
76+
for (unsigned int i = 0; i < N; i++)
77+
_alphas[i] = 0;
78+
79+
unsigned int iter = 0;
80+
unsigned int passes = 0;
81+
82+
while(passes < _params.passes && iter < _params.maxIter) {
83+
float alphaChanged = 0;
84+
85+
for (unsigned int i = 0; i < N; i++) {
86+
float Ei = margin(X, y, X[i], N) - y[i];
87+
88+
if ((y[i] * Ei < -_params.tol && _alphas[i] < _params.C) || (y[i] * Ei > _params.tol && _alphas[i] > 0)) {
89+
// alpha_i needs updating! Pick a j to update it with
90+
unsigned int j = i;
91+
92+
while (j == i)
93+
j = random(0, N);
94+
95+
float Ej = margin(X, y, X[j], N) - y[j];
96+
97+
// calculate L and H bounds for j to ensure we're in [0 _params.C]x[0 _params.C] box
98+
float ai = _alphas[i];
99+
float aj = _alphas[j];
100+
float L = 0;
101+
float H = 0;
102+
103+
if (y[i] == y[j]) {
104+
L = max(0, ai + aj - _params.C);
105+
H = min(_params.C, ai + aj);
106+
} else {
107+
L = max(0, aj - ai);
108+
H = min(_params.C, _params.C + aj - ai);
109+
}
110+
111+
if (abs(L - H) < 1e-4)
112+
continue;
113+
114+
float eta = 2 * _kernel(X[i], X[j], D) - _kernel(X[i], X[i], D) - _kernel(X[j], X[j], D);
115+
116+
if (eta >= 0)
117+
continue;
118+
119+
// compute new alpha_j and clip it inside [0 _params.C]x[0 _params.C] box
120+
// then compute alpha_i based on it.
121+
float newaj = aj - y[j] * (Ei - Ej) / eta;
122+
123+
if (newaj > H)
124+
newaj = H;
125+
if (newaj < L)
126+
newaj = L;
127+
if (abs(aj - newaj) < 1e-4)
128+
continue;
129+
130+
float newai = ai + y[i] * y[j] * (aj - newaj);
131+
132+
_alphas[i] = newai;
133+
_alphas[j] = newaj;
134+
135+
// update the bias term
136+
float b1 = _b - Ei - y[i] * (newai - ai) * _kernel(X[i], X[i], D)
137+
- y[j] * (newaj - aj) * _kernel(X[i], X[j], D);
138+
float b2 = _b - Ej - y[i] * (newai - ai) * _kernel(X[i], X[j], D)
139+
- y[j] * (newaj - aj) * _kernel(X[j], X[j], D);
140+
141+
_b = 0.5 * (b1 + b2);
142+
143+
if (newai > 0 && newai < _params.C)
144+
_b = b1;
145+
if (newaj > 0 && newaj < _params.C)
146+
_b = b2;
147+
148+
alphaChanged++;
149+
} // end alpha_i needed updating
150+
} // end for i=1..N
151+
152+
iter++;
153+
154+
if(alphaChanged == 0)
155+
passes++;
156+
else passes= 0;
157+
}
158+
159+
_y = y;
160+
_numSamples = N;
161+
}
162+
163+
/**
164+
*
165+
* @param x
166+
* @return
167+
*/
168+
int predict(float X_train[][D], float x[D]) {
169+
return margin(X_train, _y, x, _numSamples, true) > 0 ? 1 : -1;
170+
}
171+
172+
/**
173+
* Evaluate the accuracy of the classifier
174+
* @param X_train
175+
* @param X_test
176+
* @param y_test
177+
* @param testSize
178+
* @return
179+
*/
180+
float score(float X_train[][D], float X_test[][D], int y_test[], unsigned int testSize) {
181+
unsigned int correct = 0;
182+
183+
for (unsigned int i = 0; i < testSize; i++)
184+
if (predict(X_train, X_test[i]) == y_test[i])
185+
correct += 1;
186+
187+
return 1.0 * correct / testSize;
188+
}
189+
190+
protected:
191+
kernelFunction _kernel;
192+
struct {
193+
float C;
194+
float tol;
195+
float alphaTol;
196+
unsigned int maxIter;
197+
unsigned int passes;
198+
} _params;
199+
float _b = 0;
200+
unsigned int _numSamples;
201+
int *_y;
202+
float *_alphas;
203+
204+
/**
205+
*
206+
* @param X
207+
* @param y
208+
* @param x
209+
* @param N
210+
* @param skipSmallAlfas
211+
* @return
212+
*/
213+
float margin(float X[][D], int *y, float x[D], unsigned int N, bool skipSmallAlfas = false) {
214+
float sum = _b;
215+
216+
for(unsigned int i = 0; i < N; i++)
217+
if ((!skipSmallAlfas && _alphas[i] != 0) || (skipSmallAlfas && _alphas[i] > _params.alphaTol))
218+
sum += _alphas[i] * y[i] * _kernel(x, X[i], D);
219+
220+
return sum;
221+
}
222+
};
223+
}
224+
}

0 commit comments

Comments
 (0)