ML Workflows
Inference Workflows

Torch Inference Workflow

This workflow is used to perform inference on Torch models.

Constructor Arguments

  • model_source (ModelSource): The source of the model. For available model sources, see here.
  • model_args (Optional[dict[str, Any]]): The arguments for loading the model. This is different depending on the model source. For more information, see here.

Pytorch requires models to exist in the classpath. Sci-kit Learn (opens in a new tab) models are included in the classpath by default from the sk2torch (opens in a new tab) library. If you are using a model that was trained on a custom module, you will need to include that module in the classpath. Refer to our Torch Iris Classification Example (opens in a new tab) to see this in action.

Additional Installations

Since this workflow uses the torch library, you'll need to install infernet-ml[torch_inference]. Alternatively, you can install those packages directly. The optional dependencies "[torch_inference]" are provided for your convenience.

To install via pip (opens in a new tab):

pip install infernet-ml[torch_inference]

Input Format

Input format is the following dictionary:

    "dtype": str,
    "values": list[Any]
  • dtype (str): The data type of the input. For example, "float32". Refer to below for supported data types.
  • values (list[Any]): The input values. The length of the list should match the input shape of the model.

The input is pre-processed as follows:

    def do_preprocessing(self, input_data: dict[str, Any]) -> torch.Tensor:
        # lookup dtype from str
        dtype = DTYPES.get(input_data["dtype"], None)
        values = input_data["values"]
        return torch.tensor(values, dtype=dtype)

This is intentionally kept-generic. If you need to perform any specific pre-processing, you can do so by subclassing this class and overriding the do_preprocessing method.

Supported Data Types

    "float": torch.float,
    "double": torch.double,
    "cfloat": torch.cfloat,
    "cdouble": torch.cdouble,
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "uint8": torch.uint8,
    "int8": torch.int8,
    "short": torch.short,
    "long": torch.long,
    "bool": torch.bool,


from infernet_ml.workflows.inference.torch_inference_workflow import TorchInferenceWorkflow
from infernet_ml.utils.model_loader import ModelSource
from import IrisClassificationModel
workflow = TorchInferenceWorkflow(
        "repo_id": "Ritual-Net/iris-classification",
        "filename": "iris.torch",
results = workflow.inference({
    "values": [[1.0380048, 0.5586108, 1.1037828, 1.712096]],
    "dtype": "float",
print(f"results: {results}")