@@ -33,9 +33,6 @@ public struct StableDiffusionPipeline: ResourceManaging {
3333 /// Optional model for checking safety of generated image
3434 var safetyChecker : SafetyChecker ? = nil
3535
36- /// Controls the influence of the text prompt on sampling process (0=random images)
37- var guidanceScale : Float = 7.5
38-
3936 /// Reports whether this pipeline can perform safety checks
4037 public var canSafetyCheck : Bool {
4138 safetyChecker != nil
@@ -56,20 +53,17 @@ public struct StableDiffusionPipeline: ResourceManaging {
5653 /// - unet: Model for noise prediction on latent samples
5754 /// - decoder: Model for decoding latent sample to image
5855 /// - safetyChecker: Optional model for checking safety of generated images
59- /// - guidanceScale: Influence of the text prompt on generation process
6056 /// - reduceMemory: Option to enable reduced memory mode
6157 /// - Returns: Pipeline ready for image generation
6258 public init ( textEncoder: TextEncoder ,
6359 unet: Unet ,
6460 decoder: Decoder ,
6561 safetyChecker: SafetyChecker ? = nil ,
66- guidanceScale: Float = 7.5 ,
6762 reduceMemory: Bool = false ) {
6863 self . textEncoder = textEncoder
6964 self . unet = unet
7065 self . decoder = decoder
7166 self . safetyChecker = safetyChecker
72- self . guidanceScale = guidanceScale
7367 self . reduceMemory = reduceMemory
7468 }
7569
@@ -112,6 +106,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
112106 /// - stepCount: Number of inference steps to perform
113107 /// - imageCount: Number of samples/images to generate for the input prompt
114108 /// - seed: Random seed which
109+ /// - guidanceScale: Controls the influence of the text prompt on sampling process (0=random images)
115110 /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
116111 /// - progressHandler: Callback to perform after each step, stops on receiving false response
117112 /// - Returns: An array of `imageCount` optional images.
@@ -122,6 +117,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
122117 imageCount: Int = 1 ,
123118 stepCount: Int = 50 ,
124119 seed: UInt32 = 0 ,
120+ guidanceScale: Float = 7.5 ,
125121 disableSafety: Bool = false ,
126122 scheduler: StableDiffusionScheduler = . pndmScheduler,
127123 progressHandler: ( Progress ) -> Bool = { _ in true }
@@ -173,7 +169,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
173169 hiddenStates: hiddenStates
174170 )
175171
176- noise = performGuidance ( noise)
172+ noise = performGuidance ( noise, guidanceScale )
177173
178174 // Have the scheduler compute the previous (t-1) latent
179175 // sample given the predicted noise and current sample
@@ -236,11 +232,11 @@ public struct StableDiffusionPipeline: ResourceManaging {
236232 return states
237233 }
238234
239- func performGuidance( _ noise: [ MLShapedArray < Float32 > ] ) -> [ MLShapedArray < Float32 > ] {
240- noise. map { performGuidance ( $0) }
235+ func performGuidance( _ noise: [ MLShapedArray < Float32 > ] , _ guidanceScale : Float ) -> [ MLShapedArray < Float32 > ] {
236+ noise. map { performGuidance ( $0, guidanceScale ) }
241237 }
242238
243- func performGuidance( _ noise: MLShapedArray < Float32 > ) -> MLShapedArray < Float32 > {
239+ func performGuidance( _ noise: MLShapedArray < Float32 > , _ guidanceScale : Float ) -> MLShapedArray < Float32 > {
244240
245241 let blankNoiseScalars = noise [ 0 ] . scalars
246242 let textNoiseScalars = noise [ 1 ] . scalars
0 commit comments