Skip to content

Commit

Permalink
Unify types: resolve type params
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanjermakov committed Nov 25, 2024
1 parent 917811f commit 36e94c7
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 91 deletions.
2 changes: 2 additions & 0 deletions src/ast/type.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { BaseAstNode, Param } from '.'
import { ParseNode, filterNonAstNodes } from '../parser'
import { Context } from '../scope'
import { InferredType } from '../typecheck'
import { Hole, buildHole } from './match'
import { Identifier, Name, buildIdentifier, buildName } from './operand'
import { FieldDef } from './type-def'
Expand Down Expand Up @@ -36,6 +37,7 @@ export type TypeParam = BaseAstNode & {
name: Name
key?: string
bounds: Identifier[]
unified?: InferredType
}

export const buildTypeParam = (node: ParseNode, ctx: Context): TypeParam => {
Expand Down
3 changes: 1 addition & 2 deletions src/phase/top-scope-type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { AstNode } from '../ast'
import { Identifier } from '../ast/operand'
import { fieldToParamType, paramToParamType, typeToParamType } from '../ast/type'
import { Context } from '../scope'
import { makeErrorType, makeTypeParam } from '../typecheck'
import { makeErrorType } from '../typecheck'
import { assert, todo, unreachable } from '../util/todo'

/**
Expand Down Expand Up @@ -152,7 +152,6 @@ export const setTopScopeType = (node: AstNode, ctx: Context) => {
break
}
case 'type-param': {
node.type = makeTypeParam(node)
break
}
}
Expand Down
9 changes: 4 additions & 5 deletions src/phase/type-bound.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,13 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
if (node.def.type) {
node.type = instantiateType(node.def.type, ctx)
} else {
node.type = node
node.type = instantiateType(node, ctx)
}
break
} else {
node.type = makeInferredType()
if (parentBound) {
addBounds(node.type!, [parentBound])
}
// no def means it *is* the definition
assert(!!parentBound)
node.type = parentBound
}
break
}
Expand Down
69 changes: 17 additions & 52 deletions src/phase/type-unify.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { typeError } from '../semantic/error'
import { ErrorType, InferredType, inferredTypeToString, makeErrorType } from '../typecheck'
import { dedup, zip } from '../util/array'
import { assign } from '../util/object'
import { assert, unreachable } from '../util/todo'
import { assert, todo, unreachable } from '../util/todo'

/**
* Unify type bounds
Expand Down Expand Up @@ -151,9 +151,6 @@ export const unifyType = (type: InferredType, ctx: Context): void => {
assign(type, unified)
break
}
case 'inferred-fn':
// TODO
break
case 'field-pattern': {
unifyType(type.operandType, ctx)
const variant = type.fieldPattern.variant
Expand Down Expand Up @@ -204,14 +201,9 @@ export const unifyType = (type: InferredType, ctx: Context): void => {
break
}
case 'identifier':
if (type.def?.type?.kind === 'type-param') {
assign(type, type.def.type)
break
}
break
case 'fn-type':
case 'inferred-fn':
case 'name':
case 'type-param':
case 'hole':
case 'error':
break
Expand Down Expand Up @@ -257,22 +249,6 @@ const unify_ = (a: InferredType, b: InferredType, ctx: Context): InferredType =>
}
return t
}
case 'fn-type': {
const t: InferredType = {
kind: 'inferred-fn',
// TODO
typeParams: [],
params: zip(a.params, b.params, (a_, b_) => {
assert(!!b_.type)
return {
name: a_.name === b_.name?.value ? a_.name : undefined,
type: unify(a_.type, b_.type!, ctx)
}
}),
returnType: unify(a.returnType, b.returnType, ctx)
}
return t
}
}
break
}
Expand All @@ -298,6 +274,21 @@ const unify_ = (a: InferredType, b: InferredType, ctx: Context): InferredType =>
return u
}
}
if (a.def?.kind === 'type-param') {
if (b.def?.kind === 'type-param') {
todo('tp <> tp')
}
if (a.def.unified) {
const u = unify(a.def.unified, b, ctx)
assign(a, u)
assign(b, u)
return u
} else {
a.def.unified = b
assign(a, b)
return b
}
}
}
case 'inferred-fn':
case 'fn-type':
Expand All @@ -309,28 +300,9 @@ const unify_ = (a: InferredType, b: InferredType, ctx: Context): InferredType =>
assign(a, e)
assign(b, e)
return a
case 'type-param':
break
}
break
}
case 'type-param': {
// HACK to unify method signatures unify(traitMethod.type, implMethod.type)
if (b.kind === 'type-param' && a.type.name.value === b.type.name.value) {
assign(b, a)
return a
}
if (a.unified) {
const u = unify(a.unified, b, ctx)
assign(a.unified, u)
return u
}
if (b.kind !== 'hole') {
// TODO: unify with type bounds
a.unified = b
}
return b
}
case 'hole':
return b
case 'error':
Expand Down Expand Up @@ -358,11 +330,6 @@ const extractReturnType = (type: InferredType, ctx: Context): InferredType | und
case 'field-pattern':
unifyType(type, ctx)
return extractReturnType(type, ctx)
case 'type-param':
if (type.unified) {
return extractReturnType(type.unified, ctx)
}
return undefined
case 'inferred-fn':
return type.returnType
case 'identifier':
Expand Down Expand Up @@ -408,8 +375,6 @@ const extractIds = (t: InferredType): Identifier[] => {
return [t]
}
break
case 'type-param':
return t.type.bounds
}
return []
}
8 changes: 7 additions & 1 deletion src/semantic/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ export const genericError = (ctx: Context, def: AstNode, msg: string = 'error',

export const typeError = (ctx: Context, node: AstNode, e: ErrorType_, notes?: string[]): SemanticError => {
const msg = `type error (${e.errorKind})${e.message ? `: ${e.message}` : ''}`
const stackStr = e.stack && e.stack.length > 0 ? e.stack.map(t => `\n in ${t}`).join('') : ''
const stackStr =
e.stack && e.stack.length > 0
? e.stack
.toReversed()
.map(t => `\n in ${t}`)
.join('')
: ''
return semanticError(45, ctx, node, `${msg}${stackStr}`, notes)
}
83 changes: 52 additions & 31 deletions src/typecheck/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { AstNode } from '../ast'
import { FieldPattern } from '../ast/match'
import { Identifier } from '../ast/operand'
import { FnType, Type, TypeParam } from '../ast/type'
import { Context, idToString } from '../scope'
import { assert } from '../util/todo'
import { assert, unreachable } from '../util/todo'

/**
* TODO: attach "source" node to a type to indicate where this type is coming from
Expand All @@ -13,14 +14,9 @@ export type InferredType =
kind: 'inferred'
bounds: InferredType[]
}
| {
kind: 'type-param'
type: TypeParam
unified?: InferredType
}
| {
kind: 'inferred-fn'
typeParams: InferredType[]
typeParams: TypeParam[]
params: InferredParamType[]
returnType: InferredType
}
Expand Down Expand Up @@ -61,8 +57,6 @@ export type ErrorType_ = {

export const makeInferredType = (bounds: InferredType[] = []) => ({ kind: <const>'inferred', bounds })

export const makeTypeParam = (type: TypeParam) => ({ kind: <const>'type-param', type })

export const makeFieldPatternType = (operandType: InferredType, fieldPattern: FieldPattern) => ({
kind: <const>'field-pattern',
operandType,
Expand All @@ -80,24 +74,59 @@ export const makeErrorType = (message?: string, errorKind: ErrorTypeKind = 'othe
}
})

export const makeInferredFnType = (t: FnType, ctx: Context) => ({
kind: <const>'inferred-fn',
typeParams: t.typeParams.map(makeTypeParam),
params: t.params.map(pt => {
assert(!!pt.type)
return { name: pt.name?.value, type: instantiateType(pt.type!, ctx) }
}),
returnType: instantiateType(t.returnType ?? ctx.stdTypeIds.unit, ctx)
})
export const makeInferredFnType = (t: FnType, ctx: Context) => {
const updateTypeParamRef = (t: InferredType, from: TypeParam, to: TypeParam) => {
switch (t.kind) {
case 'identifier':
if (t.def === from) {
t.def = to
} else {
t.typeArgs.forEach(ta => updateTypeParamRef(ta, from, to))
}
break
case 'inferred-fn':
t.params.forEach(p => updateTypeParamRef(p.type, from, to))
updateTypeParamRef(t.returnType, from, to)
break
case 'hole':
case 'error':
break
default:
unreachable(t.kind)
break
}
}

const typeParams = t.typeParams.map(tp => ({ ...tp }))

const fnType: InferredType = {
kind: <const>'inferred-fn',
typeParams,
params: t.params.map(pt => {
assert(!!pt.type)
return { name: pt.name?.value, type: instantiateType(pt.type!, ctx) }
}),
returnType: instantiateType(t.returnType ?? ctx.stdTypeIds.unit, ctx)
}

typeParams.forEach((tp, i) => {
fnType.params.forEach(p => updateTypeParamRef(p.type, t.typeParams[i], tp))
updateTypeParamRef(fnType.returnType, t.typeParams[i], tp)
})

return fnType
}

export const instantiateType = (t: InferredType, ctx: Context): InferredType => {
switch (t.kind) {
case 'fn-type': {
return makeInferredFnType(t, ctx)
}
case 'identifier': {
return { ...t, typeArgs: t.typeArgs.map(ta => instantiateType(ta, ctx) as Identifier) }
}
default:
// dereference type to avoid modifying by unifyTypeBounds phase
return { ...t }
return t
}
}

Expand All @@ -106,15 +135,8 @@ export const inferredTypeToString = (t: InferredType, depth = 0): string => {
switch (t.kind) {
case 'inferred':
return `[${t.bounds.map(b => inferredTypeToString(b, depth + 1)).join(', ')}]`
case 'type-param':
const unified = t.unified ? ` (${inferredTypeToString(t.unified, depth + 1)})` : ''
const bounds =
t.type.bounds.length > 0
? `: ${t.type.bounds.map(b => inferredTypeToString(b, depth + 1)).join(' + ')}`
: ''
return `<${t.type.name.value}${bounds}${unified}>`
case 'inferred-fn':
const tps = t.typeParams.length > 0 ? `<${t.typeParams.map(inferredTypeToString)}>` : ''
const tps = t.typeParams.length > 0 ? `<${t.typeParams.map(typeParamToString)}>` : ''
const pts = t.params
.map(pt => `${pt.name ? `${pt.name}: ` : ''}${inferredTypeToString(pt.type)}`)
.join(', ')
Expand Down Expand Up @@ -150,14 +172,13 @@ export const typeToString = (t: Type): string => {
return '_'
case 'name':
return t.value
// default:
// return unreachable(inspect(t))
}
}

export const typeParamToString = (tp: TypeParam): string => {
const bounds = tp.bounds.length > 0 ? `: ${tp.bounds.map(b => typeToString(b)).join(' + ')}` : ''
return `${tp.name.value}${bounds}`
const unified = tp.unified ? ` (${inferredTypeToString(tp.unified)})` : ''
return `${tp.name.value}${bounds}${unified}`
}

export const addBounds = (type: InferredType, bounds: InferredType[]): void => {
Expand Down

0 comments on commit 36e94c7

Please sign in to comment.