Skip to content

Make syntax highlight generation scale better. #5642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

- Added Widget.preflight_checks to perform some debug checks after a widget is instantiated, to catch common errors. https://github.com/Textualize/textual/pull/5588

### Fixed

- Fixed TextArea's syntax highlighting. Some highlighting details were not being
applied. For example, in CSS, the text 'padding: 10px 0;' was shown in a
single colour. Now the 'px' appears in a different colour to the rest of the
text.

- Fixed a cause of slow editing for syntax highlighed TextArea widgets with
large documents.


## [2.1.2] - 2025-02-26

### Fixed
Expand Down
42 changes: 34 additions & 8 deletions src/textual/document/_syntax_aware_document.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import ContextManager

try:
from tree_sitter import Language, Node, Parser, Query, Tree

Expand All @@ -12,6 +15,35 @@
from textual.document._document import Document, EditResult, Location, _utf8_encode


@contextmanager
def temporary_query_point_range(
query: Query,
start_point: tuple[int, int] | None,
end_point: tuple[int, int] | None,
) -> ContextManager[None]:
"""Temporarily change the start and/or end point for a tree-sitter Query.

Args:
query: The tree-sitter Query.
start_point: The (row, column byte) to start the query at.
end_point: The (row, column byte) to end the query at.
"""
# Note: Although not documented for the tree-sitter Python API, an
# end-point of (0, 0) means 'end of document'.
default_point_range = [(0, 0), (0, 0)]

point_range = list(default_point_range)
if start_point is not None:
point_range[0] = start_point
if end_point is not None:
point_range[1] = end_point
query.set_point_range(point_range)
try:
yield None
finally:
query.set_point_range(default_point_range)


class SyntaxAwareDocumentError(Exception):
"""General error raised when SyntaxAwareDocument is used incorrectly."""

Expand Down Expand Up @@ -128,14 +160,8 @@ def query_syntax_tree(
"tree-sitter is not available on this architecture."
)

captures_kwargs = {}
if start_point is not None:
captures_kwargs["start_point"] = start_point
if end_point is not None:
captures_kwargs["end_point"] = end_point

captures = query.captures(self._syntax_tree.root_node, **captures_kwargs)
return captures
with temporary_query_point_range(query, start_point, end_point):
return query.captures(self._syntax_tree.root_node)

def replace_range(self, start: Location, end: Location, text: str) -> EditResult:
"""Replace text at the given range.
Expand Down
146 changes: 110 additions & 36 deletions src/textual/widgets/_text_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,105 @@ class LanguageDoesNotExist(Exception):
"""


class HighlightMap:
"""Lazy evaluated pseudo dictionary mapping lines to highlight information.

This allows TextArea syntax highlighting to scale.

Args:
text_area_widget: The associated `TextArea` widget.
"""

BLOCK_SIZE = 50

def __init__(self, text_area: TextArea):
self.text_area: TextArea = text_area
"""The text area associated with this highlight map."""

self._highlighted_blocks: set[int] = set()
"""The set of blocks that have been highlighted. Each block covers BLOCK_SIZE
lines.
"""

self._highlights: dict[int, list[Highlight]] = defaultdict(list)
"""A mapping from line index to a list of Highlight instances."""

def reset(self) -> None:
"""Reset so that future lookups rebuild the highlight map."""
self._highlights.clear()
self._highlighted_blocks.clear()

@property
def document(self) -> DocumentBase:
"""The text document being highlighted."""
return self.text_area.document

def __getitem__(self, index: int) -> list[Highlight]:
block_index = index // self.BLOCK_SIZE
if block_index not in self._highlighted_blocks:
self._highlighted_blocks.add(block_index)
self._build_part_of_highlight_map(block_index * self.BLOCK_SIZE)
return self._highlights[index]

def _build_part_of_highlight_map(self, start_index: int) -> None:
"""Build part of the highlight map.

Args:
start_index: The start of the block of line for which to build the map.
"""
highlights = self._highlights
start_point = (start_index, 0)
end_index = min(self.document.line_count, start_index + self.BLOCK_SIZE)
end_point = (end_index, 0)
captures = self.document.query_syntax_tree(
self.text_area._highlight_query,
start_point=start_point,
end_point=end_point,
)
for highlight_name, nodes in captures.items():
for node in nodes:
node_start_row, node_start_column = node.start_point
node_end_row, node_end_column = node.end_point
if node_start_row == node_end_row:
highlight = node_start_column, node_end_column, highlight_name
highlights[node_start_row].append(highlight)
else:
# Add the first line of the node range
highlights[node_start_row].append(
(node_start_column, None, highlight_name)
)

# Add the middle lines - entire row of this node is highlighted
middle_highlight = (0, None, highlight_name)
for node_row in range(node_start_row + 1, node_end_row):
highlights[node_row].append(middle_highlight)

# Add the last line of the node range
highlights[node_end_row].append(
(0, node_end_column, highlight_name)
)

# The highlights for each line need to be sorted. Each highlight is of
# the form:
#
# a, b, highlight-name
#
# Where a is a number and b is a number or ``None``. These highlights need
# to be sorted in ascending order of ``a``. When two highlights have the same
# value of ``a`` then the one with the larger a--b range comes first, with ``None``
# being considered larger than any number.
def sort_key(highlight: Highlight) -> tuple[int, int, int]:
a, b, _ = highlight
max_range_index = 1
if b is None:
max_range_index = 0
b = a
return a, max_range_index, a - b

for line_index in range(start_index, end_index):
highlights.get(line_index, []).sort(key=sort_key)


@dataclass
class TextAreaLanguage:
"""A container for a language which has been registered with the TextArea.
Expand Down Expand Up @@ -456,15 +555,15 @@ def __init__(
cursor is currently at. If the cursor is at a bracket, or there's no matching
bracket, this will be `None`."""

self._highlights: dict[int, list[Highlight]] = defaultdict(list)
"""Mapping line numbers to the set of highlights for that line."""

self._highlight_query: "Query | None" = None
"""The query that's currently being used for highlighting."""

self.document: DocumentBase = Document(text)
"""The document this widget is currently editing."""

self._highlights: HighlightMap = HighlightMap(self)
"""Mapping line numbers to the set of highlights for that line."""

self.wrapped_document: WrappedDocument = WrappedDocument(self.document)
"""The wrapped view of the document."""

Expand Down Expand Up @@ -592,36 +691,11 @@ def check_consume_key(self, key: str, character: str | None = None) -> bool:
# Otherwise we capture all printable keys
return character is not None and character.isprintable()

def _build_highlight_map(self) -> None:
"""Query the tree for ranges to highlights, and update the internal highlights mapping."""
highlights = self._highlights
highlights.clear()
if not self._highlight_query:
return

captures = self.document.query_syntax_tree(self._highlight_query)
for highlight_name, nodes in captures.items():
for node in nodes:
node_start_row, node_start_column = node.start_point
node_end_row, node_end_column = node.end_point

if node_start_row == node_end_row:
highlight = (node_start_column, node_end_column, highlight_name)
highlights[node_start_row].append(highlight)
else:
# Add the first line of the node range
highlights[node_start_row].append(
(node_start_column, None, highlight_name)
)

# Add the middle lines - entire row of this node is highlighted
for node_row in range(node_start_row + 1, node_end_row):
highlights[node_row].append((0, None, highlight_name))
def _reset_highlights(self) -> None:
"""Reset the lazily evaluated highlight map."""

# Add the last line of the node range
highlights[node_end_row].append(
(0, node_end_column, highlight_name)
)
if self._highlight_query:
self._highlights.reset()

def _watch_has_focus(self, focus: bool) -> None:
self._cursor_visible = focus
Expand Down Expand Up @@ -935,7 +1009,7 @@ def _set_document(self, text: str, language: str | None) -> None:
self.document = document
self.wrapped_document = WrappedDocument(document, tab_width=self.indent_width)
self.navigator = DocumentNavigator(self.wrapped_document)
self._build_highlight_map()
self._reset_highlights()
self.move_cursor((0, 0))
self._rewrap_and_refresh_virtual_size()

Expand Down Expand Up @@ -1348,7 +1422,7 @@ def edit(self, edit: Edit) -> EditResult:

self._refresh_size()
edit.after(self)
self._build_highlight_map()
self._reset_highlights()
self.post_message(self.Changed(self))
return result

Expand Down Expand Up @@ -1411,7 +1485,7 @@ def _undo_batch(self, edits: Sequence[Edit]) -> None:
self._refresh_size()
for edit in reversed(edits):
edit.after(self)
self._build_highlight_map()
self._reset_highlights()
self.post_message(self.Changed(self))

def _redo_batch(self, edits: Sequence[Edit]) -> None:
Expand Down Expand Up @@ -1459,7 +1533,7 @@ def _redo_batch(self, edits: Sequence[Edit]) -> None:
self._refresh_size()
for edit in edits:
edit.after(self)
self._build_highlight_map()
self._reset_highlights()
self.post_message(self.Changed(self))

async def _on_key(self, event: events.Key) -> None:
Expand Down
Loading
Loading