From bba139ccdae6ecd204102ce7b92b44b84162a5ba Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 12 Mar 2026 00:51:50 +0800 Subject: [PATCH 1/3] chore: update version to 2.7.7, enhance README and requirements, and refactor event handling --- brainpy/__init__.py | 10 ++-------- brainpy/dnn/linear.py | 24 ++++++++++++++---------- brainpy/initialize/base.py | 4 +++- brainpy/math/compat_numpy.py | 3 +-- brainpy/math/event/csr_matmat.py | 2 +- brainpy/math/event/csr_matvec.py | 2 +- brainpy/math/jitconn/event_matvec.py | 8 ++++---- brainpy/math/jitconn/matvec.py | 4 ++-- brainpy/state/README.md | 4 ++-- brainpy/state/__init__.py | 5 +++-- pyproject.toml | 3 +-- requirements.txt | 2 +- 12 files changed, 35 insertions(+), 36 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 8b9d35c5..0f0b1b4f 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -__version__ = "2.7.6" +__version__ = "2.7.7" __version_info__ = tuple(map(int, __version__.split("."))) from brainpy import _errors as errors @@ -133,20 +133,14 @@ synouts, # synaptic output synplast, # synaptic plasticity ) -from brainpy.math.object_transform.base import ( - Base as Base, -) +from brainpy.math.object_transform.base import Base as Base from brainpy.math.object_transform.collectors import ( ArrayCollector as ArrayCollector, Collector as Collector, ) -from brainpy.deprecations import deprecation_getattr - optimizers = optim - # New package from brainpy import state - diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index 87dd0435..a38e5d38 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -19,8 +19,12 @@ import jax import jax.numpy as jnp import numpy as np -from brainevent import csr_on_pre, csr2csc_on_post -from brainevent import dense_on_pre, dense_on_post +from brainevent import ( + update_csr_on_binary_pre, + update_csr_on_binary_post, + update_dense_on_binary_pre, + update_dense_on_binary_post, +) from brainpy import connect, initialize as init from brainpy import math as bm @@ -226,11 +230,11 @@ def stdp_update( if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] - self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max) + self.W.value = update_dense_on_binary_pre(self.W.value, spike, trace, w_min, w_max) if on_post is not None: spike = on_post['spike'] trace = on_post['trace'] - self.W.value = dense_on_post(self.W.value, trace, spike, w_min, w_max) + self.W.value = update_dense_on_binary_post(self.W.value, trace, spike, w_min, w_max) Linear = Dense @@ -321,11 +325,11 @@ def stdp_update( if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] - self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) + self.weight.value = update_dense_on_binary_pre(self.weight.value, spike, trace, w_min, w_max) if on_post is not None: spike = on_post['spike'] trace = on_post['trace'] - self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max) + self.weight.value = update_dense_on_binary_post(self.weight.value, trace, spike, w_min, w_max) class OneToOne(Layer, SupportSTDP): @@ -449,11 +453,11 @@ def stdp_update( if on_pre is not None: spike = on_pre['spike'] trace = on_pre['trace'] - self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) + self.weight.value = update_dense_on_binary_pre(self.weight.value, spike, trace, w_min, w_max) if on_post is not None: spike = on_post['spike'] trace = on_post['trace'] - self.weight.value = dense_on_post(self.weight.value, trace, spike, w_min, w_max) + self.weight.value = update_dense_on_binary_post(self.weight.value, trace, spike, w_min, w_max) class _CSRLayer(Layer, SupportSTDP): @@ -500,7 +504,7 @@ def stdp_update( if on_pre is not None: # update on presynaptic spike spike = on_pre['spike'] trace = on_pre['trace'] - self.weight.value = csr_on_pre( + self.weight.value = update_csr_on_binary_pre( self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max, shape=(spike.shape[0], trace.shape[0]), ) @@ -512,7 +516,7 @@ def stdp_update( ) spike = on_post['spike'] trace = on_post['trace'] - self.weight.value = csr2csc_on_post( + self.weight.value = update_csr_on_binary_post( self.weight.value, self._pre_ids, self._post_indptr, self.w_indices, trace, spike, w_min, w_max, shape=(trace.shape[0], spike.shape[0]), diff --git a/brainpy/initialize/base.py b/brainpy/initialize/base.py index 96dde401..a14e9792 100644 --- a/brainpy/initialize/base.py +++ b/brainpy/initialize/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +# -*- coding: utf-8 -*- + import abc __all__ = [ diff --git a/brainpy/math/compat_numpy.py b/brainpy/math/compat_numpy.py index 870543ee..250a447d 100644 --- a/brainpy/math/compat_numpy.py +++ b/brainpy/math/compat_numpy.py @@ -36,7 +36,7 @@ 'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', 'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', 'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round', - 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod', + 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'prod', 'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum', 'cumprod', 'cumsum', 'ediff1d', 'cross', 'isfinite', 'isinf', 'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve', @@ -397,7 +397,6 @@ def msort(a): floor = _compatible_with_brainpy_array(jnp.floor) ceil = _compatible_with_brainpy_array(jnp.ceil) trunc = _compatible_with_brainpy_array(jnp.trunc) -fix = _compatible_with_brainpy_array(jnp.fix) prod = _compatible_with_brainpy_array(jnp.prod) sum = _compatible_with_brainpy_array(jnp.sum) diff --git a/brainpy/math/event/csr_matmat.py b/brainpy/math/event/csr_matmat.py index 63158dc9..c33d1103 100644 --- a/brainpy/math/event/csr_matmat.py +++ b/brainpy/math/event/csr_matmat.py @@ -59,7 +59,7 @@ def csrmm( if isinstance(matrix, Array): matrix = matrix.value - matrix = brainevent.EventArray(matrix) + matrix = brainevent.BinaryArray(matrix) csr = brainevent.CSR((data, indices, indptr), shape=shape) if transpose: return matrix @ csr diff --git a/brainpy/math/event/csr_matvec.py b/brainpy/math/event/csr_matvec.py index d64e9b10..8be45244 100644 --- a/brainpy/math/event/csr_matvec.py +++ b/brainpy/math/event/csr_matvec.py @@ -84,7 +84,7 @@ def csrmv( if isinstance(events, Array): events = events.value - events = brainevent.EventArray(events) + events = brainevent.BinaryArray(events) csr = brainevent.CSR((data, indices, indptr), shape=shape) if transpose: return events @ csr diff --git a/brainpy/math/jitconn/event_matvec.py b/brainpy/math/jitconn/event_matvec.py index b22fcb00..042bad16 100644 --- a/brainpy/math/jitconn/event_matvec.py +++ b/brainpy/math/jitconn/event_matvec.py @@ -49,8 +49,8 @@ def event_mv_prob_homo( if isinstance(weight, Array): weight = weight.value - events = brainevent.EventArray(events) - csr = brainevent.JITCHomoR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel) + events = brainevent.BinaryArray(events) + csr = brainevent.JITCScalarR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel) if transpose: return events @ csr else: @@ -75,7 +75,7 @@ def event_mv_prob_uniform( seed = np.random.randint(0, 1000000000) if isinstance(events, Array): events = events.value - events = brainevent.EventArray(events) + events = brainevent.BinaryArray(events) if isinstance(w_low, Array): w_low = w_low.value if isinstance(w_high, Array): @@ -106,7 +106,7 @@ def event_mv_prob_normal( seed = np.random.randint(0, 1000000000) if isinstance(events, Array): events = events.value - events = brainevent.EventArray(events) + events = brainevent.BinaryArray(events) if isinstance(w_mu, Array): w_mu = w_mu.value if isinstance(w_sigma, Array): diff --git a/brainpy/math/jitconn/matvec.py b/brainpy/math/jitconn/matvec.py index cd13f5c6..094f758e 100644 --- a/brainpy/math/jitconn/matvec.py +++ b/brainpy/math/jitconn/matvec.py @@ -96,7 +96,7 @@ def mv_prob_homo( if isinstance(weight, Array): weight = weight.value - csr = brainevent.JITCHomoR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel) + csr = brainevent.JITCScalarR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel) if transpose: return vector @ csr else: @@ -290,7 +290,7 @@ def get_homo_weight_matrix( """ if seed is None: seed = np.random.randint(0, 1000000000) - csr = brainevent.JITCHomoR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel) + csr = brainevent.JITCScalarR((weight, conn_prob, seed), shape=shape, corder=outdim_parallel) if transpose: csr = csr.T return csr.todense() diff --git a/brainpy/state/README.md b/brainpy/state/README.md index 4d610359..229742af 100644 --- a/brainpy/state/README.md +++ b/brainpy/state/README.md @@ -1,4 +1,4 @@ -# `brainpy.state` - State-based Brain Dynamics Programming +# `brainpy.state` ## Overview @@ -37,7 +37,7 @@ pip install brainpy -U For development or to install the state module separately: ```bash -pip install brainpy_state -U +pip install brainpy.state -U ``` ## Usage diff --git a/brainpy/state/__init__.py b/brainpy/state/__init__.py index 4f9491a4..407d5954 100644 --- a/brainpy/state/__init__.py +++ b/brainpy/state/__init__.py @@ -16,5 +16,6 @@ from brainpy_state import * from brainpy_state import __all__ - - +if __name__ == '__main__': + print(LIF) + print(__all__) diff --git a/pyproject.toml b/pyproject.toml index 5db3f845..dc5164c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", @@ -41,7 +40,7 @@ dependencies = [ "tqdm", "brainstate>=0.2.7", "brainunit", - "brainevent>=0.0.4", + "brainevent>=0.0.7", "braintools>=0.0.9", 'brainpy_state>=0.0.3', ] diff --git a/requirements.txt b/requirements.txt index 726f19bb..b8e3db57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy brainunit -brainevent>=0.0.4 +brainevent>=0.0.7 braintools>=0.1.0 brainstate>=0.2.7 brainpy_state>=0.0.2 From 6b41da158cad0c06877b885979e96b9a2218fedd Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 12 Mar 2026 01:02:43 +0800 Subject: [PATCH 2/3] fix: update required Python version to 3.11 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dc5164c3..0897e574 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "brainpy" description = "BrainPy: Brain Dynamics Programming in Python" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" authors = [ { name = "BrainPy Team", email = "chao.brain@qq.com" } ] From c3555eef7c8723973f1cfcf77b322b9ee5136f39 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 12 Mar 2026 01:11:44 +0800 Subject: [PATCH 3/3] fix: update synaptic variable updates and error handling in multiple files --- brainpy/dnn/conv.py | 4 ++-- brainpy/dnn/interoperation_flax.py | 2 +- brainpy/dyn/neurons/lif.py | 24 +++++++++++----------- brainpy/dyn/synapses/abstract_models.py | 6 +++--- brainpy/dynsys.py | 13 +++++++++--- brainpy/helpers.py | 3 +++ brainpy/integrators/ode/adaptive_rk.py | 3 ++- brainpy/integrators/ode/generic.py | 2 +- brainpy/math/ndarray.py | 2 +- brainpy/math/object_transform/base.py | 6 +++--- brainpy/math/object_transform/variables.py | 4 ++-- brainpy/train/back_propagation.py | 8 ++++---- pyproject.toml | 1 - requirements.txt | 6 +++--- 14 files changed, 47 insertions(+), 37 deletions(-) diff --git a/brainpy/dnn/conv.py b/brainpy/dnn/conv.py index dc39b597..af4ae41d 100644 --- a/brainpy/dnn/conv.py +++ b/brainpy/dnn/conv.py @@ -179,7 +179,7 @@ def update(self, x): if self.mask is not None: try: lax.broadcast_shapes(self.w.shape, self.mask.shape) - except: + except (ValueError, TypeError): raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}") w = w * self.mask y = lax.conv_general_dilated(lhs=bm.as_jax(x), @@ -566,7 +566,7 @@ def update(self, x): if self.mask is not None: try: lax.broadcast_shapes(self.w.shape, self.mask.shape) - except: + except (ValueError, TypeError): raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}") w = w * self.mask y = lax.conv_transpose(lhs=bm.as_jax(x), diff --git a/brainpy/dnn/interoperation_flax.py b/brainpy/dnn/interoperation_flax.py index 0b6a1324..6c50a5e1 100644 --- a/brainpy/dnn/interoperation_flax.py +++ b/brainpy/dnn/interoperation_flax.py @@ -26,7 +26,7 @@ try: import flax # noqa from flax.linen.recurrent import RNNCellBase -except: +except (ImportError, ModuleNotFoundError): flax = None RNNCellBase = object diff --git a/brainpy/dyn/neurons/lif.py b/brainpy/dyn/neurons/lif.py index f7ddf886..28d693c3 100644 --- a/brainpy/dyn/neurons/lif.py +++ b/brainpy/dyn/neurons/lif.py @@ -299,7 +299,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") else: spike = V >= self.V_th @@ -509,7 +509,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient if self.ref_var: @@ -785,7 +785,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") else: spike = V >= self.V_th @@ -1142,7 +1142,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient if self.ref_var: @@ -1497,7 +1497,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") w += self.b * spike else: @@ -1843,7 +1843,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") w += self.b * spike_no_grad spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient @@ -2142,7 +2142,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") else: spike = V >= self.V_th @@ -2431,7 +2431,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient if self.ref_var: @@ -2734,7 +2734,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") w += self.b * spike else: @@ -3054,7 +3054,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") w += self.b * spike_no_grad spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient @@ -3417,7 +3417,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") I1 += spike * (self.R1 * I1 + self.A1 - I1) I2 += spike * (self.R2 * I2 + self.A2 - I2) V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike @@ -3810,7 +3810,7 @@ def update(self, x=None): elif self.spk_reset == 'hard': V += (self.V_reset - V) * spike_no_grad else: - raise ValueError + raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.") I1 += spike * (self.R1 * I1 + self.A1 - I1) I2 += spike * (self.R2 * I2 + self.A2 - I2) V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 0390fd2a..80f8ca84 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -285,7 +285,7 @@ def update(self, x): # update synaptic variables self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) - self.h += self.a * x + self.h.value = self.h.value + self.a * x return self.g.value def return_info(self): @@ -552,7 +552,7 @@ def dg(self, g, t, h): def update(self, x): # update synaptic variables self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) - self.h += x + self.h.value = self.h.value + x return self.g.value def return_info(self): @@ -737,7 +737,7 @@ def update(self, pre_spike): t = share.load('t') dt = share.load('dt') self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, dt=dt) - self.x += pre_spike + self.x.value = self.x.value + pre_spike return self.g.value def return_info(self): diff --git a/brainpy/dynsys.py b/brainpy/dynsys.py index 6d56e391..a27e9892 100644 --- a/brainpy/dynsys.py +++ b/brainpy/dynsys.py @@ -966,11 +966,18 @@ def _slice_to_num(slice_: slice, length: int): step = slice_.step if step is None: step = 1 + if step == 0: + raise ValueError("slice step cannot be zero") # number num = 0 - while start < stop: - start += step - num += 1 + if step > 0: + while start < stop: + start += step + num += 1 + else: + while start > stop: + start += step + num += 1 return num diff --git a/brainpy/helpers.py b/brainpy/helpers.py index 8cdae877..5a7f7c62 100644 --- a/brainpy/helpers.py +++ b/brainpy/helpers.py @@ -116,6 +116,9 @@ def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs): missing_keys = [] unexpected_keys = [] for name, node in nodes.items(): + if name not in state_dict: + missing_keys.append(name) + continue r = node.load_state(state_dict[name], **kwargs) if r is not None: missing, unexpected = r diff --git a/brainpy/integrators/ode/adaptive_rk.py b/brainpy/integrators/ode/adaptive_rk.py index c22847b7..ee6864f3 100644 --- a/brainpy/integrators/ode/adaptive_rk.py +++ b/brainpy/integrators/ode/adaptive_rk.py @@ -67,6 +67,7 @@ import jax.numpy as jnp +from brainpy import _errors as errors from brainpy.integrators import constants as C, utils from brainpy.integrators.ode import common from brainpy.integrators.ode.base import ODEIntegrator @@ -456,7 +457,7 @@ class BogackiShampine(AdaptiveRKIntegrator): A = [(), (0.5,), (0., 0.75), - ('2/9', '1/3', '4/0'), ] + ('2/9', '1/3', '4/9'), ] B1 = ['2/9', '1/3', '4/9', 0] B2 = ['7/24', 0.25, '1/3', 0.125] C = [0, 0.5, 0.75, 1] diff --git a/brainpy/integrators/ode/generic.py b/brainpy/integrators/ode/generic.py index 49924c06..8ee6c099 100644 --- a/brainpy/integrators/ode/generic.py +++ b/brainpy/integrators/ode/generic.py @@ -134,7 +134,7 @@ def set_default_odeint(method): raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') global _DEFAULT_DDE_METHOD - _DEFAULT_ODE_METHOD = method + _DEFAULT_DDE_METHOD = method def get_default_odeint(): diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py index 7e486650..3b14e883 100644 --- a/brainpy/math/ndarray.py +++ b/brainpy/math/ndarray.py @@ -272,7 +272,7 @@ def value(self): @value.setter def value(self, value): - self_value = self._check_tracer() + self_value = self._value if isinstance(value, Array): value = value.value diff --git a/brainpy/math/object_transform/base.py b/brainpy/math/object_transform/base.py index 1f46973d..f6a6b976 100644 --- a/brainpy/math/object_transform/base.py +++ b/brainpy/math/object_transform/base.py @@ -773,7 +773,7 @@ def update(self, *args, **kwargs) -> 'NodeDict': self[k] = v elif isinstance(arg, tuple): assert len(arg) == 2 - self[arg[0]] = args[1] + self[arg[0]] = arg[1] for k, v in kwargs.items(): self[k] = v return self @@ -781,8 +781,8 @@ def update(self, *args, **kwargs) -> 'NodeDict': def __setitem__(self, key, value) -> 'NodeDict': if self.check_unique: exist = self.get(key, None) - if id(exist) != id(value): - raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {value}.') + if exist is not None and id(exist) != id(value): + raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {exist}.') super().__setitem__(key, value) return self diff --git a/brainpy/math/object_transform/variables.py b/brainpy/math/object_transform/variables.py index f7d4d4c6..60671a14 100644 --- a/brainpy/math/object_transform/variables.py +++ b/brainpy/math/object_transform/variables.py @@ -109,7 +109,7 @@ def size_without_batch(self): return self.size else: sizes = self.size - return sizes[:self.batch_size] + sizes[self.batch_axis + 1:] + return sizes[:self.batch_axis] + sizes[self.batch_axis + 1:] @property def batch_axis(self) -> Optional[int]: @@ -390,7 +390,7 @@ def update(self, *args, **kwargs) -> 'VarDict': self[k] = v elif isinstance(arg, tuple): assert len(arg) == 2 - self[arg[0]] = args[1] + self[arg[0]] = arg[1] for k, v in kwargs.items(): self[k] = v return self diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index a820dc48..332881e5 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -109,8 +109,8 @@ def __init__( ''' ) if seed is not None: - NoLongerSupportError('"seed" is no longer supported. ' - 'Please shuffle your data by yourself.') + raise NoLongerSupportError('"seed" is no longer supported. ' + 'Please shuffle your data by yourself.') # jit settings if isinstance(self._origin_jit, bool): @@ -317,7 +317,7 @@ def fit( fit_t1 = time.time() aux = {} for k, v in fit_epoch_metric.items(): - aux[k] = jnp.mean(bm.as_jax(bm.asarray(v))) + aux[k] = np.mean(np.asarray(v)) if k not in report_train_metric: report_train_metric[k] = [] detailed_train_metric[k] = [] @@ -420,7 +420,7 @@ def fit( test_t1 = time.time() aux = {} for k, v in test_epoch_metric.items(): - aux[k] = jnp.mean(bm.as_jax(bm.asarray(v))) + aux[k] = np.mean(np.asarray(v)) if k not in report_test_metric: report_test_metric[k] = [] detailed_test_metric[k] = [] diff --git a/pyproject.toml b/pyproject.toml index 0897e574..f11afbef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,6 @@ requires-python = ">=3.11" authors = [ { name = "BrainPy Team", email = "chao.brain@qq.com" } ] -license = { text = "GPL-3.0 license" } classifiers = [ "Natural Language :: English", "Operating System :: OS Independent", diff --git a/requirements.txt b/requirements.txt index b8e3db57..09fbf6cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -numpy +numpy>=1.15 brainunit brainevent>=0.0.7 -braintools>=0.1.0 +braintools>=0.0.9 brainstate>=0.2.7 -brainpy_state>=0.0.2 +brainpy_state>=0.0.3 jax tqdm