Skip to content

Commit

Permalink
intern valtrees
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukas Markeffsky committed Feb 8, 2025
1 parent 79f82ad commit e252733
Show file tree
Hide file tree
Showing 14 changed files with 151 additions and 95 deletions.
26 changes: 12 additions & 14 deletions compiler/rustc_const_eval/src/const_eval/valtrees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rustc_abi::{BackendRepr, VariantIdx};
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_middle::mir::interpret::{EvalToValTreeResult, GlobalId, ReportedErrorInfo};
use rustc_middle::ty::layout::{LayoutCx, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::{bug, mir};
use rustc_span::DUMMY_SP;
use tracing::{debug, instrument, trace};
Expand All @@ -25,11 +25,13 @@ fn branches<'tcx>(
variant: Option<VariantIdx>,
num_nodes: &mut usize,
) -> ValTreeCreationResult<'tcx> {
let tcx = *ecx.tcx;
let place = match variant {
Some(variant) => ecx.project_downcast(place, variant).unwrap(),
None => place.clone(),
};
let variant = variant.map(|variant| Some(ty::ValTree::Leaf(ScalarInt::from(variant.as_u32()))));
let variant =
variant.map(|variant| Some(ty::ValTree::from_scalar_int(tcx, variant.as_u32().into())));
debug!(?place, ?variant);

let mut fields = Vec::with_capacity(n);
Expand All @@ -52,7 +54,7 @@ fn branches<'tcx>(
*num_nodes += 1;
}

Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(branches)))
Ok(ty::ValTree::from_branches(tcx, branches))
}

#[instrument(skip(ecx), level = "debug")]
Expand All @@ -70,7 +72,7 @@ fn slice_branches<'tcx>(
elems.push(valtree);
}

Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(elems)))
Ok(ty::ValTree::from_branches(*ecx.tcx, elems))
}

#[instrument(skip(ecx), level = "debug")]
Expand All @@ -79,6 +81,7 @@ fn const_to_valtree_inner<'tcx>(
place: &MPlaceTy<'tcx>,
num_nodes: &mut usize,
) -> ValTreeCreationResult<'tcx> {
let tcx = *ecx.tcx;
let ty = place.layout.ty;
debug!("ty kind: {:?}", ty.kind());

Expand All @@ -89,14 +92,14 @@ fn const_to_valtree_inner<'tcx>(
match ty.kind() {
ty::FnDef(..) => {
*num_nodes += 1;
Ok(ty::ValTree::zst())
Ok(ty::ValTree::zst(tcx))
}
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
let val = ecx.read_immediate(place).unwrap();
let val = val.to_scalar_int().unwrap();
*num_nodes += 1;

Ok(ty::ValTree::Leaf(val))
Ok(ty::ValTree::from_scalar_int(tcx, val))
}

ty::Pat(base, ..) => {
Expand Down Expand Up @@ -127,7 +130,7 @@ fn const_to_valtree_inner<'tcx>(
return Err(ValTreeCreationError::NonSupportedType(ty));
};
// It's just a ScalarInt!
Ok(ty::ValTree::Leaf(val))
Ok(ty::ValTree::from_scalar_int(tcx, val))
}

// Technically we could allow function pointers (represented as `ty::Instance`), but this is not guaranteed to
Expand Down Expand Up @@ -287,16 +290,11 @@ pub fn valtree_to_const_value<'tcx>(
// FIXME: Does this need an example?
match *cv.ty.kind() {
ty::FnDef(..) => {
assert!(cv.valtree.unwrap_branch().is_empty());
assert!(cv.valtree.is_zst());
mir::ConstValue::ZeroSized
}
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char | ty::RawPtr(_, _) => {
match cv.valtree {
ty::ValTree::Leaf(scalar_int) => mir::ConstValue::Scalar(Scalar::Int(scalar_int)),
ty::ValTree::Branch(_) => bug!(
"ValTrees for Bool, Int, Uint, Float, Char or RawPtr should have the form ValTree::Leaf"
),
}
mir::ConstValue::Scalar(Scalar::Int(cv.valtree.unwrap_leaf()))
}
ty::Pat(ty, _) => {
let cv = ty::Value { valtree: cv.valtree, ty };
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ macro_rules! arena_types {
[] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem,
[] ordered_name_set: rustc_data_structures::fx::FxIndexSet<rustc_span::Symbol>,
[] pats: rustc_middle::ty::PatternKind<'tcx>,
[] valtree: rustc_middle::ty::ValTreeKind<'tcx>,

// Note that this deliberately duplicates items in the `rustc_hir::arena`,
// since we need to allocate this type on both the `rustc_hir` arena
Expand Down
15 changes: 9 additions & 6 deletions compiler/rustc_middle/src/ty/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ty::Pattern<'tcx> {
}
}

impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ty::ValTree<'tcx> {
fn encode(&self, e: &mut E) {
self.0.0.encode(e);
}
}

impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ConstAllocation<'tcx> {
fn encode(&self, e: &mut E) {
self.inner().encode(e)
Expand Down Expand Up @@ -356,12 +362,9 @@ impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> Decodable<D> for ty::Pattern<'tcx> {
}
}

impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D> for [ty::ValTree<'tcx>] {
fn decode(decoder: &mut D) -> &'tcx Self {
decoder
.interner()
.arena
.alloc_from_iter((0..decoder.read_usize()).map(|_| Decodable::decode(decoder)))
impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> Decodable<D> for ty::ValTree<'tcx> {
fn decode(decoder: &mut D) -> Self {
decoder.interner().mk_valtree(Decodable::decode(decoder))
}
}

Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_middle/src/ty/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub type ConstKind<'tcx> = ir::ConstKind<TyCtxt<'tcx>>;
pub type UnevaluatedConst<'tcx> = ir::UnevaluatedConst<TyCtxt<'tcx>>;

#[cfg(target_pointer_width = "64")]
rustc_data_structures::static_assert_size!(ConstKind<'_>, 32);
rustc_data_structures::static_assert_size!(ConstKind<'_>, 24);

#[derive(Copy, Clone, PartialEq, Eq, Hash, HashStable)]
#[rustc_pass_by_value]
Expand Down Expand Up @@ -190,15 +190,15 @@ impl<'tcx> Const<'tcx> {
.size;
ty::Const::new_value(
tcx,
ty::ValTree::from_scalar_int(ScalarInt::try_from_uint(bits, size).unwrap()),
ty::ValTree::from_scalar_int(tcx, ScalarInt::try_from_uint(bits, size).unwrap()),
ty,
)
}

#[inline]
/// Creates an interned zst constant.
pub fn zero_sized(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
ty::Const::new_value(tcx, ty::ValTree::zst(), ty)
ty::Const::new_value(tcx, ty::ValTree::zst(tcx), ty)
}

#[inline]
Expand Down
92 changes: 67 additions & 25 deletions compiler/rustc_middle/src/ty/consts/valtree.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::fmt;
use std::ops::Deref;

use rustc_data_structures::intern::Interned;
use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};

use super::ScalarInt;
Expand All @@ -16,9 +20,9 @@ use crate::ty::{self, Ty, TyCtxt};
///
/// `ValTree` does not have this problem with representation, as it only contains integers or
/// lists of (nested) `ValTree`.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[derive(HashStable, TyEncodable, TyDecodable)]
pub enum ValTree<'tcx> {
pub enum ValTreeKind<'tcx> {
/// integers, `bool`, `char` are represented as scalars.
/// See the `ScalarInt` documentation for how `ScalarInt` guarantees that equal values
/// of these types have the same representation.
Expand All @@ -33,50 +37,88 @@ pub enum ValTree<'tcx> {
/// the fields of the variant.
///
/// ZST types are represented as an empty slice.
Branch(&'tcx [ValTree<'tcx>]),
Branch(Box<[ValTree<'tcx>]>),
}

impl<'tcx> ValTree<'tcx> {
pub fn zst() -> Self {
Self::Branch(&[])
}

impl<'tcx> ValTreeKind<'tcx> {
#[inline]
pub fn unwrap_leaf(self) -> ScalarInt {
pub fn unwrap_leaf(&self) -> ScalarInt {
match self {
Self::Leaf(s) => s,
Self::Leaf(s) => *s,
_ => bug!("expected leaf, got {:?}", self),
}
}

#[inline]
pub fn unwrap_branch(self) -> &'tcx [Self] {
pub fn unwrap_branch(&self) -> &[ValTree<'tcx>] {
match self {
Self::Branch(branch) => branch,
Self::Branch(branch) => &**branch,
_ => bug!("expected branch, got {:?}", self),
}
}

pub fn from_raw_bytes<'a>(tcx: TyCtxt<'tcx>, bytes: &'a [u8]) -> Self {
let branches = bytes.iter().map(|b| Self::Leaf(ScalarInt::from(*b)));
let interned = tcx.arena.alloc_from_iter(branches);
pub fn try_to_scalar(&self) -> Option<Scalar> {
self.try_to_scalar_int().map(Scalar::Int)
}

Self::Branch(interned)
pub fn try_to_scalar_int(&self) -> Option<ScalarInt> {
match self {
Self::Leaf(s) => Some(*s),
Self::Branch(_) => None,
}
}

pub fn from_scalar_int(i: ScalarInt) -> Self {
Self::Leaf(i)
pub fn try_to_branch(&self) -> Option<&[ValTree<'tcx>]> {
match self {
Self::Branch(branch) => Some(&**branch),
Self::Leaf(_) => None,
}
}
}

pub fn try_to_scalar(self) -> Option<Scalar> {
self.try_to_scalar_int().map(Scalar::Int)
/// An interned valtree. Use this rather than `ValTreeKind`, whenever possible.
///
/// See the docs of [`ValTreeKind`] or the [dev guide] for an explanation of this type.
///
/// [dev guide]: https://rustc-dev-guide.rust-lang.org/mir/index.html#valtrees
#[derive(Copy, Clone, Hash, Eq, PartialEq)]
#[derive(HashStable)]
pub struct ValTree<'tcx>(pub(crate) Interned<'tcx, ValTreeKind<'tcx>>);

impl<'tcx> ValTree<'tcx> {
pub fn zst(tcx: TyCtxt<'tcx>) -> Self {
tcx.consts.valtree_zst
}

pub fn try_to_scalar_int(self) -> Option<ScalarInt> {
match self {
Self::Leaf(s) => Some(s),
Self::Branch(_) => None,
}
pub fn is_zst(self) -> bool {
matches!(*self, ValTreeKind::Branch(box []))
}

pub fn from_raw_bytes(tcx: TyCtxt<'tcx>, bytes: &[u8]) -> Self {
let branches = bytes.iter().map(|b| tcx.mk_valtree(ValTreeKind::Leaf(ScalarInt::from(*b))));
Self::from_branches(tcx, branches)
}

pub fn from_branches(tcx: TyCtxt<'tcx>, branches: impl IntoIterator<Item = Self>) -> Self {
tcx.mk_valtree(ValTreeKind::Branch(branches.into_iter().collect()))
}

pub fn from_scalar_int(tcx: TyCtxt<'tcx>, i: ScalarInt) -> Self {
tcx.mk_valtree(ValTreeKind::Leaf(i))
}
}

impl<'tcx> Deref for ValTree<'tcx> {
type Target = &'tcx ValTreeKind<'tcx>;

fn deref(&self) -> &&'tcx ValTreeKind<'tcx> {
&self.0.0
}
}

impl fmt::Debug for ValTree<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}

Expand Down
24 changes: 20 additions & 4 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ use crate::ty::{
GenericArgsRef, GenericParamDefKind, List, ListWithCachedTypeInfo, ParamConst, ParamTy,
Pattern, PatternKind, PolyExistentialPredicate, PolyFnSig, Predicate, PredicateKind,
PredicatePolarity, Region, RegionKind, ReprOptions, TraitObjectVisitor, Ty, TyKind, TyVid,
Visibility,
ValTree, ValTreeKind, Visibility,
};

#[allow(rustc::usage_of_ty_tykind)]
Expand Down Expand Up @@ -798,6 +798,7 @@ pub struct CtxtInterners<'tcx> {
local_def_ids: InternedSet<'tcx, List<LocalDefId>>,
captures: InternedSet<'tcx, List<&'tcx ty::CapturedPlace<'tcx>>>,
offset_of: InternedSet<'tcx, List<(VariantIdx, FieldIdx)>>,
valtree: InternedSet<'tcx, ty::ValTreeKind<'tcx>>,
}

impl<'tcx> CtxtInterners<'tcx> {
Expand Down Expand Up @@ -827,6 +828,7 @@ impl<'tcx> CtxtInterners<'tcx> {
local_def_ids: Default::default(),
captures: Default::default(),
offset_of: Default::default(),
valtree: Default::default(),
}
}

Expand Down Expand Up @@ -1018,6 +1020,8 @@ pub struct CommonConsts<'tcx> {
pub unit: Const<'tcx>,
pub true_: Const<'tcx>,
pub false_: Const<'tcx>,
/// Use [`ty::ValTree::zst`] instead.
pub(crate) valtree_zst: ValTree<'tcx>,
}

impl<'tcx> CommonTypes<'tcx> {
Expand Down Expand Up @@ -1118,19 +1122,30 @@ impl<'tcx> CommonConsts<'tcx> {
)
};

let mk_valtree = |v| {
ty::ValTree(Interned::new_unchecked(
interners.valtree.intern(v, |v| InternedInSet(interners.arena.alloc(v))).0,
))
};

let valtree_zst = mk_valtree(ty::ValTreeKind::Branch(Box::default()));
let valtree_true = mk_valtree(ty::ValTreeKind::Leaf(ty::ScalarInt::TRUE));
let valtree_false = mk_valtree(ty::ValTreeKind::Leaf(ty::ScalarInt::FALSE));

CommonConsts {
unit: mk_const(ty::ConstKind::Value(ty::Value {
ty: types.unit,
valtree: ty::ValTree::zst(),
valtree: valtree_zst,
})),
true_: mk_const(ty::ConstKind::Value(ty::Value {
ty: types.bool,
valtree: ty::ValTree::Leaf(ty::ScalarInt::TRUE),
valtree: valtree_true,
})),
false_: mk_const(ty::ConstKind::Value(ty::Value {
ty: types.bool,
valtree: ty::ValTree::Leaf(ty::ScalarInt::FALSE),
valtree: valtree_false,
})),
valtree_zst,
}
}
}
Expand Down Expand Up @@ -2533,6 +2548,7 @@ direct_interners! {
ExternalConstraints -> ExternalConstraints<'tcx>,
predefined_opaques_in_body: pub mk_predefined_opaques_in_body(PredefinedOpaquesData<TyCtxt<'tcx>>):
PredefinedOpaques -> PredefinedOpaques<'tcx>,
valtree: pub mk_valtree(ValTreeKind<'tcx>): ValTree -> ValTree<'tcx>,
}

macro_rules! slice_interners {
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ pub use self::closure::{
place_to_string_for_capture,
};
pub use self::consts::{
Const, ConstInt, ConstKind, Expr, ExprKind, ScalarInt, UnevaluatedConst, ValTree, Value,
Const, ConstInt, ConstKind, Expr, ExprKind, ScalarInt, UnevaluatedConst, ValTree, ValTreeKind,
Value,
};
pub use self::context::{
CtxtInterners, CurrentGcx, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift, TyCtxt,
Expand Down
Loading

0 comments on commit e252733

Please sign in to comment.