Skip to content

Commit 93b9432

Browse files
authored
Add support for Google AI models (#1612)
1 parent 8dbb462 commit 93b9432

File tree

8 files changed

+224
-38
lines changed

8 files changed

+224
-38
lines changed

docs/docs/ai-presets.mdx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,23 @@ To use Perplexity's models:
127127
}
128128
```
129129

130+
### Google (Gemini)
131+
132+
To use Google's Gemini models from [Google AI Studio](https://aistudio.google.com):
133+
134+
```json
135+
{
136+
"ai@gemini-2.0": {
137+
"display:name": "Gemini 2.0",
138+
"display:order": 5,
139+
"ai:*": true,
140+
"ai:apitype": "google",
141+
"ai:model": "gemini-2.0-flash-exp",
142+
"ai:apitoken": "<your Google AI API key>"
143+
}
144+
}
145+
```
146+
130147
## Multiple Presets Example
131148

132149
You can define multiple presets in your `ai.json` file:

frontend/app/view/waveai/waveai.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ export class WaveAiModel implements ViewModel {
347347
// Add a typing indicator
348348
globalStore.set(this.addMessageAtom, typingMessage);
349349
const history = await this.fetchAiData();
350-
const beMsg: OpenAiStreamRequest = {
350+
const beMsg: WaveAIStreamRequest = {
351351
clientid: clientId,
352352
opts: opts,
353353
prompt: [...history, newPrompt],

go.mod

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ require (
88
github.com/fsnotify/fsnotify v1.8.0
99
github.com/golang-jwt/jwt/v5 v5.2.1
1010
github.com/golang-migrate/migrate/v4 v4.18.1
11+
github.com/google/generative-ai-go v0.19.0
1112
github.com/google/uuid v1.6.0
1213
github.com/gorilla/handlers v1.5.2
1314
github.com/gorilla/mux v1.8.1
@@ -27,12 +28,24 @@ require (
2728
golang.org/x/crypto v0.31.0
2829
golang.org/x/sys v0.28.0
2930
golang.org/x/term v0.27.0
31+
google.golang.org/api v0.214.0
3032
)
3133

3234
require (
35+
cloud.google.com/go v0.115.0 // indirect
36+
cloud.google.com/go/ai v0.8.0 // indirect
37+
cloud.google.com/go/auth v0.13.0 // indirect
38+
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
39+
cloud.google.com/go/compute/metadata v0.6.0 // indirect
40+
cloud.google.com/go/longrunning v0.5.7 // indirect
3341
github.com/ebitengine/purego v0.8.1 // indirect
3442
github.com/felixge/httpsnoop v1.0.4 // indirect
43+
github.com/go-logr/logr v1.4.2 // indirect
44+
github.com/go-logr/stdr v1.2.2 // indirect
3545
github.com/go-ole/go-ole v1.2.6 // indirect
46+
github.com/google/s2a-go v0.1.8 // indirect
47+
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
48+
github.com/googleapis/gax-go/v2 v2.14.0 // indirect
3649
github.com/hashicorp/errwrap v1.1.0 // indirect
3750
github.com/hashicorp/go-multierror v1.1.1 // indirect
3851
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -44,8 +57,21 @@ require (
4457
github.com/tklauser/numcpus v0.6.1 // indirect
4558
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
4659
github.com/yusufpapurcu/wmi v1.2.4 // indirect
60+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
61+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
62+
go.opentelemetry.io/otel v1.29.0 // indirect
63+
go.opentelemetry.io/otel/metric v1.29.0 // indirect
64+
go.opentelemetry.io/otel/trace v1.29.0 // indirect
4765
go.uber.org/atomic v1.7.0 // indirect
4866
golang.org/x/net v0.33.0 // indirect
67+
golang.org/x/oauth2 v0.24.0 // indirect
68+
golang.org/x/sync v0.10.0 // indirect
69+
golang.org/x/text v0.21.0 // indirect
70+
golang.org/x/time v0.8.0 // indirect
71+
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect
72+
google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 // indirect
73+
google.golang.org/grpc v1.67.1 // indirect
74+
google.golang.org/protobuf v1.35.2 // indirect
4975
)
5076

5177
replace github.com/kevinburke/ssh_config => github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34

go.sum

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
cloud.google.com/go v0.115.0 h1:CnFSK6Xo3lDYRoBKEcAtia6VSC837/ZkJuRduSFnr14=
2+
cloud.google.com/go v0.115.0/go.mod h1:8jIM5vVgoAEoiVxQ/O4BFTfHqulPZgs/ufEzMcFMdWU=
3+
cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w=
4+
cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE=
5+
cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs=
6+
cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q=
7+
cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU=
8+
cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8=
9+
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
10+
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
11+
cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU=
12+
cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng=
113
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
214
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
315
github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg=
@@ -14,6 +26,11 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
1426
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
1527
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
1628
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
29+
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
30+
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
31+
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
32+
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
33+
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
1734
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
1835
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
1936
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
@@ -22,11 +39,19 @@ github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17w
2239
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
2340
github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y=
2441
github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks=
42+
github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg=
43+
github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
2544
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
2645
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
2746
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
47+
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
48+
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
2849
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
2950
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
51+
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
52+
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
53+
github.com/googleapis/gax-go/v2 v2.14.0 h1:f+jMrjBPl+DL9nI4IQzLUxMq7XrAqFYB7hBPqMNIe8o=
54+
github.com/googleapis/gax-go/v2 v2.14.0/go.mod h1:lhBCnjdLrWRaPvLWhmc8IS24m9mr07qSYnHncrgo+zk=
3055
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
3156
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
3257
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -94,12 +119,26 @@ github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34 h1:I8VZVTZE
94119
github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
95120
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
96121
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
122+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc=
123+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI=
124+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
125+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
126+
go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
127+
go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
128+
go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
129+
go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
130+
go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
131+
go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
97132
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
98133
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
99134
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
100135
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
101136
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
102137
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
138+
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
139+
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
140+
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
141+
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
103142
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
104143
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
105144
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -111,7 +150,21 @@ golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
111150
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
112151
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
113152
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
153+
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
154+
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
155+
golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg=
156+
golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
114157
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
158+
google.golang.org/api v0.214.0 h1:h2Gkq07OYi6kusGOaT/9rnNljuXmqPnaig7WGPmKbwA=
159+
google.golang.org/api v0.214.0/go.mod h1:bYPpLG8AyeMWwDU6NXoB00xC0DFkikVvd5MfwoxjLqE=
160+
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 h1:M0KvPgPmDZHPlbRbaNU1APr28TvwvvdUPlSv7PUvy8g=
161+
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4=
162+
google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 h1:8ZmaLZE4XWrtU3MyClkYqqtl6Oegr3235h7jxsDyqCY=
163+
google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU=
164+
google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
165+
google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
166+
google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io=
167+
google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
115168
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
116169
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
117170
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

pkg/waveai/cloudbackend.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"io"
1111
"log"
12+
"time"
1213

1314
"github.com/gorilla/websocket"
1415
"github.com/wavetermdev/waveterm/pkg/panichandler"
@@ -20,6 +21,10 @@ type WaveAICloudBackend struct{}
2021

2122
var _ AIBackend = WaveAICloudBackend{}
2223

24+
const CloudWebsocketConnectTimeout = 1 * time.Minute
25+
const OpenAICloudReqStr = "openai-cloudreq"
26+
const PacketEOFStr = "EOF"
27+
2328
type WaveAICloudReqPacketType struct {
2429
Type string `json:"type"`
2530
ClientId string `json:"clientid"`

pkg/waveai/googlebackend.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package waveai
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
8+
"github.com/google/generative-ai-go/genai"
9+
"github.com/wavetermdev/waveterm/pkg/wshrpc"
10+
"google.golang.org/api/iterator"
11+
"google.golang.org/api/option"
12+
)
13+
14+
type GoogleBackend struct{}
15+
16+
var _ AIBackend = GoogleBackend{}
17+
18+
func (GoogleBackend) StreamCompletion(ctx context.Context, request wshrpc.WaveAIStreamRequest) chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType] {
19+
client, err := genai.NewClient(ctx, option.WithAPIKey(request.Opts.APIToken))
20+
if err != nil {
21+
log.Fatalf("failed to create client: %v", err)
22+
return nil
23+
}
24+
25+
model := client.GenerativeModel(request.Opts.Model)
26+
if model == nil {
27+
log.Fatal("model not found")
28+
client.Close()
29+
return nil
30+
}
31+
32+
cs := model.StartChat()
33+
cs.History = extractHistory(request.Prompt)
34+
iter := cs.SendMessageStream(ctx, extractPrompt(request.Prompt))
35+
36+
rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType])
37+
38+
go func() {
39+
defer client.Close()
40+
defer close(rtn)
41+
for {
42+
// Check for context cancellation
43+
select {
44+
case <-ctx.Done():
45+
rtn <- makeAIError(fmt.Errorf("request cancelled: %v", ctx.Err()))
46+
break
47+
default:
48+
}
49+
50+
resp, err := iter.Next()
51+
if err == iterator.Done {
52+
break
53+
}
54+
if err != nil {
55+
rtn <- makeAIError(fmt.Errorf("Google API error: %v", err))
56+
break
57+
}
58+
59+
rtn <- wshrpc.RespOrErrorUnion[wshrpc.WaveAIPacketType]{Response: wshrpc.WaveAIPacketType{Text: convertCandidatesToText(resp.Candidates)}}
60+
}
61+
}()
62+
return rtn
63+
}
64+
65+
func extractHistory(history []wshrpc.WaveAIPromptMessageType) []*genai.Content {
66+
var rtn []*genai.Content
67+
for _, h := range history[:len(history)-1] {
68+
if h.Role == "user" || h.Role == "model" {
69+
rtn = append(rtn, &genai.Content{
70+
Role: h.Role,
71+
Parts: []genai.Part{genai.Text(h.Content)},
72+
})
73+
}
74+
}
75+
return rtn
76+
}
77+
78+
func extractPrompt(prompt []wshrpc.WaveAIPromptMessageType) genai.Part {
79+
p := prompt[len(prompt)-1]
80+
return genai.Text(p.Content)
81+
}
82+
83+
func convertCandidatesToText(candidates []*genai.Candidate) string {
84+
var rtn string
85+
for _, c := range candidates {
86+
for _, p := range c.Content.Parts {
87+
rtn += fmt.Sprintf("%v", p)
88+
}
89+
}
90+
return rtn
91+
}

pkg/waveai/openaibackend.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ type OpenAIBackend struct{}
2020

2121
var _ AIBackend = OpenAIBackend{}
2222

23+
const DefaultAzureAPIVersion = "2023-05-15"
24+
2325
// copied from go-openai/config.go
2426
func defaultAzureMapperFn(model string) string {
2527
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")

0 commit comments

Comments
 (0)