fortuna: bug: SWAG's `state.mean.size` is `1` leading to `TypeError: len() of unsized object`.
Bug Report
Fortuna version: Latest
prob_model.load_state("../swag_checkpoints/2023-07-25 14:59:40.237855/checkpoint_18600/checkpoint")
state = prob_model.posterior.state.get()
# SWAGState(step=array(18600, dtype=int32), apply_fn=None, params=FrozenDict({
# model: {
# params: {
# dfe_subnet: {
# BatchNorm_0: {
# bias: array([-0.13133389, -0.14736553, -0.14047779, -0.12409671, -0.11933165,
# -0.16984864, -0.13965459, -0.07937623, -0.11898279, -0.1386996 ,
# -0.13736989, -0.11246286, -0.15424594, -0.10375523, -0.10800011,
# -0.14000903, -0.15316793, -0.13276398, -0.11146024, -0.16203304,
# -0.14830959, -0.13227627, -0.11291285, -0.11979104, -0.08990214,
# -0.13557586, -0.15480955, -0.17320064, -0.14736709, -0.12703426, ...
state.mean
# array(-0.01478862, dtype=float32)
This leads to an error when running prob_model.predictive.sample() on line 212 of fortuna/prob_model/posterior/swag/swag_posterior.py:
207 if state.mutable is not None and inputs_loader is None and inputs is None:
208 raise ValueError(
209 "The posterior state contains mutable objects. Please pass `inputs_loader` or `inputs`."
210 )
--> 212 n_params = len(state.mean) # TypeError: len() of unsized object
213 rank = state.dev.shape[-1]
214 which_params = decode_encoded_tuple_of_lists_of_strings_to_array(
215 state._encoded_which_params
216 )
Not sure if I’m doing something wrong here? Thanks!
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 24 (24 by maintainers)
Thanks for trying this out. I think this is indeed a bug and I think I know where the issue is. Give me a moment and I’ll push the fix. 😄