@@ -340,6 +340,21 @@ class HandshakePattern {
340340 }
341341
342342 record MessagePattern (NoiseHandshake .Role sender , Token [] tokens ) {
343+
344+ MessagePattern withAddedToken (final Token token , final int insertionIndex ) {
345+ if (insertionIndex < 0 || insertionIndex >= this .tokens ().length + 1 ) {
346+ throw new IllegalArgumentException ("Illegal insertion index" );
347+ }
348+
349+ final Token [] modifiedTokens = new Token [this .tokens ().length + 1 ];
350+ System .arraycopy (this .tokens (), 0 , modifiedTokens , 0 , insertionIndex );
351+ modifiedTokens [insertionIndex ] = token ;
352+ System .arraycopy (this .tokens (), insertionIndex , modifiedTokens ,
353+ insertionIndex + 1 , this .tokens ().length - insertionIndex );
354+
355+ return new MessagePattern (this .sender (), modifiedTokens );
356+ }
357+
343358 @ Override
344359 public String toString () {
345360 final String prefix = switch (sender ()) {
@@ -375,18 +390,24 @@ enum Token {
375390 ES ,
376391 SE ,
377392 SS ,
378- PSK ;
393+ PSK ,
394+ E1 ,
395+ EKEM1 ;
379396
380397 static Token fromString (final String string ) {
381- return switch (string ) {
382- case "e" , "E" -> E ;
383- case "s" , "S" -> S ;
384- case "ee" , "EE" -> EE ;
385- case "es" , "ES" -> ES ;
386- case "se" , "SE" -> SE ;
387- case "ss" , "SS" -> SS ;
388- case "psk" , "PSK" -> PSK ;
389- default -> throw new IllegalArgumentException ("Unrecognized token: " + string );
398+ for (final Token token : Token .values ()) {
399+ if (token .name ().equalsIgnoreCase (string )) {
400+ return token ;
401+ }
402+ }
403+
404+ throw new IllegalArgumentException ("Unrecognized token: " + string );
405+ }
406+
407+ boolean isKeyAgreementToken () {
408+ return switch (this ) {
409+ case EE , ES , SE , SS -> true ;
410+ default -> false ;
390411 };
391412 }
392413 }
@@ -482,6 +503,8 @@ HandshakePattern withModifier(final String modifier) {
482503 modifiedMessagePatterns = getPatternsWithFallbackModifier ();
483504 } else if (modifier .startsWith ("psk" )) {
484505 modifiedMessagePatterns = getPatternsWithPskModifier (modifier );
506+ } else if ("hfs" .equals (modifier )) {
507+ modifiedMessagePatterns = getPatternsWithHfsModifier ();
485508 } else {
486509 throw new IllegalArgumentException ("Unrecognized modifier: " + modifier );
487510 }
@@ -538,6 +561,74 @@ private MessagePattern[][] getPatternsWithPskModifier(final String modifier) {
538561 return new MessagePattern [][] { modifiedPreMessagePatterns , modifiedHandshakeMessagePatterns };
539562 }
540563
564+ private MessagePattern [][] getPatternsWithHfsModifier () {
565+ // Temporarily combine the pre-messages and "normal" messages to make iteration/state management easier
566+ final MessagePattern [] messagePatterns =
567+ new MessagePattern [getPreMessagePatterns ().length + getHandshakeMessagePatterns ().length ];
568+
569+ System .arraycopy (getPreMessagePatterns (), 0 , messagePatterns , 0 , getPreMessagePatterns ().length );
570+ System .arraycopy (getHandshakeMessagePatterns (), 0 , messagePatterns ,
571+ getPreMessagePatterns ().length , getHandshakeMessagePatterns ().length );
572+
573+ boolean insertedE1Token = false ;
574+ boolean insertedEkem1Token = false ;
575+
576+ for (int i = 0 ; i < messagePatterns .length ; i ++) {
577+ if (!insertedE1Token && Arrays .stream (messagePatterns [i ].tokens ()).anyMatch (token -> token == Token .E )) {
578+ // We haven't inserted an E1 token yet, and this message pattern needs one. Exactly where it should go depends
579+ // on whether this message pattern also contains a key agreement token, but either way, this pattern will wind
580+ // up one token longer than it was when it started.
581+ int insertionIndex = -1 ;
582+
583+ for (int t = 0 ; t < messagePatterns [i ].tokens ().length ; t ++) {
584+ final Token token = messagePatterns [i ].tokens ()[t ];
585+
586+ // TODO Prove that E must come before key agreement tokens
587+ if (token == Token .E || token .isKeyAgreementToken ()) {
588+ insertionIndex = t + 1 ;
589+
590+ if (token .isKeyAgreementToken ()) {
591+ break ;
592+ }
593+ }
594+ }
595+
596+ messagePatterns [i ] = messagePatterns [i ].withAddedToken (Token .E1 , insertionIndex );
597+ insertedE1Token = true ;
598+ }
599+
600+ if (!insertedEkem1Token && Arrays .stream (messagePatterns [i ].tokens ()).anyMatch (token -> token == Token .EE )) {
601+ // We haven't inserted an EKEM1 token yet, and this pattern needs one. EKEM1 tokens always go after the first
602+ // EE token.
603+ int insertionIndex = -1 ;
604+
605+ for (int t = 0 ; t < messagePatterns [i ].tokens ().length ; t ++) {
606+ if (messagePatterns [i ].tokens ()[t ] == Token .EE ) {
607+ insertionIndex = t + 1 ;
608+ break ;
609+ }
610+ }
611+
612+ messagePatterns [i ] = messagePatterns [i ].withAddedToken (Token .EKEM1 , insertionIndex );
613+ insertedEkem1Token = true ;
614+ }
615+
616+ if (insertedE1Token && insertedEkem1Token ) {
617+ // No need to inspect the rest of the message patterns if we've already inserted both of the HFS tokens
618+ break ;
619+ }
620+ }
621+
622+ final MessagePattern [] modifiedPreMessagePatterns = new MessagePattern [getPreMessagePatterns ().length ];
623+ final MessagePattern [] modifiedHandshakeMessagePatterns = new MessagePattern [getHandshakeMessagePatterns ().length ];
624+
625+ System .arraycopy (messagePatterns , 0 , modifiedPreMessagePatterns , 0 , getPreMessagePatterns ().length );
626+ System .arraycopy (messagePatterns , getPreMessagePatterns ().length ,
627+ modifiedHandshakeMessagePatterns , 0 , getHandshakeMessagePatterns ().length );
628+
629+ return new MessagePattern [][] { modifiedPreMessagePatterns , modifiedHandshakeMessagePatterns };
630+ }
631+
541632 private String getModifiedName (final String modifier ) {
542633 final String modifiedName ;
543634
0 commit comments