Skip to content

Commit

Permalink
Parse lsblk data with pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
codefiles committed Nov 7, 2024
1 parent 0370e89 commit 9ef343e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 123 deletions.
1 change: 0 additions & 1 deletion archinstall/lib/disk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
DiskEncryption,
Fido2Device,
LsblkInfo,
CleanType,
get_lsblk_info,
get_all_lsblk_info,
get_lsblk_by_mountpoint,
Expand Down
9 changes: 8 additions & 1 deletion archinstall/lib/disk/device_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,20 @@ def get_btrfs_info(
debug(f'Failed to read btrfs subvolume information: {err}')
return subvol_infos

# It is assumed that lsblk will contain the fields as
# "mountpoints": ["/mnt/archinstall/log", "/mnt/archinstall/home", "/mnt/archinstall", ...]
# "fsroots": ["/@log", "/@home", "/@"...]
# we'll thereby map the fsroot, which are the mounted filesystem roots
# to the corresponding mountpoints
btrfs_subvol_info = dict(zip(lsblk_info.fsroots, lsblk_info.mountpoints))

try:
# ID 256 gen 16 top level 5 path @
for line in result.splitlines():
# expected output format:
# ID 257 gen 8 top level 5 path @home
name = Path(line.split(' ')[-1])
sub_vol_mountpoint = lsblk_info.btrfs_subvol_info.get(name, None)
sub_vol_mountpoint = btrfs_subvol_info.get(name, None)
subvol_infos.append(_BtrfsSubvolumeInfo(name, sub_vol_mountpoint))
except json.decoder.JSONDecodeError as err:
error(f"Could not decode lsblk JSON: {result}")
Expand Down
166 changes: 45 additions & 121 deletions archinstall/lib/disk/device_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from __future__ import annotations

import dataclasses
import json
import math
import uuid
from dataclasses import dataclass, field
from enum import Enum
from enum import auto
from pathlib import Path
from typing import Optional, List, Dict, TYPE_CHECKING, Any
from typing import Union

import parted
from parted import Disk, Geometry, Partition
from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator

from ..exceptions import DiskError, SysCallError
from ..general import SysCommand
Expand Down Expand Up @@ -1292,138 +1291,64 @@ def parse_arg(cls, arg: Dict[str, str]) -> 'Fido2Device':
)


@dataclass
class LsblkInfo:
name: str = ''
path: Path = Path()
pkname: str = ''
size: Size = field(default_factory=lambda: Size(0, Unit.B, SectorSize.default()))
log_sec: int = 0
pttype: str = ''
ptuuid: str = ''
rota: bool = False
tran: Optional[str] = None
partn: Optional[int] = None
partuuid: Optional[str] = None
parttype: Optional[str] = None
uuid: Optional[str] = None
fstype: Optional[str] = None
fsver: Optional[str] = None
fsavail: Optional[str] = None
fsuse_percentage: Optional[str] = None
type: Optional[str] = None
mountpoint: Optional[Path] = None
mountpoints: List[Path] = field(default_factory=list)
fsroots: List[Path] = field(default_factory=list)
children: List[LsblkInfo] = field(default_factory=list)

def json(self) -> Dict[str, Any]:
return {
'name': self.name,
'path': str(self.path),
'pkname': self.pkname,
'size': self.size.format_size(Unit.MiB),
'log_sec': self.log_sec,
'pttype': self.pttype,
'ptuuid': self.ptuuid,
'rota': self.rota,
'tran': self.tran,
'partn': self.partn,
'partuuid': self.partuuid,
'parttype': self.parttype,
'uuid': self.uuid,
'fstype': self.fstype,
'fsver': self.fsver,
'fsavail': self.fsavail,
'fsuse_percentage': self.fsuse_percentage,
'type': self.type,
'mountpoint': str(self.mountpoint) if self.mountpoint else None,
'mountpoints': [str(m) for m in self.mountpoints],
'fsroots': [str(r) for r in self.fsroots],
'children': [c.json() for c in self.children]
}

@property
def btrfs_subvol_info(self) -> Dict[Path, Path]:
"""
It is assumed that lsblk will contain the fields as
"mountpoints": ["/mnt/archinstall/log", "/mnt/archinstall/home", "/mnt/archinstall", ...]
"fsroots": ["/@log", "/@home", "/@"...]
we'll thereby map the fsroot, which are the mounted filesystem roots
to the corresponding mountpoints
"""
return dict(zip(self.fsroots, self.mountpoints))

@classmethod
def exclude(cls) -> List[str]:
return ['children']
class LsblkInfo(BaseModel):
name: str
path: Path
pkname: str | None
log_sec: int = Field(alias='log-sec')
size: Size
pttype: str | None
ptuuid: str | None
rota: bool
tran: str | None
partn: int | None
partuuid: str | None
parttype: str | None
uuid: str | None
fstype: str | None
fsver: str | None
fsavail: int | None
fsuse_percentage: str | None = Field(alias='fsuse%')
type: str
mountpoint: Path | None
mountpoints: list[Path]
fsroots: list[Path]
children: list[LsblkInfo] = Field(default_factory=list)

@field_validator('size', mode='before')
@classmethod
def fields(cls) -> List[str]:
return [f.name for f in dataclasses.fields(LsblkInfo) if f.name not in cls.exclude()]
def convert_size(cls, v: int, info: ValidationInfo) -> Size:
sector_size = SectorSize(info.data['log_sec'], Unit.B)
return Size(v, Unit.B, sector_size)

@field_validator('mountpoints', 'fsroots', mode='before')
@classmethod
def from_json(cls, blockdevice: Dict[str, Any]) -> LsblkInfo:
lsblk_info = cls()

for f in cls.fields():
lsblk_field = _clean_field(f, CleanType.Blockdevice)
data_field = _clean_field(f, CleanType.Dataclass)

val: Any = None
if isinstance(getattr(lsblk_info, data_field), Path):
val = Path(blockdevice[lsblk_field])
elif isinstance(getattr(lsblk_info, data_field), Size):
sector_size = SectorSize(blockdevice['log-sec'], Unit.B)
val = Size(blockdevice[lsblk_field], Unit.B, sector_size)
else:
val = blockdevice[lsblk_field]
def remove_none(cls, v: list[Path | None]) -> list[Path]:
return [item for item in v if item is not None]

setattr(lsblk_info, data_field, val)
@field_serializer('size', when_used='json')
def serialize_size(self, size: Size) -> str:
return size.format_size(Unit.MiB)

lsblk_info.children = [LsblkInfo.from_json(child) for child in blockdevice.get('children', [])]

lsblk_info.mountpoint = Path(lsblk_info.mountpoint) if lsblk_info.mountpoint else None

# sometimes lsblk returns 'mountpoints': [null]
lsblk_info.mountpoints = [Path(mnt) for mnt in lsblk_info.mountpoints if mnt]

fs_roots = []
for r in lsblk_info.fsroots:
if r:
path = Path(r)
# store the fsroot entries without the leading /
fs_roots.append(path.relative_to(path.anchor))
lsblk_info.fsroots = fs_roots

return lsblk_info


class CleanType(Enum):
Blockdevice = auto()
Dataclass = auto()
Lsblk = auto()
@classmethod
def fields(cls) -> list[str]:
return [
field.alias or name
for name, field in cls.model_fields.items()
if name != 'children'
]


def _clean_field(name: str, clean_type: CleanType) -> str:
match clean_type:
case CleanType.Blockdevice:
return name.replace('_percentage', '%').replace('_', '-')
case CleanType.Dataclass:
return name.lower().replace('-', '_').replace('%', '_percentage')
case CleanType.Lsblk:
return name.replace('_percentage', '%').replace('_', '-')
class LsblkOutput(BaseModel):
blockdevices: list[LsblkInfo]


def _fetch_lsblk_info(
dev_path: Optional[Union[Path, str]] = None,
reverse: bool = False,
full_dev_path: bool = False
) -> List[LsblkInfo]:
fields = [_clean_field(f, CleanType.Lsblk) for f in LsblkInfo.fields()]
cmd = ['lsblk', '--json', '--bytes', '--output', ','.join(fields)]
cmd = ['lsblk', '--json', '--bytes', '--output', ','.join(LsblkInfo.fields())]

if reverse:
cmd.append('--inverse')
Expand Down Expand Up @@ -1453,8 +1378,7 @@ def _fetch_lsblk_info(
error(f"Could not decode lsblk JSON:\n{worker.output().decode().rstrip()}")
raise err

blockdevices = data['blockdevices']
return [LsblkInfo.from_json(device) for device in blockdevices]
return LsblkOutput(**data).blockdevices


def get_lsblk_info(
Expand Down

0 comments on commit 9ef343e

Please sign in to comment.