Skip to content

Commit

Permalink
add azure document intelligence extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Sep 4, 2024
1 parent e93971a commit 57ce063
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 0 deletions.
14 changes: 14 additions & 0 deletions config/config_extractors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/adrianliechti/llama/pkg/extractor"
"github.com/adrianliechti/llama/pkg/extractor/azure"
"github.com/adrianliechti/llama/pkg/extractor/code"
"github.com/adrianliechti/llama/pkg/extractor/tesseract"
"github.com/adrianliechti/llama/pkg/extractor/text"
Expand Down Expand Up @@ -70,6 +71,9 @@ func createExtractor(cfg extractorConfig) (extractor.Provider, error) {
case "code":
return codeExtractor(cfg)

case "azure":
return azureExtractor(cfg)

case "tesseract":
return tesseractExtractor(cfg)

Expand Down Expand Up @@ -112,6 +116,16 @@ func codeExtractor(cfg extractorConfig) (extractor.Provider, error) {
return code.New(options...)
}

func azureExtractor(cfg extractorConfig) (extractor.Provider, error) {
var options []azure.Option

if cfg.Token != "" {
options = append(options, azure.WithToken(cfg.Token))
}

return azure.New(cfg.URL, options...)
}

func tesseractExtractor(cfg extractorConfig) (extractor.Provider, error) {
var options []tesseract.Option

Expand Down
182 changes: 182 additions & 0 deletions pkg/extractor/azure/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package azure

import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/url"
"strings"
"time"

"github.com/adrianliechti/llama/pkg/extractor"
"github.com/adrianliechti/llama/pkg/text"
)

var _ extractor.Provider = &Client{}

type Client struct {
url string
token string

client *http.Client

chunkSize int
chunkOverlap int
}

type Option func(*Client)

func New(url string, options ...Option) (*Client, error) {
if url == "" {
return nil, errors.New("invalid url")
}

c := &Client{
url: url,

client: http.DefaultClient,

chunkSize: 4000,
chunkOverlap: 200,
}

for _, option := range options {
option(c)
}

return c, nil
}

func WithClient(client *http.Client) Option {
return func(c *Client) {
c.client = client
}
}

func WithToken(token string) Option {
return func(c *Client) {
c.token = token
}
}

func WithChunkSize(size int) Option {
return func(c *Client) {
c.chunkSize = size
}
}

func WithChunkOverlap(overlap int) Option {
return func(c *Client) {
c.chunkOverlap = overlap
}
}

func (c *Client) Extract(ctx context.Context, input extractor.File, options *extractor.ExtractOptions) (*extractor.Document, error) {
if options == nil {
options = &extractor.ExtractOptions{}
}

u, _ := url.Parse(strings.TrimRight(c.url, "/") + "/documentintelligence/documentModels/prebuilt-layout:analyze")

query := u.Query()
query.Set("api-version", "2024-07-31-preview")

u.RawQuery = query.Encode()

req, _ := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), input.Content)
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("Ocp-Apim-Subscription-Key", c.token)

resp, err := c.client.Do(req)

if err != nil {
return nil, err
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusAccepted {
return nil, convertError(resp)
}

operationURL := resp.Header.Get("Operation-Location")

if operationURL == "" {
return nil, errors.New("missing operation location")
}

var operation AnalyzeOperation

for {
req, _ := http.NewRequestWithContext(ctx, "GET", operationURL, nil)
req.Header.Set("Ocp-Apim-Subscription-Key", c.token)

resp, err := c.client.Do(req)

if err != nil {
return nil, err
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, convertError(resp)
}

if err := json.NewDecoder(resp.Body).Decode(&operation); err != nil {
return nil, err
}

if operation.Status == OperationStatusRunning || operation.Status == OperationStatusNotStarted {
time.Sleep(5 * time.Second)
continue
}

if operation.Status != OperationStatusSucceeded {
return nil, errors.New("operation " + string(operation.Status))
}

output, err := convertAnalyzeResult(operation.Result, c.chunkSize, c.chunkOverlap)

if err != nil {
return nil, err
}

return output, nil
}
}

func convertError(resp *http.Response) error {
data, _ := io.ReadAll(resp.Body)

if len(data) == 0 {
return errors.New(http.StatusText(resp.StatusCode))
}

return errors.New(string(data))
}

func convertAnalyzeResult(response AnalyzeResult, chunkSize, chunkOverlap int) (*extractor.Document, error) {
result := extractor.Document{}

content := text.Normalize(response.Content)

splitter := text.NewSplitter()
splitter.ChunkSize = chunkSize
splitter.ChunkOverlap = chunkOverlap

blocks := splitter.Split(content)

for _, b := range blocks {
block := extractor.Block{
//ID: fmt.Sprintf("%s#%d", input.Name, i+1),
Content: b,
}

result.Blocks = append(result.Blocks, block)
}

return &result, nil
}
21 changes: 21 additions & 0 deletions pkg/extractor/azure/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package azure

type OperationStatus string

const (
OperationStatusSucceeded OperationStatus = "succeeded"
OperationStatusRunning OperationStatus = "running"
OperationStatusNotStarted OperationStatus = "notStarted"
)

type AnalyzeOperation struct {
Status OperationStatus `json:"status"`

Result AnalyzeResult `json:"analyzeResult"`
}

type AnalyzeResult struct {
ModelID string `json:"modelId"`

Content string `json:"content"`
}

0 comments on commit 57ce063

Please sign in to comment.