lightweight_mmm: ValueError: Normal distribution got invalid loc parameter.

Hi! I’m attempting to recreate the sample presented in PyData 2022 seen here with some of my own MMM data: https://github.com/takechanman1228/mmm_pydata_global_2022/blob/main/simple_end_to_end_demo_pydataglobal.ipynb

data = data.tail(150)
data_size = len(data)

n_media_channels = len(mdsp_cols)
n_extra_features = len(control_vars)
media_data = data[mdsp_cols].to_numpy()
extra_features = data[control_vars].to_numpy()
target = data['y'].to_numpy()
costs = data[mdsp_cols].sum().to_numpy()
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean, multiply_by=0.15)

media_data_train = media_scaler.fit_transform(media_data_train)
extra_features_train = extra_features_scaler.fit_transform(extra_features_train)
target_train = target_scaler.fit_transform(target_train)
costs = cost_scaler.fit_transform(costs)
mmm = lightweight_mmm.LightweightMMM(model_name="hill_adstock")

number_warmup=1000
number_samples=1000

mmm.fit(
    media=media_data_train,
    media_prior=costs,
    target=target_train,
    extra_features=extra_features_train,
    number_warmup=number_warmup,
    number_samples=number_samples,
    media_names = mdsp_cols,
    seed=105)

The below error is displayed:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_3869/3074029020.py in <module>
     12     number_samples=number_samples,
     13     media_names = mdsp_cols,
---> 14     seed=105)

/opt/conda/lib/python3.7/site-packages/lightweight_mmm/lightweight_mmm.py in fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed)
    370         transform_function=self._model_transform_function,
    371         weekday_seasonality=weekday_seasonality,
--> 372         custom_priors=custom_priors)
    373 
    374     self.custom_priors = custom_priors

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    595         else:
    596             if self.chain_method == "sequential":
--> 597                 states, last_state = _laxmap(partial_map_fn, map_args)
    598             elif self.chain_method == "parallel":
    599                 states, last_state = pmap(partial_map_fn)(map_args)

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _laxmap(f, xs)
    158     for i in range(n):
    159         x = jit(_get_value_from_index)(xs, i)
--> 160         ys.append(f(x))
    161 
    162     return tree_map(lambda *args: jnp.stack(args), *ys)

/opt/conda/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    384                 init_params,
    385                 model_args=args,
--> 386                 model_kwargs=kwargs,
    387             )
    388         sample_fn, postprocess_fn = self._get_cached_fns()

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    705             )
    706         init_params = self._init_state(
--> 707             rng_key_init_model, model_args, model_kwargs, init_params
    708         )
    709         if self._potential_fn and init_params is None:

/opt/conda/lib/python3.7/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    657                 model_args=model_args,
    658                 model_kwargs=model_kwargs,
--> 659                 forward_mode_differentiation=self._forward_mode_differentiation,
    660             )
    661             if self._init_fn is None:

/opt/conda/lib/python3.7/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    674             with numpyro.validation_enabled(), trace() as tr:
    675                 # validate parameters
--> 676                 substituted_model(*model_args, **model_kwargs)
    677                 # validate values
    678                 for site in tr.values():

/opt/conda/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

/opt/conda/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

/opt/conda/lib/python3.7/site-packages/lightweight_mmm/models.py in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features)
    433 
    434   numpyro.sample(
--> 435       name="target", fn=dist.Normal(loc=mu, scale=sigma), obs=target_data)

/opt/conda/lib/python3.7/site-packages/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
     97             if result is not None:
     98                 return result
---> 99         return super().__call__(*args, **kwargs)
    100 
    101 

/opt/conda/lib/python3.7/site-packages/numpyro/distributions/continuous.py in __init__(self, loc, scale, validate_args)
   1700         batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
   1701         super(Normal, self).__init__(
-> 1702             batch_shape=batch_shape, validate_args=validate_args
   1703         )
   1704 

/opt/conda/lib/python3.7/site-packages/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
    177                         raise ValueError(
    178                             "{} distribution got invalid {} parameter.".format(
--> 179                                 self.__class__.__name__, param
    180                             )
    181                         )

ValueError: Normal distribution got invalid loc parameter.

I’ve checked for null values in my observation data, and none are present. In addition, I removed zero-cost channels from the data model after checking that a few had zero-cost after scaling and referred to the answer in https://github.com/google/lightweight_mmm/issues/115 as such. I also tried scaling down the number of rows and the number of columns that are fed into the model, but none of those have helped get past this error. Please let me know what I can do to diagnose this model. Thanks in advance.

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 24

Most upvoted comments

For the scalers, since they’re dividing by the mean, it sounds like you may have some channels with zero impressions? That would produce the NaNs when applying the scaler to the media data.

I figured it all out. Basically it came down to zero division for the scalers but my problem was not columns of zeros… i had 3d arrays (with sub geo data) and the division was occurring over another axis in which all the values were 0 (a certain country and media source combination was always returning 0). so this led to 0/0. simple solution is to fill the nans with 0 before feeding the scaled data to the training.

Ok I now have this working on data representing spend across channel.

  1. Use sum(channel_spend) as cost for channel in costs
  2. Ensure that there are no zero-sum features in any of the data (media_data AND extra_features)
  3. Don’t use an enormous amount of data (shrink rows and columns to a reasonable amount) a. If you use an enormous dataset, chances are you are going to run into a RuntimeError: Cannot find valid initial parameters. Please check your model again.

@michevan Somewhat related to this bug, I noticed switching from a dummy-encoded seasonality feature-space (one column for each week number and holiday with 0-1 indicator values) to the prophet-generated seasonality features (trend, holiday, season) helped to get the model to run. Do you have any recommendations on what format the extra features (notably for seasonality and holiday capture) should be in?

it’s worth trying with the full dataset but you’re probably right that it’s too large. I’d try different combinations and see what works; it’s usually best to start with just a few channels, get a working model, and then add more iteratively.

Thank you! And yes, those dimensions look logically sound to me, so that doesn’t seem to be the issue!

One thing that I notice (not exactly what’s causing your issue, but it may help) is that you have too many features in your model. Very roughly, you have 3x27 parameters for your media channels plus at least 67 more parameters for your extra features, and this is already 148 features (there are a few more internally like the seasonality components), for which you only have 126 target data points. You probably should reduce your number of features by a factor of like 5 or 10 in order to get good model convergence, and this might also (hopefully) help surface whatever issue is causing the invalid loc parameter here too.

perfect, thank you! I see there are still a bunch of zeros here and that’s probably what is causing the issue. The default prior for each channel is a half-normal distribution with mean zero and standard deviation equal to the values you’re passing here, so when that value is zero the prior gets difficult to define. Changing to a small non-zero value should fix the issue. More generally though, media channels in an MMM should usually have non-zero costs, especially if you want to compute ROIs later in the process and perform channel optimization.