Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add use statements in Amy #92

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
version = 3.7.1
runner.dialect = scala3

# force a blank first line
docstrings.blankFirstLine = yes
Expand Down
151 changes: 118 additions & 33 deletions compiler/src/main/scala/amyc/analyzer/NameAnalyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,32 @@ object NameAnalyzer extends Pipeline[N.Program, S.Program] {
registerModules(p)

// Step 2: Check name uniqueness in modules
for mod <- p.modules do
checkModuleConsistency(mod)
for mod <- p.modules do checkModuleConsistency(mod)

// Step 3: Discover types
for mod <- p.modules do
registerTypes(mod)
for mod <- p.modules do registerTypes(mod)

// Step 4: Discover type constructors
for m <- p.modules do
registerConstructors(m)
for m <- p.modules do registerConstructors(m)

// Step 5: Discover functions signatures.
for m <- p.modules do
registerFunctions(m)
for m <- p.modules do registerFunctions(m)

// Step 6: We now know all definitions in the program.
// Reconstruct modules and analyse function bodies/ expressions
transformProgram(p)

// TODO HR : Use statements in the tree are not necessary for here
// TODO HR : Drop them from the tree makes it easier to implement
// TODO HR : (no need to worry about type checking or code generation)

}

// ==============================================================================================
// ===================================== REGISTER FUCTIONS ======================================
// ==============================================================================================

/**
*
* @param mod
* @param Context
*/
Expand All @@ -66,51 +65,52 @@ object NameAnalyzer extends Pipeline[N.Program, S.Program] {
symbols.addFunction(mod.name, name, argTypes, retType2)

/**
*
* @param mod
* @param Context
*/
def registerConstructors(mod: N.ModuleDef)(using Context) =
for cc@N.CaseClassDef(name, fields, parent) <- mod.defs do
for cc @ N.CaseClassDef(name, fields, parent) <- mod.defs do
val argTypes = fields map (tt => transformType(tt, mod.name))
val retType = symbols.getType(mod.name, parent).getOrElse {
reporter.fatal(s"Parent class $parent not found", cc)
}
symbols.addConstructor(mod.name, name, argTypes, retType)

/**
*
* @param prog
* @param Context
*/
def registerModules(prog: N.Program)(using Context) =
val modNames = prog.modules.groupBy(_.name)
for (name, modules) <- modNames do
for(name, modules) <- modNames do
if modules.size > 1 then
reporter.fatal(s"Two modules named $name in program", modules.head.position)
reporter.fatal(
s"Two modules named $name in program",
modules.head.position
)
for mod <- modNames.keys.toList do
val id = symbols.addModule(mod)
ctx.withScope(id)

/**
*
* @param mod
* @param Context
*/
def registerTypes(mod: N.ModuleDef)(using Context) =
for N.AbstractClassDef(name) <- mod.defs do
symbols.addType(mod.name, name)
for N.AbstractClassDef(name) <- mod.defs do symbols.addType(mod.name, name)

/**
*
* @param mod
* @param Context
*/
def checkModuleConsistency(mod: N.ModuleDef)(using Context) =
val names = mod.defs.groupBy(_.name)
for (name, defs) <- names do
if (defs.size > 1) {
reporter.fatal(s"Two definitions named $name in module ${mod.name}", defs.head)
val names = mod.defs.groupBy(_.name)
for(name, defs) <- names do
if(defs.size > 1) {
reporter.fatal(
s"Two definitions named $name in module ${mod.name}",
defs.head
)
}

private def registerUnnamed(using Context) =
Expand All @@ -123,16 +123,101 @@ object NameAnalyzer extends Pipeline[N.Program, S.Program] {
symbols.addType(modName, "Unit")
symbols.addType(modName, "String")
// register bin op
symbols.addInfixFunction(modName, "+", List(S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType))
symbols.addInfixFunction(modName, "-", List(S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType))
symbols.addInfixFunction(modName, "*", List(S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType))
symbols.addInfixFunction(modName, "/", List(S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType))
symbols.addInfixFunction(modName, "%", List(S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType))
symbols.addInfixFunction(modName, "<", List(S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)), S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType))
symbols.addInfixFunction(modName, "<=", List(S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType), S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)), S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType))
symbols.addInfixFunction(modName, "&&", List(S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType), S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType)), S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType))
symbols.addInfixFunction(modName, "||", List(S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType), S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType)), S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType))
symbols.addInfixFunction(modName, "==", List(S.TTypeTree(NoType), S.TTypeTree(NoType)), S.TTypeTree(NoType)) // A lot of patches everywhere to make it work, this will remain like this until the implementation of polymorphic functions
symbols.addInfixFunction(modName, "++", List(S.ClassTypeTree(stdDef.StringType).withType(stdType.StringType), S.ClassTypeTree(stdDef.StringType).withType(stdType.StringType)), S.ClassTypeTree(stdDef.StringType).withType(stdType.StringType))
symbols.addInfixFunction(
modName,
"+",
List(
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
)
symbols.addInfixFunction(
modName,
"-",
List(
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
)
symbols.addInfixFunction(
modName,
"*",
List(
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
)
symbols.addInfixFunction(
modName,
"/",
List(
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
)
symbols.addInfixFunction(
modName,
"%",
List(
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
)
symbols.addInfixFunction(
modName,
"<",
List(
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
),
S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType)
)
symbols.addInfixFunction(
modName,
"<=",
List(
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType),
S.ClassTypeTree(stdDef.IntType).withType(stdType.IntType)
),
S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType)
)
symbols.addInfixFunction(
modName,
"&&",
List(
S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType),
S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType)
),
S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType)
)
symbols.addInfixFunction(
modName,
"||",
List(
S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType),
S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType)
),
S.ClassTypeTree(stdDef.BooleanType).withType(stdType.BooleanType)
)
symbols.addInfixFunction(
modName,
"==",
List(S.TTypeTree(NoType), S.TTypeTree(NoType)),
S.TTypeTree(NoType)
) // A lot of patches everywhere to make it work, this will remain like this until the implementation of polymorphic functions
symbols.addInfixFunction(
modName,
"++",
List(
S.ClassTypeTree(stdDef.StringType).withType(stdType.StringType),
S.ClassTypeTree(stdDef.StringType).withType(stdType.StringType)
),
S.ClassTypeTree(stdDef.StringType).withType(stdType.StringType)
)

}
53 changes: 23 additions & 30 deletions compiler/src/main/scala/amyc/analyzer/Scope.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,90 +8,87 @@ import amyc.core.Symbols.*
private type Bag = Map[String, Symbol]

/**
*
* @param parent
* @param params
* @param locals
*/
sealed case class Scope protected (parent: Option[Scope], params : Bag, locals : Bag) :
sealed case class Scope protected (
parent: Option[Scope],
params: Bag,
locals: Bag
):

/**
* Create a new Scope with the mapping (name -> id) in locals
* The parent of the new Scope is the caller
* @param str (name) -
* @param id (Identifier) -
* Create a new Scope with the mapping (name -> id) in locals The parent of
* the new Scope is the caller
* @param str
* (name) -
* @param id
* (Identifier) -
* @return
*/
final def withLocal(name : String, id : Symbol) : Scope =
final def withLocal(name: String, id: Symbol): Scope =
Scope(Some(this), params, locals + (name -> id))

/**
*
* @param locals
* @return
*/
final def withLocals(locals: Bag): Scope =
Scope(Some(this), params, locals)

/**
* Create a new Scope with the mapping (name -> id) in params
* The parent of the new Scope is the caller
* Create a new Scope with the mapping (name -> id) in params The parent of
* the new Scope is the caller
* @param name
* @param id
* @return
*/
final def withParam(name : String, id : Symbol) : Scope =
final def withParam(name: String, id: Symbol): Scope =
Scope(Some(this), params + (name -> id), locals)

/**
*
* @param params
* @return
*/
final def withParams(params : Bag): Scope =
final def withParams(params: Bag): Scope =
Scope(Some(this), params, locals)

/**
*
* @param name
* @return
*/
def isLocal(name: String) : Boolean =
def isLocal(name: String): Boolean =
locals.contains(name) || parent.map(_.isLocal(name)).get

/**
*
* @param name
* @return
*/
def isParam(name: String): Boolean =
params.contains(name) || parent.map(_.isParam(name)).get

/**
* Resolve a new in the current Scope
* The Search is recursive !!
* Resolve a new in the current Scope The Search is recursive !!
* @param name
* @param Context
* @return
*/
def resolve(name : String) : Option[Symbol] =
def resolve(name: String): Option[Symbol] =
resolveInScope(name) orElse parent.flatMap(_.resolve(name))

/**
* Resolve a new in the current Scope
* The Search is not recursive !!
* Resolve a new in the current Scope The Search is not recursive !!
* @param name
* @return
*/
def resolveInScope(name : String) : Option[Symbol] =
def resolveInScope(name: String): Option[Symbol] =
// Local variables shadow parameters!
locals.get(name) orElse params.get(name)


/**
*
*/
object Scope :
object Scope:

/**
* Create a fresh Scope
Expand All @@ -100,16 +97,14 @@ object Scope :
def fresh: Scope = EmptyScope

/**
*
* @param lhs
* @param rhs
* @param parent
* @return
*/
def combine(lhs : Scope, rhs : Scope)(using parent : Scope) : Scope =
def combine(lhs: Scope, rhs: Scope)(using parent: Scope): Scope =
Scope(Some(parent), lhs.params ++ rhs.params, lhs.locals ++ rhs.locals)


/**
* Empty Scope
*/
Expand All @@ -122,5 +117,3 @@ object EmptyScope extends Scope(None, Map.empty, Map.empty):
/* Override to avoid Exception in Scope::isParam */
override def isParam(name: String): Boolean =
false


Loading