diff --git a/brainpy/dnn/conv.py b/brainpy/dnn/conv.py index dc39b597..af4ae41d 100644 --- a/brainpy/dnn/conv.py +++ b/brainpy/dnn/conv.py @@ -179,7 +179,7 @@ def update(self, x): if self.mask is not None: try: lax.broadcast_shapes(self.w.shape, self.mask.shape) - except: + except (ValueError, TypeError): raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}") w = w * self.mask y = lax.conv_general_dilated(lhs=bm.as_jax(x), @@ -566,7 +566,7 @@ def update(self, x): if self.mask is not None: try: lax.broadcast_shapes(self.w.shape, self.mask.shape) - except: + except (ValueError, TypeError): raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}") w = w * self.mask y = lax.conv_transpose(lhs=bm.as_jax(x), diff --git a/brainpy/dnn/interoperation_flax.py b/brainpy/dnn/interoperation_flax.py index 0b6a1324..6c50a5e1 100644 --- a/brainpy/dnn/interoperation_flax.py +++ b/brainpy/dnn/interoperation_flax.py @@ -26,7 +26,7 @@ try: import flax # noqa from flax.linen.recurrent import RNNCellBase -except: +except (ImportError, ModuleNotFoundError): flax = None RNNCellBase = object diff --git a/brainpy/dyn/neurons/lif.py b/brainpy/dyn/neurons/lif.py index f7ddf886..28d693c3 100644 --- a/brainpy/dyn/neurons/lif.py +++ b/brainpy/dyn/neurons/lif.py @@ -299,7 +299,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") else: spike = V >= self.V_th @@ -509,7 +509,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient if self.ref_var: @@ -785,7 +785,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") else: spike = V >= self.V_th @@ -1142,7 +1142,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient if self.ref_var: @@ -1497,7 +1497,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") w += self.b * spike else: @@ -1843,7 +1843,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") w += self.b * spike_no_grad spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient @@ -2142,7 +2142,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") else: spike = V >= self.V_th @@ -2431,7 +2431,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient if self.ref_var: @@ -2734,7 +2734,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") w += self.b * spike else: @@ -3054,7 +3054,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") w += self.b * spike_no_grad spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient @@ -3417,7 +3417,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") I1 += spike * (self.R1 * I1 + self.A1 - I1) I2 += spike * (self.R2 * I2 + self.A2 - I2) V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike @@ -3810,7 +3810,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") I1 += spike * (self.R1 * I1 + self.A1 - I1) I2 += spike * (self.R2 * I2 + self.A2 - I2) V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 0390fd2a..80f8ca84 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -285,7 +285,7 @@ def update(self, x): # update synaptic variables self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) - self.h += self.a * x + self.h.value = self.h.value + self.a * x return self.g.value def return_info(self): @@ -552,7 +552,7 @@ def dg(self, g, t, h): def update(self, x): # update synaptic variables self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) - self.h += x + self.h.value = self.h.value + x return self.g.value def return_info(self): @@ -737,7 +737,7 @@ def update(self, pre_spike): t = share.load('t') dt = share.load('dt') self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, dt=dt) - self.x += pre_spike + self.x.value = self.x.value + pre_spike return self.g.value def return_info(self): diff --git a/brainpy/dynsys.py b/brainpy/dynsys.py index 6d56e391..a27e9892 100644 --- a/brainpy/dynsys.py +++ b/brainpy/dynsys.py @@ -966,11 +966,18 @@ def _slice_to_num(slice_: slice, length: int): step = slice_.step if step is None: step = 1 + if step == 0: + raise ValueError("slice step cannot be zero") # number num = 0 - while start < stop: - start += step - num += 1 + if step > 0: + while start < stop: + start += step + num += 1 + else: + while start > stop: + start += step + num += 1 return num diff --git a/brainpy/helpers.py b/brainpy/helpers.py index 8cdae877..5a7f7c62 100644 --- a/brainpy/helpers.py +++ b/brainpy/helpers.py @@ -116,6 +116,9 @@ def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs): missing_keys = [] unexpected_keys = [] for name, node in nodes.items(): + if name not in state_dict: + missing_keys.append(name) + continue r = node.load_state(state_dict[name], **kwargs) if r is not None: missing, unexpected = r diff --git a/brainpy/integrators/ode/adaptive_rk.py b/brainpy/integrators/ode/adaptive_rk.py index c22847b7..ee6864f3 100644 --- a/brainpy/integrators/ode/adaptive_rk.py +++ b/brainpy/integrators/ode/adaptive_rk.py @@ -67,6 +67,7 @@ import jax.numpy as jnp +from brainpy import _errors as errors from brainpy.integrators import constants as C, utils from brainpy.integrators.ode import common from brainpy.integrators.ode.base import ODEIntegrator @@ -456,7 +457,7 @@ class BogackiShampine(AdaptiveRKIntegrator): A = [(), (0.5,), (0., 0.75), - ('2/9', '1/3', '4/0'), ] + ('2/9', '1/3', '4/9'), ] B1 = ['2/9', '1/3', '4/9', 0] B2 = ['7/24', 0.25, '1/3', 0.125] C = [0, 0.5, 0.75, 1] diff --git a/brainpy/integrators/ode/generic.py b/brainpy/integrators/ode/generic.py index 49924c06..8ee6c099 100644 --- a/brainpy/integrators/ode/generic.py +++ b/brainpy/integrators/ode/generic.py @@ -134,7 +134,7 @@ def set_default_odeint(method): raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') global _DEFAULT_DDE_METHOD - _DEFAULT_ODE_METHOD = method + _DEFAULT_DDE_METHOD = method def get_default_odeint(): diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py index 7e486650..3b14e883 100644 --- a/brainpy/math/ndarray.py +++ b/brainpy/math/ndarray.py @@ -272,7 +272,7 @@ def value(self): @value.setter def value(self, value): - self_value = self._check_tracer() + self_value = self._value if isinstance(value, Array): value = value.value diff --git a/brainpy/math/object_transform/base.py b/brainpy/math/object_transform/base.py index 1f46973d..f6a6b976 100644 --- a/brainpy/math/object_transform/base.py +++ b/brainpy/math/object_transform/base.py @@ -773,7 +773,7 @@ def update(self, *args, **kwargs) -> 'NodeDict': self[k] = v elif isinstance(arg, tuple): assert len(arg) == 2 - self[arg[0]] = args[1] + self[arg[0]] = arg[1] for k, v in kwargs.items(): self[k] = v return self @@ -781,8 +781,8 @@ def update(self, *args, **kwargs) -> 'NodeDict': def __setitem__(self, key, value) -> 'NodeDict': if self.check_unique: exist = self.get(key, None) - if id(exist) != id(value): - raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {value}.') + if exist is not None and id(exist) != id(value): + raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {exist}.') super().__setitem__(key, value) return self diff --git a/brainpy/math/object_transform/variables.py b/brainpy/math/object_transform/variables.py index f7d4d4c6..60671a14 100644 --- a/brainpy/math/object_transform/variables.py +++ b/brainpy/math/object_transform/variables.py @@ -109,7 +109,7 @@ def size_without_batch(self): return self.size else: sizes = self.size - return sizes[:self.batch_size] + sizes[self.batch_axis + 1:] + return sizes[:self.batch_axis] + sizes[self.batch_axis + 1:] @property def batch_axis(self) -> Optional[int]: @@ -390,7 +390,7 @@ def update(self, *args, **kwargs) -> 'VarDict': self[k] = v elif isinstance(arg, tuple): assert len(arg) == 2 - self[arg[0]] = args[1] + self[arg[0]] = arg[1] for k, v in kwargs.items(): self[k] = v return self diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index a820dc48..332881e5 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -109,8 +109,8 @@ def __init__( ''' ) if seed is not None: - NoLongerSupportError('"seed" is no longer supported. ' - 'Please shuffle your data by yourself.') + raise NoLongerSupportError('"seed" is no longer supported. ' + 'Please shuffle your data by yourself.') # jit settings if isinstance(self._origin_jit, bool): @@ -317,7 +317,7 @@ def fit( fit_t1 = time.time() aux = {} for k, v in fit_epoch_metric.items(): - aux[k] = jnp.mean(bm.as_jax(bm.asarray(v))) + aux[k] = np.mean(np.asarray(v)) if k not in report_train_metric: report_train_metric[k] = [] detailed_train_metric[k] = [] @@ -420,7 +420,7 @@ def fit( test_t1 = time.time() aux = {} for k, v in test_epoch_metric.items(): - aux[k] = jnp.mean(bm.as_jax(bm.asarray(v))) + aux[k] = np.mean(np.asarray(v)) if k not in report_test_metric: report_test_metric[k] = [] detailed_test_metric[k] = [] diff --git a/pyproject.toml b/pyproject.toml index dc5164c3..40d1cb99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,10 @@ build-backend = "setuptools.build_meta" name = "brainpy" description = "BrainPy: Brain Dynamics Programming in Python" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" authors = [ { name = "BrainPy Team", email = "chao.brain@qq.com" } ] -license = { text = "GPL-3.0 license" } classifiers = [ "Natural Language :: English", "Operating System :: OS Independent", @@ -39,7 +38,7 @@ dependencies = [ "jax", "tqdm", "brainstate>=0.2.7", - "brainunit", + "brainunit>=0.2.0", "brainevent>=0.0.7", "braintools>=0.0.9", 'brainpy_state>=0.0.3', diff --git a/requirements.txt b/requirements.txt index b8e3db57..09fbf6cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -numpy +numpy>=1.15 brainunit brainevent>=0.0.7 -braintools>=0.1.0 +braintools>=0.0.9 brainstate>=0.2.7 -brainpy_state>=0.0.2 +brainpy_state>=0.0.3 jax tqdm