Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from guppylang_internals.error import GuppyError
from guppylang_internals.experimental import check_capturing_closures_enabled
from guppylang_internals.nodes import CheckedNestedFunctionDef, NestedFunctionDef
from guppylang_internals.span import function_header_span
from guppylang_internals.tys.param import Parameter, TypeParam
from guppylang_internals.tys.parsing import (
TypeParsingCtx,
Expand Down Expand Up @@ -306,8 +307,7 @@ def check_signature(
UnsupportedError(func_def.args.defaults[0], "Default arguments")
)
if func_def.returns is None:
err = MissingReturnAnnotationError(func_def)
# TODO: Error location is incorrect
err = MissingReturnAnnotationError(function_header_span(func_def))
if all(r.value is None for r in return_nodes_in_ast(func_def)):
err.add_sub_diagnostic(
MissingReturnAnnotationError.ReturnNone(None, func_def.name)
Expand Down
30 changes: 29 additions & 1 deletion guppylang-internals/src/guppylang_internals/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import TypeAlias

from guppylang_internals.ast_util import get_file, get_line_offset
from guppylang_internals.ast_util import get_file, get_line_offset, get_source
from guppylang_internals.error import InternalGuppyError
from guppylang_internals.ipython_inspect import normalize_ipython_dummy_files

Expand Down Expand Up @@ -126,6 +126,34 @@ def to_span(x: ToSpan) -> Span:
return Span(start, end)


def function_header_span(func_def: ast.FunctionDef) -> Span:
"""Returns a span covering only the function header up to and including `:`."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also describe cases where this is not possible, i.e. what the span will cover when one cannot identify the header span. So perhaps something like (with suitable line breaks for max line length):

Suggested change
"""Returns a span covering only the function header up to and including `:`."""
"""Returns a span covering only the function header up to and including `:` on a best-effort basis, falling back to the full function definition."""

However, if you can prove that your below cases can all make the span exact (via a short argument), you do not need to add this.

start = to_span(func_def).start
source = get_source(func_def)
file = get_file(func_def)
line_offset = get_line_offset(func_def)
# `check_signature` is only called on AST nodes that have been processed by
# `annotate_location`, so source metadata is always available.
assert source is not None
assert file is not None
assert line_offset is not None

lines = source.splitlines()
line_idx = func_def.lineno - 1
paren_depth = 0
for i, line in enumerate(lines[line_idx:], start=line_idx):
col_begin = func_def.col_offset if i == line_idx else 0
for col, char in enumerate(line[col_begin:], start=col_begin):
if char == "(":
paren_depth += 1
elif char == ")":
paren_depth -= 1
elif char == ":" and paren_depth == 0:
return Span(start, Loc(file, i + line_offset, col + 1))

raise InternalGuppyError("function_header_span: Could not find header colon")


#: List of source lines in a file
SourceLines: TypeAlias = list[str]

Expand Down
4 changes: 1 addition & 3 deletions tests/error/misc_errors/return_not_annotated.err
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ Error: Missing type annotation (at $FILE:5:0)
3 |
4 | @compile_guppy
5 | def foo(x: bool):
| ^^^^^^^^^^^^^^^^^
6 | return x
| ^^^^^^^^^^^^ Return type must be annotated
| ^^^^^^^^^^^^^^^^^ Return type must be annotated

Guppy compilation failed due to 1 previous error
2 changes: 0 additions & 2 deletions tests/error/misc_errors/return_not_annotated_none1.err
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ Error: Missing type annotation (at $FILE:5:0)
3 |
4 | @compile_guppy
5 | def foo():
| ^^^^^^^^^^
6 | return
| ^^^^^^^^^^ Return type must be annotated

Help: Looks like `foo` doesn't return anything. Consider annotating it with `->
Expand Down
5 changes: 1 addition & 4 deletions tests/error/misc_errors/return_not_annotated_none2.err
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ Error: Missing type annotation (at $FILE:5:0)
3 |
4 | @compile_guppy
5 | def foo():
| ^^^^^^^^^^
| ...
7 | return x
| ^^^^^^^^^^^^^^^^ Return type must be annotated
| ^^^^^^^^^^ Return type must be annotated

Help: Looks like `foo` doesn't return anything. Consider annotating it with `->
None`.
Expand Down
42 changes: 42 additions & 0 deletions tests/test_function_header_span.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import ast

import pytest
from guppylang_internals.ast_util import annotate_location
from guppylang_internals.span import function_header_span, to_span


def _parse_func(source: str) -> ast.FunctionDef:
node = ast.parse(source).body[0]
assert isinstance(node, ast.FunctionDef)
annotate_location(node, source, "test.py", 1)
return node


def _header_text(func_def: ast.FunctionDef) -> str:
span = function_header_span(func_def)
source = func_def.source # type: ignore[attr-defined]
lines = source.splitlines()
if span.is_multiline:
parts = [lines[span.start.line - 1][span.start.column :]]
parts.extend(
lines[line_no - 1] for line_no in range(span.start.line + 1, span.end.line)
)
parts.append(lines[span.end.line - 1][: span.end.column])
return "\n".join(parts)
line = lines[span.start.line - 1]
return line[span.start.column : span.end.column]


@pytest.mark.parametrize(
("source", "expected_header"),
[
("def foo():\n return", "def foo():"),
("def foo(x: bool):\n return x", "def foo(x: bool):"),
("def foo(\n x: bool,\n):\n return x", "def foo(\n x: bool,\n):"),
("def foo() :\n return", "def foo() :"),
],
)
def test_function_header_span(source: str, expected_header: str) -> None:
func_def = _parse_func(source)
assert _header_text(func_def) == expected_header
assert function_header_span(func_def).end <= to_span(func_def).end
Loading