Skip to content

Commit

Permalink
Merge pull request jdkato#77 from itzamna314/master
Browse files Browse the repository at this point in the history
Adds ModelFromFS
  • Loading branch information
jdkato authored Sep 21, 2021
2 parents 6d09853 + ddaf0b7 commit a376476
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 19 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/jdkato/prose/v3

go 1.13
go 1.16

require (
github.com/neurosnap/sentences v1.0.6 // indirect
Expand Down
65 changes: 52 additions & 13 deletions model.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package prose

import (
"io"
"io/fs"
"os"
"path/filepath"
)
Expand Down Expand Up @@ -60,12 +62,45 @@ func ModelFromData(name string, sources ...DataSource) *Model {

// ModelFromDisk loads a Model from the user-provided location.
func ModelFromDisk(path string) *Model {
name, classifier := loadClassifier(path)
filesys := os.DirFS(path)
return &Model{
Name: filepath.Base(path),

extracter: loadClassifier(filesys),
tagger: newPerceptronTagger(),
}
}

// ModelFromFS loads a model from the
func ModelFromFS(name string, filesys fs.FS) *Model {
// Locate a folder matching name within filesys
var modelFS fs.FS
err := fs.WalkDir(filesys, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}

// Model located. Exit tree traversal
if d.Name() == name {
modelFS, err = fs.Sub(filesys, path)
if err != nil {
return err
}
return io.EOF
}

return nil
})
if err != io.EOF {
checkError(err)
}

return &Model{
Name: name,

extracter: classifier,
tagger: newPerceptronTagger()}
extracter: loadClassifier(modelFS),
tagger: newPerceptronTagger(),
}
}

// Write saves a Model to the user-provided location.
Expand Down Expand Up @@ -96,24 +131,28 @@ func loadTagger(path string) *perceptronTagger {
return newTrainedPerceptronTagger(model)
}*/

func loadClassifier(path string) (string, *entityExtracter) {
func loadClassifier(filesys fs.FS) *entityExtracter {
var mapping map[string]int
var weights []float64
var labels []string

loc := filepath.Join(path, "Maxent")
dec := getDiskAsset(filepath.Join(loc, "mapping.gob"))
checkError(dec.Decode(&mapping))
maxent, err := fs.Sub(filesys, "Maxent")
checkError(err)

file, err := maxent.Open("mapping.gob")
checkError(err)
checkError(getDiskAsset(file).Decode(&mapping))

dec = getDiskAsset(filepath.Join(loc, "weights.gob"))
checkError(dec.Decode(&weights))
file, err = maxent.Open("weights.gob")
checkError(err)
checkError(getDiskAsset(file).Decode(&weights))

dec = getDiskAsset(filepath.Join(loc, "labels.gob"))
checkError(dec.Decode(&labels))
file, err = maxent.Open("labels.gob")
checkError(err)
checkError(getDiskAsset(file).Decode(&labels))

model := newMaxentClassifier(weights, mapping, labels)
name := filepath.Base(path)
return name, newTrainedEntityExtracter(model)
return newTrainedEntityExtracter(model)
}

func defaultModel(tagging, classifying bool) *Model {
Expand Down
39 changes: 39 additions & 0 deletions model_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package prose

import (
"embed"
"io/fs"
"os"
"path/filepath"
"testing"
Expand All @@ -26,3 +28,40 @@ func TestModelFromDisk(t *testing.T) {
t.Errorf("ModelFromDisk() expected = temp, got = %v", model.Name)
}
}

//go:embed testdata/PRODUCT
var embeddedModel embed.FS

func TestModelFromFS(t *testing.T) {
err := fs.WalkDir(embeddedModel, ".", func(path string, d fs.DirEntry, err error) error {
//fmt.Printf("Walking dir %s, err %s\n", path, err)
return nil
})

// Load the embedded PRODUCT model
model := ModelFromFS("PRODUCT", embeddedModel)
if model.Name != "PRODUCT" {
t.Errorf("ModelFromFS() expected = PRODUCT, got = %v", model.Name)
}

doc, err := NewDocument("Windows 10 is an operating system",
UsingModel(model))

if err != nil {
t.Errorf("Failed to create doc with ModelFromFS")
}

ents := doc.Entities()

if len(ents) != 1 {
t.Fatalf("Expected 1 entity, got %v", ents)
}

if ents[0].Text != "Windows 10" {
t.Errorf("Expected to find entity 'Windows 10' with ModelFromFS, got = %v", ents[0].Text)
}

if ents[0].Label != "PRODUCT" {
t.Errorf("Expected to tab entity with PRODUCT, got = %v", ents[0].Label)
}
}
8 changes: 3 additions & 5 deletions utilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package prose
import (
"bytes"
"encoding/gob"
"os"
"io/fs"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -46,10 +46,8 @@ func getAsset(folder, name string) *gob.Decoder {
return gob.NewDecoder(bytes.NewReader(b))
}

func getDiskAsset(path string) *gob.Decoder {
f, err := os.Open(path)
checkError(err)
return gob.NewDecoder(f)
func getDiskAsset(file fs.File) *gob.Decoder {
return gob.NewDecoder(file)
}

func hasAnyPrefix(s string, prefixes []string) bool {
Expand Down

0 comments on commit a376476

Please sign in to comment.