Skip to content

Commit

Permalink
gdal raster calc: clarify expression rewriting impl and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Feb 14, 2025
1 parent 648dfd9 commit ed4b07a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 44 deletions.
11 changes: 9 additions & 2 deletions apps/gdalalg_raster_calc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ struct GDALCalcOptions
bool checkExtent{true};
};

static bool ShouldReplace(const std::string &str, size_t from, size_t to)
static bool MatchIsCompleteVariableNameWithNoIndex(const std::string &str,
size_t from, size_t to)
{
if (to < str.size())
{
// If the character after the end of the match is:
// * alphanumeric or _ : we've matched only part of a variable name
// * [ : we've matched a variable that already has an index
// * ( : we've matched a function name
if (std::isalnum(str[to]) || str[to] == '_' || str[to] == '[' ||
str[to] == '(')
{
Expand All @@ -46,6 +51,8 @@ static bool ShouldReplace(const std::string &str, size_t from, size_t to)
}
if (from > 0)
{
// If the character before the start of the match is alphanumeric or _,
// we've matched only part of a variable name.
if (std::isalnum(str[from - 1]) || str[from - 1] == '_')
{
return false;
Expand Down Expand Up @@ -73,7 +80,7 @@ static std::string SetBandIndices(const std::string &origExpression,
{
auto end = pos + variable.size();

if (ShouldReplace(expression, pos, end))
if (MatchIsCompleteVariableNameWithNoIndex(expression, pos, end))
{
// No index specified for variable
expression = expression.substr(0, pos + variable.size()) + '[' +
Expand Down
66 changes: 24 additions & 42 deletions autotest/utilities/test_gdalalg_raster_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,58 +383,40 @@ def test_gdalalg_raster_calc_error_band_count_mismatch(calc, tmp_vsimem, bands):


@pytest.mark.parametrize(
"sources",
"expr,source,bands,expected",
[
("AX", "A"),
("A", "AX"),
("XA", "A"),
("X", "AX"),
("A_X", "A"),
("A", "A_X"),
("SIN", "S"),
],
)
@pytest.mark.parametrize(
"expr,expected",
[
("SOURCE1 + SOURCE2", ["SOURCE1[1] + SOURCE2[1]", "SOURCE1[1] + SOURCE2[2]"]),
(
"SOURCE1* 2 + SOURCE2",
["SOURCE1[1]* 2 + SOURCE2[1]", "SOURCE1[1]* 2 + SOURCE2[2]"],
),
("SOURCE1 + SOURCE2[2]", ["SOURCE1[1] + SOURCE2[2]"]),
("SOURCE2 + SOURCE1", ["SOURCE2[1] + SOURCE1[1]", "SOURCE2[2] + SOURCE1[1]"]),
("SOURCE2[2] + SOURCE1", ["SOURCE2[2] + SOURCE1[1]"]),
(
"SIN(SOURCE1) + SOURCE2",
["SIN(SOURCE1[1]) + SOURCE2[1]", "SIN(SOURCE1[1]) + SOURCE2[2]"],
),
(
"SUM(SOURCE1,SOURCE2)",
["SUM(SOURCE1[1],SOURCE2[1])", "SUM(SOURCE1[1],SOURCE2[2])"],
),
("aX + 2", "aX", 1, ["aX[1] + 2"]),
("aX + 2", "aX", 2, ["aX[1] + 2", "aX[2] + 2"]),
("aX + 2", "X", 1, ["aX + 2"]),
("aX + 2", "a", 1, ["aX + 2"]),
("2 + aX", "X", 1, ["2 + aX"]),
("2 + aX", "aX", 1, ["2 + aX[1]"]),
("B1 + B10", "B1", 1, ["B1[1] + B10"]),
("B1[1] + B10", "B1", 2, ["B1[1] + B10"]),
("B1[1] + B1", "B1", 2, ["B1[1] + B1[1]", "B1[1] + B1[2]"]),
("SIN(N) + N", "N", 1, ["SIN(N[1]) + N[1]"]),
("SUM(N,N2) + N", "N", 1, ["SUM(N[1],N2) + N[1]"]),
("SUM(N,N2) + N", "N2", 1, ["SUM(N,N2[1]) + N"]),
("A_X + X", "X", 1, ["A_X + X[1]"]),
],
)
def test_gdalalg_raster_calc_expression_rewriting(
calc, tmp_vsimem, sources, expr, expected
calc, tmp_vsimem, expr, source, bands, expected
):
# The expression rewriting isn't exposed to Python, so we
# create an VRT with an expression and a single source, and
# inspect the transformed expression in the VRT XML.
# The transformed expression need not be valid, because we
# don't actually read the VRT in GDAL.

import xml.etree.ElementTree as ET

outfile = tmp_vsimem / "out.vrt"
infile = tmp_vsimem / "input.tif"

inputs = []
for i, source in enumerate(sources):
fname = tmp_vsimem / f"{i}.tif"
with gdal.GetDriverByName("GTiff").Create(
tmp_vsimem / f"{i}.tif", 2, 2, i + 1
) as ds:
ds.GetRasterBand(1).Fill(i)
inputs.append(f"{source}={fname}")

expr = expr.replace(f"SOURCE{i + 1}", source)
expected = [expr.replace(f"SOURCE{i + 1}", source) for expr in expected]
gdal.GetDriverByName("GTiff").Create(infile, 2, 2, bands)

calc["input"] = inputs
calc["input"] = [f"{source}={infile}"]
calc["output"] = outfile
calc["calc"] = [expr]

Expand Down

0 comments on commit ed4b07a

Please sign in to comment.