coremltools: Apple's ANE Optimised `MultiHeadAttention` Export Fails With Flexible Input Shape

šŸžBug Description

Iā€™m trying to export Appleā€™s ANE optimised MultiHeadAttention layer defined here

The layer exports successfully with a fixed shape, but fails with flexible shapes. Iā€™m using this layer in a custom sequence model, so flexible shapes are imperative.

The error thrown is an AssertionError: input shapes incompatible.

Stack Trace

Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  73%|ā–‹| 129/177 [00:00<00:00, 8531.06 o

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[28], line 2
      1 flexible_shape = ct.Shape(shape = (1, 512, 1, ct.RangeDim(1, 448)))
----> 2 mlmod_flexible_shape = ct.convert(
      3     jit,
      4     inputs = [
      5         ct.TensorType("q", flexible_shape),
      6         ct.TensorType("k", flexible_shape),
      7         ct.TensorType("v", flexible_shape),
      8     ]
      9 )

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/_converters_entry.py:444, in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)
    441 if specification_version is None:
    442     specification_version = _set_default_specification_version(exact_target)
--> 444 mlmodel = mil_convert(
    445     model,
    446     convert_from=exact_source,
    447     convert_to=exact_target,
    448     inputs=inputs,
    449     outputs=outputs_as_tensor_or_image_types, # None or list[ct.ImageType/ct.TensorType]
    450     classifier_config=classifier_config,
    451     transforms=tuple(transforms),
    452     skip_model_load=skip_model_load,
    453     compute_units=compute_units,
    454     package_dir=package_dir,
    455     debug=debug,
    456     specification_version=specification_version,
    457 )
    459 if exact_target == 'milinternal':
    460     return mlmodel # Returns the MIL program

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/converter.py:187, in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
    148 @_profile
    149 def mil_convert(
    150     model,
   (...)
    154     **kwargs
    155 ):
    156     """
    157     Convert model from a specified frontend `convert_from` to a specified
    158     converter backend `convert_to`.
   (...)
    185         See `coremltools.converters.convert`
    186     """
--> 187     return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/converter.py:211, in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
    208     weights_dir = _tempfile.TemporaryDirectory()
    209     kwargs["weights_dir"] = weights_dir.name
--> 211 proto, mil_program = mil_convert_to_proto(
    212                         model,
    213                         convert_from,
    214                         convert_to,
    215                         registry,
    216                         **kwargs
    217                      )
    219 _reset_conversion_state()
    221 if convert_to == 'milinternal':

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/converter.py:281, in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    278 kwargs.setdefault("convert_to", convert_to)
    279 frontend_converter = frontend_converter_type()
--> 281 prog = frontend_converter(model, **kwargs)
    283 if convert_to.lower() != "neuralnetwork":
    284     passes = kwargs.get("transforms", list())

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/converter.py:109, in TorchFrontend.__call__(self, *args, **kwargs)
    106 def __call__(self, *args, **kwargs):
    107     from .frontend.torch import load
--> 109     return load(*args, **kwargs)

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py:57, in load(model_spec, inputs, specification_version, debug, outputs, cut_at_symbols, **kwargs)
     55 inputs = _convert_to_torch_inputtype(inputs)
     56 converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols, specification_version)
---> 57 return _perform_torch_convert(converter, debug)

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py:96, in _perform_torch_convert(converter, debug)
     94 def _perform_torch_convert(converter, debug):
     95     try:
---> 96         prog = converter.convert()
     97     except RuntimeError as e:
     98         if debug and "convert function" in str(e):

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/converter.py:281, in TorchConverter.convert(self)
    278 self.convert_const()
    280 # Add the rest of the operations
--> 281 convert_nodes(self.context, self.graph)
    283 graph_outputs = [self.context[name] for name in self.graph.outputs]
    285 # An output can be None when it's a None constant, which happens
    286 # in Fairseq MT.

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:89, in convert_nodes(context, graph)
     84     raise RuntimeError(
     85         "PyTorch convert function for op '{}' not implemented.".format(node.kind)
     86     )
     88 context.prepare_for_conversion(node)
---> 89 add_op(context, node)
     91 # We've generated all the outputs the graph needs, terminate conversion.
     92 if _all_outputs_present(context, graph):

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py:1120, in einsum(context, node)
   1118 b = context[node.inputs[1]][1]
   1119 equation = context[node.inputs[0]].val
-> 1120 x = build_einsum_mil(a, b, equation, node.name)
   1121 context.add(x)

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/frontend/_utils.py:164, in build_einsum_mil(a_var, b_var, equation, name)
    162         x = mb.einsum(values=(a_var, b_var), equation=equation, name=name)
    163     else:
--> 164         x = mb.einsum(values=(b_var, a_var), equation=equation_rev, name=name)
    165 elif vec_chw_whu_chu in [parsed_vectors, parsed_vectors_rev]:
    166     if parsed_vectors == vec_chw_whu_chu:

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/mil/ops/registry.py:176, in SSAOpRegistry.register_op.<locals>.class_wrapper.<locals>.add_op(cls, **kwargs)
    173 else:
    174     op_cls_to_add = op_reg[op_type]
--> 176 return cls._add_op(op_cls_to_add, **kwargs)

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/mil/builder.py:182, in Builder._add_op(cls, op_cls, **kwargs)
    180 curr_block()._insert_op_before(new_op, before_op=before_op)
    181 new_op.build_nested_blocks()
--> 182 new_op.type_value_inference()
    183 if len(new_op.outputs) == 1:
    184     return new_op.outputs[0]

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/mil/operation.py:253, in Operation.type_value_inference(self, overwrite_output)
    243 def type_value_inference(self, overwrite_output=False):
    244     """
    245     Perform type inference and auto_val computation based on new input Vars
    246     in kwargs. If self._output_vars is None then we generate _output_vars;
   (...)
    251     existing _output_vars
    252     """
--> 253     output_types = self.type_inference()
    254     if not isinstance(output_types, tuple):
    255         output_types = (output_types,)

File ~/miniconda3/envs/rosetta/lib/python3.8/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/linear.py:290, in einsum.type_inference(self)
    287 print(f"x, y shapes: {x_shape, y_shape}")
    289 assert len(x_shape) == len(y_shape), "inputs not of the same rank"
--> 290 assert x_shape[-1] == y_shape[-3], "input shapes incompatible"
    291 if x_shape[-2] != 1 and y_shape[-2] != 1:
    292     assert x_shape[-2] == y_shape[-2], "input shapes incompatible"

AssertionError: input shapes incompatible

To Reproduce

import torch  # 1.13.1
import numpy as np  # 1.22.3
import coremltools as ct  # 6.2

from ane_transformers.reference.multihead_attention import MultiHeadAttention

N = 10
x = torch.rand(1, 512, 1, N)

layer = MultiHeadAttention(512, n_head=8, dropout=0.0).eval()
jit = torch.jit.trace(layer, (x, x, x))


# Fixed input shape - works
mlmod_fixed_shape = ct.convert(
    jit,
    inputs = [
        ct.TensorType("q", x.shape),
        ct.TensorType("k", x.shape),
        ct.TensorType("v", x.shape),
    ]
)


# Flexible input shape - fails
flexible_shape = ct.Shape(shape = (1, 512, 1, ct.RangeDim(1, 448)))
mlmod_flexible_shape = ct.convert(
    jit,
    inputs = [
        ct.TensorType("q", flexible_shape),
        ct.TensorType("k", flexible_shape),
        ct.TensorType("v", flexible_shape),
    ]
)


# Enumerated shape (not ideal, but better than fixed) also throws the same `AssertionError`
enumerated_shapes = ct.EnumeratedShapes(
    [(1, 512, 1, i) for i in np.array(list(range(1, 449)))[::4]]
)
mlmodel_enumerated_shape = ct.convert(
    jit,
    inputs = [
        ct.TensorType("q", enumerated_shapes),
        ct.TensorType("k", enumerated_shapes),
        ct.TensorType("v", enumerated_shapes),
    ],
)

System environment (please complete the following information):

  • coremltools version: 6.2
  • torch version: 1.13.1
  • numpy version: 1.22.3
  • OS: macOS 13.0, MacBook Pro 16-inch, 2021

Additional context

Iā€™m quite certain that the shape error is happening as part of an einsum operation in the layer definition. While debugging, I printed out the equations and shapes of all einsum ops being converted (I did this by adding two print statements right below these lines). It appears that the error happens in one of the later einsum ops and not right away.

Einsum conversion equation: bkhc,bchq->bkhq
x, y shapes: ((1, is148, 1, 64), (1, 64, 1, is147))
Einsum conversion equation: bkhc,bchq->bkhq
x, y shapes: ((1, is148, 1, 64), (1, 64, 1, is147))
Einsum conversion equation: bkhc,bchq->bkhq
x, y shapes: ((1, is148, 1, 64), (1, 64, 1, is147))
Einsum conversion equation: bkhc,bchq->bkhq
x, y shapes: ((1, is148, 1, 64), (1, 64, 1, is147))
Einsum conversion equation: bkhc,bchq->bkhq
x, y shapes: ((1, is148, 1, 64), (1, 64, 1, is147))
Einsum conversion equation: bkhc,bchq->bkhq
x, y shapes: ((1, is148, 1, 64), (1, 64, 1, is147))
Einsum conversion equation: bkhc,bchq->bkhq
x, y shapes: ((1, is148, 1, 64), (1, 64, 1, is147))
Einsum conversion equation: bkhc,bchq->bkhq
x, y shapes: ((1, is148, 1, 64), (1, 64, 1, is147))
Einsum conversion equation: bchk,bkhq->bchq
x, y shapes: ((1, 64, 1, is149), (1, is148, 1, is147))

Perhaps this issue is tangentially related: https://github.com/apple/coremltools/issues/1754

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 1
  • Comments: 23 (4 by maintainers)

Most upvoted comments

Hello, author of ml-ane-transformers here šŸ‘‹ Great to hear that you are attempting to extend the reference implementationā€™s capabilities with flexible shapes. Some notes:

  • 50k+ different shapes needed while ct.EnumeratedShapes accommodates 128 max : I recommend enumerating only sequence_length=[2**n for n in range(7,17)], advancing from one sequence_length to the next one right before you overflow and use masking (decoder_k_mask) to disable attention on unused indices.
  • Benchmarking the flexible shape models you were able to generate: I recommend using model = coremltools.models.MLModel(path_to_mlpackage_file) and model.predict(dict_of_inputs_with_varying_sequence_length) for each sequence length with a timer around this call. There will be non-zero overhead with this way of benchmarking but the marginal overhead will be negligible for the larger variants.
  • Verifying output correctness: I recommend a PSNR check similar to this one in our unit tests.

Yes, thats right, the static shaped model should already be ANE resident, and if you update such a model with enumerated shapes it should run on the ANE. Unless, the process of making the model dynamic shaped, with enumerated shapes, introduces some dynamic layers (eā€¦g converts a static reshape to a fully dynamic reshape), in that case, we may lose the ANE residency.

I found rewrite _attention_fn in MultiHeadAttention works well. In your use case, bkhc,bchq->bkhq and b = h = 1, you can calculate it by kc,cq->kq and reshape tensor.

    def _attention_fn(self, q, k, v, qk_mask, k_mask, return_weights):
        ...
        attn_weights = [aw.softmax(dim=1) for aw in attn_weights
                        ]  # n_head * (batch_size, src_seq_len, 1, tgt_seq_len)
        mh_w = [self.dropout(aw) for aw in attn_weights
                ]  # n_head * (batch_size, src_seq_len, 1, tgt_seq_len)
        mh_w = [wi.reshape(wi.shape[1], wi.shape[3]) for wi in mh_w]
        mh_v = [vi.reshape(vi.shape[1], vi.shape[3]) for vi in mh_v]
        attn = [
            torch.einsum('kq,ck->cq', wi, vi)
            for wi, vi in zip(mh_w, mh_v)
        ]  # n_head * (batch_size, d_v/n_head, 1, tgt_seq_len)
        attn = [
            a.reshape(1, a.shape[0], 1, a.shape[1]) for a in attn
        ]
        attn = torch.cat(attn, dim=1)  # (batch_size, d_v, 1, tgt_seq_len)

        if return_weights:
            return attn, attn_weights
        return attn, None

Here is my test code.

import torch  # 1.13.1
import numpy as np  # 1.22.3
import coremltools as ct  # 6.2

from ane_transformers.reference.multihead_attention import MultiHeadAttention

N = 10
x = torch.rand(1, 512, 1, N)

with torch.no_grad():
    layer = MultiHeadAttention(512, n_head=8, dropout=0.0).eval()
    jit = torch.jit.trace(layer, (x, x, x))


    # Flexible input shape
    flexible_shape = ct.Shape(shape = (1, 512, 1, ct.RangeDim(1, 448)))
    mlmod_flexible_shape = ct.convert(
        jit,
        inputs = [
            ct.TensorType("q", flexible_shape),
            ct.TensorType("k", flexible_shape),
            ct.TensorType("v", flexible_shape),
        ]
    )

    out = layer(x, x, x)
    out_dict = mlmod_flexible_shape.predict({'q': x.detach().numpy().astype(np.float32),
                                             'k': x.detach().numpy().astype(np.float32),
                                             'v': x.detach().numpy().astype(np.float32)})
    np.allclose(out[0], out_dict['var_451'], rtol=0.001, atol=0.001)  # OK

I think coremltools MIL einsum has something wrong with flexible shape One of solutions is https://github.com/apple/coremltools/pull/1863 .

FWIW I have the exact same issue (fails on same line) when trying to plug the translated DistilBERT model from ml-ane-transformers into the https://github.com/huggingface/exporters exporter code

I am guessing from the info above itā€™s because HF exporter wraps the model in an extra module layer that has a bunch of conditional logic for models of different types, so the input shape becomes ā€˜flexibleā€™

Does ct.EnumeratedShapesallow for support of up to 50,000+ possible shapes, ie (1, 1), (1, 2), ā€¦ (1, 51865)? We are accruing tokens from Whispers encoder - which requires continuing to run token prediction and passing in an accrued sequence of additional tokens until you hit an end of transcript token, or the max vocal size (51865).

Does it seem feasible to use the enumerated shapes path?

My understanding is that (please correct me if I am mistaken):

  • Whisperā€™s encoder has a default maximum sequence length of 1500 tokens
  • Whisperā€™s decoder has a default maximum sequence length of 448 tokens.
  • 448 is the value @rsomani95 intended to use for flexible shapes based on the first message on this thread.
  • @vade Could you please confirm that your use case for flexible shapes is for something other than autoregressive decoding?

Right, the default size model is expected to utilize the ANE, even with flexible shapes, but if you run the model with a different shape, it would likely only run on CPU/GPU.

Commenting out the assert allows the conversion to finish, but where fixed size uses the ANE, flexible shape does not.

Instead of using ct.RangeDim, can you try using the ct.EnumeratedShapes option? That should possibly use the ANE.

Yes, I can now reproduce the issue (after locally cloning and using the ml-ane-transformers repository). Thanks for updating the code.

The issue here is that x_shape and y_shape contain symbols.