From c92501ce1c49bf9aacfc0a3644cbccd03dd9ebfa Mon Sep 17 00:00:00 2001 From: Chen Shuaimin Date: Sat, 14 Oct 2023 16:20:51 +0800 Subject: [PATCH] Search across multiple files with ripgrep --- .github/workflows/ci.yml | 31 ++- .neoconf.json | 20 ++ README.md | 4 +- lua/ssr.lua | 471 ------------------------------------- lua/ssr/config.lua | 24 ++ lua/ssr/file.lua | 147 ++++++++++++ lua/ssr/init.lua | 15 ++ lua/ssr/parse.lua | 79 ------- lua/ssr/range.lua | 50 ++++ lua/ssr/replace.lua | 59 +++++ lua/ssr/search.lua | 267 +++++++++++---------- lua/ssr/ui/confirm_win.lua | 139 +++++++++++ lua/ssr/ui/init.lua | 130 ++++++++++ lua/ssr/ui/main_win.lua | 241 +++++++++++++++++++ lua/ssr/ui/result_list.lua | 205 ++++++++++++++++ lua/ssr/utils.lua | 113 ++++++--- tests/ssr_spec.lua | 288 +++++++++++++---------- 17 files changed, 1432 insertions(+), 851 deletions(-) create mode 100644 .neoconf.json delete mode 100644 lua/ssr.lua create mode 100644 lua/ssr/config.lua create mode 100644 lua/ssr/file.lua create mode 100644 lua/ssr/init.lua delete mode 100644 lua/ssr/parse.lua create mode 100644 lua/ssr/range.lua create mode 100644 lua/ssr/replace.lua create mode 100644 lua/ssr/ui/confirm_win.lua create mode 100644 lua/ssr/ui/init.lua create mode 100644 lua/ssr/ui/main_win.lua create mode 100644 lua/ssr/ui/result_list.lua diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a0df4a..6781c41 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,34 +9,41 @@ jobs: strategy: matrix: - nvim: [v0.9.1, nightly] + nvim: [nightly] + + env: + RIPGREP_VERSION: "14.1.0" + VIM: ~/.local/share/nvim/share/nvim/runtime steps: - uses: actions/checkout@v3 - - name: Set Envs + - name: Add PATH run: | - echo "VIM=~/.local/share/nvim/share/nvim/runtime" >> $GITHUB_ENV - echo "PATH=~/.local/share/nvim/bin:$PATH" >> $GITHUB_ENV + echo "$HOME/.local/share/nvim/bin" >> $GITHUB_PATH + echo "$HOME/.local/share/ripgrep" >> $GITHUB_PATH - name: Cache Dependencies id: cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: - path: ~/.local/share/nvim - key: ${{ runner.os }}-nvim-${{ matrix.nvim }} + key: ${{ runner.os }}-nvim-${{ matrix.nvim }}-rg-${{ env.RIPGREP_VERSION }} + path: | + ~/.local/share/nvim + ~/.local/share/ripgrep - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' run: | - mkdir -p ~/.local/share/nvim/ + mkdir -p ~/.local/share/{nvim,ripgrep} + curl -sL "https://github.com/BurntSushi/ripgrep/releases/download/$RIPGREP_VERSION/ripgrep-$RIPGREP_VERSION-x86_64-unknown-linux-musl.tar.gz" | tar xzf - --strip-components=1 -C ~/.local/share/ripgrep curl -sL "https://github.com/neovim/neovim/releases/download/${{ matrix.nvim }}/nvim-linux64.tar.gz" | tar xzf - --strip-components=1 -C ~/.local/share/nvim/ git clone --depth 1 https://github.com/nvim-treesitter/nvim-treesitter.git ~/.local/share/nvim/site/pack/vendor/start/nvim-treesitter git clone --depth 1 https://github.com/nvim-lua/plenary.nvim ~/.local/share/nvim/site/pack/vendor/start/plenary.nvim - ln -s $(pwd) ~/.local/share/nvim/site/pack/vendor/start - nvim --headless -c 'TSInstallSync python javascript lua rust go' -c 'q' + ln -s $PWD ~/.local/share/nvim/site/pack/vendor/start + nvim --headless '+TSInstallSync python javascript html lua rust go' +q - name: Run tests run: | - nvim --version - nvim --headless -c 'PlenaryBustedDirectory tests/' + nvim --version | head -1 && rg --version | head -1 + nvim --headless '+PlenaryBustedDirectory tests/' diff --git a/.neoconf.json b/.neoconf.json new file mode 100644 index 0000000..39b80bd --- /dev/null +++ b/.neoconf.json @@ -0,0 +1,20 @@ +{ + "neodev": { + "library": { + "enabled": true, + "plugins": ["plenary.nvim"] + } + }, + "neoconf": { + "plugins": { + "lua_ls": { + "enabled": true + } + } + }, + "lspconfig": { + "lua_ls": { + "Lua.format.enable": false + } + } +} diff --git a/README.md b/README.md index ab12b59..7f355b9 100644 --- a/README.md +++ b/README.md @@ -45,8 +45,8 @@ First put your cursor on the structure you want to search and replace (if you are not sure, select a region instead), then open SSR by pressing `sr`. In the SSR float window you can see the placeholder search code, you can -replace part of it with wildcards. A wildcard is an identifier starts with `$`, -like `$name`. A `$name` wildcard in the search pattern will match any AST node +replace part of it with captures. A capture is an identifier starts with `$`, +like `$name`. A `$name` capture in the search pattern will match any AST node and `$name` will reference it in the replacement. Press `` to replace all matches in current buffer, or `` to diff --git a/lua/ssr.lua b/lua/ssr.lua deleted file mode 100644 index a32cfe4..0000000 --- a/lua/ssr.lua +++ /dev/null @@ -1,471 +0,0 @@ -local api = vim.api -local ts = vim.treesitter -local fn = vim.fn -local keymap = vim.keymap -local highlight = vim.highlight -local ParseContext = require("ssr.parse").ParseContext -local search = require("ssr.search").search -local replace = require("ssr.search").replace -local u = require "ssr.utils" - -local M = {} - ----@class Config -local config = { - border = "rounded", - min_width = 50, - min_height = 5, - max_width = 120, - max_height = 25, - adjust_window = true, - keymaps = { - close = "q", - next_match = "n", - prev_match = "N", - replace_confirm = "", - replace_all = "", - }, -} - --- Set config options. ----@param cfg Config? -function M.setup(cfg) - if cfg then - config = vim.tbl_deep_extend("force", config, cfg) - end -end - ----@type table -local win_uis = {} - ----@class Ui ----@field ns number ----@field cur_search_ns number ----@field augroup number ----@field ui_buf buffer ----@field extmarks {status: number, search: number, replace: number} ----@field origin_win window ----@field lang string ----@field parse_context ParseContext ----@field buf_matches table -local Ui = {} - ----@return Ui? -function Ui.new() - local self = setmetatable({}, { __index = Ui }) - - self.origin_win = api.nvim_get_current_win() - local origin_buf = api.nvim_win_get_buf(self.origin_win) - local lang = ts.language.get_lang(vim.bo[origin_buf].filetype) - if not lang then - return u.notify("Treesitter language not found") - end - self.lang = lang - - local origin_node = u.node_for_range(origin_buf, self.lang, u.get_selection(self.origin_win)) - if not origin_node then - return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang) - end - if origin_node:has_error() then - return u.notify "You have syntax errors in the selected node" - end - local parse_context = ParseContext.new(origin_buf, origin_node) - if not parse_context then - return u.notify "Can't find a proper context to parse the pattern" - end - self.parse_context = parse_context - - self.buf_matches = {} - self.ns = api.nvim_create_namespace("ssr_" .. self.origin_win) -- TODO - self.cur_search_ns = api.nvim_create_namespace("ssr_cur_match_" .. self.origin_win) - self.augroup = api.nvim_create_augroup("ssr_augroup_" .. self.origin_win, {}) - - -- Init ui buffer - self.ui_buf = api.nvim_create_buf(false, true) - vim.bo[self.ui_buf].filetype = "ssr" - - local placeholder = ts.get_node_text(origin_node, origin_buf) - placeholder = "\n\n" .. placeholder .. "\n\n" - placeholder = vim.split(placeholder, "\n") - u.remove_indent(placeholder, u.get_indent(origin_buf, origin_node:start())) - api.nvim_buf_set_lines(self.ui_buf, 0, -1, true, placeholder) - -- Enable syntax highlights - ts.start(self.ui_buf, self.lang) - - local function virt_text(row, text) - return api.nvim_buf_set_extmark(self.ui_buf, self.ns, row, 0, { virt_text = text, virt_text_pos = "overlay" }) - end - self.extmarks = { - status = virt_text(0, { { "[SSR]", "Comment" }, { " (Press ? for help)", "Comment" } }), - search = virt_text(1, { { "SEARCH:", "String" } }), - replace = virt_text(#placeholder - 2, { { "REPLACE:", "String" } }), - } - - local function map(key, func) - keymap.set("n", key, function() - func(self) - end, { buffer = self.ui_buf, nowait = true }) - end - map(config.keymaps.replace_confirm, self.replace_confirm) - map(config.keymaps.replace_all, self.replace_all) - map(config.keymaps.next_match, function() - self:goto_match(self:next_match_idx()) - end) - map(config.keymaps.prev_match, function() - self:goto_match(self:prev_match_idx()) - end) - - -- Open float window - local width, height = u.get_win_size(placeholder, config) - local ui_win = api.nvim_open_win(self.ui_buf, true, { - relative = "win", - anchor = "NE", - row = 1, - col = api.nvim_win_get_width(0) - 1, - style = "minimal", - border = config.border, - width = width, - height = height, - }) - u.set_cursor(ui_win, 2, 0) - fn.matchadd("Title", [[$\w\+]]) - - map(config.keymaps.close, function() - api.nvim_win_close(ui_win, false) - end) - - api.nvim_create_autocmd({ "TextChanged", "TextChangedI" }, { - group = self.augroup, - buffer = self.ui_buf, - callback = function() - if config.adjust_window then - local lines = api.nvim_buf_get_lines(self.ui_buf, 0, -1, true) - local width, height = u.get_win_size(lines, config) - if api.nvim_win_get_width(ui_win) ~= width then - api.nvim_win_set_width(ui_win, width) - end - if api.nvim_win_get_height(ui_win) ~= height then - api.nvim_win_set_height(ui_win, height) - end - end - self:search() - end, - }) - - -- SSR window is bound to the original window (not buffer!), which is the same behavior as IDEs and browsers. - api.nvim_create_autocmd("BufWinEnter", { - group = self.augroup, - callback = function(event) - if event.buf == self.ui_buf then - return - end - - local win = api.nvim_get_current_win() - if win == ui_win then - -- Prevent accidentally opening another file in the ssr window. - -- Adapted from neo-tree.nvim. - vim.schedule(function() - api.nvim_win_set_buf(ui_win, self.ui_buf) - local name = api.nvim_buf_get_name(event.buf) - api.nvim_win_call(self.origin_win, function() - pcall(api.nvim_buf_delete, event.buf, {}) - if name ~= "" then - vim.cmd.edit(name) - end - end) - api.nvim_set_current_win(self.origin_win) - end) - return - elseif win ~= self.origin_win then - return - end - - if ts.language.get_lang(vim.bo[event.buf].filetype) ~= self.lang then - return self:set_status "N/A" - end - self:search() - end, - }) - - api.nvim_create_autocmd("WinClosed", { - group = self.augroup, - buffer = self.ui_buf, - callback = function() - win_uis[self.origin_win] = nil - api.nvim_clear_autocmds { group = self.augroup } - api.nvim_buf_delete(self.ui_buf, {}) - for buf in pairs(self.buf_matches) do - api.nvim_buf_clear_namespace(buf, self.ns, 0, -1) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - end - end, - }) - - win_uis[self.origin_win] = self - return self -end - -function Ui:search() - local pattern = self:get_input() - local buf = api.nvim_win_get_buf(self.origin_win) - self.buf_matches[buf] = {} - api.nvim_buf_clear_namespace(buf, self.ns, 0, -1) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - - local start = vim.loop.hrtime() - local node, source = self.parse_context:parse(pattern) - if node:has_error() then - return self:set_status "Error" - end - self.buf_matches[buf] = search(buf, node, source, self.ns) - local elapsed = (vim.loop.hrtime() - start) / 1E6 - for _, match in ipairs(self.buf_matches[buf]) do - local start_row, start_col, end_row, end_col = match.range:get() - highlight.range(buf, self.ns, "Search", { start_row, start_col }, { end_row, end_col }, {}) - end - self:set_status(string.format("%d found in %dms", #self.buf_matches[buf], elapsed)) -end - -function Ui:next_match_idx() - local cursor_row, cursor_col = u.get_cursor(self.origin_win) - local buf = api.nvim_win_get_buf(self.origin_win) - for idx, matches in pairs(self.buf_matches[buf]) do - local start_row, start_col = matches.range:get() - if start_row > cursor_row or (start_row == cursor_row and start_col > cursor_col) then - return idx - end - end - return 1 -end - -function Ui:prev_match_idx() - local cursor_row, cursor_col = u.get_cursor(self.origin_win) - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - for idx = #matches, 1, -1 do - local start_row, start_col = matches[idx].range:get() - if start_row < cursor_row or (start_row == cursor_row and start_col < cursor_col) then - return idx - end - end - return #matches -end - -function Ui:goto_match(match_idx) - local buf = api.nvim_win_get_buf(self.origin_win) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - local matches = self.buf_matches[buf] - local start_row, start_col, end_row, end_col = matches[match_idx].range:get() - u.set_cursor(self.origin_win, start_row, start_col) - highlight.range( - buf, - self.cur_search_ns, - "CurSearch", - { start_row, start_col }, - { end_row, end_col }, - { priority = vim.highlight.priorities.user + 100 } - ) - api.nvim_buf_set_extmark(buf, self.cur_search_ns, start_row, start_col, { - virt_text_pos = "eol", - virt_text = { { string.format("[%d/%d]", match_idx, #matches), "DiagnosticVirtualTextInfo" } }, - }) -end - -function Ui:replace_all() - self:search() - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - if #matches == 0 then - return self:set_status "pattern not found" - end - local _, template = self:get_input() - local start = vim.loop.hrtime() - for _, match in ipairs(matches) do - replace(buf, match, template) - end - local elapsed = (vim.loop.hrtime() - start) / 1E6 - self:set_status(string.format("%d replaced in %dms", #matches, elapsed)) -end - -function Ui:replace_confirm() - self:search() - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - if #matches == 0 then - return self:set_status "pattern not found" - end - - local confirm_buf = api.nvim_create_buf(false, true) - vim.bo[confirm_buf].filetype = "ssr_confirm" - local choices = { - "• Yes", - "• No", - "──────────────", - "• All", - "• Quit", - "• Last replace", - } - local separator_idx = 3 - api.nvim_buf_set_lines(confirm_buf, 0, -1, true, choices) - for idx = 0, #choices - 1 do - if idx + 1 ~= separator_idx then - api.nvim_buf_set_extmark(confirm_buf, self.ns, idx, 4, { hl_group = "Underlined", end_row = idx, end_col = 5 }) - end - end - - local function open_confirm_win(match_idx) - self:goto_match(match_idx) - local _, _, end_row, end_col = matches[match_idx].range:get() - local cfg = { - relative = "win", - win = self.origin_win, - bufpos = { end_row, end_col }, - style = "minimal", - border = config.border, - width = 14, - height = 6, - } - if vim.fn.has "nvim-0.9" == 1 then - cfg.title = "Replace?" - cfg.title_pos = "center" - end - return api.nvim_open_win(confirm_buf, true, cfg) - end - - local match_idx = 1 - local replaced = 0 - local cursor = 1 - local _, template = self:get_input() - self:set_status(string.format("replacing 0/%d", #matches)) - - while match_idx <= #matches do - local confirm_win = open_confirm_win(match_idx) - - ---@type string - local key - while true do - -- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`. - api.nvim_buf_clear_namespace(confirm_buf, self.cur_search_ns, 0, -1) - api.nvim_buf_set_extmark( - confirm_buf, - self.cur_search_ns, - cursor - 1, - 0, - { virt_text = { { "•", "Cursor" } }, virt_text_pos = "overlay" } - ) - api.nvim_buf_set_extmark(confirm_buf, self.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" }) - vim.cmd.redraw() - - local ok, char = pcall(vim.fn.getcharstr) - key = ok and vim.fn.keytrans(char) or "" - if key == "j" then - if cursor == separator_idx - 1 then -- skip separator - cursor = separator_idx + 1 - elseif cursor == #choices then -- wrap - cursor = 1 - else - cursor = cursor + 1 - end - elseif key == "k" then - if cursor == separator_idx + 1 then -- skip separator - cursor = separator_idx - 1 - elseif cursor == 1 then -- wrap - cursor = #choices - else - cursor = cursor - 1 - end - elseif vim.tbl_contains({ "", "", "", "", "", "" }, key) then - fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key)) - else - break - end - end - - if key == "" then - key = ({ "y", "n", "", "a", "q", "l" })[cursor] - end - - if key == "y" then - replace(buf, matches[match_idx], template) - replaced = replaced + 1 - match_idx = match_idx + 1 - elseif key == "n" then - match_idx = match_idx + 1 - elseif key == "a" then - for i = match_idx, #matches do - replace(buf, matches[i], template) - end - replaced = replaced + #matches + 1 - match_idx - match_idx = #matches + 1 - elseif key == "l" then - replace(buf, matches[match_idx], template) - replaced = replaced + 1 - match_idx = #matches + 1 - elseif key == "q" or key == "" or key == "" then - match_idx = #matches + 1 - end - api.nvim_win_close(confirm_win, false) - self:set_status(string.format("replacing %d/%d", replaced, #matches)) - end - - api.nvim_buf_delete(confirm_buf, {}) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - self:set_status(string.format("%d/%d replaced", replaced, #matches)) -end - -function Ui:get_input() - local lines = api.nvim_buf_get_lines(self.ui_buf, 0, -1, true) - local pattern_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, self.ns, self.extmarks.search, {})[1] - local template_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, self.ns, self.extmarks.replace, {})[1] - local pattern = vim.trim(table.concat(lines, "\n", pattern_pos + 2, template_pos)) - local template = vim.trim(table.concat(lines, "\n", template_pos + 1, #lines)) - return pattern, template -end - ----@param status string -function Ui:set_status(status) - api.nvim_buf_set_extmark(self.ui_buf, self.ns, 0, 0, { - id = self.extmarks.status, - virt_text = { - { "[SSR] ", "Comment" }, - { status }, - { " (Press ? for help)", "Comment" }, - }, - virt_text_pos = "overlay", - }) -end - ----@param win window? ----@return Ui? -function Ui.from_win(win) - if win == nil or win == 0 then - win = api.nvim_get_current_win() - end - local ui = win_uis[win] - if not ui then - return u.notify "No open SSR window" - end - return ui -end - -function M.open() - return Ui.new() -end - --- Replace all matches. -function M.replace_all() - local ui = Ui.from_win() - if ui then - ui:replace_all() - end -end - --- Confirm each match. -function M.replace_confirm() - local ui = Ui.from_win() - if ui then - ui:replace_confirm() - end -end - -return M diff --git a/lua/ssr/config.lua b/lua/ssr/config.lua new file mode 100644 index 0000000..7326c92 --- /dev/null +++ b/lua/ssr/config.lua @@ -0,0 +1,24 @@ +local M = {} + +---@class Config +M.opts = { + border = "rounded", + min_width = 50, + max_width = 120, + min_height = 6, + max_height = 25, + adjust_window = true, + keymaps = { + close = "q", + next_match = "n", + prev_match = "N", + replace_confirm = "", + replace_all = "", + }, +} + +function M.set(config) + M.opts = vim.tbl_deep_extend("force", M.opts, config) +end + +return M diff --git a/lua/ssr/file.lua b/lua/ssr/file.lua new file mode 100644 index 0000000..d43b416 --- /dev/null +++ b/lua/ssr/file.lua @@ -0,0 +1,147 @@ +local api = vim.api +local ts = vim.treesitter +local uv = vim.uv or vim.loop + +-- File contents and it's parsed tree +-- Unloaded buffers are read with libuv because loading a vim buffer can be up to 100x slower. +---@class ssr.File +---@field path string +---@field source string | buffer +---@field tree vim.treesitter.LanguageTree +-- Only if `source` is file content +---@field lines? string[] +---@field mtime? { nsec: integer, sec: integer } +local File = {} + +---@type table +local cache = {} + +---@param path string +---@return ssr.File? +function File.new(path) + -- First check if the file is already opened as a buffer. + local buf = vim.fn.bufnr(path) + if buf ~= -1 then + cache[path] = nil + if vim.bo[buf].filetype == "" then + local ft = vim.filetype.match { buf = buf } + api.nvim_buf_call(buf, function() + vim.cmd("noautocmd setlocal filetype=" .. ft) + end) + end + return setmetatable({ + path = path, + source = buf, + tree = ts.get_parser(buf), + }, { __index = File }) + end + + local fd = uv.fs_open(path, "r", 438) + if not fd then + return + end + local stat = uv.fs_fstat(fd) ---@cast stat -? + local self = cache[path] + if self then + if stat.mtime.sec == self.mtime.sec and stat.mtime.nsec == self.mtime.nsec then + uv.fs_close(fd) + return self + else + cache[path] = nil + end + end + local source = uv.fs_read(fd, stat.size, 0) --[[@as string]] + uv.fs_close(fd) + local lines = vim.split(source, "\n", { plain = true }) + local ft = vim.filetype.match { filename = path, contents = lines } -- not work for .ts + if not ft then + return + end + local lang = ts.language.get_lang(ft) + if not lang then + return + end + local has_parser, tree = pcall(ts.get_string_parser, source, lang) + if not has_parser then + return + end + tree:parse(true) + self = setmetatable({ + path = path, + source = source, + tree = tree, + filetype = ft, + lines = lines, + mtime = stat.mtime, + }, { __index = File }) + cache[path] = self + return self +end + +---@param line integer +---@return string +function File:get_line(line) + if type(self.source) == "number" then + return api.nvim_buf_get_lines(self.source --[[@as integer]], line, line + 1, true)[1] + end + return self.lines[line + 1] +end + +---@return integer +function File:load_buf() + if type(self.source) == "integer" then + return self.source --[[@as integer]] + end + local buf = vim.fn.bufadd(self.path) + self.source = buf + vim.fn.bufload(buf) + -- api.nvim_buf_call(buf, function() + -- vim.cmd("noautocmd setlocal filetype=" .. self.filetype) + -- end) + self.lines = nil + self.mtime = nil + cache[self.path] = nil + return buf +end + +---@param dir string +---@param regex string +---@param on_file fun(file: ssr.File) +---@param on_end fun() +---@return nil +function File.grep(dir, regex, on_file, on_end) + vim.system( + { "rg", "--line-buffered", "--files-with-matches", "--multiline", regex, dir }, + { + text = true, + stdout = vim.schedule_wrap(function(err, files) + if err then + error(files) + end + if not files then + on_end() + return + end + for _, path in ipairs(vim.split(files, "\n", { plain = true, trimempty = true })) do + local file = File.new(path) + if file then + on_file(file) + end + end + end), + }, + vim.schedule_wrap(function(obj) + if obj.code == 1 then -- no match was found + on_end() + elseif obj.code ~= 0 then + error(obj.stderr) + end + end) + ) +end + +function File.clear_cache() + cache = {} +end + +return File diff --git a/lua/ssr/init.lua b/lua/ssr/init.lua new file mode 100644 index 0000000..63b3880 --- /dev/null +++ b/lua/ssr/init.lua @@ -0,0 +1,15 @@ +local M = {} + +--- Set config options. Optional. +---@param config Config? +function M.setup(config) + if config then + require("ssr.config").set(config) + end +end + +function M.open() + require("ssr.ui").new() +end + +return M diff --git a/lua/ssr/parse.lua b/lua/ssr/parse.lua deleted file mode 100644 index 2e78f40..0000000 --- a/lua/ssr/parse.lua +++ /dev/null @@ -1,79 +0,0 @@ -local ts = vim.treesitter -local wildcard_prefix = require("ssr.search").wildcard_prefix - -local M = {} - ----@class ParseContext ----@field lang string ----@field before string ----@field after string ----@field pad_rows integer ----@field pad_cols integer -local ParseContext = {} -ParseContext.__index = ParseContext -M.ParseContext = ParseContext - --- Create a context in which `origin_node` (and user input) will be parsed correctly. ----@param buf buffer ----@param origin_node TSNode ----@return ParseContext? -function ParseContext.new(buf, origin_node) - local lang = ts.language.get_lang(vim.bo[buf].filetype) - if not lang then - return - end - local self = setmetatable({ lang = lang }, { __index = ParseContext }) - - local origin_start_row, origin_start_col, origin_start_byte = origin_node:start() - local _, _, origin_end_byte = origin_node:end_() - local origin_lines = vim.split(ts.get_node_text(origin_node, buf), "\n") - local origin_sexpr = origin_node:sexpr() - ---@type TSNode? - local context_node = origin_node - - -- Find an ancestor of `origin_node` - while context_node do - local context_text = ts.get_node_text(context_node, buf) - local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root() - - -- Get the range of `origin_text` relative to the string `context_text`. - local context_start_row, context_start_col = context_node:start() - local start_row = origin_start_row - context_start_row - local start_col = origin_start_col - if start_row == 0 then - start_col = origin_start_col - context_start_col - end - local end_row = start_row + #origin_lines - 1 - local end_col = #origin_lines[#origin_lines] - if end_row == start_row then - end_col = end_col + start_col - end - local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) - if node_in_context and node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then - local context_start_byte - self.start_row, self.start_col, context_start_byte = context_node:start() - self.before = context_text:sub(1, origin_start_byte - context_start_byte) - self.after = context_text:sub(origin_end_byte - context_start_byte + 1) - self.pad_rows = start_row - self.pad_cols = start_col - return self - end - -- Try next parent - context_node = context_node:parent() - end -end - --- Parse search pattern to syntax tree in proper context. ----@param pattern string ----@return TSNode?, string -function ParseContext:parse(pattern) - -- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error. - pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1") - local context_text = self.before .. pattern .. self.after - local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root() - local lines = vim.split(pattern, "\n") - local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines]) - return node, context_text -end - -return M diff --git a/lua/ssr/range.lua b/lua/ssr/range.lua new file mode 100644 index 0000000..cacd0d0 --- /dev/null +++ b/lua/ssr/range.lua @@ -0,0 +1,50 @@ +local api = vim.api +local u = require "ssr.utils" + +---@class ssr.Range +---@field start_row integer +---@field start_col integer +---@field end_row integer +---@field end_col integer +local Range = {} + +---@param node TSNode +---@return ssr.Range +function Range.from_node(node) + local start_row, start_col, end_row, end_col = node:range() + return setmetatable({ + start_row = start_row, + start_col = start_col, + end_row = end_row, + end_col = end_col, + }, { __index = Range }) +end + +---@param other ssr.Range +---@return boolean +function Range:before(other) + return self.end_row < other.start_row or (self.end_row == other.start_row and self.end_col <= other.start_col) +end + +---@param other ssr.Range +---@return boolean +function Range:inside(other) + return ( + (self.start_row > other.start_row or (self.start_row == other.start_row and self.start_col > other.start_col)) + and (self.end_row < other.end_row or (self.end_row == other.end_row and self.end_col <= other.end_col)) + ) +end + +-- Extmark-based ranges automatically adjust as buffer contents change. +---@param buf integer +---@return integer +function Range:to_extmark(buf) + return api.nvim_buf_set_extmark(buf, u.namespace, self.start_row, self.start_col, { + end_row = self.end_row, + end_col = self.end_col, + right_gravity = false, + end_right_gravity = true, + }) +end + +return Range diff --git a/lua/ssr/replace.lua b/lua/ssr/replace.lua new file mode 100644 index 0000000..9f96f4d --- /dev/null +++ b/lua/ssr/replace.lua @@ -0,0 +1,59 @@ +local api = vim.api +local u = require "ssr.utils" + +local M = {} + +---@class ssr.PinnedMatch +---@field buf integer +---@field range integer +---@field captures integer[] +M.PinnedMatch = {} + +-- Convert `ssr.SearchResults` to extmark-based version. +---@param matches ssr.Matches +---@return ssr.PinnedMatch[] +function M.pin_matches(matches) + local res = {} + for _, row in ipairs(matches) do + local buf = row.file:load_buf() + for _, match in ipairs(row.matches) do + local pinned = { buf = buf, range = match.range:to_extmark(buf), captures = {} } + for var, range in pairs(match.captures) do + pinned.captures[var] = range:to_extmark(buf) + end + table.insert(res, pinned) + end + end + return res +end + +---@param buf integer +---@param id integer +---@return integer, number, number, number +local function get_extmark_range(buf, id) + local extmark = api.nvim_buf_get_extmark_by_id(buf, u.namespace, id, { details = true }) + return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col +end + +--- Render template and replace one match. +---@param match ssr.PinnedMatch +---@param template string +function M.replace(match, template) + -- Render templates with captured nodes. + local replacement = template:gsub("()%$([_%a%d]+)", function(pos, var) + local start_row, start_col, end_row, end_col = get_extmark_range(match.buf, match.captures[var]) + local capture_lines = api.nvim_buf_get_text(match.buf, start_row, start_col, end_row, end_col, {}) + u.remove_indent(capture_lines, u.get_indent(match.buf, start_row)) + local var_lines = vim.split(template:sub(1, pos), "\n") + local var_line = var_lines[#var_lines] + local template_indent = var_line:match "^%s*" + u.add_indent(capture_lines, template_indent) + return table.concat(capture_lines, "\n") + end) + replacement = vim.split(replacement, "\n") + local start_row, start_col, end_row, end_col = get_extmark_range(match.buf, match.range) + u.add_indent(replacement, u.get_indent(match.buf, start_row)) + api.nvim_buf_set_text(match.buf, start_row, start_col, end_row, end_col, replacement) +end + +return M diff --git a/lua/ssr/search.lua b/lua/ssr/search.lua index ec6cdcc..9767183 100644 --- a/lua/ssr/search.lua +++ b/lua/ssr/search.lua @@ -1,57 +1,15 @@ -local api = vim.api local ts = vim.treesitter +local Range = require "ssr.range" local u = require "ssr.utils" -local M = {} - -M.wildcard_prefix = "__ssr_var_" - ----@class Match ----@field range ExtmarkRange ----@field captures ExtmarkRange[] - ----@class ExtmarkRange ----@field ns number ----@field buf buffer ----@field extmark number -local ExtmarkRange = {} -M.ExtmarkRange = ExtmarkRange - ----@param ns number ----@param buf buffer ----@param node TSNode ----@return ExtmarkRange -function ExtmarkRange.new(ns, buf, node) - local start_row, start_col, end_row, end_col = node:range() - return setmetatable({ - ns = ns, - buf = buf, - extmark = api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { - end_row = end_row, - end_col = end_col, - right_gravity = false, - end_right_gravity = true, - }), - }, { __index = ExtmarkRange }) -end - ----@return number, number, number, number -function ExtmarkRange:get() - local extmark = api.nvim_buf_get_extmark_by_id(self.buf, self.ns, self.extmark, { details = true }) - return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col -end - -- Compare if two captured trees can match. --- The check is loose because users want to match different types of node. +-- The check is loose because we want to match different types of node. -- e.g. converting `{ foo: foo }` to shorthand `{ foo }`. ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) - ---@param node1 TSNode? - ---@param node2 TSNode? + ---@param node1 TSNode + ---@param node2 TSNode ---@return boolean local function tree_match(node1, node2) - if not node1 or not node2 then - return false - end if node1:named() ~= node2:named() then return false end @@ -62,40 +20,58 @@ ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) return false end for i = 0, node1:child_count() - 1 do - if not tree_match(node1:child(i), node2:child(i)) then + if + not tree_match(node1:child(i) --[[@as TSNode]], node2:child(i) --[[@as TSNode]]) + then return false end end return true end return tree_match(match[pred[2]], match[pred[3]]) -end, true) +end, { force = true }) + +-- In grammars like Lua some important nodes do not have a field name. +local crucial_nodes_without_field_name = { + ["+"] = true, + ["-"] = true, + ["*"] = true, + ["/"] = true, + ["#"] = true, + ["~"] = true, + ["and"] = true, + ["or"] = true, + ["not"] = true, +} -- Build a TS sexpr represting the node. +-- This function is more strict than `TSNode:sexpr()` by also requiring leaf nodes to match text. ---@param node TSNode ---@param source string ----@return string, table +---@return string sexpr +---@return table captures local function build_sexpr(node, source) - ---@type table - local wildcards = {} + ---@type table + local captures = {} local next_idx = 1 - -- This function is more strict than `tsnode:sexpr()` by also requiring leaf nodes to match text. ---@param node TSNode - ---@return string + ---@return string? local function build(node) - local text = ts.get_node_text(node, source) + if not node:named() then + return string.format('"%s"', u.ts_str_escape(node:type())) + end - -- Special identifier __ssr_var_name is a named wildcard. - -- Handle this early to make sure wildcard captures largest node. - local var = text:match("^" .. M.wildcard_prefix .. "([_%a%d]+)$") + -- Handle captures early to capture the largest node. + local text = ts.get_node_text(node, source) + local var = text:match("^" .. u.capture_prefix .. "([_%a%d]+)$") if var then - if not wildcards[var] then - wildcards[var] = next_idx + if not captures[var] then + captures[var] = next_idx next_idx = next_idx + 1 return "(_) @" .. var else - -- Same wildcard should match the same subtree. + -- Same capture should match the same subtree. local sexpr = string.format("(_) @_%d (#ssr-tree-match? @_%d @%s)", next_idx, next_idx, var) next_idx = next_idx + 1 return sexpr @@ -104,106 +80,141 @@ local function build_sexpr(node, source) -- Leaf nodes (keyword, identifier, literal and symbol) should match text. if node:named_child_count() == 0 then - local sexpr = string.format("(%s) @_%d (#eq? @_%d %s)", node:type(), next_idx, next_idx, u.to_ts_query_str(text)) + local sexpr = string.format('(%s) @_%d (#eq? @_%d "%s")', node:type(), next_idx, next_idx, u.ts_str_escape(text)) next_idx = next_idx + 1 return sexpr end -- Normal nodes - local sexpr = "(" .. node:type() + local sexpr = "" local add_anchor = false - for child, name in node:iter_children() do - -- Imagine using Rust's match on (name, child:named()). - if name and child:named() then - sexpr = sexpr .. string.format(" %s: %s", name, build(child)) - elseif name and not child:named() then - sexpr = sexpr .. string.format(" %s: %s", name, u.to_ts_query_str(child:type())) - elseif not name and child:named() then + for child, field in node:iter_children() do + if field then + if add_anchor then + sexpr = sexpr .. " ." + add_anchor = false + end + sexpr = string.format("%s %s: %s", sexpr, field, build(child)) + elseif child:named() or crucial_nodes_without_field_name[child:type()] then -- Pin child position with anchor `.` sexpr = string.format(" %s . %s", sexpr, build(child)) add_anchor = true - else - -- Ignore commas and parentheses end end if add_anchor then sexpr = sexpr .. " ." end - sexpr = sexpr .. ")" - return sexpr + return string.format("(%s %s)", node:type(), sexpr) end local sexpr = string.format("(%s) @all", build(node)) - return sexpr, wildcards + return sexpr, captures end ----@param buf buffer ---@param node TSNode ----@param source string ----@return Match[] -function M.search(buf, node, source, ns) - local sexpr, wildcards = build_sexpr(node, source) - local parse_query = ts.query.parse or ts.parse_query - local lang = ts.language.get_lang(vim.bo[buf].filetype) - if not lang then - return {} +local function build_rough_regex(node, source) + local regex = {} + local function build(node) + if node:child_count() == 0 then + local text = ts.get_node_text(node, source) + if text:match("^" .. u.capture_prefix .. "([_%a%d]+)$") then + table.insert(regex, ".+") + else + table.insert(regex, u.regex_escape(text)) + end + else + for child in node:iter_children() do + build(child) + end + end end - local query = parse_query(lang, sexpr) - local matches = {} - local has_parser, parser = pcall(ts.get_parser, buf, lang) - if not has_parser then - return {} + build(node) + return table.concat(regex, "\\s*") +end + +---@class ssr.Searcher +---@field pattern string +---@field queries table +---@field captures table +local Searcher = {} + +---@param lang string +---@param pattern string +---@return vim.treesitter.Query | vim.NIL, table?, TSNode? +local function parse_pattern(lang, pattern) + local node = ts.get_string_parser(pattern, lang):parse()[1]:root() + local lines = vim.split(pattern, "\n", { plain = true }) + node = node:named_descendant_for_range(0, 0, #lines - 1, #lines[#lines]) --[[@as TSNode]] + if node:has_error() then + return vim.NIL end - local root = parser:parse(true)[1]:root() - for _, nodes in query:iter_matches(root, buf, 0, -1) do - ---@type table - local captures = {} - for var, idx in pairs(wildcards) do - captures[var] = ExtmarkRange.new(ns, buf, nodes[idx]) - end - local match = { range = ExtmarkRange.new(ns, buf, nodes[#nodes]), captures = captures } - table.insert(matches, match) + local sexpr, captures = build_sexpr(node, pattern) + local query = ts.query.parse(lang, sexpr) + return query, captures, node +end + +---@param lang string +---@param pattern string +---@return ssr.Searcher? +---@return string rough_regex +function Searcher.new(lang, pattern) + -- $ can cause syntax errors in some languages + pattern = pattern:gsub("%$([_%a%d]+)", u.capture_prefix .. "%1") + local query, captures, node = parse_pattern(lang, pattern) + if query ~= vim.NIL then + return setmetatable( + { pattern = pattern, queries = { [lang] = query }, captures = captures }, + { __index = Searcher } + ), + build_rough_regex(node, pattern) end +end + +-- A single match, including its captures. +---@class ssr.Match +---@field range ssr.Range +---@field captures table + +---@param file ssr.File +---@return ssr.Match[] +function Searcher:search(file) + ---@type ssr.Match[] + local matches = {} + file.tree:for_each_tree(function(tree, lang_tree) -- must called :parse(true) + local lang = lang_tree:lang() + local query = self.queries[lang] + if query == vim.NIL then -- cached failure + return + elseif not query then + query = parse_pattern(lang, self.pattern) + self.queries[lang] = query + if query == vim.NIL then + return + end + end + for _, nodes in query:iter_matches(tree:root(), file.source, 0, -1) do + local range = Range.from_node(nodes[#nodes]) + local captures = {} + for var, idx in pairs(self.captures) do + captures[var] = Range.from_node(nodes[idx]) + end + table.insert(matches, { range = range, captures = captures }) + end + end) -- Sort matches from -- buffer top to bottom, to make goto next/prev match intuitive -- inner to outer for recursive matches, to make replacing correct - ---@param match1 { range: ExtmarkRange, captures: table} - ---@param match2 { range: ExtmarkRange, captures: table} + ---@param match1 ssr.Match + ---@param match2 ssr.Match ---@return boolean table.sort(matches, function(match1, match2) - local start_row1, start_col1, end_row1, end_col1 = match1.range:get() - local start_row2, start_col2, end_row2, end_col2 = match2.range:get() - if end_row1 < start_row2 or (end_row1 == start_row2 and end_col1 <= start_col2) then + if match1.range:before(match2.range) then return true end - return (start_row1 > start_row2 or (start_row1 == start_row2 and start_col1 > start_col2)) - and (end_row1 < end_row2 or (end_row1 == end_row2 and end_col1 <= end_col2)) + return match1.range:inside(match2.range) end) - return matches end ---- Render template and replace one match. ----@param buf buffer ----@param match Match ----@param template string -function M.replace(buf, match, template) - -- Render templates with captured nodes. - local replace = template:gsub("()%$([_%a%d]+)", function(pos, var) - local start_row, start_col, end_row, end_col = match.captures[var]:get() - local lines = api.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {}) - u.remove_indent(lines, u.get_indent(buf, start_row)) - local var_lines = vim.split(template:sub(1, pos), "\n") - local var_line = var_lines[#var_lines] - local template_indent = var_line:match "^%s*" - u.add_indent(lines, template_indent) - return table.concat(lines, "\n") - end) - replace = vim.split(replace, "\n") - local start_row, start_col, end_row, end_col = match.range:get() - u.add_indent(replace, u.get_indent(buf, start_row)) - api.nvim_buf_set_text(buf, start_row, start_col, end_row, end_col, replace) -end - -return M +return Searcher diff --git a/lua/ssr/ui/confirm_win.lua b/lua/ssr/ui/confirm_win.lua new file mode 100644 index 0000000..e02fa19 --- /dev/null +++ b/lua/ssr/ui/confirm_win.lua @@ -0,0 +1,139 @@ +local api = vim.api + +---@class ConfirmWin +local ConfirmWin = {} + +function ConfirmWin.new() end + +function ConfirmWin:open() + local buf = api.nvim_win_get_buf(self.origin_win) + local matches = self.matches[buf] + if #matches == 0 then + return self:set_status "pattern not found" + end + + local confirm_buf = api.nvim_create_buf(false, true) + vim.bo[confirm_buf].filetype = "ssr_confirm" + local choices = { + "• Yes", + "• No", + "──────────────", + "• All", + "• Quit", + "• Last replace", + } + local separator_idx = 3 + api.nvim_buf_set_lines(confirm_buf, 0, -1, true, choices) + for idx = 0, #choices - 1 do + if idx + 1 ~= separator_idx then + api.nvim_buf_set_extmark( + confirm_buf, + u.namespace, + idx, + 4, + { hl_group = "Underlined", end_row = idx, end_col = 5 } + ) + end + end + + local function open_confirm_win(match_idx) + self:goto_match(match_idx) + local _, _, end_row, end_col = matches[match_idx].range:get() + local cfg = { + relative = "win", + win = self.origin_win, + bufpos = { end_row, end_col }, + style = "minimal", + border = config.options.border, + width = 14, + height = 6, + } + if vim.fn.has "nvim-0.9" == 1 then + cfg.title = "Replace?" + cfg.title_pos = "center" + end + return api.nvim_open_win(confirm_buf, true, cfg) + end + + local match_idx = 1 + local replaced = 0 + local cursor = 1 + local _, template = self:get_input() + self:set_status(string.format("replacing 0/%d", #matches)) + + while match_idx <= #matches do + local confirm_win = open_confirm_win(match_idx) + + ---@type string + local key + while true do + -- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`. + api.nvim_buf_clear_namespace(confirm_buf, u.cur_search_ns, 0, -1) + api.nvim_buf_set_extmark( + confirm_buf, + u.cur_search_ns, + cursor - 1, + 0, + { virt_text = { { "•", "Cursor" } }, virt_text_pos = "overlay" } + ) + api.nvim_buf_set_extmark(confirm_buf, u.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" }) + vim.cmd.redraw() + + local ok, char = pcall(vim.fn.getcharstr) + key = ok and vim.fn.keytrans(char) or "" + if key == "j" then + if cursor == separator_idx - 1 then -- skip separator + cursor = separator_idx + 1 + elseif cursor == #choices then -- wrap + cursor = 1 + else + cursor = cursor + 1 + end + elseif key == "k" then + if cursor == separator_idx + 1 then -- skip separator + cursor = separator_idx - 1 + elseif cursor == 1 then -- wrap + cursor = #choices + else + cursor = cursor - 1 + end + elseif vim.tbl_contains({ "", "", "", "", "", "" }, key) then + vim.fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key)) + else + break + end + end + + if key == "" then + key = ({ "y", "n", "", "a", "q", "l" })[cursor] + end + + if key == "y" then + replace(buf, matches[match_idx], template) + replaced = replaced + 1 + match_idx = match_idx + 1 + elseif key == "n" then + match_idx = match_idx + 1 + elseif key == "a" then + for i = match_idx, #matches do + replace(buf, matches[i], template) + end + replaced = replaced + #matches + 1 - match_idx + match_idx = #matches + 1 + elseif key == "l" then + replace(buf, matches[match_idx], template) + replaced = replaced + 1 + match_idx = #matches + 1 + elseif key == "q" or key == "" or key == "" then + match_idx = #matches + 1 + end + api.nvim_win_close(confirm_win, false) + self:set_status(string.format("replacing %d/%d", replaced, #matches)) + end + + api.nvim_buf_delete(confirm_buf, {}) + api.nvim_buf_clear_namespace(buf, u.cur_search_ns, 0, -1) + self:set_status(string.format("%d/%d replaced", replaced, #matches)) +end + +return ConfirmWin diff --git a/lua/ssr/ui/init.lua b/lua/ssr/ui/init.lua new file mode 100644 index 0000000..d157df7 --- /dev/null +++ b/lua/ssr/ui/init.lua @@ -0,0 +1,130 @@ +local api = vim.api +local ts = vim.treesitter +local config = require "ssr.config" +local Searcher = require "ssr.search" +local replace = require("ssr.replace").replace +local pin_matches = require("ssr.replace").pin_matches +local File = require "ssr.file" +local MainWin = require "ssr.ui.main_win" +local u = require "ssr.utils" + +---@alias ssr.Matches { file: ssr.File, matches: ssr.Match[] }[] + +---@class Ui +---@field lang string +---@field matches ssr.Matches +---@field last_pattern string +---@field main_win MainWin +local Ui = {} + +---@return Ui? +function Ui.new() + local self = setmetatable({ matches = {} }, { __index = Ui }) + + -- Pre-checks + local origin_win = api.nvim_get_current_win() + local origin_buf = api.nvim_win_get_buf(origin_win) + local lang = ts.language.get_lang(vim.bo[origin_buf].filetype) + if not lang then + return u.notify(string.format("Treesitter language not found for filetype '%s'", vim.bo[origin_buf].filetype)) + end + self.lang = lang + local node = u.node_for_range(origin_buf, self.lang, u.get_selection(origin_win)) + if not node then + return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang) + end + if node:has_error() then + return u.notify "You have syntax errors in the selected node" + end + -- Extend the selected node if it can't be parsed without context. + repeat + local text = ts.get_node_text(node, origin_buf) + local root = ts.get_string_parser(text, self.lang):parse()[1]:root() + local lines = vim.split(text, "\n", { plain = true }) + local n = root:named_descendant_for_range(0, 0, #lines - 1, #lines[#lines]) + if not n:has_error() then + break + end + node = node:parent() + until not node + if not node then + return u.notify "Selected node can't be properly parsed." + end + + local placeholder = vim.split(ts.get_node_text(node, origin_buf), "\n", { plain = true }) + u.remove_indent(placeholder, u.get_indent(origin_buf, node:start())) + + self.main_win = MainWin.new(lang, placeholder, { "" }, origin_win) + + self.main_win:on({ "InsertLeave", "TextChanged" }, function() + self:search() + end) + + self.main_win:on_key(config.opts.keymaps.replace_all, function() + self:replace_all() + end) + + self:search() + return self +end + +function Ui:search() + local pattern = self.main_win:get_input() + if pattern == self.last_pattern then + return + end + self.last_pattern = pattern + + self.matches = {} + local found = 0 + local matched_files = 0 + local start = vim.loop.hrtime() + local searcher, rough_regex = Searcher.new(self.lang, pattern) + if not searcher then + return self:set_status "Error" + end + + File.grep(vim.loop.cwd(), rough_regex, function(file) + local matches = searcher:search(file) + if #matches == 0 then + return + end + found = found + #matches + matched_files = matched_files + 1 + table.insert(self.matches, { file = file, matches = matches }) + end, function() + local elapsed = (vim.loop.hrtime() - start) / 1E6 + self.main_win.result_list:set(self.matches) + self:set_status(string.format("%d found in %d files (%dms)", found, matched_files, elapsed)) + end) +end + +function Ui:replace_all() + if #self.matches == 0 then + return self:set_status "pattern not found" + end + local _, template = self.main_win:get_input() + local start = vim.loop.hrtime() + local pinned = pin_matches(self.matches) + for _, match in ipairs(pinned) do + replace(match, template) + end + local elapsed = (vim.loop.hrtime() - start) / 1E6 + self:set_status(string.format("%d replaced in %d files (%dms)", #self.matches, 0, elapsed)) +end + +---@param status string +---@return nil +function Ui:set_status(status) + api.nvim_buf_set_extmark(self.main_win.buf, u.namespace, 0, 0, { + id = self.main_win.extmarks.status, + virt_text = { + { "[SSR] ", "Comment" }, + { status }, + { " (Press ? for help)", "Comment" }, + }, + virt_text_pos = "overlay", + }) +end + +return Ui diff --git a/lua/ssr/ui/main_win.lua b/lua/ssr/ui/main_win.lua new file mode 100644 index 0000000..b81a830 --- /dev/null +++ b/lua/ssr/ui/main_win.lua @@ -0,0 +1,241 @@ +local api = vim.api +local ts = vim.treesitter +local config = require "ssr.config" +local ResultList = require "ssr.ui.result_list" +local u = require "ssr.utils" + +---@class MainWin +---@field buf integer +---@field win integer +---@field origin_win integer +---@field lang string +---@field last_pattern string[] +---@field last_template string[] +---@field result_list ResultList +local MainWin = {} + +function MainWin.new(lang, pattern, template, origin_win) + local self = setmetatable({ + lang = lang, + last_pattern = pattern, + last_template = template, + origin_win = origin_win, + }, { __index = MainWin }) + + self.buf = api.nvim_create_buf(false, true) + vim.bo[self.buf].filetype = "ssr" + + local lines = self:render() + self:open_win(u.get_win_size(lines)) + + self.result_list = ResultList.new(self.buf, self.win, self.extmarks.results) + + self:setup_autocmds() + self:setup_keymaps() + + return self +end + +---@private +function MainWin:render() + ts.stop(self.buf) + api.nvim_buf_clear_namespace(self.buf, u.namespace, 0, -1) + + local lines = { + "", -- [SSR] + "```" .. self.lang, -- SEARCH: + } + vim.list_extend(lines, self.last_pattern) + table.insert(lines, "") -- REPLACE: + vim.list_extend(lines, self.last_template) + table.insert(lines, "```") -- RESULTS: + api.nvim_buf_set_lines(self.buf, 0, -1, true, lines) + + -- Enable syntax highlights for input area. + local parser = ts.get_parser(self.buf, "markdown") + parser:parse(true) + parser:for_each_tree(function(tree, lang_tree) + if tree:root():start() == 2 then + ts.highlighter.new(lang_tree) + end + end) + + local function virt_text(row, text) + return api.nvim_buf_set_extmark(self.buf, u.namespace, row, 0, { virt_text = text, virt_text_pos = "overlay" }) + end + self.extmarks = { + status = virt_text(0, { { "[SSR]", "Comment" }, { " (Press ? for help)", "Comment" } }), + search = virt_text(1, { { "SEARCH: ", "String" } }), -- Extra spaces to cover too long language name. + replace = virt_text(#lines - 3, { { "REPLACE:", "String" } }), + results = virt_text(#lines - 1, { { "RESULTS:", "String" } }), + } + + return lines +end + +---@private +function MainWin:check(lines) + if #lines < 6 then + return false + end + + local function get_index(extmark) + return api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + 1 + end + + return get_index(self.extmarks.status) == 1 + and lines[1] == "" + and get_index(self.extmarks.search) == 2 + and lines[2] == "```" .. self.lang + and lines[get_index(self.extmarks.replace)] == "" + and lines[get_index(self.extmarks.results)] == "```" +end + +---@private +function MainWin:open_win(width, height) + self.win = api.nvim_open_win(self.buf, true, { + relative = "editor", + anchor = "NE", + row = 0, + col = vim.o.columns - 1, + style = "minimal", + border = config.opts.border, + width = width, + height = height, + }) + vim.wo[self.win].wrap = false + if vim.fn.has "nvim-0.10" == 1 then + vim.wo[self.win].winfixbuf = true + end + u.set_cursor(self.win, 2, 0) + vim.fn.matchadd("Title", [[$\w\+]]) +end + +function MainWin:on(event, func) + api.nvim_create_autocmd(event, { + group = u.augroup, + buffer = self.buf, + callback = func, + }) +end + +---@private +function MainWin:setup_autocmds() + self:on({ "TextChanged", "TextChangedI" }, function() + local lines = api.nvim_buf_get_lines(self.buf, 0, -1, true) + if not self:check(lines) then + self:render() + self.result_list.extmark = self.extmarks.results + self.result_list:set {} + u.set_cursor(self.win, 2, 0) + end + if not config.opts.adjust_window then + return + end + local width, height = u.get_win_size(lines) + if api.nvim_win_get_width(self.win) ~= width then + api.nvim_win_set_width(self.win, width) + end + if api.nvim_win_get_height(self.win) ~= height then + api.nvim_win_set_height(self.win, height) + end + end) + + self:on("BufWinEnter", function(event) + if event.buf == self.buf then + return + end + local win = api.nvim_get_current_win() + if win ~= self.win then + return + end + -- Prevent accidentally opening another file in the ssr window. + -- Adapted from neo-tree.nvim. + vim.schedule(function() + api.nvim_win_set_buf(self.win, self.buf) + local name = api.nvim_buf_get_name(event.buf) + api.nvim_win_call(self.origin_win, function() + pcall(api.nvim_buf_delete, event.buf, {}) + if name ~= "" then + vim.cmd.edit(name) + end + end) + api.nvim_set_current_win(self.origin_win) + end) + end) + + self:on("WinClosed", function() + api.nvim_clear_autocmds { group = u.augroup } + api.nvim_buf_delete(self.buf, {}) + end) +end + +function MainWin:on_key(key, func) + vim.keymap.set("n", key, func, { buffer = self.buf, nowait = true }) +end + +---@private +function MainWin:setup_keymaps() + self:on_key(config.opts.keymaps.close, function() + api.nvim_win_close(self.win, false) + end) + + self:on_key("gg", function() + u.set_cursor(self.win, 2, 0) + end) + + self:on_key("j", function() + local cursor = u.get_cursor(self.win) + for _, extmark in ipairs { self.extmarks.replace, self.extmarks.results } do + local skip_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + if cursor == skip_pos - 1 then + return pcall(u.set_cursor, self.win, skip_pos + 1, 0) + end + end + vim.fn.feedkeys("j", "n") + end) + + self:on_key("k", function() + local cursor = u.get_cursor(self.win) + if cursor <= 2 then + return u.set_cursor(self.win, 2, 0) + end + for _, extmark in ipairs { self.extmarks.replace, self.extmarks.results } do + local skip_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + if cursor == skip_pos + 1 then + return pcall(u.set_cursor, self.win, skip_pos - 1, 0) + end + end + vim.fn.feedkeys("k", "n") + end) + + self:on_key("l", function() + local cursor = u.get_cursor(self.win) + if cursor < self.result_list:get_start() then + return vim.fn.feedkeys("l", "n") + end + self.result_list:set_folded(false) + end) + + self:on_key("h", function() + local cursor = u.get_cursor(self.win) + if cursor < self.result_list:get_start() then + return vim.fn.feedkeys("h", "n") + end + self.result_list:set_folded(true) + end) +end + +function MainWin:get_input() + local pattern_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.search, {})[1] + local template_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.replace, {})[1] + local results_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.results, {})[1] + local lines = api.nvim_buf_get_lines(self.buf, 0, results_pos, true) + local pattern = vim.list_slice(lines, pattern_pos + 2, template_pos) + local template = vim.list_slice(lines, template_pos + 2) + self.last_pattern = pattern + self.last_template = template + return vim.trim(table.concat(pattern, "\n")), vim.trim(table.concat(template, "\n")) +end + +return MainWin diff --git a/lua/ssr/ui/result_list.lua b/lua/ssr/ui/result_list.lua new file mode 100644 index 0000000..815e6b5 --- /dev/null +++ b/lua/ssr/ui/result_list.lua @@ -0,0 +1,205 @@ +local api = vim.api +local config = require "ssr.config" +local u = require "ssr.utils" + +-- List item per line +---@class Item +---@field fold_idx integer which fold this line belongs to, 1-based +---@field match_idx integer which match this line belongs to, 0-based, 0 for filename + +-- A foldable region that may span multiple lines +---@class Fold +---@field folded boolean +---@field filename string +---@field path string +---@field preview_lines string[] +local Fold = {} + +---@param folded boolean +---@param file ssr.File +---@param matches ssr.Match[] +---@return Fold +function Fold.new(folded, file, matches) + local preview_lines = {} + for _, match in ipairs(matches) do + local line = file:get_line(match.range.start_row) + line = line:gsub("^%s*", "") + table.insert(preview_lines, "│ " .. line) + end + return setmetatable({ + folded = folded, + filename = vim.fn.fnamemodify(file.path, ":t"), + path = vim.fn.fnamemodify(file.path, ":~:.:h"), + preview_lines = preview_lines, + }, { __index = Fold }) +end + +function Fold:len() + if self.folded then + return 1 + end + return 1 + #self.preview_lines +end + +---@private +function Fold:get_lines() + if self.folded then + return { string.format(" %s %s %d", self.filename, self.path, #self.preview_lines) } + end + local lines = { string.format(" %s %s %d", self.filename, self.path, #self.preview_lines) } + vim.list_extend(lines, self.preview_lines) + return lines +end + +function Fold:highlight(buf, row) + local col = 4 -- "" is 3 bytes, plus 1 space + api.nvim_buf_add_highlight(buf, u.namespace, "Directory", row, col, col + #self.filename) + col = col + #self.filename + 1 + api.nvim_buf_add_highlight(buf, u.namespace, "Comment", row, col, col + #self.path) + col = col + #self.path + 1 + api.nvim_buf_add_highlight(buf, u.namespace, "Number", row, col, col + #(tostring(self.preview_lines))) +end + +---@class ResultList +---@field buf integer +---@field win integer +---@field extmark integer +---@field folds Fold[] +---@field items Item[] +local ResultList = {} + +function ResultList.new(buf, win, extmark) + local self = setmetatable({ + buf = buf, + win = win, + extmark = extmark, + folds = {}, + items = {}, + }, { __index = ResultList }) + + vim.keymap.set("n", config.opts.keymaps.next_match, function() + self:next_match() + end, { buffer = self.buf, nowait = true }) + vim.keymap.set("n", config.opts.keymaps.prev_match, function() + self:prev_match() + end, { buffer = self.buf, nowait = true }) + + return self +end + +---@private +function ResultList:get_start() + return api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmark, {})[1] + 1 +end + +---@params matches ssr.Matches +function ResultList:set(matches) + self.folds = {} + self.items = {} + local start = self:get_start() + api.nvim_buf_clear_namespace(self.buf, u.namespace, start, -1) + + local lines = {} + for fold_idx, row in ipairs(matches) do + local fold = Fold.new(fold_idx ~= 1, row.file, row.matches) + table.insert(self.folds, fold) + for match_idx, line in ipairs(fold:get_lines()) do + table.insert(lines, line) + table.insert(self.items, { fold_idx = fold_idx, match_idx = match_idx - 1 }) + end + end + api.nvim_buf_set_lines(self.buf, start, -1, true, lines) + + for _, fold in ipairs(self.folds) do + fold:highlight(self.buf, start) + start = start + fold:len() + end +end + +---@param folded boolean +---@param cursor integer? +function ResultList:set_folded(folded, cursor) + local result_start = self:get_start() + cursor = cursor or u.get_cursor(self.win) - result_start + local item = self.items[cursor + 1] -- +1 beacause `cursor` is 0-based + local fold = self.folds[item.fold_idx] + if fold.folded == folded then + return + end + + local start = cursor - item.match_idx -- like C macro `container_of` + local end_ = start + fold:len() + fold.folded = folded + local lines = fold:get_lines() + local items = {} + for i = 0, #lines - 1 do + table.insert(items, { fold_idx = item.fold_idx, match_idx = i }) + end + u.list_replace(self.items, start, end_, items) + start = result_start + start + end_ = result_start + end_ + api.nvim_buf_set_lines(self.buf, start, end_, true, lines) + fold:highlight(self.buf, start) + if folded then + u.set_cursor(self.win, start, 0) + end +end + +function ResultList:next_match() + local cursor = u.get_cursor(self.win) + local result_start = self:get_start() + cursor = cursor > result_start and cursor - result_start or 0 + local item = self.items[cursor + 1] -- +1: lua index + if not item then + return + end + if item.match_idx == 0 then + self:set_folded(false, cursor) + end + cursor = cursor + 1 + item = self.items[cursor + 1] + if not item then + return + end + if item.match_idx == 0 then + self:set_folded(false, cursor) + cursor = cursor + 1 + end + return u.set_cursor(self.win, cursor + result_start, 0) +end + +function ResultList:prev_match() + local cursor = u.get_cursor(self.win) + local result_start = self:get_start() + if cursor <= result_start then + if #self.items == 0 then + return + end + self:set_folded(false, #self.items - 1) + return u.set_cursor(self.win, result_start + #self.items - 1, 0) + end + + cursor = cursor - result_start + local item = self.items[cursor + 1] + if not item then + return + end + if item.match_idx <= 1 then + cursor = cursor - item.match_idx - 1 + item = self.items[cursor + 1] + if not item then + return + end + local fold = self.folds[item.fold_idx] + if fold.folded then + self:set_folded(false, cursor) + cursor = cursor + #fold.preview_lines + end + return u.set_cursor(self.win, result_start + cursor, 0) + end + + cursor = cursor - 1 + return u.set_cursor(self.win, cursor + result_start, 0) +end + +return ResultList diff --git a/lua/ssr/utils.lua b/lua/ssr/utils.lua index 099ed92..744235e 100644 --- a/lua/ssr/utils.lua +++ b/lua/ssr/utils.lua @@ -1,8 +1,14 @@ local api = vim.api local ts = vim.treesitter +local config = require "ssr.config" local M = {} +M.capture_prefix = "__ssr_capture_" +M.namespace = api.nvim_create_namespace "ssr_ns" +M.cur_search_ns = api.nvim_create_namespace "ssr_cur_search_ns" +M.augroup = api.nvim_create_augroup("ssr_augroup", {}) + -- Send a notification titled SSR. ---@param msg string ---@return nil @@ -11,14 +17,14 @@ function M.notify(msg) end -- Get (0,0)-indexed cursor position. ----@param win window +---@param win integer function M.get_cursor(win) local cursor = api.nvim_win_get_cursor(win) return cursor[1] - 1, cursor[2] end -- Set (0,0)-indexed cursor position. ----@param win window +---@param win integer ---@param row integer ---@param col integer function M.set_cursor(win, row, col) @@ -26,8 +32,8 @@ function M.set_cursor(win, row, col) end -- Get selected region, works in many modes. ----@param win window ----@return number, number, number, number +---@param win integer +---@return integer, number, number, number function M.get_selection(win) local mode = api.nvim_get_mode().mode local cursor_row, cursor_col = M.get_cursor(win) @@ -54,12 +60,12 @@ function M.get_selection(win) end -- Get smallest node for the range. ----@param buf buffer +---@param buf integer ---@param lang string ----@param start_row number ----@param start_col number ----@param end_row number ----@param end_col number +---@param start_row integer +---@param start_col integer +---@param end_row integer +---@param end_col integer ---@return TSNode? function M.node_for_range(buf, lang, start_row, start_col, end_row, end_col) local has_parser, parser = pcall(ts.get_parser, buf, lang) @@ -68,8 +74,8 @@ function M.node_for_range(buf, lang, start_row, start_col, end_row, end_col) end end ----@param buf buffer ----@param row number +---@param buf integer +---@param row integer function M.get_indent(buf, row) local line = api.nvim_buf_get_lines(buf, row, row + 1, true)[1] return line:match "^%s*" @@ -92,25 +98,15 @@ function M.remove_indent(lines, indent) end end --- Escape special characters in s and quote it in double quotes. ----@param s string -function M.to_ts_query_str(s) - s = s:gsub([[\]], [[\\]]) - s = s:gsub([["]], [[\"]]) - s = s:gsub("\n", [[\n]]) - return '"' .. s .. '"' -end - -- Compute window size to show giving lines. ---@param lines string[] ----@param config Config ----@return number ----@return number -function M.get_win_size(lines, config) - ---@param i number - ---@param min number - ---@param max number - ---@return number +---@return integer +---@return integer +function M.get_win_size(lines) + ---@param i integer + ---@param min integer + ---@param max integer + ---@return integer local function clamp(i, min, max) return math.min(math.max(i, min), max) end @@ -122,9 +118,66 @@ function M.get_win_size(lines, config) end end - width = clamp(width, config.min_width, config.max_width) - local height = clamp(#lines, config.min_height, config.max_height) + width = clamp(width, config.opts.min_width, config.opts.max_width) + local height = clamp(#lines, config.opts.min_height, config.opts.max_height) return width, height end +-- Escapes all special characters in s. +-- The string returned may be safely used as a string content in a TS query. +---@param s string +function M.ts_str_escape(s) + s = s:gsub([[\]], [[\\]]) + s = s:gsub([["]], [[\"]]) + s = s:gsub("\n", [[\n]]) + return s +end + +-- https://github.com/rust-lang/regex/blob/17284451f10aa06c6c42e622e3529b98513901a8/regex-syntax/src/lib.rs#L272 +local regex_meta_chars = { + ["\\"] = true, + ["."] = true, + ["+"] = true, + ["*"] = true, + ["?"] = true, + ["("] = true, + [")"] = true, + ["|"] = true, + ["["] = true, + ["]"] = true, + ["{"] = true, + ["}"] = true, + ["^"] = true, + ["$"] = true, + ["#"] = true, + ["&"] = true, + ["-"] = true, + ["~"] = true, +} + +-- Escapes all regular expression meta characters in s. +-- The string returned may be safely used as a literal in a regular expression. +---@param s string +---@return string +function M.regex_escape(s) + local escaped = s:gsub(".", function(ch) + return regex_meta_chars[ch] and "\\" .. ch or ch + end) + return escaped +end + +---@generic T +---@param list table +---@param start integer 0-based +---@param end_ integer exclusive +---@param replacement table +function M.list_replace(list, start, end_, replacement) + for _ = start + 1, end_ do + table.remove(list, start + 1) + end + for i = start + 1, start + #replacement do + table.insert(list, i, replacement[i - start]) + end +end + return M diff --git a/tests/ssr_spec.lua b/tests/ssr_spec.lua index f6fcd17..ff42cdd 100644 --- a/tests/ssr_spec.lua +++ b/tests/ssr_spec.lua @@ -1,99 +1,116 @@ -local u = require "ssr.utils" -local ParseContext = require("ssr.parse").ParseContext local ts = vim.treesitter -local search = require("ssr.search").search -local replace = require("ssr.search").replace +local uv = vim.uv or vim.loop +local Searcher = require "ssr.search" +local pin_matches = require("ssr.replace").pin_matches +local replace = require("ssr.replace").replace +local File = require "ssr.file" +---@type string[] local tests = {} +---@param s string local function t(s) table.insert(tests, s) end -t [[ python operators - -a - b +t [[ operators +a = b + c ==>> x + +==== t.py +a = b + c +a = b - c +a = b or c ==== -a + b ==>> (+ a b) +x +a = b - c +a = b or c +]] + +t [[ operators 2 +a = b + c ==>> x + +==== t.lua +a = b + c +a = b - c +a = b or c ==== -(+ a b) -a - b +x +a = b - c +a = b or c ]] -t [[ python complex string -<""" +t [[ complex string +""" line 1 -\r\n\a\?\\ +\r\n\a\\ 'a'"'"'b' -"""> -==== +""" ==>> x +==== t.py """ line 1 -\r\n\a\?\\ +\r\n\a\\ 'a'"'"'b' """ -==>> -x ==== x ]] -t [[ javascript keywords - -const a = 1 -==== +t [[ keywords let a = 1 ==>> x +==== t.js +let a = 1 +const a = 1 ==== x const a = 1 ]] -t [[ lua func args - -f(1, 3) -==== +t [[ func args f(1, 3) ==>> x +==== t.lua +f(1, 2, 3) +f(1, 3) ==== f(1, 2, 3) x ]] -t [[ lua recursive 1 - -==== +t [[ recursive 1 f($a) ==>> $a.f() +==== recursive.lua +f(f(f(0))) ==== 0.f().f().f() ]] -t [[ rust recursive 2 -f(f(, 2), 3) -==== +t [[ recursive 2 f($a, $b) ==>> $a.f($b) +==== t.rs +f(f(f(0, 1), 2), 3) ==== 0.f(1).f(2).f(3) ]] -t [[ rust recursive 3 -f(3, f(2, )) -==== +t [[ recursive 3 f($a, $b) ==>> $a.f($b) +==== t.rs +f(3, f(2, f(1, 0))) ==== 3.f(2.f(1.f(0))) ]] -t [[ python indent 1 -def f(): - -==== +t [[ indent 1 if $a: $b ==>> if $a: if True: $b +==== t.py +def f(): + if foo: + if bar: + pass ==== def f(): if foo: @@ -103,18 +120,18 @@ def f(): pass ]] -t [[ python indent 2 -def f(): - if len(a) != 0: - do_a(a) - -==== +t [[ indent 2 if len($a) != 0: $b ==>> if $a: $b +==== t.py +def f(): + if len(a) != 0: + do_a(a) + if len(b) != 0: + do_b(b) ==== def f(): if a: @@ -123,59 +140,45 @@ def f(): do_b(b) ]] -t [[ rust question mark -let foo = ; -==== +t [[ question mark $a? ==>> try!($a) +==== t.rs +let foo = bar().await?; ==== let foo = try!(bar().await); ]] -t [[ rust rust-analyzer ssr example -String::from() -==== +t [[ rust-analyzer ssr example foo($a, $b) ==>> ($a).foo($b) +==== t.rs +String::from(foo(y + 5, z)) ==== String::from((y + 5).foo(z)) ]] -t [[ go parse Go := in function -func main() { - -} -==== -$a, _ := os.LookupEnv($b) -==>> -$a := os.Getenv($b) -==== -func main() { - commit := os.Getenv("GITHUB_SHA") -} -]] - -t [[ go match Go if err +t [[ match Go if err +if err != nil { panic(err) } ==>> x +==== t.go fn main() { - + } } ==== -if err != nil { panic(err) } ==>> x -==== fn main() { x } ]] -t [[ rust reused wildcard: compound assignments -; +t [[ reused capture: compound assignments +$a = $a + $b; ==>> $a += $b; +==== t.rs +idx = idx + 1; bar = foo + idx; *foo.bar() = * foo . bar () + 1; (foo + bar) = (foo + bar) + 1; (foo + bar) = (foo - bar) + 1; ==== -$a = $a + $b ==>> $a += $b -==== idx += 1; bar = foo + idx; *foo.bar() += 1; @@ -183,18 +186,18 @@ bar = foo + idx; (foo + bar) = (foo - bar) + 1; ]] -t [[ python reused wildcard: indent -def f(): - -==== +t [[ reused capture: indent if $foo: if $foo: $body ==>> if $foo: $body +==== t.py +def f(): + if await foo.bar(baz): + if await foo.bar(baz): + pass ==== def f(): if await foo.bar(baz): @@ -202,74 +205,101 @@ def f(): ]] -- two `foo`s have different type: `property_identifier` and `identifier` -t [[ javascript reused wildcard: match different node types 1 -<{ foo: foo }> -{ foo: bar } -==== +t [[ reused capture: match different node types 1 { $a: $a } ==>> { $a } +==== t.js +{ foo: foo } +{ foo: bar } ==== { foo } { foo: bar } ]] -t [[ lua reused wildcard: match different node types 2 - -local a = vim.api -==== +t [[ reused capture: match different node types 2 local $a = vim.$a ==>> x +==== t.lua +local api = vim.api +local a = vim.api ==== x local a = vim.api ]] +t [[ multiple files +local $a = vim.$a ==>> _G.g_$a = vim.$a + +==== t.lua +local api = vim.api +local fn = vim.fn +==== +_G.g_api = vim.api +_G.g_fn = vim.fn + +==== README.md +# Example +```lua +local F = vim.F +local uv = vim.uv +``` +==== +# Example +```lua +_G.g_F = vim.F +_G.g_uv = vim.uv +``` +]] + describe("", function() -- Plenary runs nvim with `--noplugin` argument. - -- Make sure nvim-treesitter is loaded, which populates vim.treesitter's ft_to_lang table. + -- Load nvim-treesitter to make `ts.language.get_lang()` work. require "nvim-treesitter" for _, s in ipairs(tests) do - local ft, desc, content, pattern, template, expected = - s:match "^ (%a-) (.-)\n(.-)%s?====%s?(.-)%s?==>>%s?(.-)%s?====%s?(.-)%s?$" - content = vim.split(content, "\n") - expected = vim.split(expected, "\n") - local start_row, start_col, end_row, end_col - for idx, line in ipairs(content) do - local col = line:find "<" - if col then - start_row = idx - 1 - start_col = col - 1 - end - line = line:gsub("<", "") - col = line:find ">" - if col then - end_row = idx - 1 - end_col = col - 1 - end - line = line:gsub(">", "") - content[idx] = line - end - + local desc, pattern, template, rest = s:match "^ (.-)\n(.-)%s?==>>%s?(.-)\n%s*==(.-)$" it(desc, function() - local ns = vim.api.nvim_create_namespace "" - local buf = vim.api.nvim_create_buf(false, true) - vim.bo[buf].filetype = ft - vim.api.nvim_buf_set_lines(buf, 0, -1, true, content) - local lang = ts.language.get_lang(vim.bo[buf].filetype) - assert(lang, "language not found") - local origin_node = u.node_for_range(buf, lang, start_row, start_col, end_row, end_col) + local dir = vim.fn.tempname() + assert(uv.fs_mkdir(dir, 448)) + + local expected_files = {} + local lang + for fname, before, after in (rest .. "=="):gmatch "== (.-)\n(.-)====\n(.-)%s*==" do + after = after .. "\n" -- Vim always adds a \n to files. + fname = vim.fs.joinpath(dir, fname) + local fd = assert(uv.fs_open(fname, "w", 438)) + assert(uv.fs_write(fd, before) > 0) + assert(uv.fs_close(fd)) + expected_files[fname] = after + lang = lang or assert(ts.language.get_lang(vim.filetype.match { filename = fname })) + end - local parse_context = ParseContext.new(buf, origin_node) - assert(parse_context) - local node, source = parse_context:parse(pattern) - local matches = search(buf, node, source, ns) + local searcher, rough_regex = assert(Searcher.new(lang, pattern)) + ---@type ssr.Matches + local results = {} + local done = false + File.grep(dir, rough_regex, function(file) + local matches = searcher:search(file) + assert.is_true(#matches > 0) + table.insert(results, { file = file, matches = matches }) + end, function() + done = true + end) + vim.wait(1000, function() + return done + end) - for _, match in ipairs(matches) do - replace(buf, match, template) + local pinned_matches = pin_matches(results) + for _, match in ipairs(pinned_matches) do + replace(match, template) end - local actual = vim.api.nvim_buf_get_lines(buf, 0, -1, true) - vim.api.nvim_buf_delete(buf, {}) - assert.are.same(expected, actual) + vim.cmd "silent noautocmd wa" + for fname, expected in pairs(expected_files) do + local fd = assert(uv.fs_open(fname, "r", 438)) + local stat = assert(uv.fs_fstat(fd)) + local actual = uv.fs_read(fd, stat.size, 0) + uv.fs_close(fd) + assert.are.same(expected, actual) + end end) end end)