djl: Inference tensors cannot be saved for backward.

I don’t know if this is a bug. I’m trying to follow the official tutorial using pytorch engine.

Here is my code and exception.

String modelUrl = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
Criteria<NDList, NDList> criteria = Criteria.builder()
    .optApplication(Application.NLP.WORD_EMBEDDING)
    .setTypes(NDList.class, NDList.class)
    .optModelUrls(modelUrl)
    .optProgress(new ProgressBar())
    .build();

ZooModel<NDList, NDList> embedding = criteria.loadModel();
Predictor<NDList, NDList> embedder = embedding.newPredictor();
SequentialBlock classifier = new SequentialBlock()
    .add(
            ndList -> {
                NDArray data = ndList.singletonOrThrow();
                long batchSize = data.getShape().get(0);
                long maxLen = data.getShape().get(1);
                NDList inputs = new NDList();
                inputs.add(data.toType(DataType.INT64, false));
                inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
                inputs.add(data.getManager().arange(maxLen).toType(DataType.INT64, false).broadcast(data.getShape()));
                try {
                    return embedder.predict(inputs);
                } catch (TranslateException e) {
                    throw new RuntimeException(e);
                }
            }
    )
    .add(Linear.builder().setUnits(768).build())
    .add(Activation::relu)
    .add(Dropout.builder().optRate(0.2f).build())
    .add(Linear.builder().setUnits(5).build())
    .addSingleton(nd -> nd.get(":,0"));

Model model = Model.newInstance("review_classification");
model.setBlock(classifier);

DefaultVocabulary vocabulary = DefaultVocabulary.builder()
    .addFromTextFile(embedding.getArtifact("vocab.txt"))
    .optUnknownToken("[UNK]")
    .build();

int maxTokenLen = 64;
int batchSize = 8;
int limit = Integer.MAX_VALUE;

BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset awsDataset = getDataset(batchSize, tokenizer, maxTokenLen, limit);
RandomAccessDataset[] datasets = awsDataset.randomSplit(7, 3);
RandomAccessDataset trainDataset = datasets[0];
RandomAccessDataset evalDataset = datasets[1];

SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(
    trainer -> {
        TrainingResult result = trainer.getTrainingResult();
        Model trainerModel = trainer.getModel();
        float acc = result.getValidateEvaluation("Accuracy");
        trainerModel.setProperty("Accuracy", String.format("%.5f", acc));
        trainerModel.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
    }
);

DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
    .addEvaluator(new Accuracy())
    .addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
    .addTrainingListeners(listener);

int epoch = 2;
Trainer trainer = model.newTrainer(trainingConfig);
trainer.setMetrics(new Metrics());
Shape shape = new Shape(batchSize, maxTokenLen);
trainer.initialize(shape);
EasyTrain.fit(trainer, epoch, trainDataset, evalDataset);
System.out.println(trainer.getTrainingResult());
model.save(Paths.get("build/model"), "aws-review-rank");
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 6
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 6
[main] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().
[main] INFO ai.djl.training.listener.LoggingTrainingListener - Load PyTorch Engine Version 1.12.1 in 0.079 ms.
Exception in thread "main" ai.djl.engine.EngineException: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
	at ai.djl.pytorch.jni.PyTorchLibrary.torchNNLinear(Native Method)
	at ai.djl.pytorch.jni.JniUtils.linear(JniUtils.java:1189)
	at ai.djl.pytorch.engine.PtNDArrayEx.linear(PtNDArrayEx.java:390)
	at ai.djl.nn.core.Linear.linear(Linear.java:183)
	at ai.djl.nn.core.Linear.forwardInternal(Linear.java:88)
	at ai.djl.nn.AbstractBaseBlock.forwardInternal(AbstractBaseBlock.java:126)
	at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91)
	at ai.djl.nn.SequentialBlock.forwardInternal(SequentialBlock.java:209)
	at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91)
	at ai.djl.training.Trainer.forward(Trainer.java:175)
	at ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:122)
	at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110)
	at ai.djl.training.EasyTrain.fit(EasyTrain.java:58)
	at cn.amberdata.misc.djl.rankcls.Main.main(Main.java:114)

And here is my dependencies.

        <!-- https://mvnrepository.com/artifact/ai.djl/api -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.19.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.slf4j/slf4j-simple -->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-simple</artifactId>
            <version>1.7.36</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-engine -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.19.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/ai.djl/basicdataset -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.19.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/ai.djl/model-zoo -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.19.0</version>
        </dependency>

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 22 (11 by maintainers)

Most upvoted comments

Thanks for pointing out this bug to us - this is an issue in the NoopTranslator when using the PyTorch engine. We will take a look and determine the best fix.

For now, you can adjust your code as follows to get around the exception:

  1. Define your own custom translator that extends the NoopTranslator and overrides the processOutput method
public static final class MyTranslator extends NoopTranslator {

        @Override
        public NDList processOutput(TranslatorContext ctx, NDList input) {
            return new NDList(
                input.stream().map(ndArray -> ndArray.duplicate()).collect(Collectors.toList()));
        }
    }
  1. When you are building the Criteria object, add the optTranslator builder method and pass in the customer translator
Criteria<NDList, NDList> criteria = Criteria.builder()
    .optApplication(Application.NLP.WORD_EMBEDDING)
    .setTypes(NDList.class, NDList.class)
    .optModelUrls(modelUrl)
    .optProgress(new ProgressBar())
    .optTranslator(new MyTranslator()) // Add this line
    .build();