Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add supported fields to Authenticate struct type #62

8 changes: 1 addition & 7 deletions certverify/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,14 @@ type PoolVerifier struct {
}

// NewPoolVerifier creates a new Verifier
func NewPoolVerifier(rootsPEM []byte, intsPEM []byte, keyUsages ...x509.ExtKeyUsage) (*PoolVerifier, error) {
func NewPoolVerifier(rootsPEM []byte, keyUsages ...x509.ExtKeyUsage) (*PoolVerifier, error) {
opts := x509.VerifyOptions{
KeyUsages: keyUsages,
Roots: x509.NewCertPool(),
}
if len(rootsPEM) == 0 || !opts.Roots.AppendCertsFromPEM(rootsPEM) {
return nil, errors.New("could not append root CA(s)")
}
if len(intsPEM) > 0 {
opts.Intermediates = x509.NewCertPool()
if !opts.Intermediates.AppendCertsFromPEM(intsPEM) {
return nil, errors.New("could not append intermediate CA(s)")
}
}
return &PoolVerifier{
verifyOpts: opts,
}, nil
Expand Down
12 changes: 2 additions & 10 deletions cmd/nanomdm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ func main() {
flListen = flag.String("listen", ":9000", "HTTP listen address")
flAPIKey = flag.String("api", "", "API key for API endpoints")
flVersion = flag.Bool("version", false, "print version")
flRootsPath = flag.String("ca", "", "path to PEM CA cert(s)")
flIntsPath = flag.String("intermediate", "", "path to PEM intermediate cert(s)")
flRootsPath = flag.String("ca", "", "path to CA cert for verification")
flWebhook = flag.String("webhook-url", "", "URL to send requests to")
flCertHeader = flag.String("cert-header", "", "HTTP header containing URL-escaped TLS client certificate")
flDebug = flag.Bool("debug", false, "log debug messages")
Expand Down Expand Up @@ -82,14 +81,7 @@ func main() {
if err != nil {
stdlog.Fatal(err)
}
var intsPEM []byte
if *flIntsPath != "" {
intsPEM, err = os.ReadFile(*flIntsPath)
if err != nil {
stdlog.Fatal(err)
}
}
verifier, err := certverify.NewPoolVerifier(caPEM, intsPEM, x509.ExtKeyUsageClientAuth)
verifier, err := certverify.NewPoolVerifier(caPEM, x509.ExtKeyUsageClientAuth)
if err != nil {
stdlog.Fatal(err)
}
Expand Down
10 changes: 2 additions & 8 deletions docs/operations-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,9 @@ API authorization in NanoMDM is simply HTTP Basic authentication using "nanomdm"

### -ca string

* path to PEM CA cert(s)
* Path to CA cert for verification

NanoMDM validates that the device identity certificate is issued from specific CAs. This switch is the path to a file of PEM-encoded CAs to validate enrollments against.

### -intermediate string

* path to PEM intermediate cert(s)

NanoMDM validates that the device identity certificate is issued from specific CAs. This switch is the path to a file of PEM-encoded intermediate certificates that can be used to build a chain of trust to the CAs to validate enrollments against.
NanoMDM validates that the device identity certificate is issued from specific CAs. This switch is the path to a file of PEM-encoded CAs to validate against.

### -cert-header string

Expand Down
22 changes: 2 additions & 20 deletions http/mdm/mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package mdm

import (
"errors"
"fmt"
"net/http"
"strings"

Expand Down Expand Up @@ -38,21 +37,12 @@ func CheckinHandler(svc service.Checkin, logger log.Logger) http.HandlerFunc {
}
respBytes, err := service.CheckinRequest(svc, mdmReqFromHTTPReq(r), bodyBytes)
if err != nil {
logs := []interface{}{"msg", "check-in request"}
logger.Info("msg", "check-in request", "err", err)
httpStatus := http.StatusInternalServerError
var statusErr *service.HTTPStatusError
if errors.As(err, &statusErr) {
httpStatus = statusErr.Status
err = fmt.Errorf("HTTP error: %w", statusErr.Unwrap())
}
// manualy unwrapping the `StatusErr` is not necessary as `errors.As` manually unwraps
var parseErr *mdm.ParseError
if errors.As(err, &parseErr) {
logs = append(logs, "content", string(parseErr.Content))
err = fmt.Errorf("parse error: %w", parseErr.Unwrap())
}
logs = append(logs, "http_status", httpStatus, "err", err)
logger.Info(logs...)
http.Error(w, http.StatusText(httpStatus), httpStatus)
}
w.Write(respBytes)
Expand All @@ -71,20 +61,12 @@ func CommandAndReportResultsHandler(svc service.CommandAndReportResults, logger
}
respBytes, err := service.CommandAndReportResultsRequest(svc, mdmReqFromHTTPReq(r), bodyBytes)
if err != nil {
logs := []interface{}{"msg", "command report results"}
logger.Info("msg", "command report results", "err", err)
httpStatus := http.StatusInternalServerError
var statusErr *service.HTTPStatusError
if errors.As(err, &statusErr) {
httpStatus = statusErr.Status
err = fmt.Errorf("HTTP error: %w", statusErr.Unwrap())
}
var parseErr *mdm.ParseError
if errors.As(err, &parseErr) {
logs = append(logs, "content", string(parseErr.Content))
err = fmt.Errorf("parse error: %w", parseErr.Unwrap())
}
logs = append(logs, "http_status", httpStatus, "err", err)
logger.Info(logs...)
http.Error(w, http.StatusText(httpStatus), httpStatus)
}
w.Write(respBytes)
Expand Down
9 changes: 6 additions & 3 deletions mdm/checkin.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ type Authenticate struct {
Topic string
Raw []byte `plist:"-"` // Original Authenticate XML plist

// Additional fields required in AuthenticateRequest as specified
// in the Apple documentation.
DeviceName string
Model string
ModelName string

// Fields that may be present but are not strictly required for the
// operation of the MDM protocol. Nice-to-haves.
SerialNumber string
Expand Down Expand Up @@ -149,9 +155,6 @@ func (w *checkinUnmarshaller) UnmarshalPlist(f func(interface{}) error) error {
func DecodeCheckin(rawMessage []byte) (message interface{}, err error) {
w := &checkinUnmarshaller{raw: rawMessage}
err = plist.Unmarshal(rawMessage, w)
if err != nil {
err = &ParseError{Err: err, Content: rawMessage}
}
message = w.message
return
}
4 changes: 2 additions & 2 deletions mdm/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func DecodeCommandResults(rawResults []byte) (results *CommandResults, err error
results = new(CommandResults)
err = plist.Unmarshal(rawResults, results)
if err != nil {
return nil, &ParseError{Err: err, Content: rawResults}
return
}
results.Raw = rawResults
if results.Status == "" {
Expand All @@ -58,7 +58,7 @@ func DecodeCommand(rawCommand []byte) (command *Command, err error) {
command = new(Command)
err = plist.Unmarshal(rawCommand, command)
if err != nil {
return nil, &ParseError{Err: err, Content: rawCommand}
return
}
command.Raw = rawCommand
if command.CommandUUID == "" || command.Command.RequestType == "" {
Expand Down
17 changes: 0 additions & 17 deletions mdm/mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"crypto/x509"
"errors"
"fmt"
)

// Enrollment represents the various enrollment-related data sent with requests.
Expand Down Expand Up @@ -60,19 +59,3 @@ func (r *Request) Clone() *Request {
*r2 = *r
return r2
}

// ParseError represents a failure to parse an MDM structure (usually Apple Plist)
type ParseError struct {
Err error
Content []byte
}

// Unwrap returns the underlying error of the ParseError
func (e *ParseError) Unwrap() error {
return e.Err
}

// Error formats the ParseError as a string
func (e *ParseError) Error() string {
return fmt.Sprintf("parse error: %v: raw content: %v", e.Err, string(e.Content))
}
4 changes: 2 additions & 2 deletions storage/mysql/bstoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
func (s *MySQLStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error {
_, err := s.db.ExecContext(
r.Context,
`UPDATE devices SET bootstrap_token_b64 = ?, bootstrap_token_at = CURRENT_TIMESTAMP WHERE id = ? LIMIT 1;`,
`UPDATE nano_devices SET bootstrap_token_b64 = ?, bootstrap_token_at = CURRENT_TIMESTAMP WHERE id = ? LIMIT 1;`,
nullEmptyString(msg.BootstrapToken.BootstrapToken.String()),
r.ID,
)
Expand All @@ -21,7 +21,7 @@ func (s *MySQLStorage) RetrieveBootstrapToken(r *mdm.Request, _ *mdm.GetBootstra
var tokenB64 string
err := s.db.QueryRowContext(
r.Context,
`SELECT bootstrap_token_b64 FROM devices WHERE id = ?;`,
`SELECT bootstrap_token_b64 FROM nano_devices WHERE id = ?;`,
r.ID,
).Scan(&tokenB64)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions storage/mysql/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@ func (s *MySQLStorage) queryRowContextRowExists(ctx context.Context, query strin
func (s *MySQLStorage) EnrollmentHasCertHash(r *mdm.Request, _ string) (bool, error) {
return s.queryRowContextRowExists(
r.Context,
`SELECT COUNT(*) FROM cert_auth_associations WHERE id = ?;`,
`SELECT COUNT(*) FROM nano_cert_auth_associations WHERE id = ?;`,
r.ID,
)
}

func (s *MySQLStorage) HasCertHash(r *mdm.Request, hash string) (bool, error) {
return s.queryRowContextRowExists(
r.Context,
`SELECT COUNT(*) FROM cert_auth_associations WHERE sha256 = ?;`,
`SELECT COUNT(*) FROM nano_cert_auth_associations WHERE sha256 = ?;`,
strings.ToLower(hash),
)
}

func (s *MySQLStorage) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) {
return s.queryRowContextRowExists(
r.Context,
`SELECT COUNT(*) FROM cert_auth_associations WHERE id = ? AND sha256 = ?;`,
`SELECT COUNT(*) FROM nano_cert_auth_associations WHERE id = ? AND sha256 = ?;`,
r.ID, strings.ToLower(hash),
)
}

func (s *MySQLStorage) AssociateCertHash(r *mdm.Request, hash string) error {
_, err := s.db.ExecContext(
r.Context, `
INSERT INTO cert_auth_associations (id, sha256) VALUES (?, ?) AS new
INSERT INTO nano_cert_auth_associations (id, sha256) VALUES (?, ?)
ON DUPLICATE KEY
UPDATE sha256 = new.sha256;`,
UPDATE sha256 = VALUES(sha256);`,
r.ID,
strings.ToLower(hash),
)
Expand Down
4 changes: 2 additions & 2 deletions storage/mysql/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func (s *MySQLStorage) RetrieveMigrationCheckins(ctx context.Context, c chan<- i
// then we should synthesize a TokenUpdate to transfer it over.
deviceRows, err := s.db.QueryContext(
ctx,
`SELECT authenticate, token_update FROM devices;`,
`SELECT authenticate, token_update FROM nano_devices;`,
)
if err != nil {
return err
Expand All @@ -36,7 +36,7 @@ func (s *MySQLStorage) RetrieveMigrationCheckins(ctx context.Context, c chan<- i
}
userRows, err := s.db.QueryContext(
ctx,
`SELECT token_update FROM users;`,
`SELECT token_update FROM nano_users;`,
)
if err != nil {
return err
Expand Down
Loading