Skip to content

Commit

Permalink
Merge pull request #26 from lgrcia/obliquity-inference
Browse files Browse the repository at this point in the history
 add star obliquity for inference + switch to uv
  • Loading branch information
lgrcia authored Feb 3, 2025
2 parents a1bb3f3 + f39d0ad commit 7c6670f
Show file tree
Hide file tree
Showing 15 changed files with 794 additions and 621 deletions.
42 changes: 22 additions & 20 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ on:
jobs:
tests:
runs-on: ${{ matrix.os }}
needs: ["build"]
strategy:
fail-fast: false
matrix:
Expand All @@ -39,47 +40,48 @@ jobs:
fetch-depth: 0
submodules: true

- name: "Init: Python"
uses: actions/setup-python@v5
- name: Install uv and set the python version
uses: astral-sh/setup-uv@v5
with:
python-version: ${{ matrix.python-version }}

- name: "Install: dependencies"
run: |
python -m pip install -U pip
python -m pip install -U nox
- name: Install the project
run: uv sync --extra test --extra comparison

- name: "Tests: run"
run: |
python -m nox --non-interactive --error-on-missing-interpreter \
--session "${{matrix.session}}-${{matrix.python-version}}"
- name: Run tests
run: uv run pytest tests

build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
name: Install Python

- name: Install uv and set the python version
uses: astral-sh/setup-uv@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install -U build twine
- name: Build the distribution
run: python -m build .

- name: Install the project
run: uv sync --dev

- name: Build the project
run: uv build

- name: Check the distribution
run: python -m twine check --strict dist/*
run: |
uv tool install twine
uvx twine check --strict dist/*
- uses: actions/upload-artifact@v4
with:
path: dist/*

publish:
environment:
name: pypi
url: https://pypi.org/p/starspotter
url: https://pypi.org/p/spotter
permissions:
id-token: write
needs: [tests, build]
Expand Down
6 changes: 1 addition & 5 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
To install *spotter* from pypi

```bash
pip install starspotter
```

```{danger}
`starspotter` not `spotter` ! Because the name wasn't available...
pip install spotter
```

As *spotter* is still under development, we recommend installing the latest version from the GitHub repository. To do so, clone the repository and install the package using pip:
Expand Down
27 changes: 12 additions & 15 deletions docs/notebooks/ensemble.ipynb

Large diffs are not rendered by default.

366 changes: 134 additions & 232 deletions docs/notebooks/introduction.ipynb

Large diffs are not rendered by default.

514 changes: 257 additions & 257 deletions docs/notebooks/rotation.ipynb

Large diffs are not rendered by default.

216 changes: 216 additions & 0 deletions docs/notebooks/spot_crossing.ipynb

Large diffs are not rendered by default.

68 changes: 33 additions & 35 deletions docs/notebooks/surface_gp.ipynb

Large diffs are not rendered by default.

27 changes: 0 additions & 27 deletions noxfile.py

This file was deleted.

21 changes: 6 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
[project]
name = "starspotter"
version = "0.0.2"
description = "Stellar contamination estimates from rotational light curves"
name = "spotter"
version = "0.0.3"
description = "Forward models of fluxes and spectra time-series of non-uniform stars."
authors = [{ name = "Lionel Garcia" }, { name = "Benjamin Rackham" }]
license = "MIT"
readme = "readme.md"
requires-python = ">=3.9"
packages = [{ include = "spotter" }]
requires-python = ">=3.10"
dependencies = ["numpy", "healpy", "jax", "jaxlib", "equinox", "tinygp"]

[project.optional-dependencies]
dev = ["black", "pytest", "nox"]
test = ["pytest", "pytest-xdist"]
comparison = [
"jaxoplanet@git+https://github.com/exoplanet-dev/jaxoplanet#feat-starry-out-of-experimental",
]
compare_starry = [
"starry",
"exoplanet-core",
"numpy<1.22",
"xarray<2023.10.0",
"tqdm",
]
comparison = ["jaxoplanet>=0.0.3"]
docs = [
"matplotlib",
"sphinx",
Expand All @@ -34,6 +24,7 @@ docs = [
"toml",
"ipywidgets",
"sphinx-autoapi<3.2.0",
"ipykernel",
]

[build-system]
Expand Down
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# spotter

<p align="center">
<img src="docs/_static/spotter.png" width="270">
<img src="https://spotter.readthedocs.io/en/latest/_static/spotter.png" width="270">
</p>

<p align="center">
Expand Down
1 change: 0 additions & 1 deletion spotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def render(y, inc=None, u=None, phase=0.0, obl=0.0):


def amplitude(N_or_y, inc=None, u=None, undersampling: int = 3) -> callable:

N, _ = _N_or_Y_to_N_n(N_or_y)
resolution = hp.nside2resol(N)
X = vec(N)
Expand Down
52 changes: 49 additions & 3 deletions spotter/light_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from jax.typing import ArrayLike

from spotter import core
from spotter import core, utils
from spotter.star import Star, transited_star


Expand Down Expand Up @@ -70,8 +70,46 @@ def impl(star, time):
)


def transit_design_matrix(star, x, y, z, r, time=None):
X = design_matrix(star, time)

from jax.scipy.spatial.transform import Rotation

_z, _y, _x = core.vec(star.sides).T
v = jnp.stack((_x, _y, _z), axis=-1)

phase = star.phase(time)
_rv = Rotation.from_rotvec([phase, 0.0, 0.0]).apply(v)
rv = jnp.where(phase == 0.0, v, _rv)

inc_angle = -jnp.pi / 2 + star.inc if star.inc is not None else 0.0
_inc_angle = jnp.where(inc_angle == 0.0, 1.0, inc_angle)
_rv = Rotation.from_rotvec([0.0, _inc_angle, 0.0]).apply(rv)
rv = jnp.where(inc_angle == 0.0, rv, _rv)

if star.obl is not None:
obl_angle = jnp.where(star.obl == 0.0, 1.0, star.obl)
_rv = Rotation.from_rotvec([0.0, 0.0, obl_angle]).apply(rv)
rv = jnp.where(obl_angle == 0.0, rv, _rv)

_x, _y, _ = rv.T

distance = jnp.linalg.norm(
jnp.array([_x, _y]) - jnp.array([x, -y])[:, None], axis=0
)

transited_y = utils.sigmoid(distance - r, 1000.0)

return X * jnp.where(z >= 0, transited_y, jnp.ones_like(transited_y))


def transit_light_curve(
star: Star, x: float = 0.0, y: float = 0.0, r: float = 0.0, time: float = 0.0
star: Star,
x: float = 0.0,
y: float = 0.0,
z: float = 0.0,
r: float = 0.0,
time: float = 0.0,
):
"""Light curve of a transited Star. The x-axis cross the star in the horizontal direction (→),
and the y-axis cross the star in the vertical up direction (↑).
Expand All @@ -93,4 +131,12 @@ def transit_light_curve(
ArrayLike
Light curve array.
"""
return light_curve(transited_star(star, y, x, r), star.phase(time))

def impl(star, time):
return jnp.einsum(
"ij,ij->i", transit_design_matrix(star, x, y, z, r, time), star.y
)

return (
jnp.vectorize(impl, excluded=(0,), signature="()->(n)")(star, time).T / jnp.pi
)
58 changes: 54 additions & 4 deletions spotter/star.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def __init__(
self.radius = radius if radius is not None else 1.0
self.wv = wv

@property
def N(self):
"""Return the number of sides of the star map."""
return self.sides

@property
def x(self):
"""Return the xyz coordinates of the star pixels."""
Expand Down Expand Up @@ -118,7 +123,9 @@ def from_sides(cls, sides: int, **kwargs):
y = np.ones(core._N_or_Y_to_N_n(sides)[1])
return cls(y, **kwargs)

def phase(self, time: ArrayLike) -> ArrayLike:
def phase(self, time: ArrayLike | None) -> ArrayLike:
if time is None:
return 0.0
return (
2 * jnp.pi * time / self.period
if self.period is not None
Expand Down Expand Up @@ -175,6 +182,31 @@ def set(self, **kwargs):
current.update(kwargs)
return Star(**current)

def spot(self, lat: float, lon: float, radius: float, sharpness: float = 20):
"""Return a healpix map with a spot.
Parameters
----------
lat : float
Latitude of the spot, in radians.
lon : float
Longitude of the spot, in radians.
radius : float
Radius of the spot, in radians.
sharpness : float, optional
Sharpness of the spot, by default 20
Returns
-------
ArrayLike
healpix map with a spot.
"""
return core.spot(self.sides, lat, lon, radius, sharpness=sharpness)

@property
def coords(self):
"""Return the coordinates of the star pixels."""
return core.vec(self.sides)


def show(star: Star, phase: ArrayLike = 0.0, ax=None, **kwargs):
"""Show the star map. If `star.y` is 2D, the first map is shown.
Expand Down Expand Up @@ -214,14 +246,22 @@ def video(star: Star, duration: int = 4, fps: int = 10, **kwargs):
viz.video(
star.y[0],
star.inc if star.inc is not None else np.pi / 2,
star.obl if star.obl is not None else 0.0,
star.u[0] if star.u is not None else None,
duration=duration,
fps=fps,
**kwargs,
)


def transited_star(star: Star, x: float = 0.0, y: float = 0.0, r: float = 0.0):
def transited_star(
star: Star,
x: float = 0.0,
y: float = 0.0,
z: float = 0.0,
r: float = 0.0,
time: float = None,
):
"""Return a star transited by a circular opaque disk
Parameters
Expand All @@ -245,9 +285,16 @@ def transited_star(star: Star, x: float = 0.0, y: float = 0.0, r: float = 0.0):
_z, _y, _x = core.vec(star.sides).T
v = jnp.stack((_x, _y, _z), axis=-1)

if time is not None:
phase = star.phase(time)
_rv = Rotation.from_rotvec([phase, 0.0, 0.0]).apply(v)
rv = jnp.where(phase == 0.0, v, _rv)
else:
rv = v

inc_angle = -jnp.pi / 2 + star.inc if star.inc is not None else 0.0
_inc_angle = jnp.where(inc_angle == 0.0, 1.0, inc_angle)
_rv = Rotation.from_rotvec([0.0, _inc_angle, 0.0]).apply(v)
_rv = Rotation.from_rotvec([0.0, _inc_angle, 0.0]).apply(rv)
rv = jnp.where(inc_angle == 0.0, v, _rv)

if star.obl is not None:
Expand All @@ -260,4 +307,7 @@ def transited_star(star: Star, x: float = 0.0, y: float = 0.0, r: float = 0.0):
distance = jnp.linalg.norm(
jnp.array([_x, _y]) - jnp.array([x, -y])[:, None], axis=0
)
return utils.sigmoid(distance - r, 1000.0) * star

spotted_star = utils.sigmoid(distance - r, 1000.0) * star

return star.set(y=jnp.where(z < 0, star.y, spotted_star.y))
10 changes: 6 additions & 4 deletions spotter/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def show(y, inc=np.pi / 2, obl=0.0, u=None, phase=0.0, ax=None, **kwargs):
graticule(inc, obl, phase, ax=ax)


def video(y, inc=None, u=None, duration=4, fps=10, **kwargs):
def video(y, inc=None, obl=0.0, u=None, duration=4, fps=10, **kwargs):
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython import display
Expand All @@ -152,7 +152,9 @@ def video(y, inc=None, u=None, duration=4, fps=10, **kwargs):
inc = inc or 0.0

fig, ax = plt.subplots(figsize=(3, 3))
im = plt.imshow(core.render(y, inc, u, 0.0), extent=(-1, 1, -1, 1), **kwargs)
im = plt.imshow(
core.render(y, inc, u, 0.0, obl=obl), extent=(-1, 1, -1, 1), **kwargs
)
plt.axis("off")
plt.tight_layout()
ax.set_frame_on(False)
Expand All @@ -162,10 +164,10 @@ def video(y, inc=None, u=None, duration=4, fps=10, **kwargs):
def update(frame):
a = im.get_array()
phase = np.pi * 2 * frame / frames
a = core.render(y, inc, u, phase)
a = core.render(y, inc, u, phase, obl)
for art in list(ax.lines):
art.remove()
graticule(inc, ax=ax, theta=phase, white_contour=False)
graticule(inc, ax=ax, theta=phase, white_contour=False, obl=obl)

im.set_array(a)
return [im]
Expand Down
5 changes: 3 additions & 2 deletions tests/starry_comparison/test_flux.py → tests/test_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
@pytest.mark.parametrize("u", ([], [0.1, 0.4]))
def test_starry(deg, u):
pytest.importorskip("jaxoplanet")
from jaxoplanet.experimental.starry import Surface, Ylm, rotation
from jaxoplanet.experimental.starry.light_curves import surface_light_curve
from jaxoplanet.starry import Surface, Ylm
from jaxoplanet.starry.core import rotation
from jaxoplanet.starry.light_curves import surface_light_curve

y = np.array([1, *(1e-2 * np.random.randn((deg + 1) ** 2 - 1))])

Expand Down

0 comments on commit 7c6670f

Please sign in to comment.