Skip to content

Commit

Permalink
color transform for singleband and rgb
Browse files Browse the repository at this point in the history
  • Loading branch information
biserhong committed Jan 2, 2025
1 parent 21f2db2 commit 4577289
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 12 deletions.
12 changes: 6 additions & 6 deletions terracotta/handlers/rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,18 @@ def get_band_future(band_key: str) -> Future:
out_ranges.append(band_stretch_range)
out_arrays.append(band_data)

out = np.ma.stack(out_arrays, axis=0)
band_data = np.ma.stack(out_arrays, axis=0)

if color_transform:
band_stretch_range_arr = [np.array(band_rng, dtype=band_data.dtype) for band_rng in out_ranges]
band_stretch_range_arr = np.ma.stack(band_stretch_range_arr, axis=0)
out_ranges = [np.array(band_rng, dtype=band_data.dtype) for band_rng in out_ranges]
out_ranges = np.ma.stack(out_ranges, axis=0)

band_stretch_range_arr = image.apply_color_transform(band_stretch_range_arr, color_transform)
band_data = image.apply_color_transform(out, color_transform)
out_ranges = image.apply_color_transform(out_ranges, color_transform, band_range)
band_data = image.apply_color_transform(band_data, color_transform, band_range)

out_arrays = []
for k in range(band_data.shape[0]):
out_arrays.append(image.to_uint8(band_data[k], *band_stretch_range_arr[k]))
out_arrays.append(image.to_uint8(band_data[k], *out_ranges[k]))

out = np.ma.stack(out_arrays, axis=-1)
return image.array_to_png(out)
11 changes: 11 additions & 0 deletions terracotta/handlers/singleband.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import collections

import numpy as np

from terracotta import get_settings, get_driver, image, xyz
from terracotta.profile import trace

Expand All @@ -26,6 +28,7 @@ def singleband(
*,
colormap: Union[str, Mapping[Number, RGBA], None] = None,
stretch_range: Optional[Tuple[NumberOrString, NumberOrString]] = None,
color_transform: Optional[str] = None,
tile_size: Optional[Tuple[int, int]] = None
) -> BinaryIO:
"""Return singleband image as PNG"""
Expand Down Expand Up @@ -73,6 +76,14 @@ def singleband(

cmap_or_palette = cast(Optional[str], colormap)

if color_transform:
stretch_range_ = np.array(stretch_range_, dtype=tile_data.dtype)
stretch_range_ = np.ma.stack(stretch_range_, axis=0)

stretch_range_ = image.apply_color_transform(stretch_range_, color_transform, band_range)
tile_data = np.expand_dims(tile_data, axis=0)
tile_data = image.apply_color_transform(tile_data, color_transform, band_range)[0]

out = image.to_uint8(tile_data, *stretch_range_)

return image.array_to_png(out, colormap=cmap_or_palette)
12 changes: 9 additions & 3 deletions terracotta/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,21 @@ def to_uint8(data: Array, lower_bound: Number, upper_bound: Number) -> Array:
def apply_color_transform(
masked_data: Array,
color_transform: str,
out_dtype: type = np.uint16,
band_range: list,
) -> Array:
"""Apply gamma correction to the input array and scale it to the output dtype."""
arr = to_math_type(masked_data)

if band_range:
arr = contrast_stretch(masked_data, band_range, (0, 1))
elif np.issubdtype(masked_data.dtype, np.integer):
arr = to_math_type(masked_data)
else:
raise exceptions.InvalidArgumentsError("No band range given and array is not of integer type")


for func in parse_operations(color_transform):
arr = func(arr)

arr = scale_dtype(arr, out_dtype)
return arr


Expand Down
18 changes: 18 additions & 0 deletions terracotta/server/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from typing import Any, Union

from color_operations import parse_operations


class StringOrNumber(fields.Field):
"""
Expand Down Expand Up @@ -45,3 +47,19 @@ def validate_stretch_range(data: Any) -> None:
if isinstance(data, str):
if not re.match("^p\\d+$", data):
raise ValidationError("Percentile format is `p<digits>`")


def validate_color_transform(data: Any) -> None:
"""
Validate that the color transform is a string and can be parsed by `color_operations`.
"""
if not isinstance(data, str):
raise ValidationError("Color transform needs to be a string")

if "saturation" in data:
raise ValidationError("Saturation is currently not supported")

try:
parse_operations(data)
except (ValueError, KeyError):
raise ValidationError("Invalid color transform")
5 changes: 3 additions & 2 deletions terracotta/server/rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from marshmallow import Schema, fields, validate, pre_load, ValidationError, EXCLUDE
from flask import request, send_file, Response

from terracotta.server.fields import StringOrNumber, validate_stretch_range
from terracotta.server.fields import StringOrNumber, validate_stretch_range, validate_color_transform
from terracotta.server.flask_api import TILE_API


Expand Down Expand Up @@ -69,8 +69,9 @@ class Meta:
),
)
color_transform = fields.String(
validate=validate_color_transform,
missing=None,
description="Gamma factor to perform gamma correction."
description="Color transform DSL string from color-operations.",
)
tile_size = fields.List(
fields.Integer(),
Expand Down
10 changes: 9 additions & 1 deletion terracotta/server/singleband.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from flask import request, send_file, Response

from terracotta.server.fields import StringOrNumber, validate_stretch_range
from terracotta.server.fields import StringOrNumber, validate_stretch_range, validate_color_transform
from terracotta.server.flask_api import TILE_API
from terracotta.cmaps import AVAILABLE_CMAPS

Expand Down Expand Up @@ -65,6 +65,14 @@ class Meta:
"hex strings.",
)

color_transform = fields.String(
validate=validate_color_transform,
missing=None,
example="gamma 1 1.5, sigmoidal 1 15 0.5",
description="Color transform DSL string from color-operations."
"All color operations for singleband should specify band 1.",
)

tile_size = fields.List(
fields.Integer(),
validate=validate.Length(equal=2),
Expand Down

0 comments on commit 4577289

Please sign in to comment.