Skip to content

Commit

Permalink
Type bounds: refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanjermakov committed Nov 21, 2024
1 parent 7fae681 commit 5725e68
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 70 deletions.
20 changes: 17 additions & 3 deletions src/ast/type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { Context } from '../scope'
import { assert } from '../util/todo'
import { Hole, buildHole } from './match'
import { Identifier, Name, buildIdentifier, buildName } from './operand'
import { FieldDef } from './type-def'

export type Type = Identifier | FnType | Hole | Name

Expand Down Expand Up @@ -85,14 +86,27 @@ export const paramToParamType = (param: Param): ParamType => {
kind: 'param-type',
parseNode: param.parseNode,
name: expr.kind === 'name' ? expr : undefined,
paramType: param.paramType!
paramType: param.paramType!,
type: param.type
}
}

export const typeToParamType = (type: Type): ParamType => {
export const typeToParamType = (type: Type, name?: Name): ParamType => {
return {
kind: 'param-type',
parseNode: type.parseNode,
paramType: type
name,
paramType: type,
type
}
}

export const fieldToParamType = (field: FieldDef): ParamType => {
return {
kind: 'param-type',
parseNode: field.parseNode,
name: field.name,
paramType: field.fieldType,
type: field.type
}
}
69 changes: 36 additions & 33 deletions src/phase/top-scope-type.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,10 @@
import { AstNode } from '../ast'
import { Identifier } from '../ast/operand'
import { paramToParamType, typeToParamType } from '../ast/type'
import { fieldToParamType, paramToParamType, typeToParamType } from '../ast/type'
import { Context } from '../scope'
import { makeErrorType, makeTypeParam } from '../typecheck'
import { assert, todo, unreachable } from '../util/todo'

/**
* Set inferred def types of topScope nodes
*/
export const setTopScopeDefType = (node: AstNode, ctx: Context) => {
switch (node.kind) {
case 'module': {
node.block.statements.forEach(s => setTopScopeDefType(s, ctx))
break
}
case 'trait-def': {
node.typeParams.forEach(g => setTopScopeDefType(g, ctx))
break
}
case 'type-def': {
node.typeParams.forEach(g => setTopScopeDefType(g, ctx))
break
}
case 'type-param': {
node.type = makeTypeParam(node)
break
}
}
}

/**
* Set inferred types of topScope nodes
*/
Expand All @@ -39,9 +15,17 @@ export const setTopScopeType = (node: AstNode, ctx: Context) => {
break
}
case 'var-def': {
if (node.pattern.expr.kind !== 'name') return unreachable()
if (node.pattern.expr.kind !== 'name') return unreachable('top level destructuring')
if (node.expr) {
setTopScopeType(node.expr, ctx)
}
const def = node.pattern.expr
def.type = node.varType ?? makeErrorType()
def.type = node.varType ?? node.expr?.type ?? makeErrorType()
break
}
case 'operand-expr': {
setTopScopeType(node.operand, ctx)
node.type = node.operand.type
break
}
case 'fn-def': {
Expand All @@ -68,18 +52,18 @@ export const setTopScopeType = (node: AstNode, ctx: Context) => {
kind: 'identifier',
parseNode: node.parseNode,
names: [node.typeDef!.name],
typeArgs: node.typeDef!.typeParams.map(g => ({
typeArgs: node.typeDef!.typeParams.map(ta => ({
kind: 'identifier',
names: [g.name],
names: [ta.name],
typeArgs: [],
def: g
def: ta
})),
def: node.typeDef
}
const fnType: AstNode = {
kind: <const>'fn-type',
typeParams: node.typeDef!.typeParams,
paramTypes: node.fields.map(f => typeToParamType(f.fieldType)),
paramTypes: node.fields.map(f => fieldToParamType(f)),
returnType: typeDefId
}
setTopScopeType(fnType, ctx)
Expand All @@ -89,8 +73,27 @@ export const setTopScopeType = (node: AstNode, ctx: Context) => {
case 'field-def': {
assert(!!node.fieldType)
setTopScopeType(node.fieldType!, ctx)
node.type = node.fieldType!
setTopScopeType(node.name, ctx)
// TODO: ugly
const typeDef = node.variant!.typeDef!
const typeDefId: Identifier = {
kind: 'identifier',
parseNode: node.parseNode,
names: [typeDef.name],
typeArgs: typeDef.typeParams.map(ta => ({
kind: 'identifier',
names: [ta.name],
typeArgs: [],
def: ta
})),
def: typeDef
}
const accessorType: AstNode = {
kind: <const>'fn-type',
typeParams: typeDef.typeParams,
paramTypes: [typeToParamType(typeDefId)],
returnType: node.fieldType
}
node.type = accessorType
break
}
case 'trait-def':
Expand Down
26 changes: 13 additions & 13 deletions src/phase/type-bound.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
InferredType,
addBounds,
boundFromCall,
instantiateDefType,
instantiateDefType as instantiateType,
makeErrorType,
makeFieldPatternType,
makeInferredType,
Expand Down Expand Up @@ -87,13 +87,13 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
}
if (node.statements.length === 0) {
addBounds(node.type!, [
instantiateDefType(ctx.stdTypeIds.unit?.type ?? makeErrorType('no def', 'no-def'), ctx)
instantiateType(ctx.stdTypeIds.unit?.type ?? makeErrorType('no def', 'no-def'), ctx)
])
}
break
}
case 'param': {
collectTypeBounds(node.pattern, ctx, instantiateDefType(node.paramType!.type!, ctx))
collectTypeBounds(node.pattern, ctx, instantiateType(node.paramType!.type!, ctx))
break
}
case 'type-param': {
Expand Down Expand Up @@ -133,7 +133,7 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
case 'name': {
if (node.def) {
assert(!!node.def.type, `no def type ${inspect(node.def)}`)
node.type = instantiateDefType(node.def.type!, ctx)
node.type = instantiateType(node.def.type!, ctx)
break
} else {
node.type = makeInferredType()
Expand All @@ -152,7 +152,7 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
collectTypeBounds(node.operand, ctx)
switch (node.op.kind) {
case 'call-op': {
const fnType = instantiateDefType(node.operand.type!, ctx)
const fnType = makeInferredType([instantiateType(node.operand.type!, ctx)])
node.op.args.forEach(a => collectTypeBounds(a, ctx))
addBounds(fnType, [boundFromCall(node.op.args.map(a => a.type!))])
node.type = makeReturnType(fnType)
Expand Down Expand Up @@ -182,7 +182,7 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
assert(!!methodId)
const methodDef = findById(methodId!, ctx)
assert(!!methodDef)
const fnType = instantiateDefType(methodDef!.type!, ctx)
const fnType = instantiateType(methodDef!.type!, ctx)
addBounds(fnType, [boundFromCall([node.lOperand.type!, node.rOperand.type!])])
node.type = makeReturnType(fnType)
break
Expand Down Expand Up @@ -213,7 +213,7 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
collectTypeBounds(node.expr, ctx, node.varType ? node.varType : undefined)
}
collectTypeBounds(node.pattern, ctx, node.expr?.type)
node.type = instantiateDefType(ctx.stdTypeIds.unit?.type ?? makeErrorType('no def', 'no-def'), ctx)
node.type = instantiateType(ctx.stdTypeIds.unit?.type ?? makeErrorType('no def', 'no-def'), ctx)
break
}
case 'fn-def': {
Expand All @@ -239,27 +239,27 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
}
case 'string-interpolated': {
node.tokens.filter(t => typeof t !== 'string').forEach(t => collectTypeBounds(t, ctx))
node.type = instantiateDefType(ctx.stdTypeIds.string?.type ?? makeErrorType('no def', 'no-def'), ctx)
node.type = instantiateType(ctx.stdTypeIds.string?.type ?? makeErrorType('no def', 'no-def'), ctx)
break
}
case 'string-literal': {
node.type = instantiateDefType(ctx.stdTypeIds.string?.type ?? makeErrorType('no def', 'no-def'), ctx)
node.type = instantiateType(ctx.stdTypeIds.string?.type ?? makeErrorType('no def', 'no-def'), ctx)
break
}
case 'char-literal': {
node.type = instantiateDefType(ctx.stdTypeIds.char?.type ?? makeErrorType('no def', 'no-def'), ctx)
node.type = instantiateType(ctx.stdTypeIds.char?.type ?? makeErrorType('no def', 'no-def'), ctx)
break
}
case 'int-literal': {
node.type = instantiateDefType(ctx.stdTypeIds.int?.type ?? makeErrorType('no def', 'no-def'), ctx)
node.type = instantiateType(ctx.stdTypeIds.int?.type ?? makeErrorType('no def', 'no-def'), ctx)
break
}
case 'float-literal': {
node.type = instantiateDefType(ctx.stdTypeIds.float?.type ?? makeErrorType('no def', 'no-def'), ctx)
node.type = instantiateType(ctx.stdTypeIds.float?.type ?? makeErrorType('no def', 'no-def'), ctx)
break
}
case 'bool-literal': {
node.type = instantiateDefType(ctx.stdTypeIds.bool?.type ?? makeErrorType('no def', 'no-def'), ctx)
node.type = instantiateType(ctx.stdTypeIds.bool?.type ?? makeErrorType('no def', 'no-def'), ctx)
break
}
case 'compose-op': {
Expand Down
3 changes: 1 addition & 2 deletions src/semantic/semantic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { resolveModuleScope } from '../phase/module-resolve'
import { resolveName } from '../phase/name-resolve'
import { setStdTypeIds } from '../phase/std-type'
import { desugar } from '../phase/sugar'
import { setTopScopeDefType, setTopScopeType } from '../phase/top-scope-type'
import { setTopScopeType } from '../phase/top-scope-type'
import { collectTypeBounds } from '../phase/type-bound'
import { unifyTypeBounds } from '../phase/type-unify'
import { Context, eachModule, idToString, pathToId } from '../scope'
Expand Down Expand Up @@ -48,7 +48,6 @@ describe('semantic', () => {
setStdTypeIds,
desugar,
resolveName,
setTopScopeDefType,
setTopScopeType,
collectTypeBounds,
unifyTypeBounds
Expand Down
32 changes: 13 additions & 19 deletions src/typecheck/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { FieldPattern } from '../ast/match'
import { Type, TypeParam } from '../ast/type'
import { FnType, Type, TypeParam } from '../ast/type'
import { Context, idToString } from '../scope'
import { assert } from '../util/todo'

Expand All @@ -19,7 +19,7 @@ export type InferredType =
}
| {
kind: 'inferred-fn'
generics: InferredType[]
typeParams: InferredType[]
params: InferredType[]
returnType: InferredType
}
Expand Down Expand Up @@ -74,22 +74,16 @@ export const makeErrorType = (message?: string, errorKind: ErrorTypeKind = 'othe
}
})

export const instantiateDefType = (t: InferredType, ctx: Context): InferredType => {
export const instantiateDefType = <T extends InferredType>(t: T, ctx: Context): T => {
switch (t.kind) {
case 'fn-type': {
return makeInferredType([
{
kind: 'inferred-fn',
generics: t.typeParams.map(g => {
assert(!!g.type)
return instantiateDefType(g.type!, ctx)
}),
params: t.paramTypes.map(pt => {
return instantiateDefType(pt, ctx)
}),
returnType: instantiateDefType(t.returnType ?? ctx.stdTypeIds.unit, ctx)
}
])
const inst: FnType = {
kind: 'fn-type',
typeParams: t.typeParams.map(tp => ({ ...tp })),
paramTypes: t.paramTypes.map(pt => ({ ...pt, type: instantiateDefType(pt.type!, ctx) })),
returnType: instantiateDefType(t.returnType ?? ctx.stdTypeIds.unit, ctx)
}
return inst as any
}
default:
return t
Expand Down Expand Up @@ -136,7 +130,7 @@ export const typeToString = (t: Type): string => {
return idToString(t)
case 'fn-type':
const main = `fn(${t.paramTypes
.map(pt => `${pt.name.value}: ${typeToString(pt.paramType)}`)
.map(pt => (pt.name ? `${pt.name.value}: ` : '') + typeToString(pt.paramType))
.join(', ')}): ${typeToString(t.returnType)}`
const typeArgs = t.typeParams.length > 0 ? `<${t.typeParams.map(g => g.name.value).join(', ')}>` : ''
return typeArgs + main
Expand All @@ -152,9 +146,9 @@ export const addBounds = (type: InferredType, bounds: InferredType[]): void => {
type.bounds.push(...bounds)
return
}
assert(false, type.kind)
assert(false, `adding bounds to kind ${type.kind}`)
}

export const boundFromCall = (args: InferredType[]): InferredType => {
return { kind: 'inferred-fn', generics: [], params: args, returnType: { kind: 'hole' } }
return { kind: 'inferred-fn', typeParams: [], params: args, returnType: { kind: 'hole' } }
}

0 comments on commit 5725e68

Please sign in to comment.