Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions brainpy/dnn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def update(self, x):
if self.mask is not None:
try:
lax.broadcast_shapes(self.w.shape, self.mask.shape)
except:
except (ValueError, TypeError):
raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}")
w = w * self.mask
y = lax.conv_general_dilated(lhs=bm.as_jax(x),
Expand Down Expand Up @@ -566,7 +566,7 @@ def update(self, x):
if self.mask is not None:
try:
lax.broadcast_shapes(self.w.shape, self.mask.shape)
except:
except (ValueError, TypeError):
raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}")
w = w * self.mask
y = lax.conv_transpose(lhs=bm.as_jax(x),
Expand Down
2 changes: 1 addition & 1 deletion brainpy/dnn/interoperation_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
try:
import flax # noqa
from flax.linen.recurrent import RNNCellBase
except:
except (ImportError, ModuleNotFoundError):
flax = None
RNNCellBase = object

Expand Down
24 changes: 12 additions & 12 deletions brainpy/dyn/neurons/lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")

else:
spike = V >= self.V_th
Expand Down Expand Up @@ -509,7 +509,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike_no_grad
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
if self.ref_var:
Expand Down Expand Up @@ -785,7 +785,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")

else:
spike = V >= self.V_th
Expand Down Expand Up @@ -1142,7 +1142,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike_no_grad
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
if self.ref_var:
Expand Down Expand Up @@ -1497,7 +1497,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
w += self.b * spike

else:
Expand Down Expand Up @@ -1843,7 +1843,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike_no_grad
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
w += self.b * spike_no_grad
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
Expand Down Expand Up @@ -2142,7 +2142,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")

else:
spike = V >= self.V_th
Expand Down Expand Up @@ -2431,7 +2431,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike_no_grad
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
if self.ref_var:
Expand Down Expand Up @@ -2734,7 +2734,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
w += self.b * spike

else:
Expand Down Expand Up @@ -3054,7 +3054,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike_no_grad
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
w += self.b * spike_no_grad
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
Expand Down Expand Up @@ -3417,7 +3417,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
I1 += spike * (self.R1 * I1 + self.A1 - I1)
I2 += spike * (self.R2 * I2 + self.A2 - I2)
V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike
Expand Down Expand Up @@ -3810,7 +3810,7 @@ def update(self, x=None):
elif self.spk_reset == 'hard':
V += (self.V_reset - V) * spike_no_grad
else:
raise ValueError
raise ValueError(f"Unknown spk_reset mode: {self.spk_reset}. Must be 'soft' or 'hard'.")
I1 += spike * (self.R1 * I1 + self.A1 - I1)
I2 += spike * (self.R2 * I2 + self.A2 - I2)
V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad
Expand Down
6 changes: 3 additions & 3 deletions brainpy/dyn/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def update(self, x):

# update synaptic variables
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
self.h += self.a * x
self.h.value = self.h.value + self.a * x
return self.g.value

def return_info(self):
Expand Down Expand Up @@ -552,7 +552,7 @@ def dg(self, g, t, h):
def update(self, x):
# update synaptic variables
self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
self.h += x
self.h.value = self.h.value + x
return self.g.value

def return_info(self):
Expand Down Expand Up @@ -737,7 +737,7 @@ def update(self, pre_spike):
t = share.load('t')
dt = share.load('dt')
self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, dt=dt)
self.x += pre_spike
self.x.value = self.x.value + pre_spike
return self.g.value

def return_info(self):
Expand Down
13 changes: 10 additions & 3 deletions brainpy/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,11 +966,18 @@ def _slice_to_num(slice_: slice, length: int):
step = slice_.step
if step is None:
step = 1
if step == 0:
raise ValueError("slice step cannot be zero")
# number
num = 0
while start < stop:
start += step
num += 1
if step > 0:
while start < stop:
start += step
num += 1
else:
while start > stop:
start += step
num += 1
return num


Expand Down
3 changes: 3 additions & 0 deletions brainpy/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs):
missing_keys = []
unexpected_keys = []
for name, node in nodes.items():
if name not in state_dict:
missing_keys.append(name)
continue
r = node.load_state(state_dict[name], **kwargs)
if r is not None:
missing, unexpected = r
Expand Down
3 changes: 2 additions & 1 deletion brainpy/integrators/ode/adaptive_rk.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@

import jax.numpy as jnp

from brainpy import _errors as errors
from brainpy.integrators import constants as C, utils
from brainpy.integrators.ode import common
from brainpy.integrators.ode.base import ODEIntegrator
Expand Down Expand Up @@ -456,7 +457,7 @@ class BogackiShampine(AdaptiveRKIntegrator):
A = [(),
(0.5,),
(0., 0.75),
('2/9', '1/3', '4/0'), ]
('2/9', '1/3', '4/9'), ]
B1 = ['2/9', '1/3', '4/9', 0]
B2 = ['7/24', 0.25, '1/3', 0.125]
C = [0, 0.5, 0.75, 1]
Expand Down
2 changes: 1 addition & 1 deletion brainpy/integrators/ode/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def set_default_odeint(method):
raise ValueError(f'Unsupported ODE_INT numerical method: {method}.')

global _DEFAULT_DDE_METHOD
_DEFAULT_ODE_METHOD = method
_DEFAULT_DDE_METHOD = method
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Default ODE integrator is no longer updated; code now writes to _DEFAULT_DDE_METHOD twice instead of _DEFAULT_ODE_METHOD.

This line should continue assigning to _DEFAULT_ODE_METHOD; _DEFAULT_DDE_METHOD is already set earlier in the function.



def get_default_odeint():
Expand Down
2 changes: 1 addition & 1 deletion brainpy/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def value(self):

@value.setter
def value(self, value):
self_value = self._check_tracer()
self_value = self._value

if isinstance(value, Array):
value = value.value
Expand Down
6 changes: 3 additions & 3 deletions brainpy/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,16 +773,16 @@ def update(self, *args, **kwargs) -> 'NodeDict':
self[k] = v
elif isinstance(arg, tuple):
assert len(arg) == 2
self[arg[0]] = args[1]
self[arg[0]] = arg[1]
for k, v in kwargs.items():
self[k] = v
return self

def __setitem__(self, key, value) -> 'NodeDict':
if self.check_unique:
exist = self.get(key, None)
if id(exist) != id(value):
raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {value}.')
if exist is not None and id(exist) != id(value):
raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {exist}.')
super().__setitem__(key, value)
return self

Expand Down
4 changes: 2 additions & 2 deletions brainpy/math/object_transform/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def size_without_batch(self):
return self.size
else:
sizes = self.size
return sizes[:self.batch_size] + sizes[self.batch_axis + 1:]
return sizes[:self.batch_axis] + sizes[self.batch_axis + 1:]

@property
def batch_axis(self) -> Optional[int]:
Expand Down Expand Up @@ -390,7 +390,7 @@ def update(self, *args, **kwargs) -> 'VarDict':
self[k] = v
elif isinstance(arg, tuple):
assert len(arg) == 2
self[arg[0]] = args[1]
self[arg[0]] = arg[1]
for k, v in kwargs.items():
self[k] = v
return self
Expand Down
8 changes: 4 additions & 4 deletions brainpy/train/back_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __init__(
'''
)
if seed is not None:
NoLongerSupportError('"seed" is no longer supported. '
'Please shuffle your data by yourself.')
raise NoLongerSupportError('"seed" is no longer supported. '
'Please shuffle your data by yourself.')

# jit settings
if isinstance(self._origin_jit, bool):
Expand Down Expand Up @@ -317,7 +317,7 @@ def fit(
fit_t1 = time.time()
aux = {}
for k, v in fit_epoch_metric.items():
aux[k] = jnp.mean(bm.as_jax(bm.asarray(v)))
aux[k] = np.mean(np.asarray(v))
if k not in report_train_metric:
report_train_metric[k] = []
detailed_train_metric[k] = []
Expand Down Expand Up @@ -420,7 +420,7 @@ def fit(
test_t1 = time.time()
aux = {}
for k, v in test_epoch_metric.items():
aux[k] = jnp.mean(bm.as_jax(bm.asarray(v)))
aux[k] = np.mean(np.asarray(v))
if k not in report_test_metric:
report_test_metric[k] = []
detailed_test_metric[k] = []
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ build-backend = "setuptools.build_meta"
name = "brainpy"
description = "BrainPy: Brain Dynamics Programming in Python"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
authors = [
{ name = "BrainPy Team", email = "chao.brain@qq.com" }
]
license = { text = "GPL-3.0 license" }
classifiers = [
"Natural Language :: English",
"Operating System :: OS Independent",
Expand Down Expand Up @@ -39,7 +38,7 @@ dependencies = [
"jax",
"tqdm",
"brainstate>=0.2.7",
"brainunit",
"brainunit>=0.2.0",
"brainevent>=0.0.7",
"braintools>=0.0.9",
'brainpy_state>=0.0.3',
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
numpy
numpy>=1.15
brainunit
brainevent>=0.0.7
braintools>=0.1.0
braintools>=0.0.9
brainstate>=0.2.7
brainpy_state>=0.0.2
brainpy_state>=0.0.3
jax
tqdm
Loading