Skip to content

Commit

Permalink
add type inferencer, new version, upgrades (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxdeliso authored Jan 2, 2024
1 parent 47b6d4c commit faa9986
Show file tree
Hide file tree
Showing 9 changed files with 1,023 additions and 484 deletions.
3 changes: 3 additions & 0 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@ export * from './church'
export * from './combinators'
export * from './evaluator'
export * from './expression'
export * from './lambda'
export * from './packer'
export * from './parser'
export * from './typedLambda'
export * from './types'
39 changes: 39 additions & 0 deletions lib/lambda.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { NonTerminal } from './nonterminal'

/**
* This is a single term variable with a name.
*
Expand All @@ -12,3 +14,40 @@ export const mkVar = (name: string): LambdaVar => ({
kind: 'lambda-var',
name
})

// λx.<body>, where x is a name
type UntypedLambdaAbs = {
kind: 'lambda-abs',
name: string,
// eslint-disable-next-line no-use-before-define
body: UntypedLambda
}

export const mkUntypedAbs =
// eslint-disable-next-line no-use-before-define
(name: string, body: UntypedLambda): UntypedLambda => ({
kind: 'lambda-abs',
name,
body
})

/**
* The legal terms of the untyped lambda calculus.
* e ::= x | λx.e | e e, where x is a variable name, and e is a valid expr
*/
export type UntypedLambda
= LambdaVar
| UntypedLambdaAbs
| NonTerminal<UntypedLambda>

export const prettyPrintUntypedLambda = (ut: UntypedLambda): string => {
switch (ut.kind) {
case 'lambda-var':
return ut.name
case 'lambda-abs':
return ${ut.name}.${prettyPrintUntypedLambda(ut.body)}`
case 'non-terminal':
return `(${prettyPrintUntypedLambda(ut.lft)}` +
`${prettyPrintUntypedLambda(ut.rgt)})`
}
}
53 changes: 28 additions & 25 deletions lib/typedLambda.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { LambdaVar } from './lambda'
import { NonTerminal } from './nonterminal'
import { Type, arrow, prettyPrintTy, typesEqual } from './types'
import {
Type,
arrow,
prettyPrintTy,
typesLitEq
} from './types'

/**
* This is a typed lambda abstraction, consisting of three parts.
Expand Down Expand Up @@ -49,13 +54,14 @@ export class TypeError extends Error { }
*/
export type Context = Map<string, Type>

export const addBinding = (ctx: Context, name: string, ty: Type): Context => {
if (ctx.get(name)) {
throw new TypeError('duplicated binding for name: ' + name)
}
export const addBinding =
(ctx: Context, name: string, ty: Type): Context => {
if (ctx.get(name)) {
throw new TypeError('duplicated binding for name: ' + name)
}

return ctx.set(name, ty)
}
return ctx.set(name, ty)
}

export const typecheck = (typedTerm: TypedLambda): Type => {
return typecheckGiven(new Map<string, Type>(), typedTerm)
Expand All @@ -71,8 +77,7 @@ export const typecheck = (typedTerm: TypedLambda): Type => {
*/
export const typecheckGiven = (ctx: Context, typedTerm: TypedLambda): Type => {
switch (typedTerm.kind) {
case 'lambda-var':
{
case 'lambda-var': {
const termName = typedTerm.name
const lookedUp = ctx.get(termName)

Expand All @@ -82,25 +87,23 @@ export const typecheckGiven = (ctx: Context, typedTerm: TypedLambda): Type => {

return lookedUp
}
case 'typed-lambda-abstraction':
{
case 'typed-lambda-abstraction': {
const updatedCtx = addBinding(ctx, typedTerm.varName, typedTerm.ty)
const bodyTy = typecheckGiven(updatedCtx, typedTerm.body)
return arrow(typedTerm.ty, bodyTy)
}
case 'non-terminal':
{
case 'non-terminal': {
const tyLft = typecheckGiven(ctx, typedTerm.lft)
const tyRgt = typecheckGiven(ctx, typedTerm.rgt)

if (tyLft.kind === 'type-var') {
throw new TypeError('arrow type expected')
if (tyLft.kind !== 'non-terminal') {
throw new TypeError('arrow type expected on lhs')
}

const takes = tyLft.lft
const gives = tyLft.rgt

if (!typesEqual(tyRgt, takes)) {
if (!typesLitEq(tyRgt, takes)) {
throw new TypeError('type mismatch')
}

Expand All @@ -109,7 +112,7 @@ export const typecheckGiven = (ctx: Context, typedTerm: TypedLambda): Type => {
}
}

export const prettyPrintTypedExpr = (expr: TypedLambda): string => {
export const prettyPrintTypedLambda = (expr: TypedLambda): string => {
switch (expr.kind) {
case 'lambda-var': {
return expr.name
Expand All @@ -120,27 +123,27 @@ export const prettyPrintTypedExpr = (expr: TypedLambda): string => {
':' +
prettyPrintTy(expr.ty) +
'.' +
prettyPrintTypedExpr(expr.body)
prettyPrintTypedLambda(expr.body)
}
case 'non-terminal': {
return '(' +
prettyPrintTypedExpr(expr.lft) +
prettyPrintTypedExpr(expr.rgt) +
prettyPrintTypedLambda(expr.lft) +
prettyPrintTypedLambda(expr.rgt) +
')'
}
}
}

export const typedTermsEq = (a: TypedLambda, b: TypedLambda): boolean => {
export const typedTermsLitEq = (a: TypedLambda, b: TypedLambda): boolean => {
if (a.kind === 'lambda-var' && b.kind === 'lambda-var') {
return a.name === b.name
} else if (a.kind === 'typed-lambda-abstraction' &&
b.kind === 'typed-lambda-abstraction') {
return typesEqual(a.ty, b.ty) &&
b.kind === 'typed-lambda-abstraction') {
return typesLitEq(a.ty, b.ty) &&
a.varName === b.varName &&
typedTermsEq(a.body, b.body)
typedTermsLitEq(a.body, b.body)
} else if (a.kind === 'non-terminal' && b.kind === 'non-terminal') {
return typedTermsEq(a.lft, b.lft) && typedTermsEq(a.rgt, b.rgt)
return typedTermsLitEq(a.lft, b.lft) && typedTermsLitEq(a.rgt, b.rgt)
} else {
return false
}
Expand Down
184 changes: 182 additions & 2 deletions lib/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { UntypedLambda } from './lambda'
import { NonTerminal, nt } from './nonterminal'
import { Context, TypedLambda, mkTypedAbs, typecheck } from './typedLambda'

export type TypeVariable = {
kind: 'type-var',
Expand All @@ -24,11 +26,17 @@ export const arrows = (...tys: Type[]): Type => tys.reduceRight(
(acc, ty) => nt<Type>(ty, acc)
)

export const typesEqual = (a: Type, b: Type): boolean => {
/**
* @param a some type
* @param b another type
* @returns true if the types are literally the same (i.e. composed of the
* same literals)
*/
export const typesLitEq = (a: Type, b: Type): boolean => {
if (a.kind === 'type-var' && b.kind === 'type-var') {
return a.typeName === b.typeName
} else if (a.kind === 'non-terminal' && b.kind === 'non-terminal') {
return typesEqual(a.lft, b.lft) && typesEqual(a.rgt, b.rgt)
return typesLitEq(a.lft, b.lft) && typesLitEq(a.rgt, b.rgt)
} else {
return false
}
Expand All @@ -41,3 +49,175 @@ export const prettyPrintTy = (ty: Type): string => {
return `(${prettyPrintTy(ty.lft)}${prettyPrintTy(ty.rgt)})`
}
}

/**
* This function runs a simplified variant of Algorithm W.
* https://en.wikipedia.org/wiki/Hindley%E2%80%93Milner_type_system
*
* @param term a term in the untyped lambda calculus.
* @returns [the term with types added, the type of that term].
* @throws TypeError if no valid type could be deduced.
*/
export const inferType = (
term: UntypedLambda
): [TypedLambda, Type] => {
const absBindings = new Map<string, Type>()
const inferredContext = new Map<string, Type>()
let inferred = algorithmW(term, tyVars(), absBindings, inferredContext)

inferredContext.forEach((combinedTy, termName) => {
const originalTy = absBindings.get(termName)
if (originalTy !== undefined && !typesLitEq(combinedTy, originalTy)) {
inferred = substituteType(inferred, originalTy, combinedTy)
}
})

const normalizationMappings = new Map<string, string>()
const vars = tyVars()
inferredContext.forEach((ty, termName) => {
inferredContext.set(termName, normalizeTy(ty, normalizationMappings, vars))
})

const typedTerm = attachTypes(term, inferredContext)

return [typedTerm, typecheck(typedTerm)]
}

const algorithmW = (
term: UntypedLambda,
nextVar: () => TypeVariable,
varBindings: Context,
constraints: Context): Type => {
switch (term.kind) {
case 'lambda-var': {
const contextType = varBindings.get(term.name)
if (contextType !== undefined) {
return contextType
} else {
return nextVar()
}
}
case 'lambda-abs': {
const paramType = nextVar()
varBindings.set(term.name, paramType)
constraints.set(term.name, paramType)
const bodyType = algorithmW(term.body, nextVar, varBindings, constraints)
return arrow(paramType, bodyType)
}

case 'non-terminal': {
const leftTy = algorithmW(term.lft, nextVar, varBindings, constraints)
const rgtTy = algorithmW(term.rgt, nextVar, varBindings, constraints)
const result = nextVar()
unify(leftTy, arrow(rgtTy, result), constraints)
return result
}
}
}

const unify = (
lft: Type,
rgt: Type,
unified: Context
): void => {
unified.forEach((contextType, termName) => {
const substituted = substituteType(contextType, lft, rgt)
unified.set(termName, substituted)
})
}

const substituteType = (original: Type, lft: Type, rgt: Type): Type => {
if (typesLitEq(lft, original)) {
return rgt
}

switch (original.kind) {
case 'type-var':
return original

case 'non-terminal':
return nt(
substituteType(original.lft, lft, rgt),
substituteType(original.rgt, lft, rgt)
)
}
}

function monoInts (): () => number {
let num = 0

const generator = () => {
const ret = num
num = num + 1
return ret
}

return generator
}

function tyVars (): () => TypeVariable {
const ordinals = monoInts()

const generator = () => {
const offset = ordinals()
if (offset > 25) {
throw new Error('too many variables')
}
const str = String.fromCharCode(97 + offset)
return mkTypeVar(str)
}

return generator
}

const normalizeTy = (
ty: Type,
mapping: Map<string, string> = new Map<string, string>(),
vars: () => TypeVariable): Type => {
switch (ty.kind) {
case 'type-var': {
const mapped = mapping.get(ty.typeName)

if (mapped === undefined) {
const newVar = vars()
mapping.set(ty.typeName, newVar.typeName)
return newVar
} else {
return mkTypeVar(mapped)
}
}
case 'non-terminal':
return nt(
normalizeTy(ty.lft, mapping, vars),
normalizeTy(ty.rgt, mapping, vars)
)
}
}

const attachTypes = (
untyped: UntypedLambda,
types: Context
): TypedLambda => {
switch (untyped.kind) {
case 'lambda-var':
return untyped
case 'lambda-abs': {
const ty = types.get(untyped.name)

if (ty === undefined) {
throw new TypeError('missing a type for term: ' + untyped.name)
}

return mkTypedAbs(
untyped.name,
ty,
attachTypes(untyped.body, types)
)
}
case 'non-terminal':
return nt(
attachTypes(untyped.lft, types),
attachTypes(untyped.rgt, types)
)
}
}
Loading

0 comments on commit faa9986

Please sign in to comment.