How to Build a Machine Learning Model in Rust

Machine learning (ML) has become an increasingly important tool for data-driven decision making and predictive analytics. While Python and R have dominated the ML ecosystem, Rust is emerging as a compelling alternative, offering performance, safety, and concurrency benefits.

In this guide, we‘ll walk through the process of building and training a machine learning model in Rust. We‘ll cover the key concepts, tools, and code you need to get started with ML in Rust. By the end, you‘ll have a working ML model and the foundation to tackle more advanced projects.

Machine Learning Process Overview

Before diving into the Rust specifics, let‘s review the key steps in a typical machine learning workflow:

  1. Data Preparation – ML models learn patterns from data. The first step is collecting relevant data and converting it into a format the model can understand, usually a numeric matrix of features and labels.

  2. Model Training – The model is then trained by showing it many examples from the training data and gradually adjusting its internal parameters to minimize prediction errors. There are many types of models suited for different problems.

  3. Model Evaluation – The trained model is tested on a separate set of data to estimate its real-world performance and diagnose issues like overfitting or underfitting. Metrics like accuracy, precision, and recall are used to quantify performance.

A few key ML concepts to know:

  • Features – The input variables describing each data point, e.g. age, height, weight.
  • Labels – The output we want to predict, e.g. disease or not.
  • Training vs Test Data – The dataset is split into separate subsets for training and final evaluation. This avoids overly optimistic performance estimates.

With these core concepts in mind, let‘s see how to implement the ML workflow in Rust.

Preparing Data in Rust

Most real-world ML datasets are stored as structured text like CSV or JSON files. Rust has good support for reading these formats and converting them into numeric arrays.

For this example, we‘ll use the Titanic dataset from Kaggle. It contains passenger information like age, sex, ticket class, and whether they survived, with the goal of predicting survival from the other features.

First, add the csv and ndarray crates to your Cargo.toml dependencies:

[dependencies]
csv = "1.1"
ndarray = "0.14"

Then we can read the CSV data and convert it to an ndarray matrix:

use std::fs::File;
use csv::ReaderBuilder;
use ndarray::{Array2, Array};

fn read_csv_data(file_path: &str) -> Array2<f32> {
    let file = File::open(file_path).expect("Could not open file");
    let mut reader = ReaderBuilder::new().has_headers(true).from_reader(file);

    let headers = reader.headers().unwrap();
    let num_cols = headers.len();

    let data: Vec<f32> = reader
        .records()
        .map(|r| r.unwrap())
        .flat_map(|r| r.iter().map(|v| v.parse().unwrap()))
        .collect();

    Array::from_shape_vec((data.len() / num_cols, num_cols), data)
        .expect("Error converting data to matrix")
}

This code reads the CSV file, extracts the header row to determine the number of columns, parses each record into floats, and collects the results into a flat vector. Finally, it reshapes the vector into a 2D matrix with dimensions (num_rows, num_cols).

We also need to split the matrix into input features and output labels. We‘ll use the linfa crate for this, which provides helpful abstractions for ML in Rust:

use linfa::Dataset;

fn split_data(data: &Array2<f32>, label_col: usize) -> Dataset<f32, usize, ndarray::Dim<[usize; 1]>> {    
    let (nrows, ncols) = data.dim();

    let features = data.slice(s![.., 0..label_col]).to_owned();
    let labels = data.column(label_col).map(|&v| v as usize);

    Dataset::new(features, labels) 
}

Here we slice out the feature columns and label column from the full data matrix and return a Dataset struct containing each. With the data prepared, we‘re ready to train a model!

Training Models with Linfa

The linfa crate provides a scikit-learn inspired API for ML in Rust. It includes common models, datasets, preprocessing utilities, and abstractions for model fitting and prediction.

For illustrative purposes, we‘ll train a decision tree classifier, but linfa supports many other algorithms like random forests, SVMs, naive Bayes, k-nearest neighbors, and more.

To get started, add linfa and its derive macros to your dependencies:

[dependencies]
linfa = "0.5"
linfa-trees = "0.5"

[dependencies.linfa-datasets]
version = "0.5"
features = ["iris"]

Then we can import the dataset and model:

use linfa::prelude::*;
use linfa_trees::DecisionTree;
use linfa_datasets::iris;

Linfa provides a number of built-in datasets for testing and examples, like the famous iris flower dataset. To show the full pipeline, though, we‘ll bring in our own titanic data:

fn main() {
    // Read and split titanic data
    let data = read_csv_data("data/titanic.csv");
    let dataset = split_data(&data, 0);

    let (train, test) = dataset.split_with_ratio(0.8);

    // Train decision tree model
    let model = DecisionTree::params()
        .fit(&train)
        .unwrap();

    // Make predictions on test set  
    let preds = model.predict(&test);
    let acc = accuracy_score(&preds, &test.targets);

    println!("Test Accuracy: {:.2}", acc);
}

This code reads in the titanic CSV data, splits it 80/20 into train and test sets, fits a decision tree to the training set, makes predictions on the test set, and finally prints the accuracy score (proportion of correct predictions).

The magic happens in DecisionTree::params() which constructs a new decision tree model, and .fit() which trains the model on the given dataset. Linfa models follow this same builder pattern of config with params() and train with fit().

Evaluating Model Performance

Accuracy is a good starting point, but it doesn‘t tell the full story. A model can achieve high accuracy by just predicting the most common class! To get a more complete picture, we need metrics like:

  • Precision – What percent of positive predictions were correct? Higher precision means fewer false positives.
  • Recall – What percent of actual positives were correctly predicted? Higher recall means fewer false negatives.
  • F1 score – The harmonic mean of precision and recall, a balanced measure of model performance.

We can calculate these with linfa‘s metrics module:

use linfa::metrics::{precision_score, recall_score, f1_score};

// Make predictions on test set  
let preds = model.predict(&test);

let prec = precision_score(&preds, &test.targets);
let rec = recall_score(&preds, &test.targets);
let f1 = f1_score(&preds, &test.targets);

println!("Precision: {:.2}", prec);  
println!("Recall: {:.2}", rec);
println!("F1 Score: {:.2}", f1);

We can also visualize performance with tools like plotters to create ROC curves, confusion matrices, and more.

Next Steps

Congrats, you‘ve trained your first ML model in Rust! There‘s still plenty more to learn:

  • Explore other model types like RandomForest, SVM, KNN, etc. Linfa has many to choose from.
  • Models have tunable hyperparameters that affect performance and training. Experiment with different values passed to params().
  • Feature engineering is the process of transforming raw data into more informative representations. Try one-hot encoding categorical variables or scaling numeric features.
  • To use your trained model in a live setting, you‘ll need to serialize it to disk, integrate it into an API or service, and monitor its performance over time. The ONNX Runtime is a good option for deployment.

Conclusion

Rust is a fast, safe, and concurrent language well-suited for Machine Learning. In this guide, we:

  • Learned the key concepts and steps of the ML process
  • Prepared a real-world dataset for modeling in Rust
  • Trained and evaluated a decision tree model using the linfa crate

I encourage you to check out the Rust ML organization to discover more crates for machine learning, as well as the Machine Learning with Rust book to go deeper into the Rust ML ecosystem.

Now it‘s your turn – pick a dataset, select a model, and start training! The power and potential of machine learning in Rust awaits. Happy coding!

Similar Posts