Skip to content

Commit

Permalink
Unify type bounds: resolve method call with impl lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanjermakov committed Sep 11, 2024
1 parent b547a64 commit 55bcc30
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 67 deletions.
62 changes: 27 additions & 35 deletions src/phase/type-unify.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { AstNode } from '../ast'
import { Identifier } from '../ast/operand'
import { TypeDef } from '../ast/type-def'
import { Context, Definition, addError } from '../scope'
import { Context, addError } from '../scope'
import { typeError } from '../semantic/error'
import { findMethodDefForMethodCall } from '../semantic/impl'
import {
ErrorType,
InferredType,
addBounds,
boundFromCall,
inferredTypeToString,
instantiateDefType,
Expand Down Expand Up @@ -177,7 +175,9 @@ export const unifyType = (type: InferredType, ctx: Context): void => {
break
case 'field-access':
unifyType(type.operandType, ctx)
const typeDefs = extractDefs(type.operandType).filter(def => def.kind === 'type-def')
const typeDefs = extractIds(type.operandType)
.map(id => id.def)
.filter(def => def?.kind === 'type-def')
if (typeDefs.length > 1) {
// TODO
assign(type, makeErrorType('multiple defs', 'todo'))
Expand Down Expand Up @@ -206,39 +206,28 @@ export const unifyType = (type: InferredType, ctx: Context): void => {
break
case 'method-call':
unifyType(type.operandType, ctx)
const defs = extractDefs(type.operandType)
if (defs.length > 1) {
const operandIdCandidates = extractIds(type.operandType)
if (operandIdCandidates.length > 1) {
// TODO
assign(type, makeErrorType('multiple defs', 'todo'))
break
}
if (defs.length > 0) {
const def = defs[0]
const block =
def.kind === 'type-def' ? def.impl?.block : def.kind === 'trait-def' ? def.block : undefined
const m = block?.statements.find(s => s.kind === 'fn-def' && s.name.value === type.op.name.value)
if (!m) {
// TODO: check traits impld by operandType
const fnDef = findMethodDefForMethodCall(type, ctx)
if (!fnDef) {
const notFoundError = makeErrorType(
`method ${type.op.name.value} not found in type ${inferredTypeToString(type.operandType)}`,
'no-method'
)
assign(type, notFoundError)
break
}
const callType = makeInferredType([
instantiateDefType(fnDef.type!, ctx),
boundFromCall(type.op.call.args.map(a => a.type!))
])
assign(type, makeReturnType(callType))
unifyType(type, ctx)
if (operandIdCandidates.length > 0) {
const operandId = operandIdCandidates[0]
const fnDef = findMethodDefForMethodCall(operandId, type.op, ctx)
if (!fnDef) {
const notFoundError = makeErrorType(
`method ${type.op.name.value} not found in type ${inferredTypeToString(type.operandType)}`,
'no-method'
)
assign(type, notFoundError)
break
}
const mType = instantiateDefType(m.type!, ctx)
addBounds(mType, [boundFromCall(type.op.call.args.map(a => a.type!))])
assign(type, makeReturnType(mType))
const callType = makeInferredType([
instantiateDefType(fnDef.type!, ctx),
boundFromCall(type.op.call.args.map(a => a.type!))
])
assign(type, makeReturnType(callType))
unifyType(type, ctx)
break
}
Expand Down Expand Up @@ -335,7 +324,7 @@ const unify_ = (a: InferredType, b: InferredType, ctx: Context, stack: [string,
case 'fn-type':
case 'name':
const e = makeErrorType(
`failed unify [${[inferredTypeToString(a), inferredTypeToString(b)].join(', ')}]`,
`failed unify [${a.kind}, ${b.kind}] [${[a, b].map(inferredTypeToString).join(', ')}]`,
'no-unify'
)
assign(a, e)
Expand Down Expand Up @@ -428,15 +417,18 @@ export const findTypeErrors = (t: InferredType): ErrorType[] => {
return []
}

const extractDefs = (t: InferredType): Definition[] => {
/**
* TODO: better name
*/
const extractIds = (t: InferredType): Identifier[] => {
switch (t.kind) {
case 'identifier':
if (t.def) {
return [t.def]
return [t]
}
break
case 'type-param':
return <TypeDef[]>t.type.bounds.map(b => b.def)
return t.type.bounds
}
return []
}
62 changes: 30 additions & 32 deletions src/semantic/impl.ts
Original file line number Diff line number Diff line change
@@ -1,43 +1,41 @@
import { MethodCallOp } from '../ast/op'
import { FnDef, TraitDef } from '../ast/statement'
import { Context } from '../scope'
import { InferredType } from '../typecheck'
import { dedup } from '../util/array'
import { assert, todo, unreachable } from '../util/todo'
import { todo } from '../util/todo'

/**
* TODO(perf): track implemented traits for a type in typeDef.impldTraits, populate it in a separate phase
*/
export const findMethodDefForMethodCall = (type: InferredType, ctx: Context): FnDef | undefined => {
if (type.kind !== 'method-call') {
assert(false, type.kind)
return unreachable()
}
if (type.operandType.kind === 'identifier') {
const def = type.operandType.def
const impls = ctx.packages.flatMap(p =>
p.modules.flatMap(m => m.impls).filter(impl => impl.forTrait && impl.forTrait.def === def)
)
const impldTraits = dedup(
impls
.map(impl => impl.identifier)
.filter(i => i.def && i.def.kind === 'trait-def')
.map(i => <TraitDef>i.def)
)
const matchingFns = impldTraits.flatMap(t =>
t.block.statements
.filter(s => s.kind === 'fn-def' && s.name.value === type.op.name.value)
.map(s => <FnDef>s)
)
if (matchingFns.length === 1) {
return matchingFns[0]
} else {
// TODO: error
return undefined
}
}
if (type.operandType.kind === 'type-param') {
// TODO
export const findMethodDefForMethodCall = (
operandType: InferredType,
op: MethodCallOp,
ctx: Context
): FnDef | undefined => {
if (operandType.kind !== 'identifier') return todo(operandType.kind)

const def = operandType.def
const block = def?.kind === 'type-def' ? def?.impl?.block : def?.kind === 'trait-def' ? def?.block : undefined
const m = <FnDef>block?.statements.find(s => s.kind === 'fn-def' && s.name.value === op.name.value)
if (m) return m

const impls = ctx.packages.flatMap(p =>
p.modules.flatMap(m => m.impls).filter(impl => impl.forTrait && impl.forTrait.def === def)
)
const impldTraits = dedup(
impls
.map(impl => impl.identifier)
.filter(i => i.def && i.def.kind === 'trait-def')
.map(i => <TraitDef>i.def)
)
const matchingFns = impldTraits.flatMap(t =>
t.block.statements.filter(s => s.kind === 'fn-def' && s.name.value === op.name.value).map(s => <FnDef>s)
)
if (matchingFns.length === 1) {
return matchingFns[0]
} else {
// TODO: error
return undefined
}
return todo(type.operandType.kind)
}

0 comments on commit 55bcc30

Please sign in to comment.