From 71050865df44bb87be7a406da6a8396fbabed818 Mon Sep 17 00:00:00 2001 From: Tobias Wood Date: Wed, 28 Feb 2024 18:27:07 +0000 Subject: [PATCH] Remove Tile/TileToMatch --- python/riesling/plot.py | 2 +- src/cmd/graphics.cpp | 2 +- src/cmd/sense-calib.cpp | 10 +--------- src/cmd/sense-sim.cpp | 5 ++--- src/sense/sense.cpp | 39 +++++++++++++++++++++++---------------- src/sense/sense.hpp | 2 +- src/tensorOps.hpp | 25 ------------------------- test/op/sense.cpp | 3 +-- 8 files changed, 30 insertions(+), 58 deletions(-) diff --git a/python/riesling/plot.py b/python/riesling/plot.py index dd67f0e1..3df54b2a 100644 --- a/python/riesling/plot.py +++ b/python/riesling/plot.py @@ -65,7 +65,7 @@ def _comp(data, component, cmap, clim, climp): else: raise(f'Unknown component {component}') - if not clim: + if clim is None: if climp is None: climp = (2, 99) if component == 'mag': diff --git a/src/cmd/graphics.cpp b/src/cmd/graphics.cpp index c169a715..82bc5901 100644 --- a/src/cmd/graphics.cpp +++ b/src/cmd/graphics.cpp @@ -120,7 +120,7 @@ int main_graphics(args::Subparser &parser) args::Flag grey(parser, "G", "Greyscale", {"grey", 'g'}); args::Flag log(parser, "L", "Logarithmic intensity", {"log", 'l'}); - args::ValueFlag slN(parser, "N", "Number of slices", {"num", 'n'}, 0); + args::ValueFlag slN(parser, "N", "Number of slices (0 for all)", {"num", 'n'}, 8); args::ValueFlag slStart(parser, "S", "Start slice", {"start"}, 0); args::ValueFlag slEnd(parser, "S", "End slice", {"end"}); args::ValueFlag slDim(parser, "S", "Slice dimension (0/1/2)", {"dim"}, 0); diff --git a/src/cmd/sense-calib.cpp b/src/cmd/sense-calib.cpp index 66232b41..670bed79 100644 --- a/src/cmd/sense-calib.cpp +++ b/src/cmd/sense-calib.cpp @@ -22,15 +22,7 @@ int main_sense_calib(args::Subparser &parser) auto noncart = reader.readTensor(HD5::Keys::Noncartesian); traj.checkDims(FirstN<3>(noncart.dimensions())); auto const basis = ReadBasis(coreOpts.basisFile.Get()); - Sz3 const shape = traj.matrixForFOV(senseOpts.fov.Get()); - Cx5 const channels = SENSE::LoresChannels(senseOpts, coreOpts, traj, noncart, basis); - for (Index ii = 0; ii < 3; ii++) { - if (shape[ii] > channels.dimension(ii + 2)) { - Log::Fail("Requested SENSE FOV {} could not be satisfied with FOV {} and oversampling {}", - senseOpts.fov.Get().transpose(), traj.FOV().transpose(), coreOpts.osamp.Get()); - } - } - auto maps = SENSE::UniformNoise(senseOpts.λ.Get(), shape, channels); + auto maps = SENSE::UniformNoise(senseOpts.λ.Get(), SENSE::LoresChannels(senseOpts, coreOpts, traj, noncart, basis)); if (frame) { if (frame.Get() < 0 || frame.Get() >= maps.dimension(1)) { Log::Fail("Requested frame {} is outside valid range 0-{}", frame.Get(), maps.dimension(1)); diff --git a/src/cmd/sense-sim.cpp b/src/cmd/sense-sim.cpp index bd0ee7a0..f70ed001 100644 --- a/src/cmd/sense-sim.cpp +++ b/src/cmd/sense-sim.cpp @@ -28,9 +28,8 @@ int main_sense_sim(args::Subparser &parser) birdcage(shape, Eigen::Array3f::Constant(voxel_size.Get()), nchan.Get(), coil_rings.Get(), coil_r.Get(), coil_r.Get()); // Normalize - Cx3 rss(shape); - rss.device(Threads::GlobalDevice()) = ConjugateSum(sense, sense).sqrt(); - sense /= Tile(rss, nchan.Get()); + sense /= ConjugateSum(sense, sense).sqrt().reshape(AddFront(shape, 1, 1)).broadcast(Sz5{nchan.Get(), 1, 1, 1, 1}); + Log::Print("lolwut"); auto const fname = OutName("", iname.Get(), "sense", "h5"); HD5::Writer writer(fname); writer.writeTensor(HD5::Keys::SENSE, sense.dimensions(), sense.data(), HD5::Dims::SENSE); diff --git a/src/sense/sense.cpp b/src/sense/sense.cpp index 9f392ded..cbe9f159 100644 --- a/src/sense/sense.cpp +++ b/src/sense/sense.cpp @@ -46,21 +46,35 @@ auto LoresChannels(Opts &opts, CoreOpts &coreOpts, Trajectory const &inTraj, Cx5 auto const maxCoord = Maximum(NoNaNs(traj.points()).abs()); NoncartesianTukey(maxCoord * 0.75, maxCoord, 0.f, traj.points(), lores); Cx5 const channels(Tensorfy(lsmr.run(lores.data()), nufft->ishape)); - return channels; + + Sz3 const shape = traj.matrixForFOV(opts.fov.Get()); + for (Index ii = 0; ii < 3; ii++) { + if (shape[ii] > channels.dimension(ii + 2)) { + Log::Fail("Requested SENSE FOV {} could not be satisfied with FOV {} and oversampling {}", opts.fov.Get().transpose(), + traj.FOV().transpose(), coreOpts.osamp.Get()); + } + } + + Cx5 const cropped = Crop(channels, AddFront(shape, channels.dimension(0), channels.dimension(1))); + + return cropped; } -auto UniformNoise(float const λ, Sz3 const shape, Cx5 const &channels) -> Cx5 +auto UniformNoise(float const λ, Cx5 const &channels) -> Cx5 { - Cx5 cropped = Crop(channels, AddFront(shape, channels.dimension(0), channels.dimension(1))); - Cx4 rss(LastN<4>(cropped.dimensions())); - rss.device(Threads::GlobalDevice()) = ConjugateSum(cropped, cropped).sqrt(); + Sz5 const shape = channels.dimensions(); + Index const nC = shape[0]; + Cx4 rss(LastN<4>(shape)); + rss.device(Threads::GlobalDevice()) = ConjugateSum(channels, channels).sqrt(); if (λ > 0.f) { Log::Print("SENSE λ {}", λ); rss.device(Threads::GlobalDevice()) = rss + rss.constant(λ); } - Log::Debug("Normalizing channel images"); - cropped.device(Threads::GlobalDevice()) = cropped / TileToMatch(rss, cropped.dimensions()); - return cropped; + Log::Debug("Normalizing {} channel images", nC); + Cx5 normalized(shape); + normalized.device(Threads::GlobalDevice()) = + channels / rss.reshape(AddFront(LastN<4>(shape), 1)).broadcast(Sz5{nC, 1, 1, 1, 1}); + return normalized; } auto Choose(Opts &opts, CoreOpts &core, Trajectory const &traj, Cx5 const &noncart) -> Cx5 @@ -68,14 +82,7 @@ auto Choose(Opts &opts, CoreOpts &core, Trajectory const &traj, Cx5 const &nonca Sz3 const shape = traj.matrixForFOV(opts.fov.Get()); if (opts.type.Get() == "auto") { Log::Print("SENSE Self-Calibration"); - auto const channels = LoresChannels(opts, core, traj, noncart); - for (Index ii = 0; ii < 3; ii++) { - if (shape[ii] > channels.dimension(ii + 2)) { - Log::Fail("Requested SENSE FOV {} could not be satisfied with FOV {} and oversampling {}", opts.fov.Get().transpose(), - traj.FOV().transpose(), core.osamp.Get()); - } - } - return UniformNoise(opts.λ.Get(), shape, channels); + return UniformNoise(opts.λ.Get(), LoresChannels(opts, core, traj, noncart)); } else if (opts.type.Get() == "espirit") { Log::Fail("Not supported right now"); // auto channels = LoresChannels(opts, core, traj, noncart); diff --git a/src/sense/sense.hpp b/src/sense/sense.hpp index abd37a2c..b50de3c1 100644 --- a/src/sense/sense.hpp +++ b/src/sense/sense.hpp @@ -25,7 +25,7 @@ auto LoresChannels( Opts &opts, CoreOpts &coreOpts, Trajectory const &inTraj, Cx5 const &noncart, Basis const &basis = IdBasis()) -> Cx5; //! Normalizes by RSS with optional regularization -auto UniformNoise(float const λ, Sz3 const shape, Cx5 const &channels) -> Cx5; +auto UniformNoise(float const λ, Cx5 const &channels) -> Cx5; //! Convenience function called from recon commands to get SENSE maps auto Choose(Opts &opts, CoreOpts &core, Trajectory const &t, Cx5 const &noncart) -> Cx5; diff --git a/src/tensorOps.hpp b/src/tensorOps.hpp index 058bef48..2095d9fe 100644 --- a/src/tensorOps.hpp +++ b/src/tensorOps.hpp @@ -91,31 +91,6 @@ inline decltype(auto) FirstToLast4(T const &x) return x.shuffle(indices); } -template -inline decltype(auto) Tile(T &&x, Index const N) -{ - Eigen::IndexList, int, int, int> res; - res.set(1, x.dimension(0)); - res.set(2, x.dimension(1)); - res.set(3, x.dimension(2)); - Eigen::IndexList, Eigen::type2index<1>, Eigen::type2index<1>> brd; - brd.set(0, N); - return x.reshape(res).broadcast(brd); -} - -template -inline decltype(auto) TileToMatch(T &&x, U const &dims) -{ - using FixedOne = Eigen::type2index<1>; - Eigen::IndexList res; - res.set(1, dims[1]); - res.set(2, dims[2]); - res.set(3, dims[3]); - Eigen::IndexList brd; - brd.set(0, dims[0]); - return x.reshape(res).broadcast(brd); -} - template inline decltype(auto) Contract(T1 const &a, T2 const &b) { diff --git a/test/op/sense.cpp b/test/op/sense.cpp index f6778d25..f9ab3955 100644 --- a/test/op/sense.cpp +++ b/test/op/sense.cpp @@ -20,8 +20,7 @@ TEST_CASE("ops-sense") u.setRandom(); // The maps need to be normalized for the Dot test maps.setRandom(); - Cx4 const rss = ConjugateSum(maps, maps).sqrt(); - maps = maps / Tile(rss, channels); + maps = maps / ConjugateSum(maps, maps).sqrt().reshape(Sz5{1, 1, mapSz, mapSz, mapSz}).broadcast(Sz5{channels, 1, 1, 1, 1}); SenseOp sense(maps, 1); y = sense.forward(u);