AI

Single-unit activations confer inductive biases for emergent circuit solutions to cognitive tasks

RNN architectures and training procedure

For each of the six architectures (  ), we trained 100 fully connected RNNs, each with N = 100 units, to solve cognitive tasks. The RNN dynamics are described by the equation

$$\tau \dot=- +f\;(_ +_{{\rm }}{\bf }),$$

(1)

where f is the activation function (ReLU, sigmoid or tanh). The sigmoid function is defined as sigmoid(x) = 1/(1 + e−7.5x) with the slope 7.5. The ReLU function is defined as ReLU(x) = max(0, x), and tanh function is tanh(x) = (ex − ex)/(ex + ex).

The RNNs are trained by minimizing the loss function

$$\begin{array}{l}\,\text{Loss}\,=\langle \parallel {\bf{o}}[\,\text{mask}\,]-\hat{{\bf{o}}}[\,\text{mask}\,]{\parallel }_{2}^{2}\rangle +{\lambda }_{r}\langle | | {\bf{y}}| {| }_{2}^{2}\rangle \\\qquad\quad+{\lambda }_{\perp }\langle \parallel {W}_{\,\text{inp}}^{T}{W}_{{\rm{inp}}}-\text{diag}({W}_{{\rm{inp}}}^{T}{W}_{{\rm{inp}}}){\parallel }_{2}^{2}\rangle.\end{array}$$

(2)

We initialize the RNN connectivity matrices as described previously3. In networks without Dale’s constraint, the elements of the recurrent connectivity matrix were sampled from a Gaussian distribution \({W}_{ij}^{{\prime} } \sim N(\mu ,{\sigma }^{2})\) with \(\mu =1/\sqrt{N}\), σ = 1/N. The spectral radius of the recurrent connectivity was then adjusted using the formula \({W}_{{\rm{rec}}}=\frac{\text{s.r.}}{\mathop{\max }\nolimits_{k}| {\lambda }_{k}| }{W}_{\text{rec}\,}^{{\prime} }\), where the new spectral radius s.r. = 1.2, and \(\mathop{\max }\nolimits_{k}| {\lambda }_{k}|\) is the eigenvalue of \({W}_{\,\text{rec}\,}^{{\prime} }\) with the largest norm.

For networks with Dale’s constraint, the weights were sampled differently for the excitatory or inhibitory units. We sampled excitatory weights as the absolute values of random variables drawn from a normal distribution \(N({\mu }_{E},{\sigma }_{E}^{2})\) with \({\mu }_{E}=1/\sqrt{N}\), σE = 1/N. Inhibitory weights were sampled as the negative absolute values of random variables from \(N({\mu }_{I},{\sigma }_{I}^{2})\), with \({\mu }_{I}={R}_{E/I}/\sqrt{N}\), σI = 1/N, where RE/I is the ratio of the number of excitatory and inhibitory neurons. We used RE/I = 4 for ReLU and sigmoid RNNs and RE/I = 1 for Dale-constrained tanh RNNs. We adjusted the spectral radius of the recurrent connectivity matrix using the same procedure as for the networks without Dale’s constraint.

In all networks, the input Winp and output Wout connectivity matrices were initialized by sampling raw values from a Gaussian distribution N(μ, σ2), \(\mu =1/\sqrt{N}\), σ = 1/N and then taking the absolute value of the elements to enforce non-negativity. Regardless of whether Dale’s constraint was applied, the elements of Winp and Wout were constrained to remain non-negative throughout training.

All connectivity matrices (Winp, Wrec and Wout) were trained simultaneously using Adam optimizer in PyTorch, with the default hyperparameters: learning rate α = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 10−8. While training the networks with Dale’s constraint, if any element of these matrices switched signs, it was set to zero to ensure that none of the constraints were violated.

The RNN output was obtained by running the RNN’s dynamics forward for a given batch of inputs. We discretize the RNN dynamics using the first-order Euler scheme with a time-step dt = 1 ms and add a noise term in the discretized equation to obtain

$${{\bf{y}}}_{t+1}=(1-\gamma ){{\bf{y}}}_{t}+\gamma f\left({W}_{{\rm{rec}}}{{\bf{y}}}_{t}+{W}_{{\rm{inp}}}\left({{\bf{u}}}_{t}+\sqrt{2\gamma {\sigma }_{\,\text{inp}\,}^{2}}{{\mathbf{\zeta }}}_{t}\right)+\sqrt{2\gamma {\sigma }_{\,\text{rec}\,}^{2}}{{\mathbf{\xi }}}_{t}\right).$$

(3)

Here γ = dt/τ, and ξt and ζt are random vectors with elements sampled from the standard normal distribution N(0, 1). The hyperparameters used for RNN training are provided in Extended Data Table 3. RNNs were trained on the CDDM and Go/NoGo tasks with λr = 0.5 for niter = 5,000 iterations. RNNs were trained on the memory number task first with λr = 0 for niter = 6,000 and then with λr = 0.3 for additional niter = 6,000. The code for RNN training is available as trainRNNbrain package via GitHub at https://github.com/engellab/trainRNNbrain (ref. 46).

CDDM task

The task structure is presented in Fig. 1a. Two mutually exclusive context channels signal either ‘motion’ or ‘colour’ context. For a given context, a constant input with an amplitude of 1 is supplied through the corresponding channel for the entire trial duration. Sensory stimuli with two modalities (‘motion’ and ‘colour’) are each supplied through two corresponding input channels, encoding momentary evidence for choosing either the right or left response. Within each sensory modality, the mean difference between inputs on two channels represents the stimulus coherence, with values ranging from −1 to +1. During training, we used a discrete set of 15 coherences for each sensory modality: c = {0, ±0.01, ±0.03, ±0.06, ±0.13, ±0.25, ±0.5, ±1}. The coherence c was translated into two sensory inputs as [(1 + c)/2, (1 − c)/2]. For 300 time steps on a trial, the (6, 300)-dimensional input-stream array was calculated based on the triplet (binary context, motion coherence and colour coherence), generating Nbatch = 2 × 15 × 15 = 450 distinct trial conditions.

On each trial, the target output was set to 0 for each time step t < 100 ms. During the decision period t > 200 ms, the target was set as follows: if the relevant coherence (for example, coherence of ‘motion’ stimuli on a ‘motion’ context trial) was positive, the target for ‘right’ output channel was set to 1 from 200 ms onwards. If the relevant coherence was negative, the target for ‘left’ output channel was set to 1 instead. If the relevant coherence was 0, both output targets were set to 0. The target was specified for only a subset of time steps, forming a training mask (0 − 100) and (200 − 300) ms: enforcing no decision output before stimulus onset (0 − 100) ms and allowing the network to integrate stimulus without penalty before decision commitment during (200 − 300) ms.

Go/NoGo and memory number tasks

The structure of these tasks is presented in Fig. 1b,c. For both tasks, we used 11 uniformly spaced input values \({\mathcal{I}}\), ranging from 0 to 1, delivered through the first input channel. The ‘Go Cue’ input is delivered through the second channel and activated only at time tGoCue at the end of the trial, signalling that the RNN is required to respond. Finally, a constant bias input with an amplitude of 1 is supplied via the third channel throughout the entire trial duration. In the Go/NoGo task, the input value \({\mathcal{I}}\) was provided for the entire trial duration of 60 ms. The target output is determined as 0 before and \(\Theta ({\mathcal{I}}-0.5)\) after the Go Cue onset, where Θ is the Heaviside step function (Fig. 1b). If the input value was exactly 0.5, the network was required to output 0.5 after the Go Cue. In the memory number task, the input value \({\mathcal{I}}\) was present only for 10 ms, with the randomized stimulus onset time tstimU(0, 20) ms (Fig. 1c). The target output value was set to 0 before the Go Cue and the input value \({\mathcal{I}}\) afterwards. The onset of the Go Cue was set to tGoCue = 30 ms for the Go/NoGo task and tGoCue = 70 ms for the memory number task.

RNNs with shuffled connectivity

For each of the analysed RNNs, we produced another RNN with shuffled connectivity as a control. To shuffle the connectivity, we randomly permute each row Ri in the input matrix Winp (ith row contains all inputs to unit i). We also randomly permute non-diagonal elements of each column in the recurrent matrix Wrec (ith column contains all outputs of unit i). We keep the diagonal elements in Wrec unchanged to preserve self-excitation of each unit.

Analysis of population trajectories

We analysed 50 RNNs with the best task performance from each architecture. We simulated each RNN (including the corresponding control RNNs) to acquire a tensor of neural responses Z with dimensionality (N, T, K), where N is the number of units in the network, T is the number of time steps in a trial, and K is the number of trials. We reshape the neural response tensor Z to obtain a matrix X with dimensionality (N, TK). We then obtain a denoised matrix F with dimensionality (nPC, TK) by projecting matrix X onto the first nPC = 10 PCs along the first dimension, capturing more than 93% of variance in each instance across all RNNs and tasks. Reshaping matrix F back into a three-dimensional tensor, we obtain a denoised tensor \(\hat{Z}\) with dimensionality (nPC, T, K) containing reduced population trajectories. We further normalized the reduced trajectory tensor \(\hat{Z}\) by its variance, so that the reduced trajectory tensors have the same scale across all RNNs.

To obtain an MDS embedding of the reduced trajectories, we compute a distance matrix between reduced trajectory tensors \(\hat{{Z}_{i}}\) and \(\hat{{Z}_{j}}\) for each pair of RNNs i and j. First, we obtain the optimal linear transformation between the matrices Fi and Fj corresponding to \(\hat{{Z}_{i}}\) and \(\hat{{Z}_{j}}\) using linear least squares regression with the function numpy.linalg.lstsq in python. We perform two regression analyses: first regressing Fi onto Fj and then Fj onto Fi, resulting in two linear transformations Mij and Mji, and two scores, score1 = FiMij − Fj2 and score2 = FjMji − Fi2. We then compute the distance between two trajectory tensors as the average of two scores: dij = dji = (score1 + score2)/2. We use these pairwise distances to compute MDS embedding with the function sklearn.manifold.MDS from sklearn package in python.

Analysis of single-unit selectivity

For each RNN (including the control RNNs), we start with the same neural response tensor Z as for the analysis of population trajectories. We reshape Z to obtain matrix X with dimensionality (N, TK). We then obtain a denoised matrix G with dimensionality (N, nPC) by projecting matrix X onto the first nPC = 10 PCs along the second dimension, capturing more than 90% of variance in each instance across all RNNs and tasks. We further normalize the resulting single-unit selectivity matrix G by its variance, so that single-unit selectivity matrices have the same scale across all RNNs.

To obtain an MDS embedding, we compute a distance matrix between the single-unit selectivity matrices Gi and Gj for each pair of RNNs i and j. To compute the distance between Gi and Gj, we view each RNN unit as a point in nPC-dimensional selectivity space. We then register the point configurations of two RNNs with an optimal orthogonal transformation that permits one-to-many mapping. To register the points, we use ICP registration algorithm (‘ICP registration’ section). Since there is no one-to-one correspondence between units in two RNNs, we perform the ICP registration two times: registering Gi to Gj and then Gj to Gi, producing score1 and score2. We then set the distances dij = dji = (score1 + score2)/2. Since the ICP registration often converges to local minima, to register each pair of point clouds we run the registration procedure 60 times to ensure higher probability of accurate estimate of the distance between the two point clouds. We take the best result, corresponding to the minimal point cloud mismatch.

Fixed-point finder

To find fixed points of an RNN, we use a custom fixed-point finder algorithm. For each constant input u, we search for fixed points by minimizing the right-hand side in equation (1), F(y, u) = −y + f(Wrecy + Winpu) with scipy.optimize.fsolve function from scipy.optimize package in python. We accept point y* as a fixed point if \(\parallel F({{\bf{y}}}^{* },{\bf{u}}){\parallel }_{2}^{2}\leqslant 1{0}^{-12}\). The fsolve function also takes the Jacobian matrix J(y, u) = ∂F(y, u)/∂y of the RNN as an additional argument to enhance the efficiency of the optimization process. We initialize the minimization at a value y0 sampled randomly from the RNN trajectories: we choose a random trajectory k from K trials, and then a random time-step t from the interval (nt/2, nt), that is, from the second half of the trial. We then add Gaussian noise ξN(0, 0.01) to each coordinate of the sampled point to obtain the initial condition y0.

To find multiple fixed points for the same input u, we search for fixed points starting from multiple initial conditions within an iterative loop. On each iteration of this loop, we sample a new initial condition and perform the minimization to find a fixed point. We then compare this newly found fixed point \({{\bf{y}}}_{{\rm{new}}}^{* }\) to all previously found fixed points \({{\bf{y}}}_{{\rm{old}}}^{* }\). If the distance \(\parallel {{\bf{y}}}_{{\rm{new}}}^{* }-{{\bf{y}}}_{{\rm{old}}}^{* }{\parallel }_{2}\leqslant 1{0}^{-7}\), then we discard the new fixed point because it lies too close to one of the previously found fixed points. This iterative loop continues until either 100 distinct fixed points were found in total or no new fixed points were found for 100 consecutive iterations.

We determine the fixed-point type (stable or unstable) by computing the principal eigenvalue λ0 of the Jacobian J(y, u) evaluated at the fixed point. We classify the fixed point as stable if \({\mathbb{R}}e({\lambda }_{0})\leqslant 0\) and otherwise as unstable.

Analysis of fixed points

For each RNN (including the control RNNs), we computed fixed points for each combination of input stimuli using a custom fixed-point finder algorithm (‘Fixed-point finder’ section), obtaining a fixed-point configuration, which is a set of stable and unstable fixed points for different combinations of inputs. We collect the coordinates of all fixed points in a matrix P with dimensions (Np, N), where Np is the total number of fixed points (both stable and unstable) across all the inputs and N is the number of units. We reduce the second dimension of the matrix P by projecting the fixed points onto the first nPC = 7 PCs. We further normalized the resulting matrix by its variance, so that these fixed-point configurations have the same scale across all RNNs, obtaining a matrix \(\hat{P}\) for each RNN. Throughout the transformations, we keep each fixed point tagged by its type and the corresponding input for which it was computed.

To obtain an MDS embedding, we compute a distance matrix between fixed-point configurations \({\hat{P}}_{i}\) and \({\hat{P}}_{j}\) for each pair of RNNs i and j. To compute the distances between the two projected fixed-point configurations \({\hat{P}}_{i}\) and \({\hat{P}}_{j}\), we compute an optimal orthogonal transformation between the two sets of projected fixed points using orthogonal Procrustes with ICP registration (‘ICP registration’ section). When matching the projected fixed points, we restricted matches to the fixed points with the same tag (of the same type and obtained for the same input). We perform the ICP registration two times, registering \({\hat{P}}_{i}\) to \({\hat{P}}_{j}\) and then \({\hat{P}}_{j}\) to \({\hat{P}}_{i}\), resulting in two scores score1 and score2. We then set the distances dij = dji = (score1 + score2)/2. Using the distance matrix, we then obtain MDS embedding. To register each pair of point clouds, we run the registration procedure ntries = 60 times and then take the result corresponding to the minimal fixed-point cloud mismatch.

Analysis of trajectory endpoint configurations

For each RNN (including the control RNNs), we use the same neural response tensor Z as for the analysis of population trajectories. We then restrict the data to the last time step of each trial, resulting in (K, N) dimensional matrix S for each RNN containing the trajectory endpoint configuration. We further project the trial endpoints in S onto first nPC = 10 PCs, obtaining (K, nPC)-dimensional matrix \(\hat{S}\). Finally, we normalize each trajectory endpoint configuration matrix \(\hat{S}\) by its variance, so that these endpoint configurations have the same scale across all RNNs. We compute the distance between two matrices \({\hat{S}}_{i}\) and \({\hat{S}}_{j}\) for RNNs i and j using the same procedure as for the population trajectory matrices F (‘Analysis of population trajectories’ section). Using the distance matrix, we then obtain MDS embedding.

ICP registration

To register the point clouds (‘Analysis of single-unit selectivity’ and ‘Analysis of fixed points’ sections), we use an ICP algorithm, which proceeds in four steps:

  1. 1.

    Initialization: define a random orthogonal matrix A that transforms each point of the source point cloud Psource into Psource A.

  2. 2.

    Point matching: For each point in the target point cloud Ptarget, find the closest point in the transformed source point cloud Psource A. Construct a new matrix \({\hat{P}}_{{\rm{source}}}\) where the ith point is the point from Psource A closest to the ith point in Ptarget (points in \({\hat{P}}_{{\rm{source}}}\) may repeat).

  3. 3.

    Transformation update: update the transformation matrix A to minimize the distance between \({\hat{P}}_{{\rm{source}}}\) and Ptarget using the orthogonal Procrustes method.

  4. 4.

    Iteration: repeat steps 2 and 3 until convergence.

This algorithm iteratively refines the transformation to achieve optimal alignment between the source and target point clouds. Since this optimization is non-convex, it may converge to a local optimum. Therefore, we perform each optimization for ntries = 60 starting with random initializations and keep the solution with the minimal mean squared error as the distance between the source and target point clouds.

To compute distances between the fixed-point configurations (‘Analysis of fixed points’ section), we modify the point matching step by restricting possible matches only to the points obtained for the same inputs and of the same type (stable or unstable).

The code for the RNN analyses and the relevant datasets are available via GitHub at https://github.com/engellab/ActivationMattersRNN (ref. 47).

Latent circuit inference

To identify the circuit mechanism supporting the CDDM task execution in an RNN, we fit its responses and task behaviour with the latent circuit model12. We model RNN responses y as a linear embedding of dynamics x generated by a low-dimensional RNN

$$\tau \dot{{\bf{x}}}=-{\bf{x}}+f\;({w}_{{\rm{rec}}}{\bf{x}}+{w}_{{\rm{inp}}}{\bf{u}}),$$

(4)

which we refer to as the latent circuit. Here f is the activation function matching the activation function of the RNN. We also require the latent circuit to reproduce task behaviour via the output connectivity woutx.

To fit the latent circuit model, we first sample RNN trajectories Z, forming a (N, T, K)-dimensional tensor. We then reduce the dimensionality of Z using PCA to NPC = 30, resulting in a tensor z with (NPC, T, K) dimensions, capturing more than 99% of variance in Z for all RNNs we analysed. We then infer the latent circuit parameters wrec, winp, wout and an orthonormal embedding matrix Q by minimizing the loss function

$$\,\text{Loss}\,=\langle \parallel {\bf{o}}-\hat{{\bf{o}}}{\parallel }_{2}^{2}\rangle +{\lambda }_{{\rm{emb}}}\langle \parallel Q{\bf{x}}-{\bf{z}}{\parallel }_{2}^{2}\rangle +{\lambda }_{w}\left(\langle | {w}_{{\rm{inp}}}{| }^{2}\rangle +\langle | {w}_{{\rm{rec}}}{| }^{2}\rangle +\langle | {w}_{{\rm{out}}}{| }^{2}\rangle \right)$$

(5)

Here, 〈  〉 denotes the mean over all dimensions of a tensor. Tensor x has the dimensionality (n, T, K), where n is the number of nodes in the latent circuit. This tensor x contains the activity of the latent circuit across K trials and T time steps per trial, and y is the corresponding activity tensor for the RNN. The (NPC, n) dimensional orthonormal matrix Q embeds trajectories of the latent circuit x to match the RNN activity z, such that z ≈ Qx. Finally, o is the target circuit output, and the \(\hat{{\bf{o}}}={w}_{{\rm{out}}}\bf{x}\) is the output produced by the latent circuit.

During optimization, we constrain the input matrix such that each input channel is connected to at most one latent node. To this end, we apply to the input matrix a mask, in which 1 indicates that the weight is allowed to change during training, and 0 indicates that the weight is fixed at 0. We design the mask such that each column has a single 1. Moreover, we constrain the elements of winp and wout matrices to be non-negative.

We fitted latent circuit models to the ten RNNs with the best CDDM task performance from each architecture. For each RNN, we fit 8-node latent circuit model ≥30 times starting with random initializations and take the best-fitting circuit as a converged solution. The hyperparameters for the latent circuit fitting are provided in Extended Data Table 4. The code for latent circuit fitting is available via GitHub at https://github.com/engellab/latent_circuit_inference (ref. 48).

Alignment of RNN dynamics with the output subspace

The norm of the readout matrix can affect the dynamics that emerge in RNNs through training44. In RNNs initialized with large readout norms, the network dynamics evolved in a subspace distinct from the output subspace spanned by the rows of the readout matrix44. The angle between the dynamics and output subspaces was large, and such dynamics were termed oblique. By contrast, in RNNs initialized with small readout norms, the angle between dynamics and readout subspaces was relatively small, and such dynamics were termed aligned.

In our networks, the weights of the output matrix were initialized with σ = 1/N, corresponding to the small readout norm, associated with aligned dynamics44. To quantify whether the resulting dynamics in our networks were aligned or oblique, we computed a generalized correlation measure ρ (ref. 44), for the 50 top-performing networks of each architecture, during the epochs when RNNs were required to produce output. The generalized correlation measure \(\rho =\frac{\parallel {W}_{\,\text{out}}^{T}X{\parallel }_{{\rm{F}}}}{\parallel {W}_{{\rm{out}}}{\parallel }_{{\rm{F}}}\parallel X{\parallel }_{{\rm{F}}}}\), where X is the (N, Tout, K) tensor with population activity of N units during the task-epochs at which the networks were required to produce outputs (Tout time steps in total) in K trials; F refers to Frobenius norm.

We found that the dynamics in our networks lie along a continuum: neither fully aligned with the readout subspace nor strongly oblique (Extended Data Table 2). In addition, the generalized correlation measure ρ was both task and architecture dependent. The dynamics were most aligned with the output subspace for ReLU networks trained on the Go/NoGo task. Furthermore, tanh networks tended to produce more oblique dynamics than sigmoid and ReLU RNNs. Since the initialization procedure and noise magnitude for inputs and recurrence were the same for all networks, this result further supports the conclusion that tanh networks rely on dynamics distinct from those of ReLU and sigmoid RNNs.

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.

Don’t miss more hot News like this! AI/" target="_blank" rel="noopener">Click here to discover the latest in AI news!

2025-10-20 00:00:00

Related Articles

Back to top button