-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
66 lines (43 loc) · 1.32 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Helper functions.
author: Aaditya Chandrasekhar ([email protected])
"""
import dataclasses
import numpy as np
import jax.numpy as jnp
@dataclasses.dataclass
class Extent:
min: float
max: float
@property
def range(self)->float:
return self.max - self.min
@property
def center(self)->float:
return 0.5*(self.min + self.max)
def scale(
self,
scale_val:float
) -> 'Extent':
"""Scale the `Extent`."""
return Extent(self.min*scale_val, self.max*scale_val)
def translate(
self,
dx:float,
) -> 'Extent':
"""Translate the `Extent` by `dx`."""
return Extent(self.min + dx, self.max + dx)
def normalize(x: jnp.ndarray, extent: Extent)->jnp.ndarray:
"""Linearly normalize `x` using `extent` ranges."""
return (x - extent.min)/extent.range
def unnormalize(x: jnp.ndarray, extent: Extent)->jnp.ndarray:
"""Recover array from linearly normalized `x` using `extent` ranges."""
return x*extent.range + extent.min
def inverse_sigmoid(y: jnp.ndarray)->jnp.ndarray:
"""The inverse of the sigmoid function.
The sigmoid function f:x->y is defined as:
f(x) = 1 / (1 + exp(-x))
The inverse sigmoid function g: y->x is defined as:
g(y) = ln(y / (1 - y))
For details see https://tinyurl.com/y7mr76hm
"""
return jnp.log(y / (1. - y))