diff --git a/mustache.go b/mustache.go index 79ec82c..2f6230a 100644 --- a/mustache.go +++ b/mustache.go @@ -36,6 +36,7 @@ type Template struct { curline int dir string elems *vector.Vector + partials map[string]*Template } type parseError struct { @@ -54,7 +55,7 @@ var ( ) // taken from pkg/template -func htmlEscape(w io.Writer, s []byte) { +func htmlEscape(w io.Writer, s []byte) (err os.Error) { var esc []byte last := 0 for i, c := range s { @@ -72,11 +73,19 @@ func htmlEscape(w io.Writer, s []byte) { default: continue } - w.Write(s[last:i]) - w.Write(esc) + _, err = w.Write(s[last:i]) + if err != nil { + return + } + _, err = w.Write(esc) + if err != nil { + return + } last = i + 1 } - w.Write(s[last:]) + _, err = w.Write(s[last:]) + + return } func (tmpl *Template) readString(s string) (string, os.Error) { @@ -122,6 +131,10 @@ func (tmpl *Template) readString(s string) (string, os.Error) { } func (tmpl *Template) parsePartial(name string) (*Template, os.Error) { + if p := tmpl.GetPartial(name); p != nil { + return p, nil + } + filenames := []string{ path.Join(tmpl.dir, name), path.Join(tmpl.dir, name+".mustache"), @@ -443,7 +456,7 @@ loop: return v } -func renderSection(section *sectionElement, contextChain *vector.Vector, buf io.Writer) { +func renderSection(section *sectionElement, contextChain *vector.Vector, buf io.Writer) (err os.Error) { value := lookup(contextChain, section.name) var context = contextChain.At(contextChain.Len() - 1).(reflect.Value) var contexts = new(vector.Vector) @@ -474,37 +487,64 @@ func renderSection(section *sectionElement, contextChain *vector.Vector, buf io. ctx := contexts.At(j).(reflect.Value) contextChain.Push(ctx) for i := 0; i < section.elems.Len(); i++ { - renderElement(section.elems.At(i), contextChain, buf) + err = renderElement(section.elems.At(i), contextChain, buf) + if err != nil { + return + } } contextChain.Pop() } + + return } -func renderElement(element interface{}, contextChain *vector.Vector, buf io.Writer) { +func renderElement(element interface{}, contextChain *vector.Vector, buf io.Writer) (err os.Error) { switch elem := element.(type) { case *textElement: - buf.Write(elem.text) + _, err = buf.Write(elem.text) + if err != nil { + return + } case *varElement: val := lookup(contextChain, elem.name) if val.IsValid() { if elem.raw { - fmt.Fprint(buf, val.Interface()) + _, err = fmt.Fprint(buf, val.Interface()) + if err != nil { + return + } } else { s := fmt.Sprint(val.Interface()) - htmlEscape(buf, []byte(s)) + err = htmlEscape(buf, []byte(s)) + if err != nil { + return + } } } case *sectionElement: - renderSection(elem, contextChain, buf) + err = renderSection(elem, contextChain, buf) + if err != nil { + return + } case *Template: - elem.renderTemplate(contextChain, buf) + err = elem.renderTemplate(contextChain, buf) + if err != nil { + return + } } + + return } -func (tmpl *Template) renderTemplate(contextChain *vector.Vector, buf io.Writer) { +func (tmpl *Template) renderTemplate(contextChain *vector.Vector, buf io.Writer) (err os.Error) { for i := 0; i < tmpl.elems.Len(); i++ { - renderElement(tmpl.elems.At(i), contextChain, buf) + err = renderElement(tmpl.elems.At(i), contextChain, buf) + if err != nil { + return + } } + + return } func (tmpl *Template) Render(context ...interface{}) string { @@ -518,34 +558,146 @@ func (tmpl *Template) Render(context ...interface{}) string { return buf.String() } -func ParseString(data string) (*Template, os.Error) { - cwd := os.Getenv("CWD") - tmpl := Template{data, "{{", "}}", 0, 1, cwd, new(vector.Vector)} +func (tmpl *Template)RenderTo(w io.Writer, context ...interface{}) (os.Error) { + var contextChain vector.Vector + for _, c := range(context) { + val := reflect.ValueOf(c) + contextChain.Push(val) + } + return tmpl.renderTemplate(&contextChain, w) +} + +func (tmpl *Template)initpartials() { + if tmpl.partials == nil { + tmpl.partials = make(map[string]*Template) + } +} + +func (tmpl *Template)GetPartial(name string) (*Template) { + tmpl.initpartials() + + if p, ok := tmpl.partials[name]; ok { + return p + } + + return nil +} + +func (tmpl *Template)SetPartial(name string, p *Template) (os.Error) { + if p == nil { + return os.NewError("Can't set nil partial.") + } + + tmpl.initpartials() + + tmpl.partials[name] = p + + return nil +} + +func (tmpl *Template)DeletePartial(name string) { + tmpl.initpartials() + + tmpl.partials[name] = nil, false +} + +func (tmpl *Template)ParseReader(r io.Reader) (os.Error) { + var buf bytes.Buffer + _, err := io.Copy(&buf, r) + if err != nil { + return err + } + + return tmpl.ParseString(buf.String()) +} + +func (tmpl *Template)ParseString(data string) (os.Error) { + tmpl.data = data + tmpl.elems = new(vector.Vector) + err := tmpl.parse() if err != nil { - return nil, err + return err } - return &tmpl, err + return nil } -func ParseFile(filename string) (*Template, os.Error) { +func (tmpl *Template)ParseFile(filename string) (os.Error) { data, err := ioutil.ReadFile(filename) if err != nil { - return nil, err + return err } dirname, _ := path.Split(filename) - tmpl := Template{string(data), "{{", "}}", 0, 1, dirname, new(vector.Vector)} + tmpl.data = string(data) + tmpl.dir = dirname + tmpl.elems = new(vector.Vector) + err = tmpl.parse() if err != nil { - return nil, err + return err } - return &tmpl, nil + return nil +} + +func NewTemplate() (tmpl *Template) { + tmpl = new(Template) + + var cwd string + cwd, err := os.Getwd() + if err != nil { + cwd = "" + } + + tmpl = &Template{ + "", + "{{", "}}", + 0, + 1, + cwd, + new(vector.Vector), + nil, + } + + return +} + +func ParseReader(r io.Reader) (*Template, os.Error) { + tmpl := NewTemplate() + + err := tmpl.ParseReader(r) + if err != nil { + return nil, err + } + + return tmpl, nil +} + +func ParseString(data string) (*Template, os.Error) { + tmpl := NewTemplate() + + err := tmpl.ParseString(data) + if err != nil { + return nil, err + } + + return tmpl, nil +} + +func ParseFile(filename string) (*Template, os.Error) { + tmpl := NewTemplate() + + err := tmpl.ParseFile(filename) + if err != nil { + return nil, err + } + + return tmpl, nil } func Render(data string, context ...interface{}) string { diff --git a/mustache_test.go b/mustache_test.go index bdd5928..425835e 100644 --- a/mustache_test.go +++ b/mustache_test.go @@ -175,14 +175,81 @@ func TestPartial(t *testing.T) { } func TestSectionPartial(t *testing.T) { filename := path.Join(path.Join(os.Getenv("PWD"), "tests"), "test3.mustache") - expected := "Mike\nJoe\n" + tmpl := NewTemplate() + + part, err := ParseString("{{Name}}: {{Id}}") + if err != nil { + t.Fatalf("Error parsing string: %v\n", err) + } + err = tmpl.SetPartial("codepart", part) + if err != nil { + t.Fatalf("Error setting partial: %v\n", err) + } + + err = tmpl.ParseFile(filename) + if err != nil { + t.Fatalf("Error parsing %v: %v\n", filename, err) + } + + expected := "Mike: 1\nJoe: 2\n" context := map[string]interface{}{"users": []User{{"Mike", 1}, {"Joe", 2}}} - output := RenderFile(filename, context) + output := tmpl.Render(context) if output != expected { t.Fatalf("testSectionPartial expected %q got %q", expected, output) } } +func TestRecursivePartial(t *testing.T) { + type nodeData map[string]interface{} + type node []nodeData + + nodes := node{nodeData{ + "contents": "1", + "children": node{nodeData{ + "contents": "2", + "children": node{nodeData{ + "contents": "3", + "children": nil, + }}, + }}, + }, + nodeData{"contents": "4", + "children": node{nodeData{ + "contents": "5", + "children": node{nodeData{ + "contents": "6", + "children": nil, + }}, + }}, + }} + + filename := path.Join("tests", "test4.mustache") + + tmpl := NewTemplate() + err := tmpl.SetPartial("node", tmpl) + if err != nil { + t.Fatalf("Error setting partial: %v\n", err) + } + err = tmpl.ParseFile(filename) + if err != nil { + t.Fatalf("Error parsing %v: %v\n", err) + } + + top := NewTemplate() + err = top.SetPartial("node", tmpl) + if err != nil { + t.Fatalf("Error setting partial: %v\n", err) + } + top.ParseString("