@@ -47,12 +47,12 @@ private extension HTMLElement {
4747 }
4848 return ( string, false )
4949
50- if isRoot {
50+ /* if isRoot {
5151 return dynamicVariables.isEmpty ? (string, false) : ("DynamicString(string: \"" + string + "\").test", true)
5252 //return dynamicVariables.isEmpty ? (string, false) : ("DynamicString(string: \"" + string + "\", values: [" + dynamicVariables.map({ $0.contains("\\(") ? "\"\($0)\"" : $0 }).joined(separator: ",") + "]).test", true)
5353 } else {
5454 return (string, false)
55- }
55+ }*/
5656 }
5757 static func parse_arguments( context: some MacroExpansionContext , elementType: HTMLElementType , children: Slice < SyntaxChildren > , dynamicVariables: inout [ String ] ) -> ElementData {
5858 var attributes : [ String ] = [ ] , innerHTML : [ String ] = [ ]
@@ -65,7 +65,7 @@ private extension HTMLElement {
6565 if key == " acceptCharset " {
6666 key = " accept-charset "
6767 }
68- if var string: String = parse_attribute ( context: context, elementType: elementType, key: key, argument: child, dynamicVariables: & dynamicVariables) {
68+ if let string: String = parse_attribute ( context: context, elementType: elementType, key: key, argument: child, dynamicVariables: & dynamicVariables) {
6969 attributes. append ( key + ( string. isEmpty ? " " : " = \\ \" " + string + " \\ \" " ) )
7070 }
7171 }
@@ -86,7 +86,7 @@ private extension HTMLElement {
8686 case " custom " , " data " :
8787 var ( literalValue, returnType) : ( String , LiteralReturnType ) = parse_literal_value ( context: context, elementType: elementType, key: key, argument: function. arguments. last!, dynamicVariables: & dynamicVariables) !
8888 if returnType == . string {
89- literalValue. escapeHTML ( attribute : true )
89+ literalValue. escapeHTML ( escapeAttributes : true )
9090 }
9191 value = literalValue
9292 if key == " custom " {
@@ -97,7 +97,15 @@ private extension HTMLElement {
9797 break
9898 case " event " :
9999 key = " on " + key_element. memberAccess!. declName. baseName. text
100- value = function. arguments. last!. expression. stringLiteral!. string. escapingHTML ( attribute: true )
100+ if var ( literalValue, returnType) : ( String , LiteralReturnType ) = parse_literal_value ( context: context, elementType: elementType, key: key, argument: function. arguments. last!, dynamicVariables: & dynamicVariables) {
101+ if returnType == . string {
102+ literalValue. escapeHTML ( escapeAttributes: true )
103+ }
104+ value = literalValue
105+ } else {
106+ unallowed_expression ( context: context, node: function. arguments. last!)
107+ return [ ]
108+ }
101109 break
102110 default :
103111 if let string: String = parse_attribute ( context: context, elementType: elementType, key: key, argument: key_argument, dynamicVariables: & dynamicVariables) {
@@ -122,24 +130,27 @@ private extension HTMLElement {
122130 if let macro: MacroExpansionExprSyntax = child. expression. macroExpansion {
123131 var string : String = parse_macro ( context: context, expression: macro, isRoot: false , dynamicVariables: & dynamicVariables) . 0
124132 if elementType == . escapeHTML {
125- string. escapeHTML ( attribute : false )
133+ string. escapeHTML ( escapeAttributes : false )
126134 }
127135 return string
128- } else if var ( string, returnType ) = parse_literal_value ( context: context, elementType: elementType, key: " " , argument: child, dynamicVariables: & dynamicVariables) {
129- string. escapeHTML ( attribute : false )
136+ } else if var string: String = parse_literal_value ( context: context, elementType: elementType, key: " " , argument: child, dynamicVariables: & dynamicVariables) ? . value {
137+ string. escapeHTML ( escapeAttributes : false )
130138 return string
131139 } else {
132- context. diagnose ( Diagnostic ( node: child, message: DiagnosticMsg ( id: " unallowedExpression " , message: " Expression not allowed. String interpolation is required when encoding runtime values. " ) , fixIts: [
133- FixIt ( message: DiagnosticMsg ( id: " useStringInterpolation " , message: " Use String Interpolation. " , severity: . error) , changes: [
134- FixIt . Change. replace (
135- oldNode: Syntax ( child) ,
136- newNode: Syntax ( StringLiteralExprSyntax ( content: " \\ ( \( child) ) " ) )
137- )
138- ] )
139- ] ) )
140+ unallowed_expression ( context: context, node: child)
140141 return nil
141142 }
142143 }
144+ static func unallowed_expression( context: some MacroExpansionContext , node: LabeledExprSyntax ) {
145+ context. diagnose ( Diagnostic ( node: node, message: DiagnosticMsg ( id: " unallowedExpression " , message: " String Interpolation is required when encoding runtime values. " ) , fixIts: [
146+ FixIt ( message: DiagnosticMsg ( id: " useStringInterpolation " , message: " Use String Interpolation. " , severity: . error) , changes: [
147+ FixIt . Change. replace (
148+ oldNode: Syntax ( node) ,
149+ newNode: Syntax ( StringLiteralExprSyntax ( content: " \\ ( \( node) ) " ) )
150+ )
151+ ] )
152+ ] ) )
153+ }
143154
144155 struct ElementData {
145156 let attributes : String , innerHTML : String
@@ -166,7 +177,7 @@ private extension HTMLElement {
166177 switch returnType {
167178 case . boolean: return string. elementsEqual ( " true " ) ? " " : nil
168179 case . string:
169- string. escapeHTML ( attribute : true )
180+ string. escapeHTML ( escapeAttributes : true )
170181 return string
171182 case . interpolation: return string
172183 }
@@ -213,7 +224,7 @@ private extension HTMLElement {
213224 if function. calledExpression. as ( DeclReferenceExprSyntax . self) ? . baseName. text == " StaticString " {
214225 return ( function. arguments. first!. expression. stringLiteral!. string, . string)
215226 }
216- return ( " \( function) " , . interpolation)
227+ return ( " \\ ( \( function) ) " , . interpolation)
217228 }
218229 }
219230 if let member: MemberAccessExprSyntax = expression. memberAccess {
@@ -227,7 +238,7 @@ private extension HTMLElement {
227238 return (integer, .interpolation)
228239 }
229240 } else {*/
230- return ( " \( member) " , . interpolation)
241+ return ( " \\ ( \( member) ) " , . interpolation)
231242 //}
232243 } else {
233244 return ( HTMLElementAttribute . Extra. htmlValue ( enumName: enumName ( elementType: elementType, key: key) , for: decl) , . string)
@@ -263,22 +274,56 @@ private extension HTMLElement {
263274 //context.diagnose(Diagnostic(node: expression, message: DiagnosticMsg(id: "somethingWentWrong", message: "Something went wrong. (" + expression.debugDescription + ")", severity: .warning)))
264275 return nil
265276 }
266- if returnType == . interpolation || string. contains ( " \\ ( " ) {
267- context. diagnose ( Diagnostic ( node: expression, message: DiagnosticMsg ( id: " unsafeInterpolation " , message: " Interpolation may introduce raw HTML elements. " , severity: . warning) ) )
277+ var remaining_interpolation : Int = 0
278+ if let list: StringLiteralSegmentListSyntax = expression. stringLiteral? . segments {
279+ for segment in list {
280+ if let expr: ExpressionSegmentSyntax = segment. as ( ExpressionSegmentSyntax . self) {
281+ remaining_interpolation += 1
282+ if flatten_interpolation ( string: & string, remaining_interpolation: & remaining_interpolation, expr: expr) {
283+ remaining_interpolation -= 1
284+ }
285+ }
286+ }
287+ }
288+ if returnType == . interpolation || remaining_interpolation > 0 {
268289 //dynamicVariables.append(string)
269- string = string. contains ( " \\ ( " ) ? string : " \\ ( " + string + " ) "
290+ if !string. contains ( " \\ ( " ) {
291+ string = " \\ ( " + string + " ) "
292+ }
270293 returnType = . interpolation
294+ context. diagnose ( Diagnostic ( node: expression, message: DiagnosticMsg ( id: " unsafeInterpolation " , message: " Interpolation may introduce raw HTML. " , severity: . warning) ) )
271295 }
272296 return ( string, returnType)
273297 }
298+ static func flatten_interpolation( string: inout String , remaining_interpolation: inout Int , expr: ExpressionSegmentSyntax ) -> Bool { // TODO: can still be improved ("\(description \(title))" doesn't get flattened)
299+ let expression : ExprSyntax = expr. expressions. first!. expression
300+ if let list: StringLiteralSegmentListSyntax = expression. stringLiteral? . segments {
301+ for segment in list {
302+ if let expr: ExpressionSegmentSyntax = segment. as ( ExpressionSegmentSyntax . self) {
303+ remaining_interpolation += 1
304+ if flatten_interpolation ( string: & string, remaining_interpolation: & remaining_interpolation, expr: expr) {
305+ remaining_interpolation -= 1
306+ }
307+ } else if let fix: String = segment. as ( StringSegmentSyntax . self) ? . content. text {
308+ string. replace ( " \( expr) " , with: fix)
309+ remaining_interpolation -= 1
310+ }
311+ }
312+ }
313+ if let fix: String = expression. integerLiteral? . literal. text ?? expression. floatLiteral? . literal. text {
314+ string. replace ( " \( expr) " , with: fix)
315+ return true
316+ }
317+ return false
318+ }
274319}
275320
276321enum LiteralReturnType {
277322 case boolean, string, interpolation
278323}
279324
280325// MARK: HTMLElementType
281- enum HTMLElementType : String {
326+ enum HTMLElementType : String , CaseIterable {
282327 case escapeHTML
283328 case html
284329 case custom
@@ -424,6 +469,7 @@ enum HTMLElementType : String {
424469 }
425470}
426471
472+ // MARK: Misc
427473extension ExprSyntax {
428474 var booleanLiteral : BooleanLiteralExprSyntax ? { self . as ( BooleanLiteralExprSyntax . self) }
429475 var stringLiteral : StringLiteralExprSyntax ? { self . as ( StringLiteralExprSyntax . self) }
0 commit comments