lightweight_mmm: ValueError: Normal distribution got invalid loc parameter.
Hi! I’m attempting to recreate the sample presented in PyData 2022 seen here with some of my own MMM data: https://github.com/takechanman1228/mmm_pydata_global_2022/blob/main/simple_end_to_end_demo_pydataglobal.ipynb
data = data.tail(150)
data_size = len(data)
n_media_channels = len(mdsp_cols)
n_extra_features = len(control_vars)
media_data = data[mdsp_cols].to_numpy()
extra_features = data[control_vars].to_numpy()
target = data['y'].to_numpy()
costs = data[mdsp_cols].sum().to_numpy()
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean, multiply_by=0.15)
media_data_train = media_scaler.fit_transform(media_data_train)
extra_features_train = extra_features_scaler.fit_transform(extra_features_train)
target_train = target_scaler.fit_transform(target_train)
costs = cost_scaler.fit_transform(costs)
mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")
number_warmup=1000
number_samples=1000
mmm.fit(
media=media_data_train,
media_prior=costs,
target=target_train,
extra_features=extra_features_train,
number_warmup=number_warmup,
number_samples=number_samples,
media_names = mdsp_cols,
seed=105)
The below error is displayed:
ValueError Traceback (most recent call last)
/tmp/ipykernel_3869/3074029020.py in <module>
12 number_samples=number_samples,
13 media_names = mdsp_cols,
---> 14 seed=105)
/opt/conda/lib/python3.7/site-packages/lightweight_mmm/lightweight_mmm.py in fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed)
370 transform_function=self._model_transform_function,
371 weekday_seasonality=weekday_seasonality,
--> 372 custom_priors=custom_priors)
373
374 self.custom_priors = custom_priors
/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
595 else:
596 if self.chain_method == "sequential":
--> 597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
599 states, last_state = pmap(partial_map_fn)(map_args)
/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
158 for i in range(n):
159 x = jit(_get_value_from_index)(xs, i)
--> 160 ys.append(f(x))
161
162 return tree_map(lambda *args: jnp.stack(args), *ys)
/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
384 init_params,
385 model_args=args,
--> 386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
705 )
706 init_params = self._init_state(
--> 707 rng_key_init_model, model_args, model_kwargs, init_params
708 )
709 if self._potential_fn and init_params is None:
/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
657 model_args=model_args,
658 model_kwargs=model_kwargs,
--> 659 forward_mode_differentiation=self._forward_mode_differentiation,
660 )
661 if self._init_fn is None:
/opt/conda/lib/python3.7/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
674 with numpyro.validation_enabled(), trace() as tr:
675 # validate parameters
--> 676 substituted_model(*model_args, **model_kwargs)
677 # validate values
678 for site in tr.values():
/opt/conda/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/opt/conda/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/opt/conda/lib/python3.7/site-packages/lightweight_mmm/models.py in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features)
433
434 numpyro.sample(
--> 435 name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)
/opt/conda/lib/python3.7/site-packages/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
97 if result is not None:
98 return result
---> 99 return super().__call__(*args, **kwargs)
100
101
/opt/conda/lib/python3.7/site-packages/numpyro/distributions/continuous.py in __init__(self, loc, scale, validate_args)
1700 batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
1701 super(Normal, self).__init__(
-> 1702 batch_shape=batch_shape, validate_args=validate_args
1703 )
1704
/opt/conda/lib/python3.7/site-packages/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
177 raise ValueError(
178 "{} distribution got invalid {} parameter.".format(
--> 179 self.__class__.__name__, param
180 )
181 )
ValueError: Normal distribution got invalid loc parameter.
I’ve checked for null values in my observation data, and none are present. In addition, I removed zero-cost channels from the data model after checking that a few had zero-cost after scaling and referred to the answer in https://github.com/google/lightweight_mmm/issues/115 as such. I also tried scaling down the number of rows and the number of columns that are fed into the model, but none of those have helped get past this error. Please let me know what I can do to diagnose this model. Thanks in advance.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 24
For the scalers, since they’re dividing by the mean, it sounds like you may have some channels with zero impressions? That would produce the NaNs when applying the scaler to the media data.
I figured it all out. Basically it came down to zero division for the scalers but my problem was not columns of zeros… i had 3d arrays (with sub geo data) and the division was occurring over another axis in which all the values were 0 (a certain country and media source combination was always returning 0). so this led to 0/0. simple solution is to fill the nans with 0 before feeding the scaled data to the training.
Ok I now have this working on data representing spend across channel.
costs
RuntimeError: Cannot find valid initial parameters. Please check your model again.
@michevan Somewhat related to this bug, I noticed switching from a dummy-encoded seasonality feature-space (one column for each week number and holiday with 0-1 indicator values) to the prophet-generated seasonality features (trend, holiday, season) helped to get the model to run. Do you have any recommendations on what format the extra features (notably for seasonality and holiday capture) should be in?
it’s worth trying with the full dataset but you’re probably right that it’s too large. I’d try different combinations and see what works; it’s usually best to start with just a few channels, get a working model, and then add more iteratively.
Thank you! And yes, those dimensions look logically sound to me, so that doesn’t seem to be the issue!
One thing that I notice (not exactly what’s causing your issue, but it may help) is that you have too many features in your model. Very roughly, you have 3x27 parameters for your media channels plus at least 67 more parameters for your extra features, and this is already 148 features (there are a few more internally like the seasonality components), for which you only have 126 target data points. You probably should reduce your number of features by a factor of like 5 or 10 in order to get good model convergence, and this might also (hopefully) help surface whatever issue is causing the invalid loc parameter here too.
perfect, thank you! I see there are still a bunch of zeros here and that’s probably what is causing the issue. The default prior for each channel is a half-normal distribution with mean zero and standard deviation equal to the values you’re passing here, so when that value is zero the prior gets difficult to define. Changing to a small non-zero value should fix the issue. More generally though, media channels in an MMM should usually have non-zero costs, especially if you want to compute ROIs later in the process and perform channel optimization.