Skip to content

Commit

Permalink
* Override more methods in the PhotonicsState class
Browse files Browse the repository at this point in the history
* Add a test for state retrieval with kernel that accepts arguments
  • Loading branch information
khalatepradnya committed Sep 23, 2024
1 parent 85ee7b0 commit c61a21a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/runtime/cudaq/algorithms/py_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ state pyGetStateRemote(py::object kernel, py::args args) {
}

state pyGetStateLibraryMode(py::object kernel, py::args args) {

cudaq::info("Size of arguments = {}", args.size());

/// TODO: Pack / unpack arguments
return details::extractState([&]() mutable {
if (0 == args.size())
Expand Down
5 changes: 4 additions & 1 deletion python/tests/handlers/test_photonics_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def kernel():

state = cudaq.get_state(kernel)
state.dump()
# TODO: Add check for 'state' object
assert 4 == state.__len__()


def test_qudit_list():
Expand Down Expand Up @@ -93,6 +93,9 @@ def kernel(theta: float):

result = cudaq.sample(kernel, 0.5)
result.dump()

state = cudaq.get_state(kernel, 0.5)
state.dump()


def test_target_change():
Expand Down
35 changes: 24 additions & 11 deletions runtime/cudaq/qis/managers/photonics/PhotonicsExecutionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ struct PhotonicsState : public cudaq::SimulationState {
PhotonicsState(qpp::ket &&data, std::size_t lvl)
: state(std::move(data)), levels(lvl) {}

/// TODO: Rename the API to be generic
std::size_t getNumQubits() const override {
throw "not supported for this photonics simulator";
return (std::log2(state.size()) / std::log2(levels));
}

std::complex<double> overlap(const cudaq::SimulationState &other) override {
Expand All @@ -39,7 +40,11 @@ struct PhotonicsState : public cudaq::SimulationState {

std::complex<double>
getAmplitude(const std::vector<int> &basisState) override {
/// TODO: Check basisState.size() matches qudit count
if (getNumQubits() != basisState.size())
throw std::runtime_error(fmt::format(
"[photonics] getAmplitude with an invalid number of bits in the "
"basis state: expected {}, provided {}.",
getNumQubits(), basisState.size()));

// Convert the basis state to an index value
const std::size_t idx = std::accumulate(
Expand All @@ -50,27 +55,35 @@ struct PhotonicsState : public cudaq::SimulationState {
}

Tensor getTensor(std::size_t tensorIdx = 0) const override {
throw "not supported for this photonics simulator";
if (tensorIdx != 0)
throw std::runtime_error("[photonics] invalid tensor requested.");
return Tensor{
reinterpret_cast<void *>(
const_cast<std::complex<double> *>(state.data())),
std::vector<std::size_t>{static_cast<std::size_t>(state.size())},
getPrecision()};
}

std::vector<Tensor> getTensors() const override {
throw "not supported for this photonics simulator";
}
// /// @brief Return all tensors that represent this state
std::vector<Tensor> getTensors() const override { return {getTensor()}; }

std::size_t getNumTensors() const override {
throw "not supported for this photonics simulator";
}
// /// @brief Return the number of tensors that represent this state.
std::size_t getNumTensors() const override { return 1; }

std::complex<double>
operator()(std::size_t tensorIdx,
const std::vector<std::size_t> &indices) override {
throw "not supported for this photonics simulator";
if (tensorIdx != 0)
throw std::runtime_error("[photonics] invalid tensor requested.");
if (indices.size() != 1)
throw std::runtime_error("[photonics] invalid element extraction.");

return state[indices[0]];
}

std::unique_ptr<SimulationState>
createFromSizeAndPtr(std::size_t size, void *ptr, std::size_t) override {
throw "not supported for this photonics simulator";
;
}

void dump(std::ostream &os) const override { os << state << "\n"; }
Expand Down

0 comments on commit c61a21a

Please sign in to comment.