Skip to content

Commit

Permalink
Added methods for tango_signals to automatically infer the datatype i…
Browse files Browse the repository at this point in the history
…f not passed as an argument.
  • Loading branch information
burkeds committed Aug 20, 2024
1 parent 6d41773 commit 648aa5c
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 54 deletions.
39 changes: 36 additions & 3 deletions src/ophyd_async/tango/signal/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from __future__ import annotations

from enum import Enum, IntEnum
from typing import Optional, Type, Union

import numpy.typing as npt

from ophyd_async.core import DEFAULT_TIMEOUT, SignalR, SignalRW, SignalW, SignalX, T
from ophyd_async.tango._backend import TangoTransport
from tango import AttrWriteType, CmdArgType
from ophyd_async.tango._backend._tango_transport import TangoTransport, get_python_type
from tango import AttrDataFormat, AttrWriteType, CmdArgType, DevState
from tango import DeviceProxy as SyncDeviceProxy
from tango.asyncio import DeviceProxy

Expand Down Expand Up @@ -138,14 +141,19 @@ def tango_signal_x(

# --------------------------------------------------------------------
def tango_signal_auto(
datatype: Type[T],
datatype: Optional[Type[T]] = None,
*,
trl: str,
device_proxy: Optional[DeviceProxy] = None,
timeout: float = DEFAULT_TIMEOUT,
name: str = "",
) -> Union[SignalW, SignalX, SignalR, SignalRW]:
device_trl, tr_name = trl.rsplit("/", 1)
syn_proxy = SyncDeviceProxy(device_trl)

if datatype is None:
datatype = infer_python_type(trl)

backend = _make_backend(datatype, trl, trl, device_proxy)

if tr_name not in syn_proxy.get_attribute_list():
Expand All @@ -170,3 +178,28 @@ def tango_signal_auto(

if tr_name in device_proxy.get_pipe_list():
raise NotImplementedError("Pipes are not supported")


# --------------------------------------------------------------------
def infer_python_type(trl: str):
device_trl, tr_name = trl.rsplit("/", 1)
syn_proxy = SyncDeviceProxy(device_trl)

if tr_name in syn_proxy.get_command_list():
config = syn_proxy.get_command_config(tr_name)
isarray, py_type, _ = get_python_type(config.in_type)
elif tr_name in syn_proxy.get_attribute_list():
config = syn_proxy.get_attribute_config(tr_name)
isarray, py_type, _ = get_python_type(config.data_type)
if py_type is Enum:
enum_dict = {label: i for i, label in enumerate(config.enum_labels)}
py_type = IntEnum("TangoEnum", enum_dict)
if config.data_format in [AttrDataFormat.SPECTRUM, AttrDataFormat.IMAGE]:
isarray = True
else:
raise RuntimeError(f"Cannot find {tr_name} in {device_trl}")

if py_type is CmdArgType.DevState:
py_type = DevState

return npt.NDArray[py_type] if isarray else py_type
124 changes: 73 additions & 51 deletions tests/tango/test_tango_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,18 @@ async def test_backend_get_put_monitor_cmd(
# --------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy",
"pv, tango_type, d_format, py_type, initial_value, put_value, use_dtype, use_proxy",
[
(pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy)
(
pv,
tango_type,
d_format,
py_type,
initial_value,
put_value,
use_dtype,
use_proxy,
)
for (
pv,
tango_type,
Expand All @@ -422,9 +431,15 @@ async def test_backend_get_put_monitor_cmd(
initial_value,
put_value,
) in ATTRIBUTES_SET
for use_dtype in [True, False]
for use_proxy in [True, False]
],
ids=[
f"{x[0]}_{use_dtype}_{use_proxy}"
for x in ATTRIBUTES_SET
for use_dtype in [True, False]
for use_proxy in [True, False]
],
ids=[f"{x[0]}_{use_proxy}" for x in ATTRIBUTES_SET for use_proxy in [True, False]],
)
async def test_tango_signal_r(
echo_device: str,
Expand All @@ -434,11 +449,13 @@ async def test_tango_signal_r(
py_type: Type[T],
initial_value: T,
put_value: T,
use_dtype: bool,
use_proxy: bool,
):
await prepare_device(echo_device, pv, initial_value)
source = echo_device + "/" + pv
proxy = await DeviceProxy(echo_device) if use_proxy else None
py_type = py_type if use_dtype else None

timeout = 0.1
signal = tango_signal_r(
Expand All @@ -456,9 +473,18 @@ async def test_tango_signal_r(
# --------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy",
"pv, tango_type, d_format, py_type, initial_value, put_value, use_dtype, use_proxy",
[
(pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy)
(
pv,
tango_type,
d_format,
py_type,
initial_value,
put_value,
use_dtype,
use_proxy,
)
for (
pv,
tango_type,
Expand All @@ -467,9 +493,15 @@ async def test_tango_signal_r(
initial_value,
put_value,
) in ATTRIBUTES_SET
for use_dtype in [True, False]
for use_proxy in [True, False]
],
ids=[
f"{x[0]}_{use_dtype}_{use_proxy}"
for x in ATTRIBUTES_SET
for use_dtype in [True, False]
for use_proxy in [True, False]
],
ids=[f"{x[0]}_{use_proxy}" for x in ATTRIBUTES_SET for use_proxy in [True, False]],
)
async def test_tango_signal_w(
echo_device: str,
Expand All @@ -479,11 +511,13 @@ async def test_tango_signal_w(
py_type: Type[T],
initial_value: T,
put_value: T,
use_dtype: bool,
use_proxy: bool,
):
await prepare_device(echo_device, pv, initial_value)
source = echo_device + "/" + pv
proxy = await DeviceProxy(echo_device) if use_proxy else None
py_type = py_type if use_dtype else None

timeout = 0.1
signal = tango_signal_w(
Expand Down Expand Up @@ -514,9 +548,18 @@ async def test_tango_signal_w(
# --------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy",
"pv, tango_type, d_format, py_type, initial_value, put_value, use_dtype, use_proxy",
[
(pv, tango_type, d_format, py_type, initial_value, put_value, use_proxy)
(
pv,
tango_type,
d_format,
py_type,
initial_value,
put_value,
use_dtype,
use_proxy,
)
for (
pv,
tango_type,
Expand All @@ -525,9 +568,15 @@ async def test_tango_signal_w(
initial_value,
put_value,
) in ATTRIBUTES_SET
for use_dtype in [True, False]
for use_proxy in [True, False]
],
ids=[
f"{x[0]}_{use_dtype}_{use_proxy}"
for x in ATTRIBUTES_SET
for use_dtype in [True, False]
for use_proxy in [True, False]
],
ids=[f"{x[0]}_{use_proxy}" for x in ATTRIBUTES_SET for use_proxy in [True, False]],
)
async def test_tango_signal_rw(
echo_device: str,
Expand All @@ -537,11 +586,13 @@ async def test_tango_signal_rw(
py_type: Type[T],
initial_value: T,
put_value: T,
use_dtype: bool,
use_proxy: bool,
):
await prepare_device(echo_device, pv, initial_value)
source = echo_device + "/" + pv
proxy = await DeviceProxy(echo_device) if use_proxy else None
py_type = py_type if use_dtype else None

timeout = 0.1
signal = tango_signal_rw(
Expand All @@ -561,48 +612,10 @@ async def test_tango_signal_rw(
assert_close(location["readback"], put_value)


# --------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
"pv, tango_type, d_format, py_type, initial_value, put_value",
COMMANDS_SET,
ids=[x[0] for x in COMMANDS_SET],
)
async def test_tango_signal_x(
echo_device: str,
pv: str,
tango_type: str,
d_format: AttrDataFormat,
py_type: Type[T],
initial_value: T,
put_value: T,
):
source = echo_device + "/" + pv
timeout = 0.1
signal = tango_signal_auto(
datatype=py_type,
trl=source,
device_proxy=None,
name="test_signal",
timeout=timeout,
)
await signal.connect()
assert signal
reading = await signal.read()
assert reading["test_signal"]["value"] is None

await signal.set(put_value, wait=True, timeout=0.1)
reading = await signal.read()
value = reading["test_signal"]["value"]
if isinstance(value, np.ndarray):
value = value.tolist()
assert_close(value, put_value)


# --------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize("use_proxy", [True, False])
async def test_tango_signal_x_none(tango_test_device: str, use_proxy: bool):
async def test_tango_signal_x(tango_test_device: str, use_proxy: bool):
proxy = await DeviceProxy(tango_test_device) if use_proxy else None
timeout = 0.1
signal = tango_signal_x(
Expand Down Expand Up @@ -694,8 +707,6 @@ async def _test_signal(dtype, proxy):


# --------------------------------------------------------------------


@pytest.mark.asyncio
@pytest.mark.parametrize(
"pv, tango_type, d_format, py_type, initial_value, put_value, use_dtype, use_proxy",
Expand Down Expand Up @@ -772,7 +783,7 @@ async def _test_signal(dtype, proxy):
# --------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize("use_proxy", [True, False])
async def test_tango_signal_auto_cmds_none(tango_test_device: str, use_proxy: bool):
async def test_tango_signal_auto_cmds_void(tango_test_device: str, use_proxy: bool):
proxy = await DeviceProxy(tango_test_device) if use_proxy else None
signal = tango_signal_auto(
datatype=None,
Expand All @@ -783,3 +794,14 @@ async def test_tango_signal_auto_cmds_none(tango_test_device: str, use_proxy: bo
await signal.connect()
assert signal
await signal.trigger(wait=True)


# --------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tango_signal_auto_badtrl(tango_test_device: str):
with pytest.raises(RuntimeError) as exc_info:
tango_signal_auto(
datatype=None,
trl=tango_test_device + "/" + "badtrl",
)
assert f"Cannot find badtrl in {tango_test_device}" in str(exc_info.value)

0 comments on commit 648aa5c

Please sign in to comment.