Skip to content

Commit

Permalink
Merge pull request #567 from jinningwang/findidx
Browse files Browse the repository at this point in the history
  • Loading branch information
cuihantao authored Oct 4, 2024
2 parents c7e7304 + 3adc110 commit d534ced
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 42 deletions.
64 changes: 40 additions & 24 deletions andes/core/model/modeldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import logging
from collections import OrderedDict
from typing import Iterable, Sized

import numpy as np
from andes.core.model.modelcache import ModelCache
from andes.core.param import (BaseParam, DataParam, IdxParam, NumParam,
TimerParam)
from andes.shared import pd
from andes.utils.func import validate_keys_values

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -277,7 +277,7 @@ def find_param(self, prop):

return out

def find_idx(self, keys, values, allow_none=False, default=False):
def find_idx(self, keys, values, allow_none=False, default=False, allow_all=False):
"""
Find `idx` of devices whose values match the given pattern.
Expand All @@ -288,49 +288,65 @@ def find_idx(self, keys, values, allow_none=False, default=False):
values : array, array of arrays, Sized
Values for the corresponding key to search for. If keys is a str, values should be an array of
elements. If keys is a list, values should be an array of arrays, each corresponds to the key.
allow_none : bool, Sized
allow_none : bool, Sized, optional
Allow key, value to be not found. Used by groups.
default : bool
default : bool, optional
Default idx to return if not found (missing)
allow_all : bool, optional
If True, returns a list of lists where each nested list contains all the matches for the
corresponding search criteria.
Returns
-------
list
indices of devices
"""
if isinstance(keys, str):
keys = (keys,)
if not isinstance(values, (int, float, str, np.floating)) and not isinstance(values, Iterable):
raise ValueError(f"value must be a string, scalar or an iterable, got {values}")
if len(values) > 0 and not isinstance(values[0], (list, tuple, np.ndarray)):
values = (values,)
Notes
-----
- Only the first match is returned by default.
- If all matches are needed, set `allow_all` to True.
Examples
--------
>>> # Use example case of IEEE 14-bus system with PVD1
>>> ss = andes.load(andes.get_case('ieee14/ieee14_pvd1.xlsx'))
>>> # To find the idx of `PVD1` with `name` of 'PVD1_1' and 'PVD1_2'
>>> ss.PVD1.find_idx(keys='name', values=['PVD1_1', 'PVD1_2'])
[1, 2]
>>> # To find the idx of `PVD1` connected to bus 4
>>> ss.PVD1.find_idx(keys='bus', values=[4])
[1]
elif isinstance(keys, Sized):
if not isinstance(values, Iterable):
raise ValueError(f"value must be an iterable, got {values}")
>>> # To find ALL the idx of `PVD1` with `gammap` equals to 0.1
>>> ss.PVD1.find_idx(keys='gammap', values=[0.1], allow_all=True)
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
if len(values) > 0 and not isinstance(values[0], Iterable):
raise ValueError(f"if keys is an iterable, values must be an iterable of iterables. got {values}")
>>> # To find the idx of `PVD1` with `gammap` equals to 0.1 and `name` of 'PVD1_1'
>>> ss.PVD1.find_idx(keys=['gammap', 'name'], values=[[0.1], ['PVD1_1']])
[1]
"""

if len(keys) != len(values):
raise ValueError("keys and values must have the same length")
keys, values = validate_keys_values(keys, values)

v_attrs = [self.__dict__[key].v for key in keys]

idxes = []
for v_search in zip(*values):
v_idx = None
v_idx = []
for pos, v_attr in enumerate(zip(*v_attrs)):
if all([i == j for i, j in zip(v_search, v_attr)]):
v_idx = self.idx.v[pos]
break
if v_idx is None:
v_idx.append(self.idx.v[pos])
if not v_idx:
if allow_none is False:
raise IndexError(f'{list(keys)}={v_search} not found in {self.class_name}')
else:
v_idx = default
v_idx = [default]

idxes.append(v_idx)
if allow_all:
idxes.append(v_idx)
else:
idxes.append(v_idx[0])

return idxes
73 changes: 57 additions & 16 deletions andes/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from andes.core.service import BackRef
from andes.utils.func import list_flatten
from andes.utils.func import list_flatten, validate_keys_values

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -243,30 +243,71 @@ def set(self, src: str, idx, attr, value):

return True

def find_idx(self, keys, values, allow_none=False, default=None):
def find_idx(self, keys, values, allow_none=False, default=None, allow_all=False):
"""
Find indices of devices that satisfy the given `key=value` condition.
This method iterates over all models in this group.
Parameters
----------
keys : str, array-like, Sized
A string or an array-like of strings containing the names of parameters for the search criteria.
values : array, array of arrays, Sized
Values for the corresponding key to search for. If keys is a str, values should be an array of
elements. If keys is a list, values should be an array of arrays, each corresponding to the key.
allow_none : bool, optional
Allow key, value to be not found. Used by groups. Default is False.
default : bool, optional
Default idx to return if not found (missing). Default is None.
allow_all : bool, optional
Return all matches if set to True. Default is False.
Returns
-------
list
Indices of devices.
"""

keys, values = validate_keys_values(keys, values)

n_mdl, n_pair = len(self.models), len(values[0])

indices_found = []
# `indices_found` contains found indices returned from all models of this group
for model in self.models.values():
indices_found.append(model.find_idx(keys, values, allow_none=True, default=default))

out = []
for idx, idx_found in enumerate(zip(*indices_found)):
if not allow_none:
if idx_found.count(None) == len(idx_found):
missing_values = [item[idx] for item in values]
raise IndexError(f'{list(keys)} = {missing_values} not found in {self.class_name}')

real_idx = default
for item in idx_found:
if item is not None:
real_idx = item
indices_found.append(model.find_idx(keys, values, allow_none=True, default=default, allow_all=True))

# --- find missing pairs ---
i_val_miss = []
for i in range(n_pair):
idx_cross_mdls = [indices_found[j][i] for j in range(n_mdl)]
if all(item == [default] for item in idx_cross_mdls):
i_val_miss.append(i)

if (not allow_none) and i_val_miss:
miss_pairs = []
for i in i_val_miss:
miss_pairs.append([values[j][i] for j in range(len(keys))])
raise IndexError(f'{keys} = {miss_pairs} not found in {self.class_name}')

# --- output ---
out_pre = []
for i in range(n_pair):
idx_cross_mdls = [indices_found[j][i] for j in range(n_mdl)]
if all(item == [default] for item in idx_cross_mdls):
out_pre.append([default])
continue
for item in idx_cross_mdls:
if item != [default]:
out_pre.append(item)
break
out.append(real_idx)

if allow_all:
out = out_pre
else:
out = [item[0] for item in out_pre]

return out

def _check_src(self, src: str):
Expand Down
4 changes: 2 additions & 2 deletions andes/models/misc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def in1d(self, addr, v_code):
"""

if v_code == 'x':
return np.in1d(self.xidx, addr)
return np.isin(self.xidx, addr)
if v_code == 'y':
return np.in1d(self.yidx, addr)
return np.isin(self.yidx, addr)

raise NotImplementedError("v_code <%s> not recognized" % v_code)

Expand Down
48 changes: 48 additions & 0 deletions andes/utils/func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import operator
from typing import Iterable, Sized

from andes.shared import np

Expand Down Expand Up @@ -36,3 +37,50 @@ def interp_n2(t, x, y):
"""

return y[:, 0] + (t - x[0]) * (y[:, 1] - y[:, 0]) / (x[1] - x[0])


def validate_keys_values(keys, values):
"""
Validate the inputs for the func `find_idx`.
Parameters
----------
keys : str, array-like, Sized
A string or an array-like of strings containing the names of parameters for the search criteria.
values : array, array of arrays, Sized
Values for the corresponding key to search for. If keys is a str, values should be an array of
elements. If keys is a list, values should be an array of arrays, each corresponds to the key.
Returns
-------
tuple
Sanitized keys and values
Raises
------
ValueError
If the inputs are not valid.
"""
if isinstance(keys, str):
keys = (keys,)
if not isinstance(values, (int, float, str, np.floating)) and not isinstance(values, Iterable):
raise ValueError(f"value must be a string, scalar or an iterable, got {values}")

if len(values) > 0 and not isinstance(values[0], (list, tuple, np.ndarray)):
values = (values,)

elif isinstance(keys, Sized):
if not isinstance(values, Iterable):
raise ValueError(f"value must be an iterable, got {values}")

if len(values) > 0 and not isinstance(values[0], Iterable):
raise ValueError(f"if keys is an iterable, values must be an iterable of iterables. got {values}")

if len(keys) != len(values):
raise ValueError("keys and values must have the same length")

if isinstance(values[0], Iterable):
if not all([len(val) == len(values[0]) for val in values]):
raise ValueError("All items in values must have the same length")

return keys, values
1 change: 1 addition & 0 deletions docs/source/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ v1.9.3 (2024-04-XX)
- Adjust `BusFreq.Tw.default` to 0.1.
- Add parameter from_csv=None in TDS.run() to allow loading data from CSV files at TDS begining.
- Fix `TDS.init()` and `TDS._csv_step()` to fit loading from CSV when `Output` exists.
- Add parameter `allow_all=False` to `ModelData.find_idx()` `GroupBase.find_idx()` to allow searching all matches.

v1.9.2 (2024-03-25)
-------------------
Expand Down
17 changes: 17 additions & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_group_access(self):
[6, 7, 8, 1])

# --- find_idx ---
# same Model
self.assertListEqual(ss.DG.find_idx('name', ['PVD1_1', 'PVD1_2']),
ss.PVD1.find_idx('name', ['PVD1_1', 'PVD1_2']),
)
Expand All @@ -82,6 +83,22 @@ def test_group_access(self):
[('PVD1_1', 'PVD1_2'),
(1.0, 1.0)]))

# cross Model, given results
self.assertListEqual(ss.StaticGen.find_idx(keys='bus',
values=[1, 2, 3, 4]),
[1, 2, 3, 6])
self.assertListEqual(ss.StaticGen.find_idx(keys='bus',
values=[1, 2, 3, 4],
allow_all=True),
[[1], [2], [3], [6]])

self.assertListEqual(ss.StaticGen.find_idx(keys='bus',
values=[1, 2, 3, 4, 2024],
allow_none=True,
default=2011,
allow_all=True),
[[1], [2], [3], [6], [2011]])

# --- get_field ---
ff = ss.DG.get_field('f', list(ss.DG._idx2model.keys()), 'v_code')
self.assertTrue(any([item == 'y' for item in ff]))
53 changes: 53 additions & 0 deletions tests/test_model_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,56 @@ def test_model_set(self):
ss.GENROU.set("M", np.array(["GENROU_4"]), "v", 6.0)
np.testing.assert_equal(ss.GENROU.M.v[3], 6.0)
self.assertEqual(ss.TDS.Teye[omega_addr[3], omega_addr[3]], 6.0)

def test_find_idx(self):
ss = andes.load(andes.get_case('ieee14/ieee14_pvd1.xlsx'))
mdl = ss.PVD1

# not allow all matches
self.assertListEqual(mdl.find_idx(keys='gammap', values=[0.1], allow_all=False),
[1])

# allow all matches
self.assertListEqual(mdl.find_idx(keys='gammap', values=[0.1], allow_all=True),
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

# multiple values
self.assertListEqual(mdl.find_idx(keys='name', values=['PVD1_1', 'PVD1_2'],
allow_none=False, default=False),
[1, 2])
# non-existing value
self.assertListEqual(mdl.find_idx(keys='name', values=['PVD1_999'],
allow_none=True, default=False),
[False])

# non-existing value is not allowed
with self.assertRaises(IndexError):
mdl.find_idx(keys='name', values=['PVD1_999'],
allow_none=False, default=False)

# multiple keys
self.assertListEqual(mdl.find_idx(keys=['gammap', 'name'],
values=[[0.1, 0.1], ['PVD1_1', 'PVD1_2']]),
[1, 2])

# multiple keys, with non-existing values
self.assertListEqual(mdl.find_idx(keys=['gammap', 'name'],
values=[[0.1, 0.1], ['PVD1_1', 'PVD1_999']],
allow_none=True, default='CURENT'),
[1, 'CURENT'])

# multiple keys, with non-existing values not allowed
with self.assertRaises(IndexError):
mdl.find_idx(keys=['gammap', 'name'],
values=[[0.1, 0.1], ['PVD1_1', 'PVD1_999']],
allow_none=False, default=999)

# multiple keys, values are not iterable
with self.assertRaises(ValueError):
mdl.find_idx(keys=['gammap', 'name'],
values=[0.1, 0.1])

# multiple keys, items length are inconsistent in values
with self.assertRaises(ValueError):
mdl.find_idx(keys=['gammap', 'name'],
values=[[0.1, 0.1], ['PVD1_1']])

0 comments on commit d534ced

Please sign in to comment.