diff --git a/crates/squawk_ide/src/find_references.rs b/crates/squawk_ide/src/find_references.rs new file mode 100644 index 00000000..0f724603 --- /dev/null +++ b/crates/squawk_ide/src/find_references.rs @@ -0,0 +1,345 @@ +use crate::binder::{self, Binder}; +use crate::offsets::token_from_offset; +use crate::resolve; +use rowan::{TextRange, TextSize}; +use squawk_syntax::{ + SyntaxNodePtr, + ast::{self, AstNode}, + match_ast, +}; + +pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec { + let binder = binder::bind(file); + let Some(target) = find_target(file, offset, &binder) else { + return vec![]; + }; + + let mut refs = vec![]; + + for node in file.syntax().descendants() { + match_ast! { + match node { + ast::NameRef(name_ref) => { + if let Some(found) = resolve::resolve_name_ref(&binder, &name_ref) + && found == target + { + refs.push(name_ref.syntax().text_range()); + } + }, + ast::Name(name) => { + let found = SyntaxNodePtr::new(name.syntax()); + if found == target { + refs.push(name.syntax().text_range()); + } + }, + _ => (), + } + } + } + + refs.sort_by_key(|range| range.start()); + refs +} + +fn find_target(file: &ast::SourceFile, offset: TextSize, binder: &Binder) -> Option { + let token = token_from_offset(file, offset)?; + let parent = token.parent()?; + + if let Some(name) = ast::Name::cast(parent.clone()) { + return Some(SyntaxNodePtr::new(name.syntax())); + } + + if let Some(name_ref) = ast::NameRef::cast(parent.clone()) + && let Some(ptr) = resolve::resolve_name_ref(binder, &name_ref) + { + return Some(ptr); + } + + None +} + +#[cfg(test)] +mod test { + use crate::find_references::find_references; + use crate::test_utils::fixture; + use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle}; + use insta::assert_snapshot; + use squawk_syntax::ast; + + #[track_caller] + fn find_refs(sql: &str) -> String { + let (mut offset, sql) = fixture(sql); + offset = offset.checked_sub(1.into()).unwrap_or_default(); + let parse = ast::SourceFile::parse(&sql); + assert_eq!(parse.errors(), vec![]); + let file: ast::SourceFile = parse.tree(); + + let references = find_references(&file, offset); + + let offset_usize: usize = offset.into(); + + let labels: Vec = (1..=references.len()) + .map(|i| format!("{}. reference", i)) + .collect(); + + let mut snippet = Snippet::source(&sql).fold(true).annotation( + AnnotationKind::Context + .span(offset_usize..offset_usize + 1) + .label("0. query"), + ); + + for (i, range) in references.iter().enumerate() { + snippet = snippet.annotation( + AnnotationKind::Context + .span((*range).into()) + .label(&labels[i]), + ); + } + + let group = Level::INFO.primary_title("references").element(snippet); + let renderer = Renderer::plain().decor_style(DecorStyle::Unicode); + renderer + .render(&[group]) + .to_string() + .replace("info: references", "") + } + + #[test] + fn simple_table_reference() { + assert_snapshot!(find_refs(" +create table t(); +drop table t$0; +"), @r" + ╭▸ + 2 │ create table t(); + │ ─ 1. reference + 3 │ drop table t; + │ ┬ + │ │ + │ 0. query + ╰╴ 2. reference + "); + } + + #[test] + fn multiple_references() { + assert_snapshot!(find_refs(" +create table users(); +drop table users$0; +table users; +"), @r" + ╭▸ + 2 │ create table users(); + │ ───── 1. reference + 3 │ drop table users; + │ ┬───┬ + │ │ │ + │ │ 0. query + │ 2. reference + 4 │ table users; + ╰╴ ───── 3. reference + "); + } + + #[test] + fn find_from_definition() { + assert_snapshot!(find_refs(" +create table t$0(); +drop table t; +"), @r" + ╭▸ + 2 │ create table t(); + │ ┬ + │ │ + │ 0. query + │ 1. reference + 3 │ drop table t; + ╰╴ ─ 2. reference + "); + } + + #[test] + fn with_schema_qualified() { + assert_snapshot!(find_refs(" +create table public.users(); +drop table public.users$0; +table users; +"), @r" + ╭▸ + 2 │ create table public.users(); + │ ───── 1. reference + 3 │ drop table public.users; + │ ┬───┬ + │ │ │ + │ │ 0. query + │ 2. reference + 4 │ table users; + ╰╴ ───── 3. reference + "); + } + + #[test] + fn temp_table_shadows_public() { + assert_snapshot!(find_refs(" +create table t(); +create temp table t$0(); +drop table t; +"), @r" + ╭▸ + 3 │ create temp table t(); + │ ┬ + │ │ + │ 0. query + │ 1. reference + 4 │ drop table t; + ╰╴ ─ 2. reference + "); + } + + #[test] + fn different_schema_no_match() { + assert_snapshot!(find_refs(" +create table foo.t(); +create table bar.t$0(); +"), @r" + ╭▸ + 3 │ create table bar.t(); + │ ┬ + │ │ + │ 0. query + ╰╴ 1. reference + "); + } + + #[test] + fn with_search_path() { + assert_snapshot!(find_refs(" +set search_path to myschema; +create table myschema.users$0(); +drop table users; +"), @r" + ╭▸ + 3 │ create table myschema.users(); + │ ┬───┬ + │ │ │ + │ │ 0. query + │ 1. reference + 4 │ drop table users; + ╰╴ ───── 2. reference + "); + } + + #[test] + fn temp_table_with_pg_temp_schema() { + assert_snapshot!(find_refs(" +create temp table t(); +drop table pg_temp.t$0; +"), @r" + ╭▸ + 2 │ create temp table t(); + │ ─ 1. reference + 3 │ drop table pg_temp.t; + │ ┬ + │ │ + │ 0. query + ╰╴ 2. reference + "); + } + + #[test] + fn case_insensitive() { + assert_snapshot!(find_refs(" +create table Users(); +drop table USERS$0; +table users; +"), @r" + ╭▸ + 2 │ create table Users(); + │ ───── 1. reference + 3 │ drop table USERS; + │ ┬───┬ + │ │ │ + │ │ 0. query + │ 2. reference + 4 │ table users; + ╰╴ ───── 3. reference + "); + } + #[test] + fn case_insensitive_part_2() { + // we should see refs for `drop table` and `table` + assert_snapshot!(find_refs(r#" +create table actors(); +create table "Actors"(); +drop table ACTORS$0; +table actors; +"#), @r#" + ╭▸ + 2 │ create table actors(); + │ ────── 1. reference + 3 │ create table "Actors"(); + 4 │ drop table ACTORS; + │ ┬────┬ + │ │ │ + │ │ 0. query + │ 2. reference + 5 │ table actors; + ╰╴ ────── 3. reference + "#); + } + + #[test] + fn case_insensitive_with_schema() { + assert_snapshot!(find_refs(" +create table Public.Users(); +drop table PUBLIC.USERS$0; +table public.users; +"), @r" + ╭▸ + 2 │ create table Public.Users(); + │ ───── 1. reference + 3 │ drop table PUBLIC.USERS; + │ ┬───┬ + │ │ │ + │ │ 0. query + │ 2. reference + 4 │ table public.users; + ╰╴ ───── 3. reference + "); + } + + #[test] + fn no_partial_match() { + assert_snapshot!(find_refs(" +create table t$0(); +create table temp_t(); +"), @r" + ╭▸ + 2 │ create table t(); + │ ┬ + │ │ + │ 0. query + ╰╴ 1. reference + "); + } + + #[test] + fn identifier_boundaries() { + assert_snapshot!(find_refs(" +create table foo$0(); +drop table foo; +drop table foo1; +drop table barfoo; +drop table foo_bar; +"), @r" + ╭▸ + 2 │ create table foo(); + │ ┬─┬ + │ │ │ + │ │ 0. query + │ 1. reference + 3 │ drop table foo; + ╰╴ ─── 2. reference + "); + } +} diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 6b64eaaa..e2bc2682 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -47,6 +47,10 @@ pub fn goto_definition(file: ast::SourceFile, offset: TextSize) -> Option Result<()> { resolve_provider: None, })), selection_range_provider: Some(SelectionRangeProviderCapability::Simple(true)), + references_provider: Some(OneOf::Left(true)), definition_provider: Some(OneOf::Left(true)), ..Default::default() }) @@ -108,6 +110,9 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { "squawk/tokens" => { handle_tokens(&connection, req, &documents)?; } + References::METHOD => { + handle_references(&connection, req, &documents)?; + } _ => { info!("Ignoring unhandled request: {}", req.method); } @@ -236,6 +241,43 @@ fn handle_selection_range( Ok(()) } +fn handle_references( + connection: &Connection, + req: lsp_server::Request, + documents: &HashMap, +) -> Result<()> { + let params: ReferenceParams = serde_json::from_value(req.params)?; + let uri = params.text_document_position.text_document.uri; + let position = params.text_document_position.position; + + let content = documents.get(&uri).map_or("", |doc| &doc.content); + let parse: Parse = SourceFile::parse(content); + let file = parse.tree(); + let line_index = LineIndex::new(content); + let offset = lsp_utils::offset(&line_index, position).unwrap(); + + let ranges = find_references(&file, offset); + let include_declaration = params.context.include_declaration; + + let locations: Vec = ranges + .into_iter() + .filter(|range| include_declaration || !range.contains(offset)) + .map(|range| Location { + uri: uri.clone(), + range: lsp_utils::range(&line_index, range), + }) + .collect(); + + let resp = Response { + id: req.id, + result: Some(serde_json::to_value(&locations).unwrap()), + error: None, + }; + + connection.sender.send(Message::Response(resp))?; + Ok(()) +} + fn handle_code_action( connection: &Connection, req: lsp_server::Request,