11import type { Position , TextDocument } from "@cursorless/common" ;
2- import { showError , type TreeSitter } from "@cursorless/common" ;
3- import { groupBy , uniq } from "lodash-es" ;
4- import type { Point , Query } from "web-tree-sitter" ;
2+ import { type TreeSitter } from "@cursorless/common" ;
3+ import type * as treeSitter from "web-tree-sitter" ;
54import { ide } from "../../singletons/ide.singleton" ;
65import { getNodeRange } from "../../util/nodeSelectors" ;
76import type {
7+ MutableQueryCapture ,
88 MutableQueryMatch ,
9- QueryCapture ,
109 QueryMatch ,
1110} from "./QueryCapture" ;
1211import { checkCaptureStartEnd } from "./checkCaptureStartEnd" ;
1312import { isContainedInErrorNode } from "./isContainedInErrorNode" ;
14- import { parsePredicates } from "./parsePredicates" ;
15- import { predicateToString } from "./predicateToString" ;
16- import { rewriteStartOfEndOf } from "./rewriteStartOfEndOf" ;
13+ import { normalizeCaptureName } from "./normalizeCaptureName" ;
14+ import { parsePredicatesWithErrorHandling } from "./parsePredicatesWithErrorHandling" ;
15+ import { positionToPoint } from "./positionToPoint" ;
16+ import {
17+ getStartOfEndOfRange ,
18+ rewriteStartOfEndOf ,
19+ } from "./rewriteStartOfEndOf" ;
1720import { treeSitterQueryCache } from "./treeSitterQueryCache" ;
1821
1922/**
2023 * Wrapper around a tree-sitter query that provides a more convenient API, and
2124 * defines our own custom predicate operators
2225 */
2326export class TreeSitterQuery {
27+ private shouldCheckCaptures : boolean ;
28+
2429 private constructor (
2530 private treeSitter : TreeSitter ,
2631
2732 /**
2833 * The raw tree-sitter query as parsed by tree-sitter from the query file
2934 */
30- private query : Query ,
35+ private query : treeSitter . Query ,
3136
3237 /**
3338 * The predicates for each pattern in the query. Each element of the outer
3439 * array corresponds to a pattern, and each element of the inner array
3540 * corresponds to a predicate for that pattern.
3641 */
3742 private patternPredicates : ( ( match : MutableQueryMatch ) => boolean ) [ ] [ ] ,
38- ) { }
39-
40- static create ( languageId : string , treeSitter : TreeSitter , query : Query ) {
41- const { errors, predicates } = parsePredicates ( query . predicates ) ;
42-
43- if ( errors . length > 0 ) {
44- for ( const error of errors ) {
45- const context = [
46- `language ${ languageId } ` ,
47- `pattern ${ error . patternIdx } ` ,
48- `predicate \`${ predicateToString (
49- query . predicates [ error . patternIdx ] [ error . predicateIdx ] ,
50- ) } \``,
51- ] . join ( ", " ) ;
52-
53- void showError (
54- ide ( ) . messages ,
55- "TreeSitterQuery.parsePredicates" ,
56- `Error parsing predicate for ${ context } : ${ error . error } ` ,
57- ) ;
58- }
43+ ) {
44+ this . shouldCheckCaptures = ide ( ) . runMode !== "production" ;
45+ }
5946
60- // We show errors to the user, but we don't want to crash the extension
61- // unless we're in test mode
62- if ( ide ( ) . runMode === "test" ) {
63- throw new Error ( "Invalid predicates" ) ;
64- }
65- }
47+ static create (
48+ languageId : string ,
49+ treeSitter : TreeSitter ,
50+ query : treeSitter . Query ,
51+ ) {
52+ const predicates = parsePredicatesWithErrorHandling ( languageId , query ) ;
6653
6754 return new TreeSitterQuery ( treeSitter , query , predicates ) ;
6855 }
6956
57+ hasCapture ( name : string ) : boolean {
58+ return this . query . captureNames . some (
59+ ( n ) => normalizeCaptureName ( n ) === name ,
60+ ) ;
61+ }
62+
7063 matches (
7164 document : TextDocument ,
7265 start ?: Position ,
@@ -84,74 +77,114 @@ export class TreeSitterQuery {
8477 start ?: Position ,
8578 end ?: Position ,
8679 ) : QueryMatch [ ] {
87- return this . query
88- . matches ( this . treeSitter . getTree ( document ) . rootNode , {
89- startPosition : start == null ? undefined : positionToPoint ( start ) ,
90- endPosition : end == null ? undefined : positionToPoint ( end ) ,
91- } )
92- . map (
93- ( { pattern, captures } ) : MutableQueryMatch => ( {
94- patternIdx : pattern ,
95- captures : captures . map ( ( { name, node } ) => ( {
96- name,
97- node,
98- document,
99- range : getNodeRange ( node ) ,
100- insertionDelimiter : undefined ,
101- allowMultiple : false ,
102- hasError : ( ) => isContainedInErrorNode ( node ) ,
103- } ) ) ,
104- } ) ,
105- )
106- . filter ( ( match ) =>
107- this . patternPredicates [ match . patternIdx ] . every ( ( predicate ) =>
108- predicate ( match ) ,
109- ) ,
110- )
111- . map ( ( match ) : QueryMatch => {
112- // Merge the ranges of all captures with the same name into a single
113- // range and return one capture with that name. We consider captures
114- // with names `@foo`, `@foo.start`, and `@foo.end` to have the same
115- // name, for which we'd return a capture with name `foo`.
116- const captures : QueryCapture [ ] = Object . entries (
117- groupBy ( match . captures , ( { name } ) => normalizeCaptureName ( name ) ) ,
118- ) . map ( ( [ name , captures ] ) => {
119- captures = rewriteStartOfEndOf ( captures ) ;
120- const capturesAreValid = checkCaptureStartEnd (
121- captures ,
122- ide ( ) . messages ,
123- ) ;
124-
125- if ( ! capturesAreValid && ide ( ) . runMode === "test" ) {
126- throw new Error ( "Invalid captures" ) ;
127- }
128-
129- return {
130- name,
131- range : captures
132- . map ( ( { range } ) => range )
133- . reduce ( ( accumulator , range ) => range . union ( accumulator ) ) ,
134- allowMultiple : captures . some ( ( capture ) => capture . allowMultiple ) ,
135- insertionDelimiter : captures . find (
136- ( capture ) => capture . insertionDelimiter != null ,
137- ) ?. insertionDelimiter ,
138- hasError : ( ) => captures . some ( ( capture ) => capture . hasError ( ) ) ,
139- } ;
140- } ) ;
141-
142- return { ...match , captures } ;
143- } ) ;
80+ const matches = this . getTreeMatches ( document , start , end ) ;
81+ const results : QueryMatch [ ] = [ ] ;
82+
83+ for ( const match of matches ) {
84+ const mutableMatch = this . createMutableQueryMatch ( document , match ) ;
85+
86+ if ( ! this . runPredicates ( mutableMatch ) ) {
87+ continue ;
88+ }
89+
90+ results . push ( this . createQueryMatch ( mutableMatch ) ) ;
91+ }
92+
93+ return results ;
14494 }
14595
146- get captureNames ( ) {
147- return uniq ( this . query . captureNames . map ( normalizeCaptureName ) ) ;
96+ private getTreeMatches (
97+ document : TextDocument ,
98+ start ?: Position ,
99+ end ?: Position ,
100+ ) {
101+ const { rootNode } = this . treeSitter . getTree ( document ) ;
102+ return this . query . matches ( rootNode , {
103+ startPosition : start != null ? positionToPoint ( start ) : undefined ,
104+ endPosition : end != null ? positionToPoint ( end ) : undefined ,
105+ } ) ;
148106 }
149- }
150107
151- function normalizeCaptureName ( name : string ) : string {
152- return name . replace ( / ( \. ( s t a r t | e n d ) ) ? ( \. ( s t a r t O f | e n d O f ) ) ? $ / , "" ) ;
153- }
108+ private createMutableQueryMatch (
109+ document : TextDocument ,
110+ match : treeSitter . QueryMatch ,
111+ ) : MutableQueryMatch {
112+ return {
113+ patternIdx : match . pattern ,
114+ captures : match . captures . map ( ( { name, node } ) => ( {
115+ name,
116+ node,
117+ document,
118+ range : getNodeRange ( node ) ,
119+ insertionDelimiter : undefined ,
120+ allowMultiple : false ,
121+ hasError : ( ) => isContainedInErrorNode ( node ) ,
122+ } ) ) ,
123+ } ;
124+ }
154125
155- function positionToPoint ( start : Position ) : Point {
156- return { row : start . line , column : start . character } ;
126+ private runPredicates ( match : MutableQueryMatch ) : boolean {
127+ for ( const predicate of this . patternPredicates [ match . patternIdx ] ) {
128+ if ( ! predicate ( match ) ) {
129+ return false ;
130+ }
131+ }
132+ return true ;
133+ }
134+
135+ private createQueryMatch ( match : MutableQueryMatch ) : QueryMatch {
136+ const result : MutableQueryCapture [ ] = [ ] ;
137+ const map = new Map <
138+ string ,
139+ { acc : MutableQueryCapture ; captures : MutableQueryCapture [ ] }
140+ > ( ) ;
141+
142+ // Merge the ranges of all captures with the same name into a single
143+ // range and return one capture with that name. We consider captures
144+ // with names `@foo`, `@foo.start`, and `@foo.end` to have the same
145+ // name, for which we'd return a capture with name `foo`.
146+
147+ for ( const capture of match . captures ) {
148+ const name = normalizeCaptureName ( capture . name ) ;
149+ const range = getStartOfEndOfRange ( capture ) ;
150+ const existing = map . get ( name ) ;
151+
152+ if ( existing == null ) {
153+ const captures = [ capture ] ;
154+ const acc = {
155+ ...capture ,
156+ name,
157+ range,
158+ hasError : ( ) => captures . some ( ( c ) => c . hasError ( ) ) ,
159+ } ;
160+ result . push ( acc ) ;
161+ map . set ( name , { acc, captures } ) ;
162+ } else {
163+ existing . acc . range = existing . acc . range . union ( range ) ;
164+ existing . acc . allowMultiple =
165+ existing . acc . allowMultiple || capture . allowMultiple ;
166+ existing . acc . insertionDelimiter =
167+ existing . acc . insertionDelimiter ?? capture . insertionDelimiter ;
168+ existing . captures . push ( capture ) ;
169+ }
170+ }
171+
172+ if ( this . shouldCheckCaptures ) {
173+ this . checkCaptures ( Array . from ( map . values ( ) ) ) ;
174+ }
175+
176+ return { captures : result } ;
177+ }
178+
179+ private checkCaptures ( matches : { captures : MutableQueryCapture [ ] } [ ] ) {
180+ for ( const match of matches ) {
181+ const capturesAreValid = checkCaptureStartEnd (
182+ rewriteStartOfEndOf ( match . captures ) ,
183+ ide ( ) . messages ,
184+ ) ;
185+ if ( ! capturesAreValid && ide ( ) . runMode === "test" ) {
186+ throw new Error ( "Invalid captures" ) ;
187+ }
188+ }
189+ }
157190}
0 commit comments