jax: Upgrading from 0.2.3 seems to break something with custom VJP
I tried updating to master as well, but that didn’t help.
I will go through the changelog for 0.2.4, but could someone point me to where something might have broken? I checked that the pytrees are in fact the same. (One pytree contains tracers, the other contains some object()
, but I think that’s a JAX implementation detail.)
$ python demos/encoding.py
[*] Inferring...
$ pip install jax==0.2.4
Processing /home/neil/.cache/pip/wheels/ce/f6/a8/2075fcce214c29511994904934727cae7c800b21f48524f673/jax-0.2.4-py3-none-any.whl
Requirement already satisfied: numpy>=1.12 in /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages (from jax==0.2.4) (1.20.0.dev0+0bd548e)
Requirement already satisfied: absl-py in /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages (from jax==0.2.4) (0.10.0)
Requirement already satisfied: opt-einsum in /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages (from jax==0.2.4) (3.3.0)
Requirement already satisfied: six in /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages (from absl-py->jax==0.2.4) (1.15.0)
Installing collected packages: jax
Attempting uninstall: jax
Found existing installation: jax 0.2.3
Uninstalling jax-0.2.3:
Successfully uninstalled jax-0.2.3
Successfully installed jax-0.2.4
$ python demos/encoding.py
[*] Inferring...
Traceback (most recent call last):
File "demos/encoding.py", line 143, in <module>
encoding_demo()
File "demos/encoding.py", line 109, in encoding_demo
training_pt = solution.train(2000)
File "/home/neil/src/cmm/cmm/structure/solution/solution.py", line 111, in train
augmented, trajectory = method(None,
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 213, in f_jitted
out = xla.xla_call(
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 1174, in bind
return call_bind(self, fun, *args, **params)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 1165, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 1177, in process
return trace.process_call(self, fun, tracers, params)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/core.py", line 576, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/xla.py", line 556, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 247, in memoized_fun
ans = call(fun, *args)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/xla.py", line 632, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1183, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1164, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/tjax/fixed_point/iterated_function.py", line 102, in sample_trajectory
return scan(f, self.initial_augmented(initial_state), None, iteration_limit)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1251, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1238, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 72, in _initial_style_jaxpr
jaxpr, out_avals, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 67, in _initial_style_open_jaxpr
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1154, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1164, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/tjax/fixed_point/iterated_function.py", line 97, in f
new_state, trajectory = self.sampled_state_trajectory(theta, augmented)
File "/home/neil/src/cmm/cmm/structure/solution/runner.py", line 66, in sampled_state_trajectory
return self._sampled_state_trajectory(theta, augmented.current_state)
File "/home/neil/src/cmm/cmm/structure/solution/runner.py", line 90, in _sampled_state_trajectory
rl_result = self.rl_inference.infer(state.parameter_states, state.rng)
File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 92, in infer
rl_state = while_loop(cond_fun, self._body_fun, rl_state)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 286, in while_loop
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 272, in _create_jaxpr
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 72, in _initial_style_jaxpr
jaxpr, out_avals, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 67, in _initial_style_open_jaxpr
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1154, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1164, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/neil/src/cmm/cmm/structure/rl/inference.py", line 157, in _body_fun
weights_bar, primals = f(rl_state.parameter_states.weights)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 756, in grad_f_aux
(_, aux), g = value_and_grad_f(*args, **kwargs)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 819, in value_and_grad_f
g = vjp_py(np.ones((), dtype=dtype))
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 1791, in _vjp_pullback_wrapper
ans = fun(*args)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 120, in unbound_vjp
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 220, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 634, in _custom_lin_transpose
cts_in = bwd.call_wrapped(*res, *cts_out)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/tjax/shims.py", line 114, in new_bwd
input_bar = bwd(*static_args, internal_residuals, output_bar)
File "/home/neil/src/cmm/cmm/encoding/inference.py", line 171, in _infer_encoding_configuration_bwd
observation_bars, weights_bars = vmapped_f_vjp(encoding_regularizers_bars)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 1230, in batched_fun
out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/batching.py", line 36, in batch
return batched_fun.call_wrapped(*in_vals)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 156, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/api.py", line 1791, in _vjp_pullback_wrapper
ans = fun(*args)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 120, in unbound_vjp
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 220, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/interpreters/ad.py", line 634, in _custom_lin_transpose
cts_in = bwd.call_wrapped(*res, *cts_out)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/linear_util.py", line 169, in call_wrapped
ans = gen.send(ans)
File "/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/jax/custom_derivatives.py", line 544, in _flatten_bwd
raise TypeError(msg.format(in_tree2, in_tree)) from None
TypeError: Custom VJP rule must produce an output with the same container (pytree) structure as the args tuple of the primal function, and in particular must produce a tuple of length equal to the number of arguments to the primal function, but got VJP output structure PyTreeDef(tuple, [PyTreeDef(<class 'cmm.encoding.iterated_function.EncodingIteratedFunction'>[()], [*,*,*,*,*,*,PyTreeDef(<class 'cmm.encoding.element.EncodingElement'>[(Path(('module', 'observation.x', 'encoding')),)], [PyTreeDef(<class 'cmm.pss.space.exponential_family.ExpFamSpace'>[(Path(('module', 'observation.x', 'encoding', 'space')), NormalUnitVariance(shape=(), num_parameters=5))], []),PyTreeDef(<class 'cmm.encoding.cluster.CodeCluster'>[()], [*]),PyTreeDef(<class 'cmm.encoding.cluster.PresenceScoreCluster'>[(Path(('module', 'observation.x', 'encoding', 'presence_score_cluster')),)], [*,PyTreeDef(<class 'cmm.structure.parameter.bias.Bias'>[(Path(('module', 'observation.x', 'encoding', 'presence_score_cluster', 'decay')), (1,), 'encoding_decay', True)], [])]),PyTreeDef(<class 'cmm.encoding.cluster.ValueCluster'>[()], [*]),PyTreeDef(<class 'cmm.encoding.cluster.MechanicalCluster'>[()], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'code_consumption_link')), (1, 4), 'encoding', True, OuterMatching(shape=[(1,), (4,)]), <PresenceViewNormalization.each_source_has_unit_output_weights: 1>)], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'value_consumption_link')), (1, 5), 'encoding', True, OuterMatching(shape=[(1,), (5,)]), <PresenceViewNormalization.each_source_has_unit_output_weights: 1>)], []),PyTreeDef(<class 'cmm.link.value_view_link.ValueViewLink'>[(Path(('module', 'observation.x', 'encoding', 'explanation_link')), (5, 4), 'encoding', False, OuterMatching(shape=[(5,), (4,)]))], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'code_gating_link')), (4, 1), 'encoding', True, OuterMatching(shape=[(4,), (1,)]), <PresenceViewNormalization.each_target_has_unit_input_weights: 0>)], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'value_gating_link')), (5, 1), 'encoding', True, OuterMatching(shape=[(5,), (1,)]), <PresenceViewNormalization.each_target_has_unit_input_weights: 0>)], []),PyTreeDef(<class 'cmm.link.value_view_link.ValueViewLink'>[(Path(('module', 'observation.x', 'encoding', 'demand_link')), (4, 5), 'encoding', False, OuterMatching(shape=[(4,), (5,)]))], []),*,*,PyTreeDef(<class 'cmm.structure.parameter.bias.Bias'>[(Path(('module', 'observation.x', 'encoding', 'natural_explanation_bias')), (5,), 'deduction', False)], [])])]),PyTreeDef(<class 'cmm.encoding.iterated_function.EncodingIteratedFunctionParameters'>[()], [PyTreeDef(<class 'cmm.pss.observation.Observation'>[()], [*,*]),*,PyTreeDef(<class 'cmm.structure.foundation.parallel.ParallelStructure'>[()], [PyTreeDef(dict[[Path(('module', 'observation.x', 'encoding', 'code_consumption_link')), Path(('module', 'observation.x', 'encoding', 'code_gating_link')), Path(('module', 'observation.x', 'encoding', 'demand_link')), Path(('module', 'observation.x', 'encoding', 'explanation_link')), Path(('module', 'observation.x', 'encoding', 'natural_explanation_bias')), Path(('module', 'observation.x', 'encoding', 'presence_score_cluster', 'decay')), Path(('module', 'observation.x', 'encoding', 'value_consumption_link')), Path(('module', 'observation.x', 'encoding', 'value_gating_link'))]], [*,*,*,*,*,*,*,*])])]),PyTreeDef(<class 'cmm.encoding.configuration.EncodingState'>[()], [PyTreeDef(<class 'cmm.encoding.configuration.EncodingDifferentiand'>[()], [PyTreeDef(<class 'cmm.encoding.configuration.EncodingComparand'>[()], [*,*,*,*,*,*,*,*]),*,*]),PyTreeDef(<class 'tjax.generator.Generator'>[()], [*])])]) for primal input structure PyTreeDef(tuple, [PyTreeDef(<class 'cmm.encoding.iterated_function.EncodingIteratedFunction'>[()], [*,*,*,*,*,*,PyTreeDef(<class 'cmm.encoding.element.EncodingElement'>[(Path(('module', 'observation.x', 'encoding')),)], [PyTreeDef(<class 'cmm.pss.space.exponential_family.ExpFamSpace'>[(Path(('module', 'observation.x', 'encoding', 'space')), NormalUnitVariance(shape=(), num_parameters=5))], []),PyTreeDef(<class 'cmm.encoding.cluster.CodeCluster'>[()], [*]),PyTreeDef(<class 'cmm.encoding.cluster.PresenceScoreCluster'>[(Path(('module', 'observation.x', 'encoding', 'presence_score_cluster')),)], [*,PyTreeDef(<class 'cmm.structure.parameter.bias.Bias'>[(Path(('module', 'observation.x', 'encoding', 'presence_score_cluster', 'decay')), (1,), 'encoding_decay', True)], [])]),PyTreeDef(<class 'cmm.encoding.cluster.ValueCluster'>[()], [*]),PyTreeDef(<class 'cmm.encoding.cluster.MechanicalCluster'>[()], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'code_consumption_link')), (1, 4), 'encoding', True, OuterMatching(shape=[(1,), (4,)]), <PresenceViewNormalization.each_source_has_unit_output_weights: 1>)], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'value_consumption_link')), (1, 5), 'encoding', True, OuterMatching(shape=[(1,), (5,)]), <PresenceViewNormalization.each_source_has_unit_output_weights: 1>)], []),PyTreeDef(<class 'cmm.link.value_view_link.ValueViewLink'>[(Path(('module', 'observation.x', 'encoding', 'explanation_link')), (5, 4), 'encoding', False, OuterMatching(shape=[(5,), (4,)]))], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'code_gating_link')), (4, 1), 'encoding', True, OuterMatching(shape=[(4,), (1,)]), <PresenceViewNormalization.each_target_has_unit_input_weights: 0>)], []),PyTreeDef(<class 'cmm.link.presence_view_link.SquaredPresenceViewLink'>[(Path(('module', 'observation.x', 'encoding', 'value_gating_link')), (5, 1), 'encoding', True, OuterMatching(shape=[(5,), (1,)]), <PresenceViewNormalization.each_target_has_unit_input_weights: 0>)], []),PyTreeDef(<class 'cmm.link.value_view_link.ValueViewLink'>[(Path(('module', 'observation.x', 'encoding', 'demand_link')), (4, 5), 'encoding', False, OuterMatching(shape=[(4,), (5,)]))], []),*,*,PyTreeDef(<class 'cmm.structure.parameter.bias.Bias'>[(Path(('module', 'observation.x', 'encoding', 'natural_explanation_bias')), (5,), 'deduction', False)], [])])]),PyTreeDef(<class 'cmm.encoding.iterated_function.EncodingIteratedFunctionParameters'>[()], [PyTreeDef(<class 'cmm.pss.observation.Observation'>[()], [*,*]),*,PyTreeDef(<class 'cmm.structure.foundation.parallel.ParallelStructure'>[()], [PyTreeDef(dict[[Path(('module', 'observation.x', 'encoding', 'code_consumption_link')), Path(('module', 'observation.x', 'encoding', 'code_gating_link')), Path(('module', 'observation.x', 'encoding', 'demand_link')), Path(('module', 'observation.x', 'encoding', 'explanation_link')), Path(('module', 'observation.x', 'encoding', 'natural_explanation_bias')), Path(('module', 'observation.x', 'encoding', 'presence_score_cluster', 'decay')), Path(('module', 'observation.x', 'encoding', 'value_consumption_link')), Path(('module', 'observation.x', 'encoding', 'value_gating_link'))]], [*,*,*,*,*,*,*,*])])]),PyTreeDef(<class 'cmm.encoding.configuration.EncodingState'>[()], [PyTreeDef(<class 'cmm.encoding.configuration.EncodingDifferentiand'>[()], [PyTreeDef(<class 'cmm.encoding.configuration.EncodingComparand'>[()], [*,*,*,*,*,*,*,*]),*,*]),PyTreeDef(<class 'tjax.generator.Generator'>[()], [*])])]).
About this issue
- Original URL
- State: open
- Created 4 years ago
- Comments: 23 (23 by maintainers)
Commits related to this issue
- make the custom_vjp bwd None handling more robust fixes #4673 — committed to google/jax by mattjj 4 years ago
I’d like to improve the errors and robustness here, so let me leave this issue open until I do 😃
Glad you’re unblocked!
Yeah I needed Python 3.8 to repro!
A quick Bing search answered my question!