Skip to content

Commit 535e409

Browse files
authored
Merge pull request #47 from zoq/ios-finetuning-app
iOS finetuning integration
2 parents e9825e6 + 333f9a0 commit 535e409

File tree

12 files changed

+2405
-93
lines changed

12 files changed

+2405
-93
lines changed

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ enum LlamaError: Error {
55
case couldNotInitializeContext
66
}
77

8+
struct LlamaRuntimeOptions {
9+
var contextLength: Int32
10+
var nGpuLayers: Int32
11+
var seed: UInt32
12+
var temperature: Float
13+
var topP: Float
14+
var topK: Int32
15+
var flashAttention: Bool
16+
}
17+
818
func llama_batch_clear(_ batch: inout llama_batch) {
919
batch.n_tokens = 0
1020
}
@@ -25,9 +35,10 @@ actor LlamaContext {
2535
private var model: OpaquePointer
2636
private var context: OpaquePointer
2737
private var vocab: OpaquePointer
28-
private var sampling: UnsafeMutablePointer<llama_sampler>
38+
private var sampling: UnsafeMutablePointer<llama_sampler>?
2939
private var batch: llama_batch
3040
private var tokens_list: [llama_token]
41+
private var runtimeOptions: LlamaRuntimeOptions
3142
var is_done: Bool = false
3243

3344
/// This variable is used to store temporarily invalid cchars
@@ -38,34 +49,80 @@ actor LlamaContext {
3849

3950
var n_decode: Int32 = 0
4051

41-
init(model: OpaquePointer, context: OpaquePointer) {
52+
init(model: OpaquePointer, context: OpaquePointer, options: LlamaRuntimeOptions) {
4253
self.model = model
4354
self.context = context
4455
self.tokens_list = []
45-
self.batch = llama_batch_init(512, 0, 1)
56+
self.batch = llama_batch_init(max(Int32(512), options.contextLength), 0, 1)
4657
self.temporary_invalid_cchars = []
47-
let sparams = llama_sampler_chain_default_params()
48-
self.sampling = llama_sampler_chain_init(sparams)
49-
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
50-
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
58+
self.runtimeOptions = options
59+
self.n_len = options.contextLength
5160
vocab = llama_model_get_vocab(model)
61+
62+
let chainParams = llama_sampler_chain_default_params()
63+
let initialChain = llama_sampler_chain_init(chainParams)
64+
65+
if options.topK > 0 {
66+
llama_sampler_chain_add(initialChain, llama_sampler_init_top_k(options.topK))
67+
}
68+
69+
let clampedTopP = max(0.0, min(Double(options.topP), 1.0))
70+
llama_sampler_chain_add(initialChain, llama_sampler_init_top_p(Float(clampedTopP), 1))
71+
72+
let clampedTemp = max(0.0, Double(options.temperature))
73+
llama_sampler_chain_add(initialChain, llama_sampler_init_temp(Float(clampedTemp)))
74+
75+
let seed = options.seed == 0 ? UInt32.max : options.seed
76+
llama_sampler_chain_add(initialChain, llama_sampler_init_dist(seed))
77+
78+
sampling = initialChain
5279
}
5380

5481
deinit {
55-
llama_sampler_free(sampling)
82+
if let sampling {
83+
llama_sampler_free(sampling)
84+
}
5685
llama_batch_free(batch)
5786
llama_model_free(model)
5887
llama_free(context)
5988
llama_backend_free()
6089
}
6190

62-
static func create_context(path: String) throws -> LlamaContext {
91+
private func rebuildSamplerChain() {
92+
let chainParams = llama_sampler_chain_default_params()
93+
let newChain = llama_sampler_chain_init(chainParams)
94+
95+
if runtimeOptions.topK > 0 {
96+
llama_sampler_chain_add(newChain, llama_sampler_init_top_k(runtimeOptions.topK))
97+
}
98+
99+
let clampedTopP = max(0.0, min(runtimeOptions.topP, 1.0))
100+
llama_sampler_chain_add(newChain, llama_sampler_init_top_p(clampedTopP, 1))
101+
102+
let clampedTemp = max(0.0, Double(runtimeOptions.temperature))
103+
llama_sampler_chain_add(newChain, llama_sampler_init_temp(Float(clampedTemp)))
104+
105+
let seed = runtimeOptions.seed == 0 ? UInt32.max : runtimeOptions.seed
106+
llama_sampler_chain_add(newChain, llama_sampler_init_dist(seed))
107+
108+
if let sampling {
109+
llama_sampler_free(sampling)
110+
}
111+
112+
sampling = newChain
113+
}
114+
115+
static func create_context(path: String, options: LlamaRuntimeOptions) throws -> LlamaContext {
63116
llama_backend_init()
64117
var model_params = llama_model_default_params()
65118

66119
#if targetEnvironment(simulator)
67120
model_params.n_gpu_layers = 0
68121
print("Running on simulator, force use n_gpu_layers = 0")
122+
#else
123+
if options.nGpuLayers >= 0 {
124+
model_params.n_gpu_layers = options.nGpuLayers
125+
}
69126
#endif
70127
let model = llama_model_load_from_file(path, model_params)
71128
guard let model else {
@@ -77,17 +134,24 @@ actor LlamaContext {
77134
print("Using \(n_threads) threads")
78135

79136
var ctx_params = llama_context_default_params()
80-
ctx_params.n_ctx = 2048
137+
ctx_params.n_ctx = UInt32(options.contextLength)
81138
ctx_params.n_threads = Int32(n_threads)
82139
ctx_params.n_threads_batch = Int32(n_threads)
140+
ctx_params.flash_attn_type = options.flashAttention ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED
83141

84142
let context = llama_init_from_model(model, ctx_params)
85143
guard let context else {
86144
print("Could not load context!")
87145
throw LlamaError.couldNotInitializeContext
88146
}
89147

90-
return LlamaContext(model: model, context: context)
148+
return LlamaContext(model: model, context: context, options: options)
149+
}
150+
151+
func updateSampler(options: LlamaRuntimeOptions) {
152+
runtimeOptions = options
153+
n_len = options.contextLength
154+
rebuildSamplerChain()
91155
}
92156

93157
func model_info() -> String {
@@ -151,6 +215,10 @@ actor LlamaContext {
151215
func completion_loop() -> String {
152216
var new_token_id: llama_token = 0
153217

218+
guard let sampling else {
219+
return ""
220+
}
221+
154222
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
155223

156224
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {

examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DD84C9FD2D747FED007778EC /* llama.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD84C9FC2D747FED007778EC /* llama.xcframework */; };
2121
DD84C9FE2D747FED007778EC /* llama.xcframework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = DD84C9FC2D747FED007778EC /* llama.xcframework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
2222
F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */; };
23+
D1A2F0012E0C8E8B00A1B1C0 /* FinetuneBridge.mm in Sources */ = {isa = PBXBuildFile; fileRef = D1A2F0002E0C8E8B00A1B1C0 /* FinetuneBridge.mm */; };
2324
/* End PBXBuildFile section */
2425

2526
/* Begin PBXCopyFilesBuildPhase section */
@@ -51,6 +52,9 @@
5152
DD84C9FC2D747FED007778EC /* llama.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = llama.xcframework; path = "../../build-apple/llama.xcframework"; sourceTree = "<group>"; };
5253
DF2D2FE72B4A59BE00FCB72D /* llama.cpp */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = llama.cpp; path = ../..; sourceTree = "<group>"; };
5354
F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoadCustomButton.swift; sourceTree = "<group>"; };
55+
D1A2F0002E0C8E8B00A1B1C0 /* FinetuneBridge.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = FinetuneBridge.mm; sourceTree = "<group>"; };
56+
D1A2F0022E0C8E8B00A1B1C0 /* FinetuneBridge.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = FinetuneBridge.h; sourceTree = "<group>"; };
57+
D1A2F0032E0C8E8B00A1B1C0 /* llama_swiftui-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "llama_swiftui-Bridging-Header.h"; sourceTree = "<group>"; };
5458
/* End PBXFileReference section */
5559

5660
/* Begin PBXFrameworksBuildPhase section */
@@ -93,6 +97,7 @@
9397
8A3F84102AC4BD85005E2EE8 /* Resources */,
9498
8A9F7C4B2AC332DC008AE1EA /* Models */,
9599
8A9F7C4A2AC332BF008AE1EA /* UI */,
100+
D1A2F0062E0C8E8B00A1B1C0 /* Bridging */,
96101
8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */,
97102
8A1C837A2AC328BE0096AF73 /* Assets.xcassets */,
98103
);
@@ -109,19 +114,19 @@
109114
name = Frameworks;
110115
sourceTree = "<group>";
111116
};
112-
8A3F84102AC4BD85005E2EE8 /* Resources */ = {
117+
8A3F84112AC4BD8C005E2EE8 /* models */ = {
113118
isa = PBXGroup;
114119
children = (
115-
8A3F84112AC4BD8C005E2EE8 /* models */,
116120
);
117-
path = Resources;
121+
path = models;
118122
sourceTree = "<group>";
119123
};
120-
8A3F84112AC4BD8C005E2EE8 /* models */ = {
124+
8A3F84102AC4BD85005E2EE8 /* Resources */ = {
121125
isa = PBXGroup;
122126
children = (
127+
8A3F84112AC4BD8C005E2EE8 /* models */,
123128
);
124-
path = models;
129+
path = Resources;
125130
sourceTree = "<group>";
126131
};
127132
8A907F312AC7134E006146EA /* llama.cpp.swift */ = {
@@ -132,6 +137,16 @@
132137
path = llama.cpp.swift;
133138
sourceTree = "<group>";
134139
};
140+
D1A2F0062E0C8E8B00A1B1C0 /* Bridging */ = {
141+
isa = PBXGroup;
142+
children = (
143+
D1A2F0022E0C8E8B00A1B1C0 /* FinetuneBridge.h */,
144+
D1A2F0032E0C8E8B00A1B1C0 /* llama_swiftui-Bridging-Header.h */,
145+
D1A2F0002E0C8E8B00A1B1C0 /* FinetuneBridge.mm */,
146+
);
147+
path = Bridging;
148+
sourceTree = "<group>";
149+
};
135150
8A9F7C4A2AC332BF008AE1EA /* UI */ = {
136151
isa = PBXGroup;
137152
children = (
@@ -230,6 +245,7 @@
230245
F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */,
231246
8A907F332AC7138A006146EA /* LibLlama.swift in Sources */,
232247
8A9F7C4D2AC332EE008AE1EA /* LlamaState.swift in Sources */,
248+
D1A2F0012E0C8E8B00A1B1C0 /* FinetuneBridge.mm in Sources */,
233249
8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */,
234250
8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */,
235251
7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */,
@@ -364,9 +380,10 @@
364380
buildSettings = {
365381
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
366382
CLANG_ENABLE_MODULES = YES;
383+
CODE_SIGN_IDENTITY = "Apple Development";
367384
CODE_SIGN_STYLE = Automatic;
368385
CURRENT_PROJECT_VERSION = 1;
369-
DEVELOPMENT_TEAM = K5UQJPP73A;
386+
DEVELOPMENT_TEAM = 3LT8Z8ZRCG;
370387
ENABLE_PREVIEWS = YES;
371388
GENERATE_INFOPLIST_FILE = YES;
372389
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
@@ -380,14 +397,16 @@
380397
"@executable_path/Frameworks",
381398
);
382399
MARKETING_VERSION = 1.0;
383-
PRODUCT_BUNDLE_IDENTIFIER = "com.bachittle.llama-swift";
400+
PRODUCT_BUNDLE_IDENTIFIER = "llama-collabora";
384401
PRODUCT_NAME = "$(TARGET_NAME)";
402+
PROVISIONING_PROFILE_SPECIFIER = "";
385403
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator xros xrsimulator";
386404
SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO;
387405
SWIFT_EMIT_LOC_STRINGS = YES;
388406
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
389407
SWIFT_VERSION = 5.0;
390408
TARGETED_DEVICE_FAMILY = "1,2,7";
409+
SWIFT_OBJC_BRIDGING_HEADER = "llama.swiftui/Bridging/llama_swiftui-Bridging-Header.h";
391410
};
392411
name = Debug;
393412
};
@@ -396,9 +415,10 @@
396415
buildSettings = {
397416
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
398417
CLANG_ENABLE_MODULES = YES;
418+
CODE_SIGN_IDENTITY = "Apple Development";
399419
CODE_SIGN_STYLE = Automatic;
400420
CURRENT_PROJECT_VERSION = 1;
401-
DEVELOPMENT_TEAM = K5UQJPP73A;
421+
DEVELOPMENT_TEAM = 3LT8Z8ZRCG;
402422
ENABLE_PREVIEWS = YES;
403423
GENERATE_INFOPLIST_FILE = YES;
404424
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
@@ -412,13 +432,15 @@
412432
"@executable_path/Frameworks",
413433
);
414434
MARKETING_VERSION = 1.0;
415-
PRODUCT_BUNDLE_IDENTIFIER = "com.bachittle.llama-swift";
435+
PRODUCT_BUNDLE_IDENTIFIER = "llama-collabora";
416436
PRODUCT_NAME = "$(TARGET_NAME)";
437+
PROVISIONING_PROFILE_SPECIFIER = "";
417438
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator xros xrsimulator";
418439
SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO;
419440
SWIFT_EMIT_LOC_STRINGS = YES;
420441
SWIFT_VERSION = 5.0;
421442
TARGETED_DEVICE_FAMILY = "1,2,7";
443+
SWIFT_OBJC_BRIDGING_HEADER = "llama.swiftui/Bridging/llama_swiftui-Bridging-Header.h";
422444
};
423445
name = Release;
424446
};
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#include <stdbool.h>
4+
#include <stddef.h>
5+
#include <stdint.h>
6+
7+
#ifdef __cplusplus
8+
extern "C" {
9+
#endif
10+
11+
// Error codes returned by llama_swift_run_lora_finetune
12+
enum llama_swift_finetune_error {
13+
LLAMA_SWIFT_FINETUNE_OK = 0,
14+
LLAMA_SWIFT_FINETUNE_ERROR_INVALID_ARGUMENT = 1,
15+
LLAMA_SWIFT_FINETUNE_ERROR_MODEL_LOAD = 2,
16+
LLAMA_SWIFT_FINETUNE_ERROR_CONTEXT_CREATE = 3,
17+
LLAMA_SWIFT_FINETUNE_ERROR_DATASET = 4,
18+
LLAMA_SWIFT_FINETUNE_ERROR_TRAINING_INIT = 5,
19+
LLAMA_SWIFT_FINETUNE_ERROR_SAVE = 6,
20+
};
21+
22+
struct llama_swift_finetune_options {
23+
int32_t n_ctx;
24+
int32_t n_threads;
25+
int32_t n_batch;
26+
int32_t n_ubatch;
27+
int32_t epochs;
28+
int32_t lora_rank;
29+
float lora_alpha;
30+
float learning_rate;
31+
float val_split;
32+
uint32_t target_modules;
33+
int32_t seed;
34+
bool flash_attn;
35+
int32_t n_gpu_layers;
36+
};
37+
38+
typedef void (*llama_swift_finetune_log_callback)(const char * message, void * user_data);
39+
40+
enum llama_swift_finetune_error llama_swift_run_lora_finetune(
41+
const char * model_path,
42+
const char * dataset_path,
43+
const char * output_adapter_path,
44+
const struct llama_swift_finetune_options * options,
45+
llama_swift_finetune_log_callback logger,
46+
void * user_data);
47+
48+
#ifdef __cplusplus
49+
}
50+
#endif

0 commit comments

Comments
 (0)