Skip to content

Commit

Permalink
Make it work with automatic optimization and jit!
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 4, 2024
1 parent a52972c commit 8feece6
Showing 1 changed file with 62 additions and 117 deletions.
179 changes: 62 additions & 117 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,35 +103,40 @@ def to_channels_last[T: jax.Array | torch.Tensor](tensor: T) -> T:
class JaxFunction(torch.autograd.Function):
params_treedef: ClassVar

@staticmethod
def loss_function(
params: VariableDict,
x: jax.Array,
y: jax.Array,
):
logits = CNN().apply(params, x)
assert isinstance(logits, jax.Array)
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
assert isinstance(loss, jax.Array)
return loss, logits

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
ctx: torch.autograd.function.NestedIOFunction,
x: torch.Tensor,
y: torch.Tensor,
params_treedef: PyTreeDef,
loss_fn: Callable[[VariableDict, jax.Array, jax.Array], tuple[jax.Array, jax.Array]],
loss_value_and_grad_fn: Callable[
[VariableDict, jax.Array, jax.Array], tuple[jax.Array, jax.Array]
],
*params: torch.Tensor,
):
ctx.save_for_backward(x, y, *params)
ctx.params_treedef = params_treedef # type: ignore
jax_x = torch_to_jax(x)
jax_y = torch_to_jax(y)
jax_params = tuple(map(torch_to_jax, params))
jax_params = jax.tree.unflatten(params_treedef, jax_params)
jax_loss, jax_logits = JaxFunction.loss_function(jax_params, x=jax_x, y=jax_y)
loss = jax_to_torch(jax_loss)
logits = jax_to_torch(jax_logits)

needs_grad: tuple[bool, ...] = ctx.needs_input_grad # type: ignore
x_needs_grad, y_needs_grad, _, _, _, *params_need_grad = needs_grad
# todo: broaden a bit:
assert not x_needs_grad
assert not y_needs_grad
if all(params_need_grad):
# We're going to need to do the backward pass, so do it right away and save the grads
# in the context.
(loss, logits), param_grads = loss_value_and_grad_fn(jax_params, jax_x, jax_y)
flattened_param_grads = jax.tree.leaves(param_grads)
torch_grads = tuple(map(jax_to_torch, flattened_param_grads))
ctx.save_for_backward(*torch_grads)
else:
assert not any(params_need_grad)
loss, logits = loss_fn(jax_params, jax_x, jax_y)
loss = jax_to_torch(loss)
logits = jax_to_torch(logits)
return loss, logits

@staticmethod
Expand All @@ -140,33 +145,21 @@ def backward(
grad_loss: torch.Tensor,
grad_logits: torch.Tensor,
):
x: torch.Tensor
params: tuple[torch.Tensor, ...]
x, y, *params = ctx.saved_tensors # type: ignore
params_treedef: PyTreeDef = ctx.params_treedef # type: ignore
jax_x = torch_to_jax(x)
jax_y = torch_to_jax(y)
# jax_grad_output = torch_to_jax(grad_output) # TODO: Can we even pass this to jax.grad?

structured_params = jax.tree.unflatten(params_treedef, params)
jax_params = jax.tree.map(torch_to_jax, structured_params)
x_needs_grad, y_needs_grad, _, _, _, *params_needs_grad = ctx.needs_input_grad
# todo: broaden this a bit in case we need the grad of the input.
# todo: Figure out how to do jax.grad for a function that outputs a matrix or vector.
assert not x_needs_grad
assert not y_needs_grad

grad_input = None
grad_y = None
if all(params_needs_grad):
params_grads = ctx.saved_tensors
else:
assert not any(params_needs_grad)
params_grads = tuple(None for _ in params_needs_grad)

# todo: broaden this a bit in case we need the grad of the input.
assert ctx.needs_input_grad == (
False, # input
False, # y
False, # params_treedef
*(True for _ in params),
), ctx.needs_input_grad
jax_params_grad, logits = jax.grad(JaxFunction.loss_function, argnums=0, has_aux=True)(
jax_params, jax_x, jax_y
)
torch_params_grad = jax.tree.map(jax_to_torch, jax_params_grad)
torch_flat_params_grad = jax.tree.leaves(torch_params_grad)
return grad_input, grad_y, None, *torch_flat_params_grad
return grad_input, grad_y, None, None, None, *params_grads


class JaxAlgorithm(Algorithm):
Expand Down Expand Up @@ -199,46 +192,28 @@ def __init__(
map(operator.methodcaller("clone"), map(jax_to_torch, params_list))
)

# We will do the backward pass ourselves, and PL will only be used to synchronize stuff
# between workers, do logging, etc.
self.automatic_optimization = True

# def on_fit_start(self):
# Setting those here, because otherwise we get pickling errors when running with multiple
# GPUs.

# def loss_fn(params: VariableDict, x: jax.Array, y: jax.Array):
# logits = self.network.apply(params, x)
# assert isinstance(logits, jax.Array)
# loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
# assert isinstance(loss, jax.Array)
# return loss, logits

# self.forward_pass = loss_fn
# self.backward_pass = value_and_grad(self.forward_pass)

# if not self.hp.debug:
# self.forward_pass = jit(self.forward_pass)
# self.backward_pass = jit(self.backward_pass)

# def jax_params(self) -> VariableDict:
# # View the torch parameters as jax Arrays
# jax_parameters = jax.tree.map(torch_to_jax, list(self.parameters()))
# # Reconstruct the original object structure.
# jax_params_tuple = jax.tree.unflatten(self.params_treedef, jax_parameters)
# return jax_params_tuple

# def on_before_batch_transfer(
# self, batch: tuple[torch.Tensor, torch.Tensor], dataloader_idx: int
# ):
# # Convert the batch to jax Arrays.
# x, y = batch
# # Seems like jax likes channels last tensors: jax.from_dlpack doesn't work with
# # channels-first tensors, so we have to do a transpose here.
# x = to_channels_last(x)
# # View the torch inputs as jax Arrays.
# x, y = torch_to_jax(x), torch_to_jax(y)
# return x, y
def on_fit_start(self):
# Setting those here, because otherwise we get pickling errors when running with multiple
# GPUs.
def loss_function(
params: VariableDict,
x: jax.Array,
y: jax.Array,
):
logits = self.network.apply(params, x)
assert isinstance(logits, jax.Array)
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
assert isinstance(loss, jax.Array)
return loss, logits

self.forward_pass = loss_function
self.backward_pass = value_and_grad(self.forward_pass, argnums=0, has_aux=True)

if not self.hp.debug:
self.forward_pass = jit(self.forward_pass)
self.backward_pass = jit(self.backward_pass)

def shared_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr
Expand All @@ -247,52 +222,22 @@ def shared_step(
# Convert the batch to jax Arrays.
# Seems like jax likes channels last tensors: jax.from_dlpack doesn't work with
# channels-first tensors, so we have to do a transpose here.

x = to_channels_last(x)
loss, logits = JaxFunction.apply(x, y, self.params_treedef, *self.parameters()) # type: ignore

loss: torch.Tensor
logits: torch.Tensor
loss, logits = JaxFunction.apply( # type: ignore
x, y, self.params_treedef, self.forward_pass, self.backward_pass, *self.parameters()
)

assert isinstance(logits, torch.Tensor)
if phase == "train":
assert loss.requires_grad
self.log(f"{phase}/loss", loss, prog_bar=True, sync_dist=True)
acc = logits.argmax(-1).eq(y).float().mean()
self.log(f"{phase}/acc", acc, prog_bar=True, sync_dist=True)
return loss
# View the torch inputs as jax Arrays.
x, y = torch_to_jax(x), torch_to_jax(y)

jax_params = self.jax_params()
if phase != "train":
# Only use the forward pass.
loss, logits = self.forward_pass(jax_params, x, y)
else:
optimizer = self.optimizers()
assert isinstance(optimizer, torch.optim.Optimizer)

# Perform the backward pass
(loss, logits), jax_grads = self.backward_pass(jax_params, x, y)
distributed = torch.distributed.is_initialized()

with torch.no_grad():
# 'convert' the gradients to pytorch
torch_grads = jax.tree.map(jax_to_torch, jax_grads)
# Update the torch parameters tensors in-place using the jax grads.
for param, grad in zip(self.parameters(), jax.tree.leaves(torch_grads)):
if distributed:
torch.distributed.all_reduce(grad, op=torch.distributed.ReduceOp.AVG)
if param.grad is None:
param.grad = grad
else:
param.grad += grad
optimizer.step()
optimizer.zero_grad()

# IDEA: What about a hacky .backward method on a torch Tensor, that calls the backward pass
# and sets the grads? Could we then use automatic optimization?
torch_loss = jax_to_torch(loss)
torch_y = batch[1]
accuracy = jax_to_torch(logits).argmax(-1).eq(torch_y).float().mean()
self.log(f"{phase}/accuracy", accuracy, prog_bar=True, sync_dist=True)
self.log(f"{phase}/loss", torch_loss, prog_bar=True, sync_dist=True)
return torch_loss

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.hp.lr)
Expand Down

0 comments on commit 8feece6

Please sign in to comment.