Skip to content

Commit

Permalink
Sync to upstream/release/652 (#1525)
Browse files Browse the repository at this point in the history
## What's new?

* Add support for mixed-mode type checking, which allows modules checked
in the old type solver to be checked and autocompleted by the new one.
* Generalize `RequireResolver` to support require-by-string semantics in
`luau-analyze`.
* Fix a bug in incremental autocomplete where `DefId`s associated with
index expressions were not correctly picked up.
* Fix a bug that prevented "complex" types in generic parameters (for
example, `local x: X<(() -> ())?>`).

### Issues fixed
* #1507
* #1518

---

Internal Contributors:

Co-authored-by: Aaron Weiss <aaronweiss@roblox.com>
Co-authored-by: Andy Friesen <afriesen@roblox.com>
Co-authored-by: Hunter Goldstein <hgoldstein@roblox.com>
Co-authored-by: Varun Saini <vsaini@roblox.com>
Co-authored-by: Vighnesh Vijay <vvijay@roblox.com>
6 people authored Nov 15, 2024
1 parent d1025d0 commit e905e30
Showing 44 changed files with 1,750 additions and 905 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
/luau
/luau-tests
/luau-analyze
/luau-bytecode
/luau-compile
__pycache__
.cache
19 changes: 0 additions & 19 deletions Analysis/include/Luau/DataFlowGraph.h
Original file line number Diff line number Diff line change
@@ -126,25 +126,6 @@ struct DataFlowGraphBuilder
NotNull<InternalErrorReporter> handle
);

/**
* Takes a stale graph along with a list of scopes, a small fragment of the ast, and a cursor position
* and constructs the DataFlowGraph for just that fragment. This method will fabricate defs in the final
* DFG for things that have been referenced and exist in the stale dfg.
* For example, the fragment local z = x + y will populate defs for x and y from the stale graph.
* @param staleGraph - the old DFG
* @param scopes - the old DfgScopes in the graph
* @param fragment - the Ast Fragment to re-build the root for
* @param cursorPos - the current location of the cursor - used to determine which scope we are currently in
* @param handle - for internal compiler errors
*/
static DataFlowGraph updateGraph(
const DataFlowGraph& staleGraph,
const std::vector<std::unique_ptr<DfgScope>>& scopes,
AstStatBlock* fragment,
const Position& cursorPos,
NotNull<InternalErrorReporter> handle
);

private:
DataFlowGraphBuilder() = default;

13 changes: 10 additions & 3 deletions Analysis/include/Luau/FragmentAutocomplete.h
Original file line number Diff line number Diff line change
@@ -49,14 +49,20 @@ struct FragmentAutocompleteResult

FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos);

FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos);
FragmentParseResult parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition
);

FragmentTypeCheckResult typecheckFragment(
Frontend& frontend,
const ModuleName& moduleName,
const Position& cursorPos,
std::optional<FrontendOptions> opts,
std::string_view src
std::string_view src,
std::optional<Position> fragmentEndPosition
);

FragmentAutocompleteResult fragmentAutocomplete(
@@ -65,7 +71,8 @@ FragmentAutocompleteResult fragmentAutocomplete(
const ModuleName& moduleName,
Position cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback
StringCompletionCallback callback,
std::optional<Position> fragmentEndPosition = std::nullopt
);


3 changes: 3 additions & 0 deletions Analysis/include/Luau/Module.h
Original file line number Diff line number Diff line change
@@ -68,6 +68,9 @@ struct Module
{
~Module();

// TODO: Clip this when we clip FFlagLuauSolverV2
bool checkedInNewSolver = false;

ModuleName name;
std::string humanReadableName;

1 change: 0 additions & 1 deletion Analysis/src/AutocompleteCore.cpp
Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@

LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions2)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTINT(LuauTypeInferIterationLimit)
LUAU_FASTINT(LuauTypeInferRecursionLimit)

1 change: 0 additions & 1 deletion Analysis/src/BuiltinDefinitions.cpp
Original file line number Diff line number Diff line change
@@ -29,7 +29,6 @@
*/

LUAU_FASTFLAG(LuauSolverV2)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2)
LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix)

209 changes: 51 additions & 158 deletions Analysis/src/ConstraintGenerator.cpp
Original file line number Diff line number Diff line change
@@ -32,7 +32,6 @@ LUAU_FASTINT(LuauCheckRecursionLimit)
LUAU_FASTFLAG(DebugLuauLogSolverToJson)
LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(DebugLuauEqSatSimplification)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAG(LuauTypestateBuiltins2)

LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues)
@@ -225,8 +224,7 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block)

Checkpoint start = checkpoint(this);

ControlFlow cf =
DFInt::LuauTypeSolverRelease >= 646 ? visitBlockWithoutChildScope(scope, block) : visitBlockWithoutChildScope_DEPRECATED(scope, block);
ControlFlow cf = visitBlockWithoutChildScope(scope, block);
if (cf == ControlFlow::None)
addConstraint(scope, block->location, PackSubtypeConstraint{builtinTypes->emptyTypePack, rootScope->returnType});

@@ -876,123 +874,6 @@ ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& sco
return firstControlFlow.value_or(ControlFlow::None);
}

ControlFlow ConstraintGenerator::visitBlockWithoutChildScope_DEPRECATED(const ScopePtr& scope, AstStatBlock* block)
{
RecursionCounter counter{&recursionCount};

if (recursionCount >= FInt::LuauCheckRecursionLimit)
{
reportCodeTooComplex(block->location);
return ControlFlow::None;
}

std::unordered_map<Name, Location> aliasDefinitionLocations;

// In order to enable mutually-recursive type aliases, we need to
// populate the type bindings before we actually check any of the
// alias statements.
for (AstStat* stat : block->body)
{
if (auto alias = stat->as<AstStatTypeAlias>())
{
if (scope->exportedTypeBindings.count(alias->name.value) || scope->privateTypeBindings.count(alias->name.value))
{
auto it = aliasDefinitionLocations.find(alias->name.value);
LUAU_ASSERT(it != aliasDefinitionLocations.end());
reportError(alias->location, DuplicateTypeDefinition{alias->name.value, it->second});
continue;
}

// A type alias might have no name if the code is syntactically
// illegal. We mustn't prepopulate anything in this case.
if (alias->name == kParseNameError || alias->name == "typeof")
continue;

ScopePtr defnScope = childScope(alias, scope);

TypeId initialType = arena->addType(BlockedType{});
TypeFun initialFun{initialType};

for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true))
{
initialFun.typeParams.push_back(gen);
}

for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true))
{
initialFun.typePackParams.push_back(genPack);
}

if (alias->exported)
scope->exportedTypeBindings[alias->name.value] = std::move(initialFun);
else
scope->privateTypeBindings[alias->name.value] = std::move(initialFun);

astTypeAliasDefiningScopes[alias] = defnScope;
aliasDefinitionLocations[alias->name.value] = alias->location;
}
else if (auto function = stat->as<AstStatTypeFunction>())
{
// If a type function w/ same name has already been defined, error for having duplicates
if (scope->exportedTypeBindings.count(function->name.value) || scope->privateTypeBindings.count(function->name.value))
{
auto it = aliasDefinitionLocations.find(function->name.value);
LUAU_ASSERT(it != aliasDefinitionLocations.end());
reportError(function->location, DuplicateTypeDefinition{function->name.value, it->second});
continue;
}

if (scope->parent != globalScope)
{
reportError(function->location, GenericError{"Local user-defined functions are not supported yet"});
continue;
}

ScopePtr defnScope = childScope(function, scope);

// Create TypeFunctionInstanceType

std::vector<TypeId> typeParams;
typeParams.reserve(function->body->args.size);

std::vector<GenericTypeDefinition> quantifiedTypeParams;
quantifiedTypeParams.reserve(function->body->args.size);

for (size_t i = 0; i < function->body->args.size; i++)
{
std::string name = format("T%zu", i);
TypeId ty = arena->addType(GenericType{name});
typeParams.push_back(ty);

GenericTypeDefinition genericTy{ty};
quantifiedTypeParams.push_back(genericTy);
}

if (std::optional<std::string> error = typeFunctionRuntime->registerFunction(function))
reportError(function->location, GenericError{*error});

TypeId typeFunctionTy =
arena->addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, {}});

TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy};

// Set type bindings and definition locations for this user-defined type function
scope->privateTypeBindings[function->name.value] = std::move(typeFunction);
aliasDefinitionLocations[function->name.value] = function->location;
}
}

std::optional<ControlFlow> firstControlFlow;
for (AstStat* stat : block->body)
{
ControlFlow cf = visit(scope, stat);
if (cf != ControlFlow::None && !firstControlFlow)
firstControlFlow = cf;
}

return firstControlFlow.value_or(ControlFlow::None);
}

ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStat* stat)
{
RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit};
@@ -1336,10 +1217,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatRepeat* rep
{
ScopePtr repeatScope = childScope(repeat, scope);

if (DFInt::LuauTypeSolverRelease >= 646)
visitBlockWithoutChildScope(repeatScope, repeat->body);
else
visitBlockWithoutChildScope_DEPRECATED(repeatScope, repeat->body);
visitBlockWithoutChildScope(repeatScope, repeat->body);

check(repeatScope, repeat->condition);

@@ -1513,8 +1391,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatBlock* bloc
{
ScopePtr innerScope = childScope(block, scope);

ControlFlow flow = DFInt::LuauTypeSolverRelease >= 646 ? visitBlockWithoutChildScope(innerScope, block)
: visitBlockWithoutChildScope_DEPRECATED(innerScope, block);
ControlFlow flow = visitBlockWithoutChildScope(innerScope, block);

// An AstStatBlock has linear control flow, i.e. one entry and one exit, so we can inherit
// all the changes to the environment occurred by the statements in that block.
@@ -1705,7 +1582,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunctio
TypeFun typeFunction = bindingIt->second;

// Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver
if (auto typeFunctionTy = get<TypeFunctionInstanceType>(DFInt::LuauTypeSolverRelease >= 646 ? follow(typeFunction.type) : typeFunction.type))
if (auto typeFunctionTy = get<TypeFunctionInstanceType>(follow(typeFunction.type)))
{
TypeId expansionTy = arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments});
addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy});
@@ -3026,32 +2903,12 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
{
Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice};
std::vector<TypeId> toBlock;
if (DFInt::LuauTypeSolverRelease >= 648)
{
// This logic is incomplete as we want to re-run this
// _after_ blocked types have resolved, but this
// allows us to do some bidirectional inference.
toBlock = findBlockedTypesIn(expr, NotNull{&module->astTypes});
if (toBlock.empty())
{
matchLiteralType(
NotNull{&module->astTypes},
NotNull{&module->astExpectedTypes},
builtinTypes,
arena,
NotNull{&unifier},
*expectedType,
ty,
expr,
toBlock
);
// The visitor we ran prior should ensure that there are no
// blocked types that we would encounter while matching on
// this expression.
LUAU_ASSERT(toBlock.empty());
}
}
else
// This logic is incomplete as we want to re-run this
// _after_ blocked types have resolved, but this
// allows us to do some bidirectional inference.
toBlock = findBlockedTypesIn(expr, NotNull{&module->astTypes});

if (toBlock.empty())
{
matchLiteralType(
NotNull{&module->astTypes},
@@ -3063,7 +2920,11 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr,
ty,
expr,
toBlock
);
);
// The visitor we ran prior should ensure that there are no
// blocked types that we would encounter while matching on
// this expression.
LUAU_ASSERT(toBlock.empty());
}
}

@@ -3265,8 +3126,7 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu
void ConstraintGenerator::checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn)
{
// If it is possible for execution to reach the end of the function, the return type must be compatible with ()
ControlFlow cf =
DFInt::LuauTypeSolverRelease >= 646 ? visitBlockWithoutChildScope(scope, fn->body) : visitBlockWithoutChildScope_DEPRECATED(scope, fn->body);
ControlFlow cf = visitBlockWithoutChildScope(scope, fn->body);
if (cf == ControlFlow::None)
addConstraint(scope, fn->location, PackSubtypeConstraint{builtinTypes->emptyTypePack, scope->returnType});
}
@@ -3745,11 +3605,18 @@ struct FragmentTypeCheckGlobalPrepopulator : AstVisitor
const NotNull<Scope> globalScope;
const NotNull<Scope> currentScope;
const NotNull<const DataFlowGraph> dfg;
const NotNull<TypeArena> arena;

FragmentTypeCheckGlobalPrepopulator(NotNull<Scope> globalScope, NotNull<Scope> currentScope, NotNull<const DataFlowGraph> dfg)
FragmentTypeCheckGlobalPrepopulator(
NotNull<Scope> globalScope,
NotNull<Scope> currentScope,
NotNull<const DataFlowGraph> dfg,
NotNull<TypeArena> arena
)
: globalScope(globalScope)
, currentScope(currentScope)
, dfg(dfg)
, arena(arena)
{
}

@@ -3761,6 +3628,32 @@ struct FragmentTypeCheckGlobalPrepopulator : AstVisitor
// We only want to write into the current scope the type of the global
currentScope->lvalueTypes[def] = *ty;
}
else if (auto ty = currentScope->lookup(global->name))
{
// We are trying to create a binding for a brand new function, so we actually do have to write it into the scope.
DefId def = dfg->getDef(global);
// We only want to write into the current scope the type of the global
currentScope->lvalueTypes[def] = *ty;
}

return true;
}

bool visit(AstStatFunction* function) override
{
if (AstExprGlobal* g = function->name->as<AstExprGlobal>())
{
if (auto ty = globalScope->lookup(g->name))
{
currentScope->bindings[g->name] = Binding{*ty};
}
else
{
// Hasn't existed since a previous typecheck
TypeId bt = arena->addType(BlockedType{});
currentScope->bindings[g->name] = Binding{bt};
}
}

return true;
}
@@ -3814,7 +3707,7 @@ struct GlobalPrepopulator : AstVisitor

void ConstraintGenerator::prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program)
{
FragmentTypeCheckGlobalPrepopulator gp{NotNull{globalScope.get()}, NotNull{resumeScope.get()}, dfg};
FragmentTypeCheckGlobalPrepopulator gp{NotNull{globalScope.get()}, NotNull{resumeScope.get()}, dfg, arena};
if (prepareModuleScope)
prepareModuleScope(module->name, resumeScope);
program->visit(&gp);
71 changes: 20 additions & 51 deletions Analysis/src/ConstraintSolver.cpp
Original file line number Diff line number Diff line change
@@ -31,7 +31,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies)
LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings)
LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification)
LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations)
@@ -919,19 +918,10 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul

auto bindResult = [this, &c, constraint](TypeId result)
{
if (DFInt::LuauTypeSolverRelease >= 646)
{
auto cTarget = follow(c.target);
LUAU_ASSERT(get<PendingExpansionType>(cTarget));
shiftReferences(cTarget, result);
bind(constraint, cTarget, result);
}
else
{
LUAU_ASSERT(get<PendingExpansionType>(c.target));
shiftReferences(c.target, result);
bind(constraint, c.target, result);
}
auto cTarget = follow(c.target);
LUAU_ASSERT(get<PendingExpansionType>(cTarget));
shiftReferences(cTarget, result);
bind(constraint, cTarget, result);
};

std::optional<TypeFun> tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value)
@@ -959,7 +949,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
// Due to how pending expansion types and TypeFun's are created
// If this check passes, we have created a cyclic / corecursive type alias
// of size 0
TypeId lhs = DFInt::LuauTypeSolverRelease >= 646 ? follow(c.target) : c.target;
TypeId lhs = follow(c.target);
TypeId rhs = tf->type;
if (occursCheck(lhs, rhs))
{
@@ -1343,21 +1333,18 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
if (isBlocked(argsPack))
return true;

if (DFInt::LuauTypeSolverRelease >= 648)
// This is expensive as we need to traverse a (potentially large)
// literal up front in order to determine if there are any blocked
// types, otherwise we may run `matchTypeLiteral` multiple times,
// which right now may fail due to being non-idempotent (it
// destructively updates the underlying literal type).
auto blockedTypes = findBlockedArgTypesIn(c.callSite, c.astTypes);
for (const auto ty : blockedTypes)
{
// This is expensive as we need to traverse a (potentially large)
// literal up front in order to determine if there are any blocked
// types, otherwise we may run `matchTypeLiteral` multiple times,
// which right now may fail due to being non-idempotent (it
// destructively updates the underlying literal type).
auto blockedTypes = findBlockedArgTypesIn(c.callSite, c.astTypes);
for (const auto ty : blockedTypes)
{
block(ty, constraint);
}
if (!blockedTypes.empty())
return false;
block(ty, constraint);
}
if (!blockedTypes.empty())
return false;

// We know the type of the function and the arguments it expects to receive.
// We also know the TypeIds of the actual arguments that will be passed.
@@ -1454,17 +1441,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull<con
Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}};
std::vector<TypeId> toBlock;
(void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock);
if (DFInt::LuauTypeSolverRelease >= 648)
{
LUAU_ASSERT(toBlock.empty());
}
else
{
for (auto t : toBlock)
block(t, constraint);
if (!toBlock.empty())
return false;
}
LUAU_ASSERT(toBlock.empty());
}
}

@@ -1498,17 +1475,9 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull<con
else if (expectedType && maybeSingleton(*expectedType))
bindTo = freeType->lowerBound;

if (DFInt::LuauTypeSolverRelease >= 645)
{
auto ty = follow(c.freeType);
shiftReferences(ty, bindTo);
bind(constraint, ty, bindTo);
}
else
{
shiftReferences(c.freeType, bindTo);
bind(constraint, c.freeType, bindTo);
}
auto ty = follow(c.freeType);
shiftReferences(ty, bindTo);
bind(constraint, ty, bindTo);

return true;
}
@@ -1793,7 +1762,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull<const

if (auto lhsFree = getMutable<FreeType>(lhsType))
{
auto lhsFreeUpperBound = DFInt::LuauTypeSolverRelease >= 648 ? follow(lhsFree->upperBound) : lhsFree->upperBound;
auto lhsFreeUpperBound = follow(lhsFree->upperBound);
if (get<TableType>(lhsFreeUpperBound) || get<MetatableType>(lhsFreeUpperBound))
lhsType = lhsFreeUpperBound;
else
55 changes: 0 additions & 55 deletions Analysis/src/DataFlowGraph.cpp
Original file line number Diff line number Diff line change
@@ -182,8 +182,6 @@ DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull<InternalE
{
LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");

LUAU_ASSERT(FFlag::LuauSolverV2);

DataFlowGraphBuilder builder;
builder.handle = handle;
DfgScope* moduleScope = builder.makeChildScope(block->location);
@@ -208,8 +206,6 @@ std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>

LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");

LUAU_ASSERT(FFlag::LuauSolverV2);

DataFlowGraphBuilder builder;
builder.handle = handle;
DfgScope* moduleScope = builder.makeChildScope(block->location);
@@ -226,56 +222,6 @@ std::pair<std::shared_ptr<DataFlowGraph>, std::vector<std::unique_ptr<DfgScope>>
return {std::make_shared<DataFlowGraph>(std::move(builder.graph)), std::move(builder.scopes)};
}

DataFlowGraph DataFlowGraphBuilder::updateGraph(
const DataFlowGraph& staleGraph,
const std::vector<std::unique_ptr<DfgScope>>& scopes,
AstStatBlock* fragment,
const Position& cursorPos,
NotNull<InternalErrorReporter> handle
)
{
LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking");
LUAU_ASSERT(FFlag::LuauSolverV2);

DataFlowGraphBuilder builder;
builder.handle = handle;
// Generate a list of prepopulated locals
ReferencedDefFinder finder;
fragment->visit(&finder);
for (AstLocal* loc : finder.referencedLocalDefs)
{
if (staleGraph.localDefs.contains(loc))
{
builder.graph.localDefs[loc] = *staleGraph.localDefs.find(loc);
}
}

// Figure out which scope we should start re-accumulating DFG information from again
DfgScope* nearest = nullptr;
for (auto& sc : scopes)
{
if (nearest == nullptr || (sc->location.begin <= cursorPos && nearest->location.begin < sc->location.begin))
nearest = sc.get();
}

// The scope stack should start with the nearest enclosing scope so we can resume DFG'ing correctly
PushScope ps{builder.scopeStack, nearest};
// Conspire for the current scope in the scope stack to be a fresh dfg scope, parented to the above nearest enclosing scope, so any insertions are
// isolated there
DfgScope* scope = builder.makeChildScope(fragment->location);
PushScope psAgain{builder.scopeStack, scope};

builder.visitBlockWithoutChildScope(fragment);

if (FFlag::DebugLuauFreezeArena)
{
builder.defArena->allocator.freeze();
builder.keyArena->allocator.freeze();
}

return std::move(builder.graph);
}

void DataFlowGraphBuilder::resolveCaptures()
{
for (const auto& [_, capture] : captures)
@@ -982,7 +928,6 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c)
DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexName* i)
{
auto [parentDef, parentKey] = visitExpr(i->expr);

std::string index = i->index.value;

DefId def = lookup(parentDef, index);
11 changes: 6 additions & 5 deletions Analysis/src/EqSatSimplification.cpp
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
#include <vector>

LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplification)
LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplificationToDot)
LUAU_FASTFLAGVARIABLE(DebugLuauExtraEqSatSanityChecks)

namespace Luau::EqSatSimplification
@@ -2327,7 +2328,7 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl
int count = 0;
const int MAX_COUNT = 1000;

if (FFlag::DebugLuauLogSimplification)
if (FFlag::DebugLuauLogSimplificationToDot)
std::ofstream("begin.dot") << toDot(simplifier->stringCache, simplifier->egraph);

auto& egraph = simplifier->egraph;
@@ -2409,11 +2410,11 @@ std::optional<EqSatSimplificationResult> eqSatSimplify(NotNull<Simplifier> simpl

++count;

if (FFlag::DebugLuauLogSimplification)
{
if (isFresh)
std::cout << "count=" << std::setw(3) << count << "\t" << subst.desc << '\n';
if (FFlag::DebugLuauLogSimplification && isFresh)
std::cout << "count=" << std::setw(3) << count << "\t" << subst.desc << '\n';

if (FFlag::DebugLuauLogSimplificationToDot)
{
std::string filename = format("step%03d.dot", count);
std::ofstream(filename) << toDot(simplifier->stringCache, egraph);
}
151 changes: 129 additions & 22 deletions Analysis/src/FragmentAutocomplete.cpp
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

#include "Luau/Ast.h"
#include "Luau/AstQuery.h"
#include "Luau/Autocomplete.h"
#include "Luau/Common.h"
#include "Luau/EqSatSimplification.h"
#include "Luau/Parser.h"
@@ -21,6 +22,7 @@

#include "AutocompleteCore.h"


LUAU_FASTINT(LuauTypeInferRecursionLimit);
LUAU_FASTINT(LuauTypeInferIterationLimit);
LUAU_FASTINT(LuauTarjanChildLimit)
@@ -48,6 +50,14 @@ void copyModuleMap(Luau::DenseHashMap<K, V>& result, const Luau::DenseHashMap<K,
namespace Luau
{

static FrontendModuleResolver& getModuleResolver(Frontend& frontend, std::optional<FrontendOptions> options)
{
if (FFlag::LuauSolverV2 || !options)
return frontend.moduleResolver;

return options->forAutocomplete ? frontend.moduleResolverForAutocomplete : frontend.moduleResolver;
}

FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos)
{
std::vector<AstNode*> ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos);
@@ -93,13 +103,19 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* ro

/**
* Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that
* document and attempts to get the concrete text between those points. It returns a pair of:
* document and attempts to get the concrete text between those points. It returns a tuple of:
* - start offset that represents an index in the source `char*` corresponding to startPos
* - length, that represents how many more bytes to read to get to endPos.
* Example - your document is "foo bar baz" and getDocumentOffsets is passed (1, 4) - (1, 8). This function returns the pair {3, 7},
* which corresponds to the string " bar "
* - cursorPos, that represents the position of the cursor relative to the start offset.
* Example - your document is "foo bar baz" and getDocumentOffsets is passed (0, 4), (0, 7), (0, 8). This function returns the tuple {3, 5,
* Position{0, 4}}, which corresponds to the string " bar "
*/
std::pair<size_t, size_t> getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos)
std::tuple<size_t, size_t, Position> getDocumentOffsets(
const std::string_view& src,
const Position& startPos,
Position cursorPos,
const Position& endPos
)
{
size_t lineCount = 0;
size_t colCount = 0;
@@ -108,7 +124,12 @@ std::pair<size_t, size_t> getDocumentOffsets(const std::string_view& src, const
size_t startOffset = 0;
size_t endOffset = 0;
bool foundStart = false;
bool foundCursor = false;
bool foundEnd = false;

unsigned int colOffsetFromStart = 0;
unsigned int lineOffsetFromStart = 0;

for (char c : src)
{
if (foundStart && foundEnd)
@@ -120,6 +141,12 @@ std::pair<size_t, size_t> getDocumentOffsets(const std::string_view& src, const
startOffset = docOffset;
}

if (cursorPos.line == lineCount && cursorPos.column == colCount)
{
foundCursor = true;
cursorPos = {lineOffsetFromStart, colOffsetFromStart};
}

if (endPos.line == lineCount && endPos.column == colCount)
{
endOffset = docOffset;
@@ -135,20 +162,32 @@ std::pair<size_t, size_t> getDocumentOffsets(const std::string_view& src, const

if (c == '\n')
{
if (foundStart)
{
lineOffsetFromStart++;
colOffsetFromStart = 0;
}
lineCount++;
colCount = 0;
}
else
{
if (foundStart)
colOffsetFromStart++;
colCount++;
}
docOffset++;
}

if (foundStart && !foundEnd)
endOffset = src.length();

if (foundStart && !foundCursor)
cursorPos = {lineOffsetFromStart, colOffsetFromStart};

size_t min = std::min(startOffset, endOffset);
size_t len = std::max(startOffset, endOffset) - min;
return {min, len};
return {min, len, cursorPos};
}

ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement)
@@ -167,12 +206,17 @@ ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStateme
return closest;
}

FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_view src, const Position& cursorPos)
FragmentParseResult parseFragment(
const SourceModule& srcModule,
std::string_view src,
const Position& cursorPos,
std::optional<Position> fragmentEndPosition
)
{
FragmentAutocompleteAncestryResult result = findAncestryForFragmentParse(srcModule.root, cursorPos);
ParseOptions opts;
opts.allowDeclarationSyntax = false;
opts.captureComments = false;
opts.captureComments = true;
opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)};
AstStat* nearestStatement = result.nearestStatement;

@@ -182,7 +226,7 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
// statement spans multiple lines
bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line;

const Position endPos = cursorPos;
const Position endPos = fragmentEndPosition.value_or(cursorPos);

// We start by re-parsing everything (we'll refine this as we go)
Position startPos = srcModule.root->location.begin;
@@ -193,10 +237,13 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
// Statement spans one line && cursorPos is on a different line
else if (!multiline && cursorPos.line != nearestStatement->location.end.line)
startPos = nearestStatement->location.end;
else if (multiline && nearestStatement->location.end.line < cursorPos.line)
startPos = nearestStatement->location.end;
else
startPos = nearestStatement->location.begin;

auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos);
auto [offsetStart, parseLength, cursorInFragment] = getDocumentOffsets(src, startPos, cursorPos, endPos);


const char* srcStart = src.data() + offsetStart;
std::string_view dbg = src.substr(offsetStart, parseLength);
@@ -207,7 +254,11 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie
ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts);

std::vector<AstNode*> fabricatedAncestry = std::move(result.ancestry);
std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end);

// Get the ancestry for the fragment at the offset cursor position.
// Consumers have the option to request with fragment end position, so we cannot just use the end position of our parse result as the
// cursor position. Instead, use the cursor position calculated as an offset from our start position.
std::vector<AstNode*> fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, cursorInFragment);
fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end());
if (nearestStatement == nullptr)
nearestStatement = p.root;
@@ -242,6 +293,46 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr<Allocator> alloc)
return incrementalModule;
}

struct MixedModeIncrementalTCDefFinder : public AstVisitor
{
bool visit(AstExprLocal* local) override
{
referencedLocalDefs.push_back({local->local, local});
return true;
}
// ast defs is just a mapping from expr -> def in general
// will get built up by the dfg builder

// localDefs, we need to copy over
std::vector<std::pair<AstLocal*, AstExpr*>> referencedLocalDefs;
};

void mixedModeCompatibility(
const ScopePtr& bottomScopeStale,
const ScopePtr& myFakeScope,
const ModulePtr& stale,
NotNull<DataFlowGraph> dfg,
AstStatBlock* program
)
{
// This code does the following
// traverse program
// look for ast refs for locals
// ask for the corresponding defId from dfg
// given that defId, and that expression, in the incremental module, map lvalue types from defID to

MixedModeIncrementalTCDefFinder finder;
program->visit(&finder);
std::vector<std::pair<AstLocal*, AstExpr*>> locals = std::move(finder.referencedLocalDefs);
for (auto [loc, expr] : locals)
{
if (std::optional<Binding> binding = bottomScopeStale->linearSearchForBinding(loc->name.value, true))
{
myFakeScope->lvalueTypes[dfg->getDef(expr)] = binding->typeId;
}
}
}

FragmentTypeCheckResult typecheckFragment_(
Frontend& frontend,
AstStatBlock* root,
@@ -255,6 +346,7 @@ FragmentTypeCheckResult typecheckFragment_(
freeze(stale->internalTypes);
freeze(stale->interfaceTypes);
ModulePtr incrementalModule = copyModule(stale, std::move(astAllocator));
incrementalModule->checkedInNewSolver = true;
unfreeze(incrementalModule->internalTypes);
unfreeze(incrementalModule->interfaceTypes);

@@ -280,30 +372,35 @@ FragmentTypeCheckResult typecheckFragment_(
TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits});

/// Create a DataFlowGraph just for the surrounding context
auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler);

auto dfg = DataFlowGraphBuilder::build(root, iceHandler);
SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes);

FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);

/// Contraint Generator
ConstraintGenerator cg{
incrementalModule,
NotNull{&normalizer},
NotNull{simplifier.get()},
NotNull{&typeFunctionRuntime},
NotNull{&frontend.moduleResolver},
NotNull{&resolver},
frontend.builtinTypes,
iceHandler,
stale->getModuleScope(),
nullptr,
nullptr,
NotNull{&updatedDfg},
NotNull{&dfg},
{}
};

cg.rootScope = stale->getModuleScope().get();
// Any additions to the scope must occur in a fresh scope
auto freshChildOfNearestScope = std::make_shared<Scope>(closestScope);
incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope);

// Update freshChildOfNearestScope with the appropriate lvalueTypes
mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root);

// closest Scope -> children = { ...., freshChildOfNearestScope}
// We need to trim nearestChild from the scope hierarcy
closestScope->children.push_back(NotNull{freshChildOfNearestScope.get()});
@@ -323,10 +420,10 @@ FragmentTypeCheckResult typecheckFragment_(
NotNull(cg.rootScope),
borrowConstraints(cg.constraints),
incrementalModule->name,
NotNull{&frontend.moduleResolver},
NotNull{&resolver},
{},
nullptr,
NotNull{&updatedDfg},
NotNull{&dfg},
limits
};

@@ -358,7 +455,8 @@ FragmentTypeCheckResult typecheckFragment(
const ModuleName& moduleName,
const Position& cursorPos,
std::optional<FrontendOptions> opts,
std::string_view src
std::string_view src,
std::optional<Position> fragmentEndPosition
)
{
const SourceModule* sourceModule = frontend.getSourceModule(moduleName);
@@ -368,8 +466,15 @@ FragmentTypeCheckResult typecheckFragment(
return {};
}

ModulePtr module = frontend.moduleResolver.getModule(moduleName);
FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos);
FrontendModuleResolver& resolver = getModuleResolver(frontend, opts);
ModulePtr module = resolver.getModule(moduleName);
if (!module)
{
LUAU_ASSERT(!"Expected Module for fragment typecheck");
return {};
}

FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition);
FrontendOptions frontendOptions = opts.value_or(frontend.options);
const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement);
FragmentTypeCheckResult result =
@@ -385,10 +490,10 @@ FragmentAutocompleteResult fragmentAutocomplete(
const ModuleName& moduleName,
Position cursorPosition,
std::optional<FrontendOptions> opts,
StringCompletionCallback callback
StringCompletionCallback callback,
std::optional<Position> fragmentEndPosition
)
{
LUAU_ASSERT(FFlag::LuauSolverV2);
LUAU_ASSERT(FFlag::LuauAllowFragmentParsing);
LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2);
LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete);
@@ -400,7 +505,8 @@ FragmentAutocompleteResult fragmentAutocomplete(
return {};
}

auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src);
auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition);

TypeArena arenaForFragmentAutocomplete;
auto result = Luau::autocomplete_(
tcResult.incrementalModule,
@@ -413,6 +519,7 @@ FragmentAutocompleteResult fragmentAutocomplete(
frontend.fileResolver,
callback
);

return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)};
}

3 changes: 3 additions & 0 deletions Analysis/src/Frontend.cpp
Original file line number Diff line number Diff line change
@@ -50,6 +50,7 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false)

LUAU_FASTFLAG(StudioReportLuauAny2)
LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2)
LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule)

namespace Luau
{
@@ -1285,6 +1286,8 @@ ModulePtr check(
LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str());

ModulePtr result = std::make_shared<Module>();
if (FFlag::LuauStoreSolverTypeOnModule)
result->checkedInNewSolver = true;
result->name = sourceModule.name;
result->humanReadableName = sourceModule.humanReadableName;
result->mode = mode;
127 changes: 39 additions & 88 deletions Analysis/src/Generalization.cpp
Original file line number Diff line number Diff line change
@@ -9,8 +9,6 @@
#include "Luau/TypePack.h"
#include "Luau/VisitType.h"

LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)

namespace Luau
{

@@ -528,12 +526,7 @@ struct TypeCacher : TypeOnceVisitor
DenseHashSet<TypePackId> uncacheablePacks{nullptr};

explicit TypeCacher(NotNull<DenseHashSet<TypeId>> cachedTypes)
// CLI-120975: once we roll out release 646, we _want_ to visit bound
// types to ensure they're marked as uncacheable if the types they are
// bound to are also uncacheable. Hence: if LuauTypeSolverRelease is
// less than 646, skip bound types (the prior behavior). Otherwise,
// do not skip bound types.
: TypeOnceVisitor(/* skipBoundTypes */ DFInt::LuauTypeSolverRelease < 646)
: TypeOnceVisitor(/* skipBoundTypes */ false)
, cachedTypes(cachedTypes)
{
}
@@ -570,33 +563,19 @@ struct TypeCacher : TypeOnceVisitor

bool visit(TypeId ty) override
{
if (DFInt::LuauTypeSolverRelease >= 646)
{
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
// otherwise it's prone to marking types that cannot be cached as
// cacheable.
LUAU_ASSERT(false);
LUAU_UNREACHABLE();
}
else
{
return true;
}
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
// otherwise it's prone to marking types that cannot be cached as
// cacheable.
LUAU_ASSERT(false);
LUAU_UNREACHABLE();
}

bool visit(TypeId ty, const BoundType& btv) override
{
if (DFInt::LuauTypeSolverRelease >= 646)
{
traverse(btv.boundTo);
if (isUncacheable(btv.boundTo))
markUncacheable(ty);
return false;
}
else
{
return true;
}
traverse(btv.boundTo);
if (isUncacheable(btv.boundTo))
markUncacheable(ty);
return false;
}

bool visit(TypeId ty, const FreeType& ft) override
@@ -623,15 +602,8 @@ struct TypeCacher : TypeOnceVisitor

bool visit(TypeId ty, const ErrorType&) override
{
if (DFInt::LuauTypeSolverRelease >= 646)
{
cache(ty);
return false;
}
else
{
return true;
}
cache(ty);
return false;
}

bool visit(TypeId ty, const PrimitiveType&) override
@@ -773,20 +745,13 @@ struct TypeCacher : TypeOnceVisitor

bool visit(TypeId ty, const MetatableType& mtv) override
{
if (DFInt::LuauTypeSolverRelease >= 646)
{
traverse(mtv.table);
traverse(mtv.metatable);
if (isUncacheable(mtv.table) || isUncacheable(mtv.metatable))
markUncacheable(ty);
else
cache(ty);
return false;
}
traverse(mtv.table);
traverse(mtv.metatable);
if (isUncacheable(mtv.table) || isUncacheable(mtv.metatable))
markUncacheable(ty);
else
{
return true;
}
cache(ty);
return false;
}

bool visit(TypeId ty, const ClassType&) override
@@ -911,18 +876,11 @@ struct TypeCacher : TypeOnceVisitor

bool visit(TypePackId tp) override
{
if (DFInt::LuauTypeSolverRelease >= 646)
{
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
// otherwise it's prone to marking types that cannot be cached as
// cacheable, which will segfault down the line.
LUAU_ASSERT(false);
LUAU_UNREACHABLE();
}
else
{
return true;
}
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
// otherwise it's prone to marking types that cannot be cached as
// cacheable, which will segfault down the line.
LUAU_ASSERT(false);
LUAU_UNREACHABLE();
}

bool visit(TypePackId tp, const FreeTypePack&) override
@@ -967,35 +925,28 @@ struct TypeCacher : TypeOnceVisitor
}

bool visit(TypePackId tp, const BoundTypePack& btp) override {
if (DFInt::LuauTypeSolverRelease >= 645) {
traverse(btp.boundTo);
if (isUncacheable(btp.boundTo))
markUncacheable(tp);
return false;
}
return true;
traverse(btp.boundTo);
if (isUncacheable(btp.boundTo))
markUncacheable(tp);
return false;
}

bool visit(TypePackId tp, const TypePack& typ) override
{
if (DFInt::LuauTypeSolverRelease >= 646)
bool uncacheable = false;
for (TypeId ty : typ.head)
{
bool uncacheable = false;
for (TypeId ty : typ.head)
{
traverse(ty);
uncacheable |= isUncacheable(ty);
}
if (typ.tail)
{
traverse(*typ.tail);
uncacheable |= isUncacheable(*typ.tail);
}
if (uncacheable)
markUncacheable(tp);
return false;
traverse(ty);
uncacheable |= isUncacheable(ty);
}
return true;
if (typ.tail)
{
traverse(*typ.tail);
uncacheable |= isUncacheable(*typ.tail);
}
if (uncacheable)
markUncacheable(tp);
return false;
}
};

46 changes: 15 additions & 31 deletions Analysis/src/Module.cpp
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@
#include <algorithm>

LUAU_FASTFLAG(LuauSolverV2);
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)

namespace Luau
{
@@ -132,34 +131,27 @@ struct ClonePublicInterface : Substitution
}

ftv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
if (FFlag::LuauSolverV2)
ftv->scope = nullptr;
}
else if (TableType* ttv = getMutable<TableType>(result))
{
ttv->level = TypeLevel{0, 0};
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
if (FFlag::LuauSolverV2)
ttv->scope = nullptr;
}

if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
if (FFlag::LuauSolverV2)
{
if (auto freety = getMutable<FreeType>(result))
{
if (DFInt::LuauTypeSolverRelease >= 646)
{
module->errors.emplace_back(
freety->scope->location,
module->name,
InternalError{"Free type is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
module->errors.emplace_back(
freety->scope->location,
module->name,
InternalError{"Free type is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
);
result = builtinTypes->errorRecoveryType();
}
else
{
freety->scope = nullptr;
}
}
else if (auto genericty = getMutable<GenericType>(result))
{
@@ -172,26 +164,18 @@ struct ClonePublicInterface : Substitution

TypePackId clean(TypePackId tp) override
{
if (FFlag::LuauSolverV2 && DFInt::LuauTypeSolverRelease >= 645)
if (FFlag::LuauSolverV2)
{
auto clonedTp = clone(tp);
if (auto ftp = getMutable<FreeTypePack>(clonedTp))
{

if (DFInt::LuauTypeSolverRelease >= 646)
{
module->errors.emplace_back(
ftp->scope->location,
module->name,
InternalError{"Free type pack is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
module->errors.emplace_back(
ftp->scope->location,
module->name,
InternalError{"Free type pack is escaping its module; please report this bug at "
"https://github.com/luau-lang/luau/issues"}
);
clonedTp = builtinTypes->errorRecoveryTypePack();
}
else
{
ftp->scope = nullptr;
}
clonedTp = builtinTypes->errorRecoveryTypePack();
}
else if (auto gtp = getMutable<GenericTypePack>(clonedTp))
gtp->scope = nullptr;
126 changes: 125 additions & 1 deletion Analysis/src/NonStrictTypeChecker.cpp
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
#include <iterator>

LUAU_FASTFLAGVARIABLE(LuauUserTypeFunNonstrict)
LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict)

namespace Luau
{
@@ -537,9 +538,132 @@ struct NonStrictTypeChecker
return {};
}


NonStrictContext visit(AstExprCall* call)
{
if (FFlag::LuauCountSelfCallsNonstrict)
return visitCall(call);
else
return visitCall_DEPRECATED(call);
}

// rename this to `visit` when `FFlag::LuauCountSelfCallsNonstrict` is removed, and clean up above `visit`.
NonStrictContext visitCall(AstExprCall* call)
{
LUAU_ASSERT(FFlag::LuauCountSelfCallsNonstrict);

NonStrictContext fresh{};
TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func);
if (!originalCallTy)
return fresh;

TypeId fnTy = *originalCallTy;
if (auto fn = get<FunctionType>(follow(fnTy)); fn && fn->isCheckedFunction)
{
// We know fn is a checked function, which means it looks like:
// (S1, ... SN) -> T &
// (~S1, unknown^N-1) -> error &
// (unknown, ~S2, unknown^N-2) -> error
// ...
// ...
// (unknown^N-1, ~S_N) -> error

std::vector<AstExpr*> arguments;
arguments.reserve(call->args.size + (call->self ? 1 : 0));
if (call->self)
{
if (auto indexExpr = call->func->as<AstExprIndexName>())
arguments.push_back(indexExpr->expr);
else
ice->ice("method call expression has no 'self'");
}
arguments.insert(arguments.end(), call->args.begin(), call->args.end());

std::vector<TypeId> argTypes;
argTypes.reserve(arguments.size());

// Move all the types over from the argument typepack for `fn`
TypePackIterator curr = begin(fn->argTypes);
TypePackIterator fin = end(fn->argTypes);
for (; curr != fin; curr++)
argTypes.push_back(*curr);

// Pad out the rest with the variadic as needed.
if (auto argTail = curr.tail())
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*argTail)))
{
while (argTypes.size() < arguments.size())
{
argTypes.push_back(vtp->ty);
}
}
}

std::string functionName = getFunctionNameAsString(*call->func).value_or("");
if (arguments.size() > argTypes.size())
{
// We are passing more arguments than we expect, so we should error
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), arguments.size()}, call->location);
return fresh;
}

for (size_t i = 0; i < arguments.size(); i++)
{
// For example, if the arg is "hi"
// The actual arg type is string
// The expected arg type is number
// The type of the argument in the overload is ~number
// We will compare arg and ~number
AstExpr* arg = arguments[i];
TypeId expectedArgType = argTypes[i];
std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType);
DefId def = dfg->getDef(arg);
TypeId runTimeErrorTy;
// If we're dealing with any, negating any will cause all subtype tests to fail
// However, when someone calls this function, they're going to want to be able to pass it anything,
// for that reason, we manually inject never into the context so that the runtime test will always pass.
if (!norm)
reportError(NormalizationTooComplex{}, arg->location);

if (norm && get<AnyType>(norm->tops))
runTimeErrorTy = builtinTypes->neverType;
else
runTimeErrorTy = getOrCreateNegation(expectedArgType);
fresh.addContext(def, runTimeErrorTy);
}

// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
for (size_t i = 0; i < arguments.size(); i++)
{
AstExpr* arg = arguments[i];
if (auto runTimeFailureType = willRunTimeError(arg, fresh))
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location);
}

if (arguments.size() < argTypes.size())
{
// We are passing fewer arguments than we expect
// so we need to ensure that the rest of the args are optional.
bool remainingArgsOptional = true;
for (size_t i = arguments.size(); i < argTypes.size(); i++)
remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]);

if (!remainingArgsOptional)
{
reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), arguments.size()}, call->location);
return fresh;
}
}
}

return fresh;
}

// Remove with `FFlag::LuauCountSelfCallsNonstrict` clean up.
NonStrictContext visitCall_DEPRECATED(AstExprCall* call)
{
LUAU_ASSERT(!FFlag::LuauCountSelfCallsNonstrict);

NonStrictContext fresh{};
TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func);
if (!originalCallTy)
6 changes: 0 additions & 6 deletions Analysis/src/Simplify.cpp
Original file line number Diff line number Diff line change
@@ -1411,8 +1411,6 @@ TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet<TypeId>& seen)

SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right)
{
LUAU_ASSERT(FFlag::LuauSolverV2);

TypeSimplifier s{builtinTypes, arena};

// fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str());
@@ -1426,8 +1424,6 @@ SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<

SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, std::set<TypeId> parts)
{
LUAU_ASSERT(FFlag::LuauSolverV2);

TypeSimplifier s{builtinTypes, arena};

TypeId res = s.intersectFromParts(std::move(parts));
@@ -1437,8 +1433,6 @@ SimplifyResult simplifyIntersection(NotNull<BuiltinTypes> builtinTypes, NotNull<

SimplifyResult simplifyUnion(NotNull<BuiltinTypes> builtinTypes, NotNull<TypeArena> arena, TypeId left, TypeId right)
{
LUAU_ASSERT(FFlag::LuauSolverV2);

TypeSimplifier s{builtinTypes, arena};

TypeId res = s.union_(left, right);
15 changes: 3 additions & 12 deletions Analysis/src/Subtyping.cpp
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@
#include <algorithm>

LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauRetrySubtypingWithoutHiddenPack)

namespace Luau
@@ -1395,17 +1394,9 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl

SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt, NotNull<Scope> scope)
{
if (DFInt::LuauTypeSolverRelease >= 646)
{
return isCovariantWith(env, subMt->table, superMt->table, scope)
.withBothComponent(TypePath::TypeField::Table)
.andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable));
}
else
{
return isCovariantWith(env, subMt->table, superMt->table, scope)
.andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable));
}
return isCovariantWith(env, subMt->table, superMt->table, scope)
.withBothComponent(TypePath::TypeField::Table)
.andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable));
}

SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const TableType* superTable, NotNull<Scope> scope)
16 changes: 2 additions & 14 deletions Analysis/src/TableLiteralInference.cpp
Original file line number Diff line number Diff line change
@@ -9,8 +9,6 @@
#include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h"

LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)

namespace Luau
{

@@ -376,21 +374,11 @@ TypeId matchLiteralType(
const TypeId* keyTy = astTypes->find(item.key);
LUAU_ASSERT(keyTy);
TypeId tKey = follow(*keyTy);
if (DFInt::LuauTypeSolverRelease >= 648)
{
LUAU_ASSERT(!is<BlockedType>(tKey));
}
else if (get<BlockedType>(tKey))
toBlock.push_back(tKey);
LUAU_ASSERT(!is<BlockedType>(tKey));
const TypeId* propTy = astTypes->find(item.value);
LUAU_ASSERT(propTy);
TypeId tProp = follow(*propTy);
if (DFInt::LuauTypeSolverRelease >= 648)
{
LUAU_ASSERT(!is<BlockedType>(tKey));
}
else if (get<BlockedType>(tProp))
toBlock.push_back(tProp);
LUAU_ASSERT(!is<BlockedType>(tProp));
// Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings)
if (!item.key->as<AstExprConstantString>() && expectedTableTy->indexer)
(*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType;
20 changes: 12 additions & 8 deletions Analysis/src/TypeChecker2.cpp
Original file line number Diff line number Diff line change
@@ -32,7 +32,8 @@

LUAU_FASTFLAG(DebugLuauMagicTypes)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctions2)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)

LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues)

namespace Luau
{
@@ -1625,10 +1626,6 @@ void TypeChecker2::indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const M
indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType);
else
{
if (!(DFInt::LuauTypeSolverRelease >= 647))
{
LUAU_ASSERT(tt || get<PrimitiveType>(follow(metaTable->table)));
}
// CLI-122161: We're not handling unions correctly (probably).
reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location);
}
@@ -1836,11 +1833,18 @@ void TypeChecker2::visit(AstExprFunction* fn)

void TypeChecker2::visit(AstExprTable* expr)
{
// TODO!
for (const AstExprTable::Item& item : expr->items)
{
if (item.key)
visit(item.key, ValueContext::LValue);
if (FFlag::LuauTableKeysAreRValues)
{
if (item.key)
visit(item.key, ValueContext::RValue);
}
else
{
if (item.key)
visit(item.key, ValueContext::LValue);
}
visit(item.value, ValueContext::RValue);
}
}
5 changes: 1 addition & 4 deletions Analysis/src/TypeFunction.cpp
Original file line number Diff line number Diff line change
@@ -53,8 +53,6 @@ LUAU_FASTFLAG(LuauRemoveNotAnyHack)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState)
LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal)

LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)

namespace Luau
{

@@ -848,8 +846,7 @@ TypeFunctionReductionResult<TypeId> lenTypeFunction(
return {ctx->builtins->numberType, false, {}, {}};

// we use the normalized operand here in case there was an intersection or union.
TypeId normalizedOperand =
DFInt::LuauTypeSolverRelease >= 646 ? follow(ctx->normalizer->typeFromNormal(*normTy)) : ctx->normalizer->typeFromNormal(*normTy);
TypeId normalizedOperand = follow(ctx->normalizer->typeFromNormal(*normTy));
if (normTy->hasTopTable() || get<TableType>(normalizedOperand))
return {ctx->builtins->numberType, false, {}, {}};

10 changes: 3 additions & 7 deletions Analysis/src/TypeFunctionRuntime.cpp
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@
#include <vector>

LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixRegister)
LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite)

@@ -699,13 +698,10 @@ static int setTableIndexer(lua_State* L)
TypeFunctionTypeId key = getTypeUserData(L, 2);
TypeFunctionTypeId value = getTypeUserData(L, 3);

if (DFInt::LuauTypeSolverRelease >= 646)
if (auto tfnt = get<TypeFunctionNeverType>(key))
{
if (auto tfnt = get<TypeFunctionNeverType>(key))
{
tftt->indexer = std::nullopt;
return 0;
}
tftt->indexer = std::nullopt;
return 0;
}

tftt->indexer = TypeFunctionTableIndexer{key, value};
53 changes: 0 additions & 53 deletions Analysis/src/TypeInfer.cpp
Original file line number Diff line number Diff line change
@@ -32,7 +32,6 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections)
LUAU_FASTFLAGVARIABLE(LuauMetatableFollow)
LUAU_FASTFLAGVARIABLE(LuauRequireCyclesDontAlwaysReturnAny)

@@ -3490,7 +3489,6 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
}
}

if (FFlag::LuauAcceptIndexingTableUnionsIntersections)
{
// We're going to have a whole vector.
std::vector<TableType*> tableTypes{};
@@ -3641,57 +3639,6 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex

return addType(IntersectionType{{resultTypes.begin(), resultTypes.end()}});
}
else
{
TableType* exprTable = getMutableTableType(exprType);
if (!exprTable)
{
reportError(TypeError{expr.expr->location, NotATable{exprType}});
return errorRecoveryType(scope);
}

if (value)
{
const auto& it = exprTable->props.find(value->value.data);
if (it != exprTable->props.end())
{
return it->second.type();
}
else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)
{
TypeId resultType = freshType(scope);
Property& property = exprTable->props[value->value.data];
property.setType(resultType);
property.location = expr.index->location;
return resultType;
}
}

if (exprTable->indexer)
{
const TableIndexer& indexer = *exprTable->indexer;
unify(indexType, indexer.indexType, scope, expr.index->location);
return indexer.indexResultType;
}
else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)
{
TypeId indexerType = freshType(exprTable->level);
unify(indexType, indexerType, scope, expr.location);
TypeId indexResultType = freshType(exprTable->level);

exprTable->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)};
return indexResultType;
}
else
{
/*
* If we use [] indexing to fetch a property from a sealed table that
* has no indexer, we have no idea if it will work so we just return any
* and hope for the best.
*/
return anyType;
}
}
}

// Answers the question: "Can I define another function with this name?"
7 changes: 1 addition & 6 deletions Ast/src/Ast.cpp
Original file line number Diff line number Diff line change
@@ -3,12 +3,7 @@

#include "Luau/Common.h"

LUAU_FASTFLAG(LuauNativeAttribute);

// The default value here is 643 because the first release in which this was implemented is 644,
// and actively we want new changes to be off by default until they're enabled consciously.
// The flag is placed in AST project here to be common in all Luau libraries
LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeSolverRelease, 643)
LUAU_FASTFLAG(LuauNativeAttribute)

namespace Luau
{
74 changes: 69 additions & 5 deletions Ast/src/Parser.cpp
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport)
LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing)
LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck)
LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams)

namespace Luau
{
@@ -1773,6 +1774,11 @@ AstType* Parser::parseFunctionTypeTail(
);
}

static bool isTypeFollow(Lexeme::Type c)
{
return c == '|' || c == '?' || c == '&';
}

// Type ::=
// nil |
// Name[`.' Name] [`<' namelist `>'] |
@@ -2953,17 +2959,75 @@ AstArray<AstTypeOrPack> Parser::parseTypeParams()
if (shouldParseTypePack(lexer))
{
AstTypePack* typePack = parseTypePack();

parameters.push_back({{}, typePack});
}
else if (lexer.current().type == '(')
{
auto [type, typePack] = parseSimpleTypeOrPack();
if (FFlag::LuauAllowComplexTypesInGenericParams)
{
Location begin = lexer.current().location;
AstType* type = nullptr;
AstTypePack* typePack = nullptr;
Lexeme::Type c = lexer.current().type;

if (typePack)
parameters.push_back({{}, typePack});
if (c != '|' && c != '&')
{
auto typeOrTypePack = parseSimpleType(/* allowPack */ true, /* inDeclarationContext */ false);
type = typeOrTypePack.type;
typePack = typeOrTypePack.typePack;
}

// Consider the following type:
//
// X<(T)>
//
// Is this a type pack or a parenthesized type? The
// assumption will be a type pack, as that's what allows one
// to express either a singular type pack or a potential
// complex type.

if (typePack)
{
auto explicitTypePack = typePack->as<AstTypePackExplicit>();
if (explicitTypePack && explicitTypePack->typeList.tailType == nullptr && explicitTypePack->typeList.types.size == 1 &&
isTypeFollow(lexer.current().type))
{
// If we parsed an explicit type pack with a single
// type in it (something of the form `(T)`), and
// the next lexeme is one that follows a type
// (&, |, ?), then assume that this was actually a
// parenthesized type.
parameters.push_back({parseTypeSuffix(explicitTypePack->typeList.types.data[0], begin), {}});
}
else
{
// Otherwise, it's a type pack.
parameters.push_back({{}, typePack});
}
}
else
{
// There's two cases in which `typePack` will be null:
// - We try to parse a simple type or a type pack, and
// we get a simple type: there's no ambiguity and
// we attempt to parse a complex type.
// - The next lexeme was a `|` or `&` indicating a
// union or intersection type with a leading
// separator. We just fall right into
// `parseTypeSuffix`, which allows its first
// argument to be `nullptr`
parameters.push_back({parseTypeSuffix(type, begin), {}});
}
}
else
parameters.push_back({type, {}});
{
auto [type, typePack] = parseSimpleTypeOrPack();

if (typePack)
parameters.push_back({{}, typePack});
else
parameters.push_back({type, {}});
}
}
else if (lexer.current().type == '>' && parameters.empty())
{
60 changes: 53 additions & 7 deletions CLI/Analyze.cpp
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@

#include "FileUtils.h"
#include "Flags.h"
#include "Require.h"

#include <condition_variable>
#include <functional>
@@ -170,14 +171,17 @@ struct CliFileResolver : Luau::FileResolver
{
if (Luau::AstExprConstantString* expr = node->as<Luau::AstExprConstantString>())
{
Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau";
if (!readFile(name))
{
// fall back to .lua if a module with .luau doesn't exist
name = std::string(expr->value.data, expr->value.size) + ".lua";
}
std::string path{expr->value.data, expr->value.size};

AnalysisRequireContext requireContext{context->name};
AnalysisCacheManager cacheManager;
AnalysisErrorHandler errorHandler;

RequireResolver resolver(path, requireContext, cacheManager, errorHandler);
RequireResolver::ResolvedRequire resolvedRequire = resolver.resolveRequire();

return {{name}};
if (resolvedRequire.status == RequireResolver::ModuleStatus::FileRead)
return {{resolvedRequire.identifier}};
}

return std::nullopt;
@@ -189,6 +193,48 @@ struct CliFileResolver : Luau::FileResolver
return "stdin";
return name;
}

private:
struct AnalysisRequireContext : RequireResolver::RequireContext
{
explicit AnalysisRequireContext(std::string path)
: path(std::move(path))
{
}

std::string getPath() override
{
return path;
}

bool isRequireAllowed() override
{
return true;
}

bool isStdin() override
{
return path == "-";
}

std::string createNewIdentifer(const std::string& path) override
{
return path;
}

private:
std::string path;
};

struct AnalysisCacheManager : public RequireResolver::CacheManager
{
AnalysisCacheManager() = default;
};

struct AnalysisErrorHandler : RequireResolver::ErrorHandler
{
AnalysisErrorHandler() = default;
};
};

struct CliConfigResolver : Luau::ConfigResolver
33 changes: 23 additions & 10 deletions CLI/FileUtils.cpp
Original file line number Diff line number Diff line change
@@ -108,6 +108,7 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath)
// - if relative (when path and baseFilePath are both relative), resolvedPathPrefix remains empty
// - if absolute (if either path or baseFilePath are absolute), resolvedPathPrefix is "C:\", "/", etc.
std::string resolvedPathPrefix;
bool isResolvedPathRelative = false;

if (isAbsolutePath(path))
{
@@ -118,19 +119,19 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath)
}
else
{
pathComponents = splitPath(path);
size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1;
baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix));
if (isAbsolutePath(baseFilePath))
{
// path is relative and baseFilePath is absolute, we use baseFilePath's prefix
size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1;
resolvedPathPrefix = baseFilePath.substr(0, afterPrefix);
baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix));
}
else
{
// path and baseFilePath are both relative, we do not set a prefix (resolved path will be relative)
baseFilePathComponents = splitPath(baseFilePath);
isResolvedPathRelative = true;
}
pathComponents = splitPath(path);
}

// Remove filename from components
@@ -145,7 +146,7 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath)
{
if (baseFilePathComponents.empty())
{
if (resolvedPathPrefix.empty()) // only when final resolved path will be relative
if (isResolvedPathRelative)
numPrependedParents++; // "../" will later be added to the beginning of the resolved path
}
else if (baseFilePathComponents.back() != "..")
@@ -159,13 +160,25 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath)
}
}

// Join baseFilePathComponents to form the resolved path
std::string resolvedPath = resolvedPathPrefix;
// Only when resolvedPath will be relative
for (int i = 0; i < numPrependedParents; i++)
// Create resolved path prefix for relative paths
if (isResolvedPathRelative)
{
resolvedPath += "../";
if (numPrependedParents > 0)
{
resolvedPathPrefix.reserve(numPrependedParents * 3);
for (int i = 0; i < numPrependedParents; i++)
{
resolvedPathPrefix += "../";
}
}
else
{
resolvedPathPrefix = "./";
}
}

// Join baseFilePathComponents to form the resolved path
std::string resolvedPath = resolvedPathPrefix;
for (auto iter = baseFilePathComponents.begin(); iter != baseFilePathComponents.end(); ++iter)
{
if (iter != baseFilePathComponents.begin())
10 changes: 0 additions & 10 deletions CLI/Flags.cpp
Original file line number Diff line number Diff line change
@@ -2,14 +2,11 @@
#include "Luau/Common.h"
#include "Luau/ExperimentalFlags.h"

#include <limits> // TODO: remove with LuauTypeSolverRelease
#include <string_view>

#include <stdio.h>
#include <string.h>

LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)

static void setLuauFlag(std::string_view name, bool state)
{
for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
@@ -26,13 +23,6 @@ static void setLuauFlag(std::string_view name, bool state)

static void setLuauFlags(bool state)
{
if (state)
{
// Setting flags to 'true' means enabling all Luau flags including new type solver
// In that case, it is provided with all fixes enabled (as if each fix had its own boolean flag)
DFInt::LuauTypeSolverRelease.value = std::numeric_limits<int>::max();
}

for (Luau::FValue<bool>* flag = Luau::FValue<bool>::list; flag; flag = flag->next)
if (strncmp(flag->name, "Luau", 4) == 0)
flag->value = state;
109 changes: 103 additions & 6 deletions CLI/Repl.cpp
Original file line number Diff line number Diff line change
@@ -19,6 +19,8 @@
#include "isocline.h"

#include <memory>
#include <string>
#include <string_view>

#ifdef _WIN32
#include <io.h>
@@ -119,18 +121,113 @@ static int finishrequire(lua_State* L)
return 1;
}

struct RuntimeRequireContext : public RequireResolver::RequireContext
{
// In the context of the REPL, source is the calling context's chunkname.
//
// These chunknames have certain prefixes that indicate context. These
// are used when displaying debug information (see luaO_chunkid).
//
// Generally, the '@' prefix is used for filepaths, and the '=' prefix is
// used for custom chunknames, such as =stdin.
explicit RuntimeRequireContext(std::string source)
: source(std::move(source))
{
}

std::string getPath() override
{
return source.substr(1);
}

bool isRequireAllowed() override
{
return isStdin() || (!source.empty() && source[0] == '@');
}

bool isStdin() override
{
return source == "=stdin";
}

std::string createNewIdentifer(const std::string& path) override
{
return "@" + path;
}

private:
std::string source;
};

struct RuntimeCacheManager : public RequireResolver::CacheManager
{
explicit RuntimeCacheManager(lua_State* L)
: L(L)
{
}

bool isCached(const std::string& path) override
{
luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1);
lua_getfield(L, -1, path.c_str());
bool cached = !lua_isnil(L, -1);
lua_pop(L, 2);

if (cached)
cacheKey = path;

return cached;
}

std::string cacheKey;

private:
lua_State* L;
};

struct RuntimeErrorHandler : RequireResolver::ErrorHandler
{
explicit RuntimeErrorHandler(lua_State* L)
: L(L)
{
}

void reportError(const std::string message) override
{
luaL_errorL(L, "%s", message.c_str());
}

private:
lua_State* L;
};

static int lua_require(lua_State* L)
{
std::string name = luaL_checkstring(L, 1);

RequireResolver::ResolvedRequire resolvedRequire = RequireResolver::resolveRequire(L, std::move(name));
RequireResolver::ResolvedRequire resolvedRequire;
{
lua_Debug ar;
lua_getinfo(L, 1, "s", &ar);

RuntimeRequireContext requireContext{ar.source};
RuntimeCacheManager cacheManager{L};
RuntimeErrorHandler errorHandler{L};

RequireResolver resolver(std::move(name), requireContext, cacheManager, errorHandler);

resolvedRequire = resolver.resolveRequire(
[L, &cacheKey = cacheManager.cacheKey](const RequireResolver::ModuleStatus status)
{
lua_getfield(L, LUA_REGISTRYINDEX, "_MODULES");
if (status == RequireResolver::ModuleStatus::Cached)
lua_getfield(L, -1, cacheKey.c_str());
}
);
}

if (resolvedRequire.status == RequireResolver::ModuleStatus::Cached)
return finishrequire(L);
else if (resolvedRequire.status == RequireResolver::ModuleStatus::Ambiguous)
luaL_errorL(L, "require path could not be resolved to a unique file");
else if (resolvedRequire.status == RequireResolver::ModuleStatus::NotFound)
luaL_errorL(L, "error requiring module");

// module needs to run in a new thread, isolated from the rest
// note: we create ML on main thread so that it doesn't inherit environment of L
@@ -143,7 +240,7 @@ static int lua_require(lua_State* L)

// now we can compile & run module on the new thread
std::string bytecode = Luau::compile(resolvedRequire.sourceCode, copts());
if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
if (luau_load(ML, resolvedRequire.identifier.c_str(), bytecode.data(), bytecode.size(), 0) == 0)
{
if (codegen)
{
248 changes: 128 additions & 120 deletions CLI/Require.cpp
Original file line number Diff line number Diff line change
@@ -9,165 +9,170 @@
#include <array>
#include <utility>

RequireResolver::RequireResolver(lua_State* L, std::string path)
static constexpr char kRequireErrorGeneric[] = "error requiring module";

RequireResolver::RequireResolver(std::string path, RequireContext& requireContext, CacheManager& cacheManager, ErrorHandler& errorHandler)
: pathToResolve(std::move(path))
, L(L)
, requireContext(requireContext)
, cacheManager(cacheManager)
, errorHandler(errorHandler)
{
lua_Debug ar;
lua_getinfo(L, 1, "s", &ar);
sourceChunkname = ar.source;
}

if (!isRequireAllowed(sourceChunkname))
luaL_errorL(L, "require is not supported in this context");
RequireResolver::ResolvedRequire RequireResolver::resolveRequire(std::function<void(const ModuleStatus)> completionCallback)
{
if (isRequireResolved)
{
errorHandler.reportError("require statement has already been resolved");
return ResolvedRequire{ModuleStatus::ErrorReported};
}

if (isAbsolutePath(pathToResolve))
luaL_argerrorL(L, 1, "cannot require an absolute path");
if (!initialize())
return ResolvedRequire{ModuleStatus::ErrorReported};

std::replace(pathToResolve.begin(), pathToResolve.end(), '\\', '/');
resolvedRequire.status = findModule();

if (!isPrefixValid())
luaL_argerrorL(L, 1, "require path must start with a valid prefix: ./, ../, or @");
if (completionCallback)
completionCallback(resolvedRequire.status);

substituteAliasIfPresent(pathToResolve);
isRequireResolved = true;
return resolvedRequire;
}

[[nodiscard]] RequireResolver::ResolvedRequire RequireResolver::resolveRequire(lua_State* L, std::string path)
static bool hasValidPrefix(std::string_view path)
{
RequireResolver resolver(L, std::move(path));
ModuleStatus status = resolver.findModule();
if (status != ModuleStatus::FileRead)
return ResolvedRequire{status};
else
return ResolvedRequire{status, std::move(resolver.chunkname), std::move(resolver.absolutePath), std::move(resolver.sourceCode)};
return path.compare(0, 2, "./") == 0 || path.compare(0, 3, "../") == 0 || path.compare(0, 1, "@") == 0;
}

RequireResolver::ModuleStatus RequireResolver::findModule()
static bool isPathAmbiguous(const std::string& path)
{
resolveAndStoreDefaultPaths();
bool found = false;
for (const char* suffix : {".luau", ".lua"})
{
if (isFile(path + suffix))
{
if (found)
return true;
else
found = true;
}
}
if (isDirectory(path) && found)
return true;

// Put _MODULES table on stack for checking and saving to the cache
luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1);
return false;
}

bool RequireResolver::initialize()
{
if (!requireContext.isRequireAllowed())
{
errorHandler.reportError("require is not supported in this context");
return false;
}

return findModuleImpl();
if (isAbsolutePath(pathToResolve))
{
errorHandler.reportError("cannot require an absolute path");
return false;
}

std::replace(pathToResolve.begin(), pathToResolve.end(), '\\', '/');

if (!hasValidPrefix(pathToResolve))
{
errorHandler.reportError("require path must start with a valid prefix: ./, ../, or @");
return false;
}

return substituteAliasIfPresent(pathToResolve);
}

RequireResolver::ModuleStatus RequireResolver::findModuleImpl()
RequireResolver::ModuleStatus RequireResolver::findModule()
{
if (isPathAmbiguous(absolutePath))
return ModuleStatus::Ambiguous;
if (!resolveAndStoreDefaultPaths())
return ModuleStatus::ErrorReported;

static const std::array<const char*, 4> possibleSuffixes = {".luau", ".lua", "/init.luau", "/init.lua"};
if (isPathAmbiguous(resolvedRequire.absolutePath))
{
errorHandler.reportError("require path could not be resolved to a unique file");
return ModuleStatus::ErrorReported;
}

size_t unsuffixedAbsolutePathSize = absolutePath.size();
static constexpr std::array<const char*, 4> possibleSuffixes = {".luau", ".lua", "/init.luau", "/init.lua"};
size_t unsuffixedAbsolutePathSize = resolvedRequire.absolutePath.size();

for (const char* possibleSuffix : possibleSuffixes)
{
absolutePath += possibleSuffix;
resolvedRequire.absolutePath += possibleSuffix;

// Check cache for module
lua_getfield(L, -1, absolutePath.c_str());
if (!lua_isnil(L, -1))
{
if (cacheManager.isCached(resolvedRequire.absolutePath))
return ModuleStatus::Cached;
}
lua_pop(L, 1);

// Try to read the matching file
std::optional<std::string> source = readFile(absolutePath);
if (source)
if (std::optional<std::string> source = readFile(resolvedRequire.absolutePath))
{
chunkname = "=" + chunkname + possibleSuffix;
sourceCode = *source;
resolvedRequire.identifier = requireContext.createNewIdentifer(resolvedRequire.identifier + possibleSuffix);
resolvedRequire.sourceCode = *source;
return ModuleStatus::FileRead;
}

absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix
resolvedRequire.absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix
}

if (hasFileExtension(absolutePath, {".luau", ".lua"}) && isFile(absolutePath))
luaL_argerrorL(L, 1, "error requiring module: consider removing the file extension");

return ModuleStatus::NotFound;
}

bool RequireResolver::isPathAmbiguous(const std::string& path)
{
bool found = false;
for (const char* suffix : {".luau", ".lua"})
if (hasFileExtension(resolvedRequire.absolutePath, {".luau", ".lua"}) && isFile(resolvedRequire.absolutePath))
{
if (isFile(path + suffix))
{
if (found)
return true;
else
found = true;
}
errorHandler.reportError("error requiring module: consider removing the file extension");
return ModuleStatus::ErrorReported;
}
if (isDirectory(path) && found)
return true;

return false;
}

bool RequireResolver::isRequireAllowed(std::string_view sourceChunkname)
{
LUAU_ASSERT(!sourceChunkname.empty());
return (sourceChunkname[0] == '=' || sourceChunkname[0] == '@');
}

bool RequireResolver::isPrefixValid()
{
return pathToResolve.compare(0, 2, "./") == 0 || pathToResolve.compare(0, 3, "../") == 0 || pathToResolve.compare(0, 1, "@") == 0;
errorHandler.reportError(kRequireErrorGeneric);
return ModuleStatus::ErrorReported;
}

void RequireResolver::resolveAndStoreDefaultPaths()
bool RequireResolver::resolveAndStoreDefaultPaths()
{
if (!isAbsolutePath(pathToResolve))
{
std::string chunknameContext = getRequiringContextRelative();
std::string identifierContext = getRequiringContextRelative();
std::optional<std::string> absolutePathContext = getRequiringContextAbsolute();

if (!absolutePathContext)
luaL_errorL(L, "error requiring module");
return false;

// resolvePath automatically sanitizes/normalizes the paths
std::optional<std::string> chunknameOpt = resolvePath(pathToResolve, chunknameContext);
std::optional<std::string> absolutePathOpt = resolvePath(pathToResolve, *absolutePathContext);

if (!chunknameOpt || !absolutePathOpt)
luaL_errorL(L, "error requiring module");

chunkname = std::move(*chunknameOpt);
absolutePath = std::move(*absolutePathOpt);
resolvedRequire.identifier = resolvePath(pathToResolve, identifierContext);
resolvedRequire.absolutePath = resolvePath(pathToResolve, *absolutePathContext);
}
else
{
// Here we must explicitly sanitize, as the path is taken as is
std::optional<std::string> sanitizedPath = normalizePath(pathToResolve);
if (!sanitizedPath)
luaL_errorL(L, "error requiring module");

chunkname = *sanitizedPath;
absolutePath = std::move(*sanitizedPath);
std::string sanitizedPath = normalizePath(pathToResolve);
resolvedRequire.identifier = sanitizedPath;
resolvedRequire.absolutePath = std::move(sanitizedPath);
}
return true;
}

std::optional<std::string> RequireResolver::getRequiringContextAbsolute()
{
std::string requiringFile;
if (isAbsolutePath(sourceChunkname.substr(1)))
if (isAbsolutePath(requireContext.getPath()))
{
// We already have an absolute path for the requiring file
requiringFile = sourceChunkname.substr(1);
requiringFile = requireContext.getPath();
}
else
{
// Requiring file's stored path is relative to the CWD, must make absolute
std::optional<std::string> cwd = getCurrentWorkingDirectory();
if (!cwd)
{
errorHandler.reportError("could not determine current working directory");
return std::nullopt;
}

if (sourceChunkname.substr(1) == "stdin")
if (requireContext.isStdin())
{
// Require statement is being executed from REPL input prompt
// The requiring context is the pseudo-file "stdin" in the CWD
@@ -176,11 +181,7 @@ std::optional<std::string> RequireResolver::getRequiringContextAbsolute()
else
{
// Require statement is being executed in a file, must resolve relative to CWD
std::optional<std::string> requiringFileOpt = resolvePath(sourceChunkname.substr(1), joinPaths(*cwd, "stdin"));
if (!requiringFileOpt)
return std::nullopt;

requiringFile = *requiringFileOpt;
requiringFile = resolvePath(requireContext.getPath(), joinPaths(*cwd, "stdin"));
}
}
std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/');
@@ -189,17 +190,13 @@ std::optional<std::string> RequireResolver::getRequiringContextAbsolute()

std::string RequireResolver::getRequiringContextRelative()
{
std::string baseFilePath;
if (sourceChunkname.substr(1) != "stdin")
baseFilePath = sourceChunkname.substr(1);

return baseFilePath;
return requireContext.isStdin() ? "" : requireContext.getPath();
}

void RequireResolver::substituteAliasIfPresent(std::string& path)
bool RequireResolver::substituteAliasIfPresent(std::string& path)
{
if (path.size() < 1 || path[0] != '@')
return;
return true;

// To ignore the '@' alias prefix when processing the alias
const size_t aliasStartPos = 1;
@@ -215,17 +212,19 @@ void RequireResolver::substituteAliasIfPresent(std::string& path)

// Not worth searching when potentialAlias cannot be an alias
if (!Luau::isValidAlias(potentialAlias))
luaL_errorL(L, "@%s is not a valid alias", potentialAlias.c_str());

std::optional<std::string> alias = getAlias(potentialAlias);
if (alias)
{
path = *alias + path.substr(potentialAlias.size() + 1);
errorHandler.reportError("@" + potentialAlias + " is not a valid alias");
return false;
}
else

if (std::optional<std::string> alias = getAlias(potentialAlias))
{
luaL_errorL(L, "@%s is not a valid alias", potentialAlias.c_str());
path = *alias + path.substr(potentialAlias.size() + 1);
return true;
}

errorHandler.reportError("@" + potentialAlias + " is not a valid alias");
return false;
}

std::optional<std::string> RequireResolver::getAlias(std::string alias)
@@ -241,7 +240,8 @@ std::optional<std::string> RequireResolver::getAlias(std::string alias)
);
while (!config.aliases.contains(alias) && !isConfigFullyResolved)
{
parseNextConfig();
if (!parseNextConfig())
return std::nullopt; // error parsing config
}
if (!config.aliases.contains(alias) && isConfigFullyResolved)
return std::nullopt; // could not find alias
@@ -250,17 +250,17 @@ std::optional<std::string> RequireResolver::getAlias(std::string alias)
return resolvePath(aliasInfo.value, aliasInfo.configLocation);
}

void RequireResolver::parseNextConfig()
bool RequireResolver::parseNextConfig()
{
if (isConfigFullyResolved)
return; // no config files left to parse
return true; // no config files left to parse

std::optional<std::string> directory;
if (lastSearchedDir.empty())
{
std::optional<std::string> requiringFile = getRequiringContextAbsolute();
if (!requiringFile)
luaL_errorL(L, "error requiring module");
return false;

directory = getParentPath(*requiringFile);
}
@@ -270,13 +270,16 @@ void RequireResolver::parseNextConfig()
if (directory)
{
lastSearchedDir = *directory;
parseConfigInDirectory(*directory);
if (!parseConfigInDirectory(*directory))
return false;
}
else
isConfigFullyResolved = true;

return true;
}

void RequireResolver::parseConfigInDirectory(const std::string& directory)
bool RequireResolver::parseConfigInDirectory(const std::string& directory)
{
std::string configPath = joinPaths(directory, Luau::kConfigName);

@@ -291,6 +294,11 @@ void RequireResolver::parseConfigInDirectory(const std::string& directory)
{
std::optional<std::string> error = Luau::parseConfig(*contents, config, opts);
if (error)
luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str());
{
errorHandler.reportError("error parsing " + configPath + "(" + *error + ")");
return false;
}
}
}

return true;
}
66 changes: 43 additions & 23 deletions CLI/Require.h
Original file line number Diff line number Diff line change
@@ -1,64 +1,84 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once

#include "lua.h"
#include "lualib.h"

#include "Luau/Config.h"

#include <functional>
#include <string>
#include <string_view>

class RequireResolver
{
public:
std::string chunkname;
std::string absolutePath;
std::string sourceCode;

enum class ModuleStatus
{
Cached,
FileRead,
Ambiguous,
NotFound
ErrorReported
};

struct ResolvedRequire
{
ModuleStatus status;
std::string chunkName;
std::string identifier;
std::string absolutePath;
std::string sourceCode;
};

[[nodiscard]] ResolvedRequire static resolveRequire(lua_State* L, std::string path);
struct RequireContext
{
virtual ~RequireContext() = default;
virtual std::string getPath() = 0;
virtual bool isRequireAllowed() = 0;
virtual bool isStdin() = 0;
virtual std::string createNewIdentifer(const std::string& path) = 0;
};

struct CacheManager
{
virtual ~CacheManager() = default;
virtual bool isCached(const std::string& path)
{
return false;
}
};

struct ErrorHandler
{
virtual ~ErrorHandler() = default;
virtual void reportError(const std::string message) {}
};

RequireResolver(std::string pathToResolve, RequireContext& requireContext, CacheManager& cacheManager, ErrorHandler& errorHandler);

[[nodiscard]] ResolvedRequire resolveRequire(std::function<void(const ModuleStatus)> completionCallback = nullptr);

private:
std::string pathToResolve;
std::string_view sourceChunkname;

RequireResolver(lua_State* L, std::string path);
RequireContext& requireContext;
CacheManager& cacheManager;
ErrorHandler& errorHandler;

ResolvedRequire resolvedRequire;
bool isRequireResolved = false;

ModuleStatus findModule();
lua_State* L;
Luau::Config config;
std::string lastSearchedDir;
bool isConfigFullyResolved = false;

bool isRequireAllowed(std::string_view sourceChunkname);
bool isPrefixValid();
[[nodiscard]] bool initialize();

void resolveAndStoreDefaultPaths();
ModuleStatus findModule();
ModuleStatus findModuleImpl();
bool isPathAmbiguous(const std::string& path);

[[nodiscard]] bool resolveAndStoreDefaultPaths();
std::optional<std::string> getRequiringContextAbsolute();
std::string getRequiringContextRelative();

void substituteAliasIfPresent(std::string& path);
[[nodiscard]] bool substituteAliasIfPresent(std::string& path);
std::optional<std::string> getAlias(std::string alias);

void parseNextConfig();
void parseConfigInDirectory(const std::string& path);
};
[[nodiscard]] bool parseNextConfig();
[[nodiscard]] bool parseConfigInDirectory(const std::string& directory);
};
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.c
REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o)
REPL_CLI_TARGET=$(BUILD)/luau

ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Analyze.cpp
ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Require.cpp CLI/Analyze.cpp
ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o)
ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze

4 changes: 3 additions & 1 deletion Sources.cmake
Original file line number Diff line number Diff line change
@@ -410,7 +410,9 @@ endif()
if(TARGET Luau.Analyze.CLI)
# Luau.Analyze.CLI Sources
target_sources(Luau.Analyze.CLI PRIVATE
CLI/Analyze.cpp)
CLI/Analyze.cpp
CLI/Require.cpp
)
endif()

if(TARGET Luau.Ast.CLI)
7 changes: 7 additions & 0 deletions VM/src/lobject.cpp
Original file line number Diff line number Diff line change
@@ -116,6 +116,13 @@ const char* luaO_pushfstring(lua_State* L, const char* fmt, ...)
return msg;
}

// Possible chunkname prefixes:
//
// '=' prefix: meant to represent custom chunknames. When truncation is needed,
// the beginning of the chunkname is kept.
//
// '@' prefix: meant to represent filepaths. When truncation is needed, the end
// of the filepath is kept, as this is more useful for identifying the file.
const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t srclen)
{
if (*source == '=')
9 changes: 4 additions & 5 deletions tests/Fixture.cpp
Original file line number Diff line number Diff line change
@@ -26,7 +26,6 @@ static const char* mainModuleName = "MainModule";

LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAG(DebugLuauLogSolverToJsonFile)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)

LUAU_FASTFLAGVARIABLE(DebugLuauForceAllNewSolverTests);

@@ -243,21 +242,21 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars
return result.root;
}

CheckResult Fixture::check(Mode mode, const std::string& source)
CheckResult Fixture::check(Mode mode, const std::string& source, std::optional<FrontendOptions> options)
{
ModuleName mm = fromString(mainModuleName);
configResolver.defaultConfig.mode = mode;
fileResolver.source[mm] = std::move(source);
frontend.markDirty(mm);

CheckResult result = frontend.check(mm);
CheckResult result = frontend.check(mm, options);

return result;
}

CheckResult Fixture::check(const std::string& source)
CheckResult Fixture::check(const std::string& source, std::optional<FrontendOptions> options)
{
return check(Mode::Strict, source);
return check(Mode::Strict, source, options);
}

LintResult Fixture::lint(const std::string& source, const std::optional<LintOptions>& lintOptions)
4 changes: 2 additions & 2 deletions tests/Fixture.h
Original file line number Diff line number Diff line change
@@ -76,8 +76,8 @@ struct Fixture

// Throws Luau::ParseErrors if the parse fails.
AstStatBlock* parse(const std::string& source, const ParseOptions& parseOptions = {});
CheckResult check(Mode mode, const std::string& source);
CheckResult check(const std::string& source);
CheckResult check(Mode mode, const std::string& source, std::optional<FrontendOptions> = std::nullopt);
CheckResult check(const std::string& source, std::optional<FrontendOptions> = std::nullopt);

LintResult lint(const std::string& source, const std::optional<LintOptions>& lintOptions = {});
LintResult lintModule(const ModuleName& moduleName, const std::optional<LintOptions>& lintOptions = {});
910 changes: 764 additions & 146 deletions tests/FragmentAutocomplete.test.cpp

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions tests/NonStrictTypeChecker.test.cpp
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
#include "Fixture.h"

#include "Luau/Ast.h"
#include "Luau/BuiltinDefinitions.h"
#include "Luau/Common.h"
#include "Luau/IostreamHelpers.h"
#include "Luau/ModuleResolver.h"
@@ -13,6 +14,8 @@
#include "doctest.h"
#include <iostream>

LUAU_FASTFLAG(LuauCountSelfCallsNonstrict)

using namespace Luau;

#define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \
@@ -576,4 +579,25 @@ buffer.readi8(b, 0)
LUAU_REQUIRE_NO_ERRORS(result);
}

TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_method_calls")
{
ScopedFastFlag sff{FFlag::LuauCountSelfCallsNonstrict, true};

Luau::unfreeze(frontend.globals.globalTypes);
Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes);

registerBuiltinGlobals(frontend, frontend.globals);
registerTestTypes();

Luau::freeze(frontend.globals.globalTypes);
Luau::freeze(frontend.globalsForAutocomplete.globalTypes);

CheckResult result = checkNonStrict(R"(
local test = "test"
test:lower()
)");

LUAU_REQUIRE_NO_ERRORS(result);
}

TEST_SUITE_END();
65 changes: 65 additions & 0 deletions tests/Parser.test.cpp
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr)
LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2)
LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport)
LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams)

namespace
{
@@ -3678,5 +3679,69 @@ TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed")
matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses.");
}

TEST_CASE_FIXTURE(Fixture, "grouped_function_type")
{
ScopedFastFlag _{FFlag::LuauAllowComplexTypesInGenericParams, true};
const auto root = parse(R"(
type X<T> = T
local x: X<(() -> ())?>
)");
LUAU_ASSERT(root);
CHECK_EQ(root->body.size, 2);
auto assignment = root->body.data[1]->as<AstStatLocal>();
LUAU_ASSERT(assignment);
CHECK_EQ(assignment->vars.size, 1);
CHECK_EQ(assignment->values.size, 0);
auto binding = assignment->vars.data[0];
CHECK_EQ(binding->name, "x");
auto genericTy = binding->annotation->as<AstTypeReference>();
LUAU_ASSERT(genericTy);
CHECK_EQ(genericTy->parameters.size, 1);
auto paramTy = genericTy->parameters.data[0];
LUAU_ASSERT(paramTy.type);
auto unionTy = paramTy.type->as<AstTypeUnion>();
LUAU_ASSERT(unionTy);
CHECK_EQ(unionTy->types.size, 2);
CHECK(unionTy->types.data[0]->is<AstTypeFunction>()); // () -> ()
CHECK(unionTy->types.data[1]->is<AstTypeReference>()); // nil
}

TEST_CASE_FIXTURE(Fixture, "complex_union_in_generic_ty")
{
ScopedFastFlag _{FFlag::LuauAllowComplexTypesInGenericParams, true};
const auto root = parse(R"(
type X<T> = T
local x: X<
| number
| boolean
| string
>
)");
LUAU_ASSERT(root);
CHECK_EQ(root->body.size, 2);
auto assignment = root->body.data[1]->as<AstStatLocal>();
LUAU_ASSERT(assignment);
CHECK_EQ(assignment->vars.size, 1);
CHECK_EQ(assignment->values.size, 0);
auto binding = assignment->vars.data[0];
CHECK_EQ(binding->name, "x");
auto genericTy = binding->annotation->as<AstTypeReference>();
LUAU_ASSERT(genericTy);
CHECK_EQ(genericTy->parameters.size, 1);
auto paramTy = genericTy->parameters.data[0];
LUAU_ASSERT(paramTy.type);
auto unionTy = paramTy.type->as<AstTypeUnion>();
LUAU_ASSERT(unionTy);
CHECK_EQ(unionTy->types.size, 3);
// NOTE: These are `const char*` so we can compare them to `AstName`s later.
std::vector<const char*> expectedTypes{"number", "boolean", "string"};
for (size_t i = 0; i < expectedTypes.size(); i++)
{
auto ty = unionTy->types.data[i]->as<AstTypeReference>();
LUAU_ASSERT(ty);
CHECK_EQ(ty->name, expectedTypes[i]);
}
}


TEST_SUITE_END();
11 changes: 9 additions & 2 deletions tests/RequireByString.test.cpp
Original file line number Diff line number Diff line change
@@ -225,8 +225,8 @@ TEST_CASE("PathResolution")

CHECK(resolvePath("../module", "") == "../module");
CHECK(resolvePath("../../module", "") == "../../module");
CHECK(resolvePath("../module/..", "") == "..");
CHECK(resolvePath("../module/../..", "") == "../..");
CHECK(resolvePath("../module/..", "") == "../");
CHECK(resolvePath("../module/../..", "") == "../../");

CHECK(resolvePath("../dependency", prefix + "Users/modules/module.luau") == prefix + "Users/dependency");
CHECK(resolvePath("../dependency/", prefix + "Users/modules/module.luau") == prefix + "Users/dependency");
@@ -400,6 +400,13 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireInitLua")
REQUIRE_FALSE_MESSAGE(lua_isnil(L, -1), "Cache did not contain module result");
}

TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCachedResult")
{
std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/validate_cache";
runProtectedRequire(relativePath);
assertOutputContainsAll({"true"});
}

TEST_CASE_FIXTURE(ReplWithPathFixture, "LoadStringRelative")
{
runCode(L, "return pcall(function() return loadstring(\"require('a/relative/path')\")() end)");
1 change: 0 additions & 1 deletion tests/TypeInfer.builtins.test.cpp
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@
using namespace Luau;

LUAU_FASTFLAG(LuauSolverV2)
LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAG(LuauTypestateBuiltins2)
LUAU_FASTFLAG(LuauStringFormatArityFix)

30 changes: 23 additions & 7 deletions tests/TypeInfer.tables.test.cpp
Original file line number Diff line number Diff line change
@@ -18,10 +18,8 @@ using namespace Luau;
LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauInstantiateInSubtyping)
LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering)
LUAU_FASTFLAG(LuauAcceptIndexingTableUnionsIntersections)
LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack)

LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)
LUAU_FASTFLAG(LuauTableKeysAreRValues)

TEST_SUITE_BEGIN("TableTests");

@@ -4802,8 +4800,6 @@ end

TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table")
{
ScopedFastFlag sff{FFlag::LuauAcceptIndexingTableUnionsIntersections, true};

CheckResult result = check(R"(
local test = if true then { "meow", "woof" } else { 4, 81 }
local test2 = test[1]
@@ -4820,8 +4816,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table")

TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table2")
{
ScopedFastFlag sff{FFlag::LuauAcceptIndexingTableUnionsIntersections, true};

CheckResult result = check(R"(
local test = if true then {} else {}
local test2 = test[1]
@@ -4936,4 +4930,26 @@ TEST_CASE_FIXTURE(Fixture, "function_check_constraint_too_eager")
)"));
}


TEST_CASE_FIXTURE(BuiltinsFixture, "read_only_property_reads")
{
ScopedFastFlag newSolver{FFlag::LuauSolverV2, true};
ScopedFastFlag sff{FFlag::LuauTableKeysAreRValues, true};

// none of the `t.id` accesses here should error
auto result = check(R"(
--!strict
type readonlyTable = {read id: number}
local t:readonlyTable = {id = 1}

local _:{number} = {[t.id] = 1}
local _:{number} = {[t.id::number] = 1}

local arr:{number} = {}
arr[t.id] = 1
)");

LUAU_REQUIRE_NO_ERRORS(result);
}

TEST_SUITE_END();
6 changes: 1 addition & 5 deletions tests/TypeInfer.unionTypes.test.cpp
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@
using namespace Luau;

LUAU_FASTFLAG(LuauSolverV2)
LUAU_FASTFLAG(LuauAcceptIndexingTableUnionsIntersections)

TEST_SUITE_BEGIN("UnionTypes");

@@ -645,10 +644,7 @@ TEST_CASE_FIXTURE(Fixture, "indexing_into_a_cyclic_union_doesnt_crash")
// this is a cyclic union of number arrays, so it _is_ a table, even if it's a nonsense type.
// no need to generate a NotATable error here. The new solver automatically handles this and
// correctly reports no errors.
if (FFlag::LuauAcceptIndexingTableUnionsIntersections || FFlag::LuauSolverV2)
LUAU_REQUIRE_NO_ERRORS(result);
else
LUAU_REQUIRE_ERROR_COUNT(1, result);
LUAU_REQUIRE_NO_ERRORS(result);
}

TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect")
9 changes: 0 additions & 9 deletions tests/main.cpp
Original file line number Diff line number Diff line change
@@ -27,13 +27,10 @@
#include <sys/sysctl.h>
#endif

#include <limits> // TODO: remove with LuauTypeSolverRelease
#include <optional>

#include <stdio.h>

LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease)

// Indicates if verbose output is enabled; can be overridden via --verbose
// Currently, this enables output from 'print', but other verbose output could be enabled eventually.
bool verbose = false;
@@ -415,12 +412,6 @@ int main(int argc, char** argv)
printf("Using RNG seed %u\n", *randomSeed);
}

// New Luau type solver uses a temporary scheme where fixes are made under a single version flag
// When flags are enabled, new solver is enabled with all new features and fixes
// When it's disabled, this value should have no effect (all uses under a new solver)
// Flag setup argument can still be used to override this to a specific value if desired
DFInt::LuauTypeSolverRelease.value = std::numeric_limits<int>::max();

if (std::vector<doctest::String> flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags))
setFastFlags(flags);

4 changes: 4 additions & 0 deletions tests/require/without_config/validate_cache.luau
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
local result1 = require("./dependency")
local result2 = require("./dependency")
assert(result1 == result2)
return {}

0 comments on commit e905e30

Please sign in to comment.