11use ide_db:: defs:: { Definition , NameRefClass } ;
22use syntax:: {
3- ast:: { self , HasName } ,
3+ ast:: { self , HasName , Name } ,
44 ted, AstNode , SyntaxNode ,
55} ;
66
@@ -48,15 +48,15 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
4848 other => format ! ( "{{ {other} }}" ) ,
4949 } ;
5050 let extracting_arm_pat = extracting_arm. pat ( ) ?;
51- let extracted_variable = find_extracted_variable ( ctx, & extracting_arm) ?;
51+ let extracted_variable_positions = find_extracted_variable ( ctx, & extracting_arm) ?;
5252
5353 acc. add (
5454 AssistId ( "convert_match_to_let_else" , AssistKind :: RefactorRewrite ) ,
5555 "Convert match to let-else" ,
5656 let_stmt. syntax ( ) . text_range ( ) ,
5757 |builder| {
5858 let extracting_arm_pat =
59- rename_variable ( & extracting_arm_pat, extracted_variable , binding) ;
59+ rename_variable ( & extracting_arm_pat, & extracted_variable_positions , binding) ;
6060 builder. replace (
6161 let_stmt. syntax ( ) . text_range ( ) ,
6262 format ! ( "let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};" ) ,
@@ -95,14 +95,15 @@ fn find_arms(
9595}
9696
9797// Given an extracting arm, find the extracted variable.
98- fn find_extracted_variable ( ctx : & AssistContext < ' _ > , arm : & ast:: MatchArm ) -> Option < ast :: Name > {
98+ fn find_extracted_variable ( ctx : & AssistContext < ' _ > , arm : & ast:: MatchArm ) -> Option < Vec < Name > > {
9999 match arm. expr ( ) ? {
100100 ast:: Expr :: PathExpr ( path) => {
101101 let name_ref = path. syntax ( ) . descendants ( ) . find_map ( ast:: NameRef :: cast) ?;
102102 match NameRefClass :: classify ( & ctx. sema , & name_ref) ? {
103103 NameRefClass :: Definition ( Definition :: Local ( local) ) => {
104- let source = local. primary_source ( ctx. db ( ) ) . into_ident_pat ( ) ?;
105- Some ( source. name ( ) ?)
104+ let source =
105+ local. sources ( ctx. db ( ) ) . into_iter ( ) . map ( |x| x. into_ident_pat ( ) ?. name ( ) ) ;
106+ source. collect ( )
106107 }
107108 _ => None ,
108109 }
@@ -115,27 +116,34 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti
115116}
116117
117118// Rename `extracted` with `binding` in `pat`.
118- fn rename_variable ( pat : & ast:: Pat , extracted : ast :: Name , binding : ast:: Pat ) -> SyntaxNode {
119+ fn rename_variable ( pat : & ast:: Pat , extracted : & [ Name ] , binding : ast:: Pat ) -> SyntaxNode {
119120 let syntax = pat. syntax ( ) . clone_for_update ( ) ;
120- let extracted_syntax = syntax. covering_element ( extracted. syntax ( ) . text_range ( ) ) ;
121-
122- // If `extracted` variable is a record field, we should rename it to `binding`,
123- // otherwise we just need to replace `extracted` with `binding`.
124-
125- if let Some ( record_pat_field) = extracted_syntax. ancestors ( ) . find_map ( ast:: RecordPatField :: cast)
126- {
127- if let Some ( name_ref) = record_pat_field. field_name ( ) {
128- ted:: replace (
129- record_pat_field. syntax ( ) ,
130- ast:: make:: record_pat_field ( ast:: make:: name_ref ( & name_ref. text ( ) ) , binding)
121+ let extracted = extracted
122+ . iter ( )
123+ . map ( |e| syntax. covering_element ( e. syntax ( ) . text_range ( ) ) )
124+ . collect :: < Vec < _ > > ( ) ;
125+ for extracted_syntax in extracted {
126+ // If `extracted` variable is a record field, we should rename it to `binding`,
127+ // otherwise we just need to replace `extracted` with `binding`.
128+
129+ if let Some ( record_pat_field) =
130+ extracted_syntax. ancestors ( ) . find_map ( ast:: RecordPatField :: cast)
131+ {
132+ if let Some ( name_ref) = record_pat_field. field_name ( ) {
133+ ted:: replace (
134+ record_pat_field. syntax ( ) ,
135+ ast:: make:: record_pat_field (
136+ ast:: make:: name_ref ( & name_ref. text ( ) ) ,
137+ binding. clone ( ) ,
138+ )
131139 . syntax ( )
132140 . clone_for_update ( ) ,
133- ) ;
141+ ) ;
142+ }
143+ } else {
144+ ted:: replace ( extracted_syntax, binding. clone ( ) . syntax ( ) . clone_for_update ( ) ) ;
134145 }
135- } else {
136- ted:: replace ( extracted_syntax, binding. syntax ( ) . clone_for_update ( ) ) ;
137146 }
138-
139147 syntax
140148}
141149
@@ -162,6 +170,39 @@ fn foo(opt: Option<()>) {
162170 ) ;
163171 }
164172
173+ #[ test]
174+ fn or_pattern_multiple_binding ( ) {
175+ check_assist (
176+ convert_match_to_let_else,
177+ r#"
178+ //- minicore: option
179+ enum Foo {
180+ A(u32),
181+ B(u32),
182+ C(String),
183+ }
184+
185+ fn foo(opt: Option<Foo>) -> Result<u32, ()> {
186+ let va$0lue = match opt {
187+ Some(Foo::A(it) | Foo::B(it)) => it,
188+ _ => return Err(()),
189+ };
190+ }
191+ "# ,
192+ r#"
193+ enum Foo {
194+ A(u32),
195+ B(u32),
196+ C(String),
197+ }
198+
199+ fn foo(opt: Option<Foo>) -> Result<u32, ()> {
200+ let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
201+ }
202+ "# ,
203+ ) ;
204+ }
205+
165206 #[ test]
166207 fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr ( ) {
167208 cov_mark:: check_count!( extracting_arm_is_not_an_identity_expr, 2 ) ;
0 commit comments