Skip to content

Commit

Permalink
Merge pull request #4 from DePasqualeOrg/handle-llama-3.2
Browse files Browse the repository at this point in the history
Handle Llama 3.2 chat template
  • Loading branch information
johnmai-dev authored Oct 3, 2024
2 parents 4ffa95c + 2c91972 commit b435eb6
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 30 deletions.
4 changes: 4 additions & 0 deletions Sources/Ast.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,7 @@ struct KeywordArgumentExpression: Expression {
var key: Identifier
var value: any Expression
}

struct NullLiteral: Literal {
var value: Any? = nil
}
4 changes: 3 additions & 1 deletion Sources/Environment.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Environment {
args[0] is UndefinedValue
},
"equalto": { _ in
throw JinjaError.syntaxNotSupported
throw JinjaError.syntaxNotSupported("equalto")
},
]

Expand Down Expand Up @@ -165,6 +165,8 @@ class Environment {
}

return ObjectValue(value: object)
case is NullValue:
return NullValue()
default:
throw JinjaError.runtime("Cannot convert to runtime value: \(input) type:\(type(of: input))")
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ enum JinjaError: Error, LocalizedError {
case parser(String)
case runtime(String)
case todo(String)
case syntaxNotSupported
case syntaxNotSupported(String)

var errorDescription: String? {
switch self {
case .syntax(let message): return "Syntax error: \(message)"
case .parser(let message): return "Parser error: \(message)"
case .runtime(let message): return "Runtime error: \(message)"
case .todo(let message): return "Todo error: \(message)"
case .syntaxNotSupported: return "Syntax not supported"
case .syntaxNotSupported(let string): return "Syntax not supported: \(string)"
}
}
}
4 changes: 4 additions & 0 deletions Sources/Lexer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum TokenType: String {

case numericLiteral = "NumericLiteral"
case booleanLiteral = "BooleanLiteral"
case nullLiteral = "NullLiteral"
case stringLiteral = "StringLiteral"
case identifier = "Identifier"
case equals = "Equals"
Expand Down Expand Up @@ -69,8 +70,10 @@ let keywords: [String: TokenType] = [
"and": .and,
"or": .or,
"not": .not,
// Literals
"true": .booleanLiteral,
"false": .booleanLiteral,
"none": .nullLiteral,
]

func isWord(char: String) -> Bool {
Expand Down Expand Up @@ -226,6 +229,7 @@ func tokenize(_ source: String, options: PreprocessOptions = PreprocessOptions()
case .identifier,
.numericLiteral,
.booleanLiteral,
.nullLiteral,
.stringLiteral,
.closeParen,
.closeSquareBracket:
Expand Down
22 changes: 8 additions & 14 deletions Sources/Parser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -187,21 +187,18 @@ func parse(tokens: [Token]) throws -> Program {
while typeof(.is) {
current += 1
let negate = typeof(.not)

if negate {
current += 1
}

var filter = try parsePrimaryExpression()

if let boolLiteralFlter = filter as? BoolLiteral {
filter = Identifier(value: String(boolLiteralFlter.value))
if let boolLiteralFilter = filter as? BoolLiteral {
filter = Identifier(value: String(boolLiteralFilter.value))
} else if filter is NullLiteral {
filter = Identifier(value: "none")
}

if let test = filter as? Identifier {
operand = TestExpression(operand: operand as! Expression, negate: negate, test: test)
}
else {
} else {
throw JinjaError.syntax("Expected identifier for the test")
}
}
Expand Down Expand Up @@ -373,6 +370,9 @@ func parse(tokens: [Token]) throws -> Program {
case .booleanLiteral:
current += 1
return BoolLiteral(value: token.value == "true")
case .nullLiteral:
current += 1
return NullLiteral()
case .identifier:
current += 1
return Identifier(value: token.value)
Expand All @@ -389,13 +389,11 @@ func parse(tokens: [Token]) throws -> Program {
var values: [Expression] = []
while !typeof(.closeSquareBracket) {
try values.append(parseExpression() as! Expression)

if typeof(.comma) {
current += 1
}
}
current += 1

return ArrayLiteral(value: values)
case .openCurlyBracket:
current += 1
Expand All @@ -404,16 +402,12 @@ func parse(tokens: [Token]) throws -> Program {
let key = try parseExpression()
try expect(type: .colon, error: "Expected colon between key and value in object literal")
let value = try parseExpression()

values.append((key as! Expression, value as! Expression))

if typeof(.comma) {
current += 1
}
}

current += 1

return ObjectLiteral(value: values)
default:
throw JinjaError.syntax("Unexpected token: \(token.type)")
Expand Down
52 changes: 40 additions & 12 deletions Sources/Runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ struct Interpreter {
throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of:current))")
}
default:
throw JinjaError.syntaxNotSupported
throw JinjaError.syntaxNotSupported(String(describing: node.loopvar))
}

let evaluated = try self.evaluateBlock(statements: node.body, environment: scope)
Expand Down Expand Up @@ -353,21 +353,21 @@ struct Interpreter {
}
else if let left = left as? NumericValue, let right = right as? NumericValue {
switch node.operation.value {
case "+": throw JinjaError.syntaxNotSupported
case "-": throw JinjaError.syntaxNotSupported
case "*": throw JinjaError.syntaxNotSupported
case "/": throw JinjaError.syntaxNotSupported
case "+": throw JinjaError.syntaxNotSupported("+")
case "-": throw JinjaError.syntaxNotSupported("-")
case "*": throw JinjaError.syntaxNotSupported("*")
case "/": throw JinjaError.syntaxNotSupported("/")
case "%":
switch left.value {
case is Int:
return NumericValue(value: left.value as! Int % (right.value as! Int))
default:
throw JinjaError.runtime("Unknown value type:\(type(of: left.value))")
}
case "<": throw JinjaError.syntaxNotSupported
case ">": throw JinjaError.syntaxNotSupported
case ">=": throw JinjaError.syntaxNotSupported
case "<=": throw JinjaError.syntaxNotSupported
case "<": throw JinjaError.syntaxNotSupported("<")
case ">": throw JinjaError.syntaxNotSupported(">")
case ">=": throw JinjaError.syntaxNotSupported(">=")
case "<=": throw JinjaError.syntaxNotSupported("<=")
default:
throw JinjaError.runtime("Unknown operation type:\(node.operation.value)")
}
Expand All @@ -380,7 +380,7 @@ struct Interpreter {
}
}
else if right is ArrayValue {
throw JinjaError.syntaxNotSupported
throw JinjaError.syntaxNotSupported("right is ArrayValue")
}

if left is StringValue || right is StringValue {
Expand Down Expand Up @@ -428,7 +428,20 @@ struct Interpreter {
}

if left is StringValue, right is ObjectValue {
throw JinjaError.syntaxNotSupported
switch node.operation.value {
case "in":
if let leftString = (left as? StringValue)?.value,
let rightObject = right as? ObjectValue {
return BooleanValue(value: rightObject.value.keys.contains(leftString))
}
case "not in":
if let leftString = (left as? StringValue)?.value,
let rightObject = right as? ObjectValue {
return BooleanValue(value: !rightObject.value.keys.contains(leftString))
}
default:
throw JinjaError.runtime("Unsupported operation '\(node.operation.value)' between StringValue and ObjectValue")
}
}

throw JinjaError.syntax(
Expand Down Expand Up @@ -664,6 +677,17 @@ struct Interpreter {
throw JinjaError.runtime("Unknown filter: \(node.filter)")
}

func evaluateTestExpression(node: TestExpression, environment: Environment) throws -> any RuntimeValue {
let operand = try self.evaluate(statement: node.operand, environment: environment)

if let testFunction = environment.tests[node.test.value] {
let result = try testFunction(operand)
return BooleanValue(value: node.negate ? !result : result)
} else {
throw JinjaError.runtime("Unknown test: \(node.test.value)")
}
}

func evaluate(statement: Statement?, environment: Environment) throws -> any RuntimeValue {
if let statement {
switch statement {
Expand Down Expand Up @@ -693,8 +717,12 @@ struct Interpreter {
return BooleanValue(value: statement.value)
case let statement as FilterExpression:
return try self.evaluateFilterExpression(node: statement, environment: environment)
case let statement as TestExpression:
return try self.evaluateTestExpression(node: statement, environment: environment)
case is NullLiteral:
return NullValue()
default:
throw JinjaError.runtime("Unknown node type: \(type(of:statement))")
throw JinjaError.runtime("Unknown node type: \(type(of:statement)), statement: \(String(describing: statement))")
}
}
else {
Expand Down
1 change: 1 addition & 0 deletions Sources/Template.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public struct Template {

try env.set(name: "false", value: false)
try env.set(name: "true", value: true)
try env.set(name: "none", value: NullValue())
try env.set(
name: "raise_exception",
value: { (args: String) throws in
Expand Down
44 changes: 43 additions & 1 deletion Tests/LexerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ final class LexerTests: XCTestCase {
"UNDEFINED_VARIABLES": "{{ undefined_variable }}",
"UNDEFINED_ACCESS": "{{ object.undefined_attribute }}",

// Null
"NULL_VARIABLE": "{% if not null_val is defined %}{% set null_val = none %}{% endif %}{% if null_val is not none %}{{ 'fail' }}{% else %}{{ 'pass' }}{% endif %}",

// Ternary operator
"TERNARY_OPERATOR":
"|{{ 'a' if true else 'b' }}|{{ 'a' if false else 'b' }}|{{ 'a' if 1 + 1 == 2 else 'b' }}|{{ 'a' if 1 + 1 == 3 or 1 * 2 == 3 else 'b' }}|",
Expand Down Expand Up @@ -2032,7 +2035,7 @@ final class LexerTests: XCTestCase {
Token(value: "unknown", type: .stringLiteral),
Token(value: ")", type: .closeParen),
Token(value: "is", type: .is),
Token(value: "none", type: .identifier),
Token(value: "none", type: .nullLiteral),
Token(value: "}}", type: .closeExpression),
Token(value: "|", type: .text),
Token(value: "{{", type: .openExpression),
Expand Down Expand Up @@ -2177,6 +2180,45 @@ final class LexerTests: XCTestCase {
Token(value: "}}", type: .closeExpression),
],

// Null
"NULL_VARIABLE": [
Token(value: "{%", type: .openStatement),
Token(value: "if", type: .if),
Token(value: "not", type: .not),
Token(value: "null_val", type: .identifier),
Token(value: "is", type: .is),
Token(value: "defined", type: .identifier),
Token(value: "%}", type: .closeStatement),
Token(value: "{%", type: .openStatement),
Token(value: "set", type: .set),
Token(value: "null_val", type: .identifier),
Token(value: "=", type: .equals),
Token(value: "none", type: .nullLiteral),
Token(value: "%}", type: .closeStatement),
Token(value: "{%", type: .openStatement),
Token(value: "endif", type: .endIf),
Token(value: "%}", type: .closeStatement),
Token(value: "{%", type: .openStatement),
Token(value: "if", type: .if),
Token(value: "null_val", type: .identifier),
Token(value: "is", type: .is),
Token(value: "not", type: .not),
Token(value: "none", type: .nullLiteral),
Token(value: "%}", type: .closeStatement),
Token(value: "{{", type: .openExpression),
Token(value: "fail", type: .stringLiteral),
Token(value: "}}", type: .closeExpression),
Token(value: "{%", type: .openStatement),
Token(value: "else", type: .else),
Token(value: "%}", type: .closeStatement),
Token(value: "{{", type: .openExpression),
Token(value: "pass", type: .stringLiteral),
Token(value: "}}", type: .closeExpression),
Token(value: "{%", type: .openStatement),
Token(value: "endif", type: .endIf),
Token(value: "%}", type: .closeStatement),
],

// Ternary operator
"TERNARY_OPERATOR": [
Token(value: "|", type: .text),
Expand Down

0 comments on commit b435eb6

Please sign in to comment.