Skip to content

Commit

Permalink
Add return type validation
Browse files Browse the repository at this point in the history
  • Loading branch information
NiiRoZz committed Nov 30, 2024
1 parent d8c0fa9 commit 835e3b0
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/NZSL/Ast/SanitizeVisitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ namespace nzsl::Ast
StatementPtr Clone(ForEachStatement& node) override;
StatementPtr Clone(ImportStatement& node) override;
StatementPtr Clone(MultiStatement& node) override;
StatementPtr Clone(ReturnStatement& node) override;
StatementPtr Clone(ScopedStatement& node) override;
StatementPtr Clone(WhileStatement& node) override;

Expand Down Expand Up @@ -171,6 +172,7 @@ namespace nzsl::Ast
ValidationResult TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right, const SourceLocation& sourceLocation);

ValidationResult Validate(DeclareAliasStatement& node);
ValidationResult Validate(ReturnStatement& node);
ValidationResult Validate(WhileStatement& node);

ValidationResult Validate(AccessIndexExpression& node);
Expand Down
2 changes: 2 additions & 0 deletions include/NZSL/Lang/ErrorList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnexpectedEntryFunction, "{} is an en
NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnmatchingParameterCount, "function {} expects {} parameter(s), but got {}", std::string, std::uint32_t, std::uint32_t)
NZSL_SHADERLANG_COMPILER_ERROR(FunctionCallUnmatchingParameterType, "function {} parameter #{} type mismatch (expected {}, got {})", std::string, std::uint32_t, std::string, std::string)
NZSL_SHADERLANG_COMPILER_ERROR(FunctionDeclarationInsideFunction, "a function cannot be defined inside another function")
NZSL_SHADERLANG_COMPILER_ERROR(FunctionReturnStatementWithAValue, "return-statement with a value, in function returning no value")
NZSL_SHADERLANG_COMPILER_ERROR(FunctionReturnStatementWithNoValue, "return-statement with no value, in function returning {}", std::string)
NZSL_SHADERLANG_COMPILER_ERROR(IdentifierAlreadyUsed, "identifier {} is already used", std::string)
NZSL_SHADERLANG_COMPILER_ERROR(ImportIdentifierAlreadyPresent, "{} identifier was already imported", std::string)
NZSL_SHADERLANG_COMPILER_ERROR(ImportIdentifierNotFound, "identifier {} not found in module {}", std::string, std::string)
Expand Down
41 changes: 41 additions & 0 deletions src/NZSL/Ast/SanitizeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2738,6 +2738,15 @@ NAZARA_WARNING_POP()
return clone;
}

StatementPtr SanitizeVisitor::Clone(ReturnStatement& node)
{
auto clone = Nz::StaticUniquePointerCast<ReturnStatement>(Cloner::Clone(node));

Validate(*clone);

return clone;
}

StatementPtr SanitizeVisitor::Clone(ScopedStatement& node)
{
MandatoryStatement(node.statement, node.sourceLocation);
Expand Down Expand Up @@ -4162,6 +4171,38 @@ NAZARA_WARNING_POP()
return ValidationResult::Validated;
}

auto SanitizeVisitor::Validate(ReturnStatement& node) -> ValidationResult
{
Nz::Assert(m_context->currentFunction);

auto& function = m_context->currentFunction;
if (!function->node->returnType.IsResultingValue())
return ValidationResult::Unresolved;

const bool functionHasNoReturnType = std::holds_alternative<NoType>(function->node->returnType.GetResultingValue());

if (!node.returnExpr)
{
if (!functionHasNoReturnType)
throw CompilerFunctionReturnStatementWithNoValueError{ node.sourceLocation, ToString(function->node->returnType.GetResultingValue(), node.sourceLocation) };

//If node doesn't have an expression and function doesn't have return type, then we can directly validate
return ValidationResult::Validated;
}

if (functionHasNoReturnType)
throw CompilerFunctionReturnStatementWithAValueError{ node.sourceLocation };

const ExpressionType* returnTypeOpt = GetExpressionType(MandatoryExpr(node.returnExpr, node.sourceLocation));
if (!returnTypeOpt)
return ValidationResult::Unresolved;

ExpressionType returnType = ResolveType(*returnTypeOpt, true, node.sourceLocation);
TypeMustMatch(returnType, function->node->returnType.GetResultingValue(), node.sourceLocation);

return ValidationResult::Validated;
}

auto SanitizeVisitor::Validate(WhileStatement& node) -> ValidationResult
{
const ExpressionType* conditionType = GetExpressionType(MandatoryExpr(node.condition, node.sourceLocation));
Expand Down
41 changes: 38 additions & 3 deletions tests/src/Tests/ErrorsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ external
data: uniform[Outer]
}
fn GetValue(data: Inner) -> i32
fn GetValue(data: Inner) -> f32
{
return data.value;
}
Expand Down Expand Up @@ -625,9 +625,9 @@ external
data: uniform[Outer]
}
fn GetValue(data: array[Inner, 3]) -> i32
fn GetValue(data: array[Inner, 3]) -> f32
{
return data[1];
return data[1].value;
}
fn main()
Expand Down Expand Up @@ -667,6 +667,41 @@ external

/************************************************************************/

SECTION("Functions")
{
CHECK_THROWS_WITH(Compile(R"(
[nzsl_version("1.0")]
module;
fn test() -> i32
{
return 42.666;
}
)"), "(7,2 -> 15): CUnmatchingTypes error: left expression type (f32) doesn't match right expression type (i32)");

CHECK_THROWS_WITH(Compile(R"(
[nzsl_version("1.0")]
module;
fn test() -> i32
{
return;
}
)"), "(7,2 -> 8): CFunctionReturnStatementWithNoValue error: return-statement with no value, in function returning i32");

CHECK_THROWS_WITH(Compile(R"(
[nzsl_version("1.0")]
module;
fn test()
{
return 10;
}
)"), "(7,2 -> 11): CFunctionReturnStatementWithAValue error: return-statement with a value, in function returning no value");
}

/************************************************************************/

SECTION("Import")
{
CHECK_THROWS_WITH(Compile(R"(
Expand Down
2 changes: 1 addition & 1 deletion tests/src/Tests/SerializationsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ struct FragOut
}
[entry(frag)]
fn main(input: FragIn)
fn main(input: FragIn) -> FragOut
{
let output: FragOut;
output.color = tex2D(texture, input.uv);
Expand Down

0 comments on commit 835e3b0

Please sign in to comment.