diff --git a/Codec/Archive/Tar/Read.hs b/Codec/Archive/Tar/Read.hs index 67816bb..87703dd 100644 --- a/Codec/Archive/Tar/Read.hs +++ b/Codec/Archive/Tar/Read.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE CPP, DeriveDataTypeable, BangPatterns #-} +{-# LANGUAGE BangPatterns #-} ----------------------------------------------------------------------------- -- | -- Module : Codec.Archive.Tar.Read @@ -12,18 +12,22 @@ -- Portability : portable -- ----------------------------------------------------------------------------- -module Codec.Archive.Tar.Read (read, FormatError(..)) where +module Codec.Archive.Tar.Read + ( read + , FormatError(..) + ) where import Codec.Archive.Tar.Types import Data.Char (ord) import Data.Int (Int64) -import Data.Bits (Bits(shiftL)) +import Data.Bits (Bits(shiftL, (.&.), complement)) import Control.Exception (Exception(..)) import Data.Typeable (Typeable) import Control.Applicative import Control.Monad import Control.DeepSeq +import Control.Monad.Trans.State.Lazy import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as BS.Char8 @@ -64,97 +68,128 @@ instance NFData FormatError where -- * The conversion is done lazily. -- read :: LBS.ByteString -> Entries FormatError -read = unfoldEntries getEntry - -getEntry :: LBS.ByteString -> Either FormatError (Maybe (Entry, LBS.ByteString)) -getEntry bs - | BS.length header < 512 = Left TruncatedArchive - - -- Tar files end with at least two blocks of all '0'. Checking this serves - -- two purposes. It checks the format but also forces the tail of the data - -- which is necessary to close the file if it came from a lazily read file. - -- - -- It's tempting to fall into trailer parsing as soon as LBS.head bs == '\0', - -- because, if interpreted as an 'Entry', it means that 'entryTarPath' is an empty - -- string. Yet it's not a concern of this function: parse it as an 'Entry' - -- and let further pipeline such as 'checkEntrySecurity' deal with it. After all, - -- it might be a format extension with unknown semantics. Such somewhat malformed - -- archives do exist in the wild, see https://github.com/haskell/tar/issues/73. - -- - -- Only if an entire block is null, we assume that we are parsing a trailer. - | LBS.all (== 0) (LBS.take 512 bs) = case LBS.splitAt 1024 bs of - (end, trailing) - | LBS.length end /= 1024 -> Left ShortTrailer - | not (LBS.all (== 0) end) -> Left BadTrailer - | not (LBS.all (== 0) trailing) -> Left TrailingJunk - | otherwise -> Right Nothing - - | otherwise = do - - case (chksum_, format_) of +read = evalState (readStreaming getN get) + where + getN :: Int64 -> State LBS.ByteString LBS.ByteString + getN n = do + (pref, st) <- LBS.splitAt n <$> get + put st + pure pref + +readStreaming + :: Monad m + => (Int64 -> m LBS.ByteString) + -> m LBS.ByteString + -> m (Entries FormatError) +readStreaming = (unfoldEntriesM id .) . getEntryStreaming + +getEntryStreaming + :: Monad m + => (Int64 -> m LBS.ByteString) + -> m LBS.ByteString + -> m (Either FormatError (Maybe Entry)) +getEntryStreaming getN getAll = do + header <- getN 512 + if LBS.length header < 512 then pure (Left TruncatedArchive) else do + + -- Tar files end with at least two blocks of all '0'. Checking this serves + -- two purposes. It checks the format but also forces the tail of the data + -- which is necessary to close the file if it came from a lazily read file. + -- + -- It's tempting to fall into trailer parsing as soon as LBS.head bs == '\0', + -- because, if interpreted as an 'Entry', it means that 'entryTarPath' is an empty + -- string. Yet it's not a concern of this function: parse it as an 'Entry' + -- and let further pipeline such as 'checkEntrySecurity' deal with it. After all, + -- it might be a format extension with unknown semantics. Such somewhat malformed + -- archives do exist in the wild, see https://github.com/haskell/tar/issues/73. + -- + -- Only if an entire block is null, we assume that we are parsing a trailer. + if LBS.all (== 0) header then do + nextBlock <- getN 512 + if LBS.length nextBlock < 512 then pure (Left ShortTrailer) + else if LBS.all (== 0) nextBlock then do + remainder <- getAll + pure $ if LBS.all (== 0) remainder then Right Nothing else Left TrailingJunk + else pure (Left BadTrailer) + + else case parseHeader header of + Left err -> pure $ Left err + Right (name, mode, uid, gid, size, mtime, typecode, linkname, format, uname, gname, devmajor, devminor, prefix) -> do + + -- It is crucial to get (size + padding) in one monadic operation + -- and drop padding in a pure. If you get size bytes first, + -- then skip padding, unpacking in constant memory will become impossible. + let paddedSize = (size + 511) .&. complement 511 + paddedContent <- getN paddedSize + let content = LBS.take size paddedContent + + pure $ Right $ Just $ Entry { + entryTarPath = TarPath name prefix, + entryContent = case typecode of + '\0' -> NormalFile content size + '0' -> NormalFile content size + '1' -> HardLink (LinkTarget linkname) + '2' -> SymbolicLink (LinkTarget linkname) + _ | format == V7Format + -> OtherEntryType typecode content size + '3' -> CharacterDevice devmajor devminor + '4' -> BlockDevice devmajor devminor + '5' -> Directory + '6' -> NamedPipe + '7' -> NormalFile content size + _ -> OtherEntryType typecode content size, + entryPermissions = mode, + entryOwnership = Ownership (BS.Char8.unpack uname) + (BS.Char8.unpack gname) uid gid, + entryTime = mtime, + entryFormat = format + } + +parseHeader + :: LBS.ByteString + -> Either FormatError (BS.ByteString, Permissions, Int, Int, Int64, EpochTime, Char, BS.ByteString, Format, BS.ByteString, BS.ByteString, DevMajor, DevMinor, BS.ByteString) +parseHeader header' = do + case (chksum_, format_ magic) of (Right chksum, _ ) | correctChecksum header chksum -> return () (Right _, Right _) -> Left ChecksumIncorrect _ -> Left NotTarFormat - -- These fields are partial, have to check them - format <- format_; mode <- mode_; - uid <- uid_; gid <- gid_; - size <- size_; mtime <- mtime_; - devmajor <- devmajor_; devminor <- devminor_; - - let content = LBS.take size (LBS.drop 512 bs) - padding = (512 - size) `mod` 512 - bs' = LBS.drop (512 + size + padding) bs - - entry = Entry { - entryTarPath = TarPath name prefix, - entryContent = case typecode of - '\0' -> NormalFile content size - '0' -> NormalFile content size - '1' -> HardLink (LinkTarget linkname) - '2' -> SymbolicLink (LinkTarget linkname) - _ | format == V7Format - -> OtherEntryType typecode content size - '3' -> CharacterDevice devmajor devminor - '4' -> BlockDevice devmajor devminor - '5' -> Directory - '6' -> NamedPipe - '7' -> NormalFile content size - _ -> OtherEntryType typecode content size, - entryPermissions = mode, - entryOwnership = Ownership (BS.Char8.unpack uname) - (BS.Char8.unpack gname) uid gid, - entryTime = mtime, - entryFormat = format - } - - return (Just (entry, bs')) + mode <- mode_ + uid <- uid_ + gid <- gid_ + size <- size_ + mtime <- mtime_ + format <- format_ magic + devmajor <- devmajor_ + devminor <- devminor_ + pure (name, mode, uid, gid, size, mtime, typecode, linkname, format, uname, gname, devmajor, devminor, prefix) where - header = LBS.toStrict (LBS.take 512 bs) - - name = getString 0 100 header - mode_ = getOct 100 8 header - uid_ = getOct 108 8 header - gid_ = getOct 116 8 header - size_ = getOct 124 12 header - mtime_ = getOct 136 12 header - chksum_ = getOct 148 8 header - typecode = getByte 156 header - linkname = getString 157 100 header - magic = getChars 257 8 header - uname = getString 265 32 header - gname = getString 297 32 header - devmajor_ = getOct 329 8 header - devminor_ = getOct 337 8 header - prefix = getString 345 155 header --- trailing = getBytes 500 12 header - - format_ - | magic == ustarMagic = return UstarFormat - | magic == gnuMagic = return GnuFormat - | magic == v7Magic = return V7Format - | otherwise = Left UnrecognisedTarFormat + header = LBS.toStrict header' + + name = getString 0 100 header + mode_ = getOct 100 8 header + uid_ = getOct 108 8 header + gid_ = getOct 116 8 header + size_ = getOct 124 12 header + mtime_ = getOct 136 12 header + chksum_ = getOct 148 8 header + typecode = getByte 156 header + linkname = getString 157 100 header + magic = getChars 257 8 header + uname = getString 265 32 header + gname = getString 297 32 header + devmajor_ = getOct 329 8 header + devminor_ = getOct 337 8 header + prefix = getString 345 155 header + -- trailing = getBytes 500 12 header + +format_ :: BS.ByteString -> Either FormatError Format +format_ magic + | magic == ustarMagic = return UstarFormat + | magic == gnuMagic = return GnuFormat + | magic == v7Magic = return V7Format + | otherwise = Left UnrecognisedTarFormat v7Magic, ustarMagic, gnuMagic :: BS.ByteString v7Magic = BS.Char8.pack "\0\0\0\0\0\0\0\0" diff --git a/Codec/Archive/Tar/Types.hs b/Codec/Archive/Tar/Types.hs index 0571393..69d43b5 100644 --- a/Codec/Archive/Tar/Types.hs +++ b/Codec/Archive/Tar/Types.hs @@ -65,6 +65,7 @@ module Codec.Archive.Tar.Types ( foldEntries, foldlEntries, unfoldEntries, + unfoldEntriesM, ) where import Data.Int (Int64) @@ -604,6 +605,21 @@ unfoldEntries f = unfold Right Nothing -> Done Right (Just (e, x')) -> Next e (unfold x') +unfoldEntriesM + :: Monad m + => (forall a. m a -> m a) + -- ^ id or unsafeInterleaveIO + -> m (Either e (Maybe (GenEntry tarPath linkTarget))) + -> m (GenEntries tarPath linkTarget e) +unfoldEntriesM interleave f = unfold + where + unfold = do + f' <- f + case f' of + Left err -> pure $ Fail err + Right Nothing -> pure Done + Right (Just e) -> Next e <$> interleave unfold + -- | This is like the standard 'foldr' function on lists, but for 'Entries'. -- Compared to 'foldr' it takes an extra function to account for the -- possibility of failure. diff --git a/tar.cabal b/tar.cabal index 243ac5b..80cb496 100644 --- a/tar.cabal +++ b/tar.cabal @@ -56,6 +56,7 @@ library tar-internal directory >= 1.3.1 && < 1.4, filepath < 1.6, time < 1.13, + transformers < 0.7, exposed-modules: Codec.Archive.Tar