TorchSharp: Missing torch.load

Similar with the discussion of the missing torch.save

pytorch torch.load refers to torch\serialization.py

which provide parameter instructions for loading

Example

        >>> torch.load('tensors.pt')
        # Load all tensors onto the CPU
        >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
        # Load all tensors onto the CPU, using a function
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
        # Load all tensors onto GPU 1
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
        # Map tensors from GPU 1 to GPU 0
        >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
        # Load tensor from io.BytesIO object
        >>> with open('tensor.pt', 'rb') as f:
        ...     buffer = io.BytesIO(f.read())
        >>> torch.load(buffer)
        # Load a module with 'ascii' encoding for unpickling
        >>> torch.load('module.pt', encoding='ascii')

Currently LibTorchSharp implements one of the possible loading options listed above

Since pickling is an overkill as discussed for .NET

As I am still learning … is there a need to provide more loading options provided through torch.load instead of Module.load in TorchSharp?

I am raising this issue, as I fail to load a saved State_Dict created through exportsd.py back to TorchSharp using Module.Load

I did not get any error message, as the process crashes.

suggestions: is there a need for error messages when loading fail to assist in a more reliable loading state_dict.

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 17 (8 by maintainers)

Most upvoted comments

@GeorgeS2019 – I suggest adding a print statement (on your machine) to the exportsd.py, something like:

    for entry in sd:
        print(entry)
        stream.write(leb128.u.encode(len(entry)))
        stream.write(bytes(entry, 'utf-8'))
        _write_tensor(sd[entry], stream)

and see what the names of all the state_dict entries are, then compare that to your .NET module that you are loading the weights into.

Anything that looks like a <String,Tensor> dictionary and was saved using the format that exportsd.py also uses should be possible to load, but when loading, the keys come from the model instance (either a custom module or Sequential) that the weights are being loaded into. On the saving side, the keys likewise come from the original model.

Thus, the two have to exactly match – that’s the key here. Without seeing the model definition on both sides, it’s hard to help debug it. The best I can do, and I will try to get that into the next release, is to improve the error messages so that they are more informative.

Thanks for that information. In terms of the how TorchSharp model serialization works, it loads and saves model parameters (weights), not models. That means that in order to load weights, you have to have an exact copy of the original model defined in .NET, and an instance of the model created (presumably with random or empty weights).

Parameters should be represented as fields in the model, as should buffers (tensors that are used by the model, but not affected by training), and they must have exactly the same name as in the original model.

I’ll construct some negative unit tests for this and see if I can improve the error messages to be more informative.