How to Convert a Keras SavedModel into a Browser-based Web App

Deep learning has revolutionized what‘s possible with machine learning, enabling powerful applications like computer vision, natural language processing, and much more. However, deploying deep learning models can be challenging, often requiring specialized backend infrastructure and DevOps expertise.

Fortunately, an exciting alternative has emerged in recent years: running deep learning models directly in web browsers using JavaScript. With tools like TensorFlow.js, it‘s now possible to integrate machine learning into any website, without needing backend servers to run the models. This approach has several compelling benefits:

  • Offloading computation from your servers to your users‘ devices
  • Enabling real-time, interactive demos of your ML models
  • Preserving user privacy by running models locally in the browser rather than sending sensitive data to servers

In this post, we‘ll walk through the end-to-end process of converting a Keras model to run in the browser with TensorFlow.js. We‘ll cover the key concepts and steps including training and exporting a Keras model, converting it to the TensorFlow.js Layers format, loading and running inference in a web page, and designing an intuitive user interface around the model. Let‘s dive in!

Training and Exporting a Keras SavedModel

The first step is to define, train, and save your model using Keras. Here‘s a simple example of training an image classification model on the CIFAR-10 dataset:

from tensorflow import keras
from tensorflow.keras import layers

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

model = keras.Sequential(
    [
        layers.Input((32, 32, 3)),
        layers.Conv2D(32, 3, activation="relu"),
        layers.Conv2D(64, 3, activation="relu"),
        layers.MaxPooling2D(),
        layers.Conv2D(128, 3, activation="relu"),
        layers.Flatten(),
        layers.Dense(64, activation="relu"),
        layers.Dense(10),
    ]
)

model.compile(
    optimizer="adam",
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.2)

model.save(‘my_model‘)

This trains a small convolutional neural network to ~70% validation accuracy on CIFAR-10. The key part for our purposes is the last line, which saves the trained model in the SavedModel format using model.save(‘my_model‘). The SavedModel is a serialized version of the model containing the model architecture, trained weights, and compilation information, allowing the model to be reinstantiated later.

Converting to TensorFlow.js Layers Format

Now that we have a trained SavedModel, the next step is to convert it to a web-friendly format. We‘ll use the TensorFlow.js Layers format, which stores the model architecture and weights in a JSON file and binary weight files, respectively.

To convert the model, we‘ll use the tensorflowjs Python package. First install it:

pip install tensorflowjs

Then convert the SavedModel like this:

tensorflowjs_converter --input_format=keras my_model web_model

This converts the SavedModel in the my_model directory to the web_model directory containing model.json and binary weight files.

Loading the Model in a Web Page

With the converted model files in hand, we‘re ready to load the model in a web page using TensorFlow.js. First, add a script tag to load the TensorFlow.js library:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>

Then load the model using tf.loadLayersModel():

const model = await tf.loadLayersModel(‘web_model/model.json‘);

Since model loading is asynchronous (it has to download the model files), we use the await keyword to wait for it to complete. This requires the enclosing function to be declared async.

One best practice is to cache the loaded model weights in IndexedDB using TensorFlow.js‘s storage API. This allows the model to load more quickly on subsequent page visits:

const model = await tf.loadLayersModel(‘web_model/model.json‘, {
  onProgress: updateProgressBar, // Optional callback to track loading progress
  weightPathPrefix: ‘https://example.com/web_model/‘, // Optional URL prefix for loading  from a different origin 
  requestInit: {credentials: ‘include‘}, // Optional request initialization options
  fetchFunc: customFetchFunction, // Optional custom fetch implementation
  weightUrlConverter: customUrlConverter, // Optional function for converting weight URLs
});

Once the model is loaded, we can inspect its architecture:

console.log(model.summary());

Preprocessing Inputs

Before we can run the model in the browser, we need to preprocess any input data to match the model‘s expected format. For our CIFAR-10 image classification model, this involves normalizing the pixel values and reshaping the images to the model‘s input shape.

Here‘s how we can use TensorFlow.js to preprocess an input image:

const preprocessImage = async (imageFile) => {
  const imageElement = document.createElement(‘img‘);
  const promise = new Promise((resolve, _) => {
    imageElement.onload = () => resolve(imageElement);
  });
  imageElement.src = URL.createObjectURL(imageFile);
  await promise;

  const imageTensor = tf.browser.fromPixels(imageElement);
  const resizedImage = tf.image.resizeBilinear(imageTensor, [32, 32]);
  const normalizedImage = resizedImage.div(255);
  const inputImage = normalizedImage.reshape([1, 32, 32, 3]);
  return inputImage;
};

This function takes an image file selected by the user, creates an <img> element to load it, converts the loaded image to a tensor, resizes it to 32×32 pixels, normalizes the pixel values to [0, 1], and reshapes it to the model‘s input shape of [1, 32, 32, 3].

Running Inference

With the model loaded and inputs preprocessed, we‘re finally ready to run inference in the browser!

const prediction = model.predict(inputImage);

The prediction will be a tensor containing the model‘s output logits for each of the 10 classes. We can get the predicted class index by finding the argmax:

const predictedIndex = prediction.argMax(1).dataSync()[0];

And lookup the corresponding class name:

const classNames = [‘airplane‘, ‘automobile‘, ‘bird‘, ‘cat‘, ‘deer‘, ‘dog‘, ‘frog‘, ‘horse‘, ‘ship‘, ‘truck‘];
const predictedClassName = classNames[predictedIndex];
console.log(`Predicted class: ${predictedClassName}`);

Alternatively, we can get the top-k predicted classes and their probabilities:

const k = 3;
const probs = tf.softmax(prediction);
const topK = probs.topk(k);
const topkIndices = topK.indices.dataSync();
const topkProbs = topK.values.dataSync();

const topkClassNames = topkIndices.map(index => classNames[index]);
console.log(`Top ${k} predictions:`);
for (let i = 0; i < k; i++) {
  console.log(`${topkClassNames[i]}: ${topkProbs[i].toFixed(3)}`);
}

Designing a User Interface

The final step is to integrate the model into an intuitive web app UI. At a minimum, the interface will need:

  • A way for users to select an image file to classify
  • A button to trigger inference
  • A place to display the model‘s predictions

Here‘s a minimal example using HTML and JavaScript:

<input type="file" id="image-input" accept="image/*">
<button id="predict-button">Predict</button>
<div id="predictions"></div>

<script>
const imageInput = document.getElementById(‘image-input‘);
const predictButton = document.getElementById(‘predict-button‘);
const predictions = document.getElementById(‘predictions‘);

predictButton.addEventListener(‘click‘, async () => {
  const imageFile = imageInput.files[0];
  if (!imageFile) {
    alert(‘Please select an image first.‘);
    return;
  }

  const inputImage = await preprocessImage(imageFile);
  const prediction = model.predict(inputImage);
  const predictedIndex = prediction.argMax(1).dataSync()[0];
  const predictedClassName = classNames[predictedIndex];

  predictions.textContent = `Predicted class: ${predictedClassName}`;
});
</script>

For a more polished UI, you could use a frontend framework like React or Angular, and add features like:

  • A preview of the selected image
  • A progress bar during model loading
  • Nicer formatting of the top-k predictions, perhaps with bars representing the probabilities
  • The ability to select between multiple models

Limitations and Considerations

Running ML models in the browser is a powerful technique, but it does have some limitations and tradeoffs to consider:

  • Large models may take a long time to load on slow connections or devices
  • Old browsers may not support all the required web APIs
  • Applications processing sensitive user data need to consider privacy carefully
  • On-device computation may drain the battery more quickly on mobile devices
  • WebGL driver bugs can cause models to produce incorrect results on some devices
  • Debugging models can be tricky without direct access to the runtime environment

Despite these challenges, in-browser machine learning offers an exciting new paradigm for deploying intelligent web apps. With some creativity and careful engineering, the sky‘s the limit for what‘s possible!

Conclusion

In this post, we walked through the process of converting a Keras model to run in a web browser using TensorFlow.js. The key steps were:

  1. Train and save a Keras model
  2. Convert the SavedModel to the TensorFlow.js Layers format
  3. Load the converted model in a web page
  4. Preprocess input data in the browser
  5. Run inference and display the predictions

We also touched on some best practices for designing web apps around ML models and discussed the limitations and tradeoffs involved.

Hopefully this post has given you a taste of what‘s possible with in-browser ML and the tools to get started. So what are you waiting for? Go forth and build some awesome intelligent web apps! And if you have any questions or comments, feel free to share them below. Happy hacking!

Similar Posts