1515using libsignalservicedotnet . crypto ;
1616using System ;
1717using System . Collections . Generic ;
18+ using System . Threading . Tasks ;
1819using System . Linq ;
1920
2021namespace libsignalservice . crypto
@@ -72,77 +73,108 @@ public OutgoingPushMessage Encrypt(SignalProtocolAddress destination, Unidentifi
7273 /// Decrypt a received <see cref="SignalServiceEnvelope"/>
7374 /// </summary>
7475 /// <param name="envelope">The received SignalServiceEnvelope</param>
76+ /// <param name="callback">Optional callback to call during the decrypt process before it is acked</param>
7577 /// <returns>a decrypted SignalServiceContent</returns>
76- public SignalServiceContent ? Decrypt ( SignalServiceEnvelope envelope )
78+ public async Task < SignalServiceContent ? > Decrypt ( SignalServiceEnvelope envelope , Func < SignalServiceContent ? , Task > callback = null )
7779 {
80+ Func < Plaintext , Task > callback_func = null ;
81+ if ( callback != null )
82+ {
83+ callback_func = async ( data ) => await callback ( await DecryptComplete ( envelope , data ) ) ;
84+ }
7885 try
7986 {
87+ Plaintext plaintext = null ;
8088 if ( envelope . HasLegacyMessage ( ) )
8189 {
82- Plaintext plaintext = Decrypt ( envelope , envelope . GetLegacyMessage ( ) ) ;
83- DataMessage message = DataMessage . Parser . ParseFrom ( plaintext . Data ) ;
90+ plaintext = await Decrypt ( envelope , envelope . GetLegacyMessage ( ) , callback_func ) ;
91+ }
92+ else if ( envelope . HasContent ( ) )
93+ {
94+ plaintext = await Decrypt ( envelope , envelope . GetContent ( ) , callback_func ) ;
95+ }
96+ if ( callback_func != null )
97+ {
98+ return null ;
99+ }
100+ return await DecryptComplete ( envelope , plaintext ) ;
101+ }
102+ catch ( InvalidProtocolBufferException e )
103+ {
104+ throw new InvalidMessageException ( e ) ;
105+ }
106+ }
107+ private async Task < SignalServiceContent > DecryptComplete ( SignalServiceEnvelope envelope , Plaintext plaintext )
108+ {
109+ if ( envelope . HasLegacyMessage ( ) )
110+ {
111+ DataMessage message = DataMessage . Parser . ParseFrom ( plaintext . Data ) ;
112+ return new SignalServiceContent ( plaintext . Metadata . Sender ,
113+ plaintext . Metadata . SenderDevice ,
114+ plaintext . Metadata . Timestamp ,
115+ plaintext . Metadata . NeedsReceipt )
116+ {
117+ Message = CreateSignalServiceMessage ( plaintext . Metadata , message )
118+ } ;
119+ }
120+ else if ( envelope . HasContent ( ) )
121+ {
122+ Content message = Content . Parser . ParseFrom ( plaintext . Data ) ;
123+ if ( message . DataMessageOneofCase == Content . DataMessageOneofOneofCase . DataMessage )
124+ {
84125 return new SignalServiceContent ( plaintext . Metadata . Sender ,
85- plaintext . Metadata . SenderDevice ,
86- plaintext . Metadata . Timestamp ,
87- plaintext . Metadata . NeedsReceipt )
126+ plaintext . Metadata . SenderDevice ,
127+ plaintext . Metadata . Timestamp ,
128+ plaintext . Metadata . NeedsReceipt )
88129 {
89- Message = CreateSignalServiceMessage ( plaintext . Metadata , message )
130+ Message = CreateSignalServiceMessage ( plaintext . Metadata , message . DataMessage )
90131 } ;
91132 }
92- else if ( envelope . HasContent ( ) )
133+ else if ( message . SyncMessageOneofCase == Content . SyncMessageOneofOneofCase . SyncMessage )
93134 {
94- Plaintext plaintext = Decrypt ( envelope , envelope . Envelope . Content . ToByteArray ( ) ) ;
95- Content message = Content . Parser . ParseFrom ( plaintext . Data ) ;
96- if ( message . DataMessageOneofCase == Content . DataMessageOneofOneofCase . DataMessage )
97- {
98- return new SignalServiceContent ( plaintext . Metadata . Sender ,
99- plaintext . Metadata . SenderDevice ,
100- plaintext . Metadata . Timestamp ,
101- plaintext . Metadata . NeedsReceipt )
102- {
103- Message = CreateSignalServiceMessage ( plaintext . Metadata , message . DataMessage )
104- } ;
105- }
106- else if ( message . SyncMessageOneofCase == Content . SyncMessageOneofOneofCase . SyncMessage )
135+ return new SignalServiceContent ( plaintext . Metadata . Sender ,
136+ plaintext . Metadata . SenderDevice ,
137+ plaintext . Metadata . Timestamp ,
138+ plaintext . Metadata . NeedsReceipt )
107139 {
108- return new SignalServiceContent ( plaintext . Metadata . Sender ,
109- plaintext . Metadata . SenderDevice ,
110- plaintext . Metadata . Timestamp ,
111- plaintext . Metadata . NeedsReceipt )
112- {
113- SynchronizeMessage = CreateSynchronizeMessage ( plaintext . Metadata , message . SyncMessage )
114- } ;
115- }
116- else if ( message . CallMessageOneofCase == Content . CallMessageOneofOneofCase . CallMessage )
140+ SynchronizeMessage = CreateSynchronizeMessage ( plaintext . Metadata , message . SyncMessage )
141+ } ;
142+ }
143+ else if ( message . CallMessageOneofCase == Content . CallMessageOneofOneofCase . CallMessage )
144+ {
145+ return new SignalServiceContent ( plaintext . Metadata . Sender ,
146+ plaintext . Metadata . SenderDevice ,
147+ plaintext . Metadata . Timestamp ,
148+ plaintext . Metadata . NeedsReceipt )
117149 {
118- return new SignalServiceContent ( plaintext . Metadata . Sender ,
119- plaintext . Metadata . SenderDevice ,
120- plaintext . Metadata . Timestamp ,
121- plaintext . Metadata . NeedsReceipt )
122- {
123- CallMessage = CreateCallMessage ( message . CallMessage )
124- } ;
125- }
126- else if ( message . ReceiptMessageOneofCase == Content . ReceiptMessageOneofOneofCase . ReceiptMessage )
150+ CallMessage = CreateCallMessage ( message . CallMessage )
151+ } ;
152+ }
153+ else if ( message . ReceiptMessageOneofCase == Content . ReceiptMessageOneofOneofCase . ReceiptMessage )
154+ {
155+ return new SignalServiceContent ( plaintext . Metadata . Sender ,
156+ plaintext . Metadata . SenderDevice ,
157+ plaintext . Metadata . Timestamp ,
158+ plaintext . Metadata . NeedsReceipt )
127159 {
128- return new SignalServiceContent ( plaintext . Metadata . Sender ,
129- plaintext . Metadata . SenderDevice ,
130- plaintext . Metadata . Timestamp ,
131- plaintext . Metadata . NeedsReceipt )
132- {
133- ReadMessage = CreateReceiptMessage ( plaintext . Metadata , message . ReceiptMessage )
134- } ;
135- }
160+ ReadMessage = CreateReceiptMessage ( plaintext . Metadata , message . ReceiptMessage )
161+ } ;
136162 }
137- return null ;
138163 }
139- catch ( InvalidProtocolBufferException e )
164+ return null ;
165+ }
166+ private class DecryptionCallbackHandler : DecryptionCallback
167+ {
168+ public Task handlePlaintext ( byte [ ] data , uint sessionVersion )
140169 {
141- throw new InvalidMetadataMessageException ( e ) ;
170+ data = GetStrippedMessage ( sessionVersion , data ) ;
171+ return callback ( new Plaintext ( metadata , data ) ) ;
142172 }
173+ public SessionCipher sessionCipher ;
174+ public Metadata metadata ;
175+ public Func < Plaintext , Task > callback ;
143176 }
144-
145- private Plaintext Decrypt ( SignalServiceEnvelope envelope , byte [ ] ciphertext )
177+ private async Task < Plaintext > Decrypt ( SignalServiceEnvelope envelope , byte [ ] ciphertext , Func < Plaintext , Task > callback = null )
146178 {
147179 try
148180 {
@@ -153,15 +185,27 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
153185 byte [ ] paddedMessage ;
154186 Metadata metadata ;
155187 uint sessionVersion ;
156-
188+ DecryptionCallbackHandler callback_handler = null ;
189+ if ( callback != null )
190+ callback_handler = new DecryptionCallbackHandler { callback = callback , sessionCipher = sessionCipher } ;
157191 if ( envelope . IsPreKeySignalMessage ( ) )
158192 {
159- paddedMessage = sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) ) ;
160193 metadata = new Metadata ( envelope . GetSource ( ) , envelope . GetSourceDevice ( ) , envelope . GetTimestamp ( ) , false ) ;
194+ if ( callback_handler != null )
195+ {
196+ await sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) , callback_handler ) ;
197+ return null ;
198+ }
199+ paddedMessage = sessionCipher . decrypt ( new PreKeySignalMessage ( ciphertext ) ) ;
161200 sessionVersion = sessionCipher . getSessionVersion ( ) ;
162201 }
163202 else if ( envelope . IsSignalMessage ( ) )
164203 {
204+ if ( callback_handler != null )
205+ {
206+ await sessionCipher . decrypt ( new SignalMessage ( ciphertext ) , callback_handler ) ;
207+ return null ;
208+ }
165209 paddedMessage = sessionCipher . decrypt ( new SignalMessage ( ciphertext ) ) ;
166210 metadata = new Metadata ( envelope . GetSource ( ) , envelope . GetSourceDevice ( ) , envelope . GetTimestamp ( ) , false ) ;
167211 sessionVersion = sessionCipher . getSessionVersion ( ) ;
@@ -170,16 +214,14 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
170214 {
171215 var results = sealedSessionCipher . Decrypt ( CertificateValidator , ciphertext , ( long ) envelope . Envelope . ServerTimestamp ) ;
172216 paddedMessage = results . Item2 ;
173- metadata = new Metadata ( results . Item1 . Name , ( int ) results . Item1 . DeviceId , ( long ) envelope . Envelope . Timestamp , true ) ;
174- sessionVersion = ( uint ) sealedSessionCipher . GetSessionVersion ( new SignalProtocolAddress ( metadata . Sender , ( uint ) metadata . SenderDevice ) ) ;
217+ metadata = new Metadata ( results . Item1 . Name , ( int ) results . Item1 . DeviceId , ( long ) envelope . Envelope . Timestamp , true ) ;
218+ sessionVersion = ( uint ) sealedSessionCipher . GetSessionVersion ( new SignalProtocolAddress ( metadata . Sender , ( uint ) metadata . SenderDevice ) ) ;
175219 }
176220 else
177221 {
178222 throw new InvalidMessageException ( "Unknown type: " + envelope . GetEnvelopeType ( ) + " from " + envelope . GetSource ( ) ) ;
179223 }
180-
181- PushTransportDetails transportDetails = new PushTransportDetails ( sessionVersion ) ;
182- byte [ ] data = transportDetails . GetStrippedPaddingMessageBody ( paddedMessage ) ;
224+ var data = GetStrippedMessage ( sessionVersion , paddedMessage ) ;
183225 return new Plaintext ( metadata , data ) ;
184226 }
185227 catch ( DuplicateMessageException e )
@@ -214,7 +256,15 @@ private Plaintext Decrypt(SignalServiceEnvelope envelope, byte[] ciphertext)
214256 {
215257 throw new ProtocolNoSessionException ( e , envelope . GetSource ( ) , envelope . GetSourceDevice ( ) ) ;
216258 }
259+
217260 }
261+ private static byte [ ] GetStrippedMessage ( uint sessionVersion , byte [ ] paddedMessage )
262+ {
263+ PushTransportDetails transportDetails = new PushTransportDetails ( sessionVersion ) ;
264+ byte [ ] data = transportDetails . GetStrippedPaddingMessageBody ( paddedMessage ) ;
265+ return data ;
266+ }
267+
218268
219269 private SignalServiceDataMessage CreateSignalServiceMessage ( Metadata metadata , DataMessage content )
220270 {
0 commit comments