From 8feece6555649f9a0fc12ef06dbf23f4124706bc Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 4 Jun 2024 21:05:58 +0000 Subject: [PATCH] Make it work with automatic optimization and jit! Signed-off-by: Fabrice Normandin --- project/algorithms/jax_algo.py | 179 ++++++++++++--------------------- 1 file changed, 62 insertions(+), 117 deletions(-) diff --git a/project/algorithms/jax_algo.py b/project/algorithms/jax_algo.py index 071367f8..f355da59 100644 --- a/project/algorithms/jax_algo.py +++ b/project/algorithms/jax_algo.py @@ -103,35 +103,40 @@ def to_channels_last[T: jax.Array | torch.Tensor](tensor: T) -> T: class JaxFunction(torch.autograd.Function): params_treedef: ClassVar - @staticmethod - def loss_function( - params: VariableDict, - x: jax.Array, - y: jax.Array, - ): - logits = CNN().apply(params, x) - assert isinstance(logits, jax.Array) - loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean() - assert isinstance(loss, jax.Array) - return loss, logits - @staticmethod def forward( - ctx: torch.autograd.function.FunctionCtx, + ctx: torch.autograd.function.NestedIOFunction, x: torch.Tensor, y: torch.Tensor, params_treedef: PyTreeDef, + loss_fn: Callable[[VariableDict, jax.Array, jax.Array], tuple[jax.Array, jax.Array]], + loss_value_and_grad_fn: Callable[ + [VariableDict, jax.Array, jax.Array], tuple[jax.Array, jax.Array] + ], *params: torch.Tensor, ): - ctx.save_for_backward(x, y, *params) - ctx.params_treedef = params_treedef # type: ignore jax_x = torch_to_jax(x) jax_y = torch_to_jax(y) jax_params = tuple(map(torch_to_jax, params)) jax_params = jax.tree.unflatten(params_treedef, jax_params) - jax_loss, jax_logits = JaxFunction.loss_function(jax_params, x=jax_x, y=jax_y) - loss = jax_to_torch(jax_loss) - logits = jax_to_torch(jax_logits) + + needs_grad: tuple[bool, ...] = ctx.needs_input_grad # type: ignore + x_needs_grad, y_needs_grad, _, _, _, *params_need_grad = needs_grad + # todo: broaden a bit: + assert not x_needs_grad + assert not y_needs_grad + if all(params_need_grad): + # We're going to need to do the backward pass, so do it right away and save the grads + # in the context. + (loss, logits), param_grads = loss_value_and_grad_fn(jax_params, jax_x, jax_y) + flattened_param_grads = jax.tree.leaves(param_grads) + torch_grads = tuple(map(jax_to_torch, flattened_param_grads)) + ctx.save_for_backward(*torch_grads) + else: + assert not any(params_need_grad) + loss, logits = loss_fn(jax_params, jax_x, jax_y) + loss = jax_to_torch(loss) + logits = jax_to_torch(logits) return loss, logits @staticmethod @@ -140,33 +145,21 @@ def backward( grad_loss: torch.Tensor, grad_logits: torch.Tensor, ): - x: torch.Tensor - params: tuple[torch.Tensor, ...] - x, y, *params = ctx.saved_tensors # type: ignore - params_treedef: PyTreeDef = ctx.params_treedef # type: ignore - jax_x = torch_to_jax(x) - jax_y = torch_to_jax(y) - # jax_grad_output = torch_to_jax(grad_output) # TODO: Can we even pass this to jax.grad? - - structured_params = jax.tree.unflatten(params_treedef, params) - jax_params = jax.tree.map(torch_to_jax, structured_params) + x_needs_grad, y_needs_grad, _, _, _, *params_needs_grad = ctx.needs_input_grad + # todo: broaden this a bit in case we need the grad of the input. + # todo: Figure out how to do jax.grad for a function that outputs a matrix or vector. + assert not x_needs_grad + assert not y_needs_grad grad_input = None grad_y = None + if all(params_needs_grad): + params_grads = ctx.saved_tensors + else: + assert not any(params_needs_grad) + params_grads = tuple(None for _ in params_needs_grad) - # todo: broaden this a bit in case we need the grad of the input. - assert ctx.needs_input_grad == ( - False, # input - False, # y - False, # params_treedef - *(True for _ in params), - ), ctx.needs_input_grad - jax_params_grad, logits = jax.grad(JaxFunction.loss_function, argnums=0, has_aux=True)( - jax_params, jax_x, jax_y - ) - torch_params_grad = jax.tree.map(jax_to_torch, jax_params_grad) - torch_flat_params_grad = jax.tree.leaves(torch_params_grad) - return grad_input, grad_y, None, *torch_flat_params_grad + return grad_input, grad_y, None, None, None, *params_grads class JaxAlgorithm(Algorithm): @@ -199,46 +192,28 @@ def __init__( map(operator.methodcaller("clone"), map(jax_to_torch, params_list)) ) - # We will do the backward pass ourselves, and PL will only be used to synchronize stuff - # between workers, do logging, etc. self.automatic_optimization = True - # def on_fit_start(self): - # Setting those here, because otherwise we get pickling errors when running with multiple - # GPUs. - - # def loss_fn(params: VariableDict, x: jax.Array, y: jax.Array): - # logits = self.network.apply(params, x) - # assert isinstance(logits, jax.Array) - # loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean() - # assert isinstance(loss, jax.Array) - # return loss, logits - - # self.forward_pass = loss_fn - # self.backward_pass = value_and_grad(self.forward_pass) - - # if not self.hp.debug: - # self.forward_pass = jit(self.forward_pass) - # self.backward_pass = jit(self.backward_pass) - - # def jax_params(self) -> VariableDict: - # # View the torch parameters as jax Arrays - # jax_parameters = jax.tree.map(torch_to_jax, list(self.parameters())) - # # Reconstruct the original object structure. - # jax_params_tuple = jax.tree.unflatten(self.params_treedef, jax_parameters) - # return jax_params_tuple - - # def on_before_batch_transfer( - # self, batch: tuple[torch.Tensor, torch.Tensor], dataloader_idx: int - # ): - # # Convert the batch to jax Arrays. - # x, y = batch - # # Seems like jax likes channels last tensors: jax.from_dlpack doesn't work with - # # channels-first tensors, so we have to do a transpose here. - # x = to_channels_last(x) - # # View the torch inputs as jax Arrays. - # x, y = torch_to_jax(x), torch_to_jax(y) - # return x, y + def on_fit_start(self): + # Setting those here, because otherwise we get pickling errors when running with multiple + # GPUs. + def loss_function( + params: VariableDict, + x: jax.Array, + y: jax.Array, + ): + logits = self.network.apply(params, x) + assert isinstance(logits, jax.Array) + loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean() + assert isinstance(loss, jax.Array) + return loss, logits + + self.forward_pass = loss_function + self.backward_pass = value_and_grad(self.forward_pass, argnums=0, has_aux=True) + + if not self.hp.debug: + self.forward_pass = jit(self.forward_pass) + self.backward_pass = jit(self.backward_pass) def shared_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr @@ -247,8 +222,15 @@ def shared_step( # Convert the batch to jax Arrays. # Seems like jax likes channels last tensors: jax.from_dlpack doesn't work with # channels-first tensors, so we have to do a transpose here. + x = to_channels_last(x) - loss, logits = JaxFunction.apply(x, y, self.params_treedef, *self.parameters()) # type: ignore + + loss: torch.Tensor + logits: torch.Tensor + loss, logits = JaxFunction.apply( # type: ignore + x, y, self.params_treedef, self.forward_pass, self.backward_pass, *self.parameters() + ) + assert isinstance(logits, torch.Tensor) if phase == "train": assert loss.requires_grad @@ -256,43 +238,6 @@ def shared_step( acc = logits.argmax(-1).eq(y).float().mean() self.log(f"{phase}/acc", acc, prog_bar=True, sync_dist=True) return loss - # View the torch inputs as jax Arrays. - x, y = torch_to_jax(x), torch_to_jax(y) - - jax_params = self.jax_params() - if phase != "train": - # Only use the forward pass. - loss, logits = self.forward_pass(jax_params, x, y) - else: - optimizer = self.optimizers() - assert isinstance(optimizer, torch.optim.Optimizer) - - # Perform the backward pass - (loss, logits), jax_grads = self.backward_pass(jax_params, x, y) - distributed = torch.distributed.is_initialized() - - with torch.no_grad(): - # 'convert' the gradients to pytorch - torch_grads = jax.tree.map(jax_to_torch, jax_grads) - # Update the torch parameters tensors in-place using the jax grads. - for param, grad in zip(self.parameters(), jax.tree.leaves(torch_grads)): - if distributed: - torch.distributed.all_reduce(grad, op=torch.distributed.ReduceOp.AVG) - if param.grad is None: - param.grad = grad - else: - param.grad += grad - optimizer.step() - optimizer.zero_grad() - - # IDEA: What about a hacky .backward method on a torch Tensor, that calls the backward pass - # and sets the grads? Could we then use automatic optimization? - torch_loss = jax_to_torch(loss) - torch_y = batch[1] - accuracy = jax_to_torch(logits).argmax(-1).eq(torch_y).float().mean() - self.log(f"{phase}/accuracy", accuracy, prog_bar=True, sync_dist=True) - self.log(f"{phase}/loss", torch_loss, prog_bar=True, sync_dist=True) - return torch_loss def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=self.hp.lr)