Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jan 7, 2024
1 parent 7ff688d commit e584051
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 23 deletions.
15 changes: 8 additions & 7 deletions brainpy/_src/dnn/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class Flatten(Layer):

def __init__(
self,
start_dim: int = 1,
start_dim: int = 0,
end_dim: int = -1,
name: Optional[str] = None,
mode: bm.Mode = None,
Expand All @@ -86,11 +86,11 @@ def __init__(
self.end_dim = end_dim

def update(self, x):
# if isinstance(self.mode, bm.BatchingMode):
# return x.reshape((x.shape[0], -1))
# else:
# return x.flatten()
return bm.flatten(x, self.start_dim, self.end_dim)
if self.mode.is_child_of(bm.BatchingMode):
start_dim = (self.start_dim + 1) if self.start_dim >= 0 else (x.ndim + self.start_dim + 1)
else:
start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim
return bm.flatten(x, start_dim, self.end_dim)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})'
Expand Down Expand Up @@ -153,7 +153,8 @@ def __init__(self, dim: int, sizes: Sequence[int], mode: bm.Mode = None, name: s
raise TypeError("unflattened_size must be tuple or list, but found type {}".format(type(sizes).__name__))

def update(self, x):
return bm.unflatten(x, self.dim, self.sizes)
dim = self.dim + 1 if self.mode.is_batch_mode() else self.dim
return bm.unflatten(x, dim, self.sizes)

def __repr__(self):
return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})'
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_flatten_non_batching_mode(self):

output = layer.update(input)

expected_shape = (10, 60)
expected_shape = (600,)
self.assertEqual(output.shape, expected_shape)
bm.clear_buffer_memory()

Expand Down
45 changes: 30 additions & 15 deletions brainpy/_src/initialize/generic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable, Optional, Sequence
import functools
from typing import Union, Callable, Optional, Sequence, Any

import jax
import jax.numpy as jnp
Expand All @@ -19,6 +19,9 @@
'delay',
]

DType = Any



def _check_none(x, allow_none: bool = False):
pass
Expand All @@ -39,7 +42,8 @@ def parameter(
sizes: Shape,
allow_none: bool = True,
allow_scalar: bool = True,
sharding: Optional[Sharding] = None
sharding: Optional[Sharding] = None,
dtype: DType = None,
):
"""Initialize parameters.
Expand All @@ -59,6 +63,8 @@ def parameter(
Whether allow the parameter is a scalar value.
sharding: Sharding
The axes for automatic array sharding.
dtype: DType
The data type of the parameter.
Returns
-------
Expand All @@ -80,20 +86,22 @@ def parameter(
return param

if callable(param):
v = bm.jit(param,
v = bm.jit(functools.partial(param, dtype=dtype),
static_argnums=0,
out_shardings=bm.sharding.get_sharding(sharding))(sizes)
return _check_var(v) # TODO: checking the Variable need to be traced

elif isinstance(param, (np.ndarray, jnp.ndarray)):
param = bm.asarray(param)
param = param
elif isinstance(param, bm.Variable):
param = param
elif isinstance(param, bm.Array):
param = param
else:
raise ValueError(f'Unknown param type {type(param)}: {param}')

if dtype is not None:
if param.dtype != dtype:
param.value = param.astype(dtype)
if allow_scalar:
if param.shape == () or param.shape == (1,):
return param
Expand All @@ -109,6 +117,7 @@ def variable_(
batch_axis: int = 0,
axis_names: Optional[Sequence[str]] = None,
batch_axis_name: Optional[str] = None,
dtype: DType = None,
):
"""Initialize a :math:`~.Variable` from a callable function or a data.
Expand All @@ -122,7 +131,8 @@ def variable_(
sizes=sizes,
batch_axis=batch_axis,
axis_names=axis_names,
batch_axis_name=batch_axis_name)
batch_axis_name=batch_axis_name,
dtype=dtype)


def variable(
Expand All @@ -132,6 +142,7 @@ def variable(
batch_axis: int = 0,
axis_names: Optional[Sequence[str]] = None,
batch_axis_name: Optional[str] = None,
dtype: DType = None,
):
"""Initialize variables.
Expand All @@ -152,6 +163,8 @@ def variable(
The name for each axis. These names should match the given ``axes``.
batch_axis_name: str
The name for the batch axis. The name will be used if ``batch_size_or_mode`` is given.
dtype: DType
The data type of the variable.
Returns
-------
Expand All @@ -175,15 +188,15 @@ def variable(
if sizes is None:
raise ValueError('"varshape" cannot be None when data is a callable function.')
if isinstance(batch_or_mode, bm.NonBatchingMode):
data = bm.Variable(init(sizes), axis_names=axis_names)
data = bm.Variable(init(sizes, dtype=dtype), axis_names=axis_names)
elif isinstance(batch_or_mode, bm.BatchingMode):
new_shape = sizes[:batch_axis] + (batch_or_mode.batch_size,) + sizes[batch_axis:]
data = bm.Variable(init(new_shape), batch_axis=batch_axis, axis_names=axis_names)
data = bm.Variable(init(new_shape, dtype=dtype), batch_axis=batch_axis, axis_names=axis_names)
elif batch_or_mode in (None, False):
data = bm.Variable(init(sizes), axis_names=axis_names)
data = bm.Variable(init(sizes, dtype=dtype), axis_names=axis_names)
elif isinstance(batch_or_mode, int):
new_shape = sizes[:batch_axis] + (int(batch_or_mode),) + sizes[batch_axis:]
data = bm.Variable(init(new_shape), batch_axis=batch_axis, axis_names=axis_names)
data = bm.Variable(init(new_shape, dtype=dtype), batch_axis=batch_axis, axis_names=axis_names)
else:
raise ValueError(f'Unknown batch_size_or_mode: {batch_or_mode}')

Expand All @@ -192,21 +205,23 @@ def variable(
if bm.shape(init) != sizes:
raise ValueError(f'The shape of "data" {bm.shape(init)} does not match with "var_shape" {sizes}')
if isinstance(batch_or_mode, bm.NonBatchingMode):
data = bm.Variable(init, axis_names=axis_names)
data = bm.Variable(init, axis_names=axis_names, dtype=dtype)
elif isinstance(batch_or_mode, bm.BatchingMode):
data = bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis),
batch_or_mode.batch_size,
axis=batch_axis),
batch_axis=batch_axis,
axis_names=axis_names)
axis_names=axis_names,
dtype=dtype)
elif batch_or_mode in (None, False):
data = bm.Variable(init, axis_names=axis_names)
data = bm.Variable(init, axis_names=axis_names, dtype=dtype)
elif isinstance(batch_or_mode, int):
data = bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis),
int(batch_or_mode),
axis=batch_axis),
batch_axis=batch_axis,
axis_names=axis_names)
axis_names=axis_names,
dtype=dtype)
else:
raise ValueError('Unknown batch_size_or_mode.')
return bm.sharding.partition_by_axname(data, axis_names)
Expand Down

0 comments on commit e584051

Please sign in to comment.