netket: Error with GCNN

Hi,

I was trying to reproduce the (very nice!) tutorial on GCNNs locally, with the following

import netket as nk
import numpy as np
import jax

# Basis Vectors that define the positioning of the unit cell
basis_vectors = [[0, 1], [np.sqrt(3) / 2, -1 / 2]]

# Locations of atoms within the unit cell
atom_positions = [[0, 0], [np.sqrt(3) / 6, 1 / 2]]

# Number of unit cells in each direction
dimensions = [3, 3]

# Define the graph
graph = nk.graph.Lattice(
    basis_vectors=basis_vectors, atoms_coord=atom_positions, extent=dimensions
)

# Lattice translation operations
symmetries = graph.automorphisms()

# Hilbert space of spins on the graph
hi = nk.hilbert.Spin(s=1 / 2, N=graph.n_nodes, total_sz=0)

ha = nk.operator.Heisenberg(hilbert=hi, graph=graph)

# Feature dimensions of hidden layers, from first to last
feature_dims = (8, 8, 8, 8)

# Number of layers
num_layers = 4

# Define the GCNN
ma = nk.models.GCNN(symmetries=symmetries, layers=num_layers, features=feature_dims)

# Metropolis Local Sampling
sa = nk.sampler.MetropolisExchange(hi, graph=graph, n_chains=16)

# Optimizer
op = nk.optim.Sgd(learning_rate=0.01)
sr = nk.optim.SR(0.1)

# Variational monte carlo driver
gs = nk.VMC(ha, op, sa, ma, n_samples=1000, n_discard=100, sr=sr)

# Run the optimization for 300 iterations
gs.run(n_iter=300, out="test")

but with the latest master version I get the following error

Traceback (most recent call last):
  File "/Users/giuscarl/Works/NetKet/netket/Examples/Heisenberg1d/heis_gconv.py", line 58, in <module>
    gs = nk.VMC(ha, op, sa, ma, n_samples=1000, n_discard=100, sr=sr)
  File "/Users/giuscarl/Works/NetKet/netket/netket/driver/vmc.py", line 70, in __init__
    variational_state = MCState(*args, **kwargs)
  File "/Users/giuscarl/Works/NetKet/netket/netket/variational/mc_state.py", line 192, in __init__
    self.init(seed, dtype=sampler.dtype)
  File "/Users/giuscarl/Works/NetKet/netket/netket/variational/mc_state.py", line 221, in init
    variables = self._init_fun({"params": key}, dummy_input)
  File "/Users/giuscarl/Works/NetKet/netket/netket/variational/mc_state.py", line 160, in <lambda>
    lambda model, *args, **kwargs: model.init(*args, **kwargs), model
  File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 884, in init
    _, v_out = self.init_with_output(rngs, *args, method=method, **kwargs)
  File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 862, in init_with_output
    return self.apply(
  File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 841, in apply
    return apply(fn, mutable=mutable)(variables, rngs=rngs)
  File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/core/scope.py", line 608, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 834, in <lambda>
    fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs)
  File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/Users/giuscarl/Works/NetKet/netket/netket/models/equivariant.py", line 137, in __call__
    x = self.equivariant_layers[layer](x)
  File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/Users/giuscarl/Works/NetKet/netket/netket/nn/linear.py", line 418, in __call__
    kernel = self.full_kernel(kernel)
  File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/Users/giuscarl/Works/NetKet/netket/netket/nn/linear.py", line 388, in full_kernel
    result = result.transpose(2, 0, 3, 1).reshape(
TypeError: transpose() takes from 1 to 2 positional arguments but 5 were given

not sure this is being addressed already in one of the open PRs about GCNNs and related?

In any case, I think that this means that we do not have enough tests to cover GCNNs, since AFAIK all tests are passing @chrisrothUT

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 16 (9 by maintainers)

Most upvoted comments

I would still change it, since in principle we are also supporting that “old” version of jax…