netket: Error with GCNN
Hi,
I was trying to reproduce the (very nice!) tutorial on GCNNs locally, with the following
import netket as nk
import numpy as np
import jax
# Basis Vectors that define the positioning of the unit cell
basis_vectors = [[0, 1], [np.sqrt(3) / 2, -1 / 2]]
# Locations of atoms within the unit cell
atom_positions = [[0, 0], [np.sqrt(3) / 6, 1 / 2]]
# Number of unit cells in each direction
dimensions = [3, 3]
# Define the graph
graph = nk.graph.Lattice(
basis_vectors=basis_vectors, atoms_coord=atom_positions, extent=dimensions
)
# Lattice translation operations
symmetries = graph.automorphisms()
# Hilbert space of spins on the graph
hi = nk.hilbert.Spin(s=1 / 2, N=graph.n_nodes, total_sz=0)
ha = nk.operator.Heisenberg(hilbert=hi, graph=graph)
# Feature dimensions of hidden layers, from first to last
feature_dims = (8, 8, 8, 8)
# Number of layers
num_layers = 4
# Define the GCNN
ma = nk.models.GCNN(symmetries=symmetries, layers=num_layers, features=feature_dims)
# Metropolis Local Sampling
sa = nk.sampler.MetropolisExchange(hi, graph=graph, n_chains=16)
# Optimizer
op = nk.optim.Sgd(learning_rate=0.01)
sr = nk.optim.SR(0.1)
# Variational monte carlo driver
gs = nk.VMC(ha, op, sa, ma, n_samples=1000, n_discard=100, sr=sr)
# Run the optimization for 300 iterations
gs.run(n_iter=300, out="test")
but with the latest master version I get the following error
Traceback (most recent call last):
File "/Users/giuscarl/Works/NetKet/netket/Examples/Heisenberg1d/heis_gconv.py", line 58, in <module>
gs = nk.VMC(ha, op, sa, ma, n_samples=1000, n_discard=100, sr=sr)
File "/Users/giuscarl/Works/NetKet/netket/netket/driver/vmc.py", line 70, in __init__
variational_state = MCState(*args, **kwargs)
File "/Users/giuscarl/Works/NetKet/netket/netket/variational/mc_state.py", line 192, in __init__
self.init(seed, dtype=sampler.dtype)
File "/Users/giuscarl/Works/NetKet/netket/netket/variational/mc_state.py", line 221, in init
variables = self._init_fun({"params": key}, dummy_input)
File "/Users/giuscarl/Works/NetKet/netket/netket/variational/mc_state.py", line 160, in <lambda>
lambda model, *args, **kwargs: model.init(*args, **kwargs), model
File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 884, in init
_, v_out = self.init_with_output(rngs, *args, method=method, **kwargs)
File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 862, in init_with_output
return self.apply(
File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 841, in apply
return apply(fn, mutable=mutable)(variables, rngs=rngs)
File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/core/scope.py", line 608, in wrapper
y = fn(root, *args, **kwargs)
File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 834, in <lambda>
fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs)
File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/Users/giuscarl/Works/NetKet/netket/netket/models/equivariant.py", line 137, in __call__
x = self.equivariant_layers[layer](x)
File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/Users/giuscarl/Works/NetKet/netket/netket/nn/linear.py", line 418, in __call__
kernel = self.full_kernel(kernel)
File "/Users/giuscarl/miniconda3/envs/netket_env/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/Users/giuscarl/Works/NetKet/netket/netket/nn/linear.py", line 388, in full_kernel
result = result.transpose(2, 0, 3, 1).reshape(
TypeError: transpose() takes from 1 to 2 positional arguments but 5 were given
not sure this is being addressed already in one of the open PRs about GCNNs and related?
In any case, I think that this means that we do not have enough tests to cover GCNNs, since AFAIK all tests are passing @chrisrothUT
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 16 (9 by maintainers)
I would still change it, since in principle we are also supporting that “old” version of jax…