TensorFlow for Mobile and IoT and TensorFlow.js – Deep Learning with TensorFlow 2 and Keras – Second Edition


TensorFlow for Mobile and IoT and TensorFlow.js

In this chapter we will learn the basics of TensorFlow for Mobile and IoT (Internet of Things). We will briefly present TensorFlow Mobile and we will introduce TensorFlow Lite in more detail. TensorFlow Mobile and TensorFlow Lite are open source deep learning frameworks for on-device inference. Some examples of Android, iOS, and Raspberry PI applications will be discussed, together with examples of deploying pretrained models such as MobileNet v1, v2, v3 (image classification models designed for mobile and embedded vision applications), PoseNet for pose estimation (a vision model that estimates the poses of people in image or video), DeepLab segmentation (an image segmentation model that assigns semantic labels (for example, dog, cat, car) to every pixel in the input image), and MobileNet SSD object detection (an image classification model that detects multiple objects with bounding boxes). This chapter will conclude with an example of federated learning, a new machine learning framework distributed over millions of mobile devices that is thought to respect user privacy.

TensorFlow Mobile

TensorFlow Mobile is a framework for producing code on iOS and Android. The key idea is to have a platform that allows you to have light models that don't consume too much device resources such as battery or memory. Typical examples of applications are image recognition on the device, speech recognition, or gesture recognition. TensorFlow Mobile was quite popular until 2018 but then became progressively less and less adopted in favor of TensorFlow Lite.

TensorFlow Lite

TensorFlow Lite is a lightweight platform designed by TensorFlow. This platform is focused on mobile and embedded devices such as Android, iOS, and Raspberry PI. The main goal is to enable machine learning inference directly on the device by putting a lot of effort in three main characteristics: (1) small binary and model size to save on memory, (2) low energy consumption to save on the battery, and (3) low latency for efficiency. It goes without saying that battery and memory are two important resources for mobile and embedded devices. In order to achieve these goals, Lite uses a number of techniques such as Quantization, FlatBuffers, Mobile interpreter, and Mobile converter, which we are going to review briefly in the following sections.


Quantization refers to a set of techniques that constrains an input made of continuous values (such as real numbers) into a discrete set (such as integers). The key idea is to reduce the space occupancy of Deep Learning (DL) models by representing the internal weight with integers instead of real numbers. Of course, this implies trading space gains for some amount of performance of the model. However, it has been empirically shown in many situations that a quantized model does not suffer from a significant decay in performance. TensorFlow Lite is internally built around a set of core operators supporting both quantized and floating-point operations.

Model quantization is a toolkit for applying quantization. This operation is applied to the representations of weights and, optionally, to the activations for both storage and computation. There are two types of quantization available:

  • Post-training quantization quantizes weights and the result of activations post training.
  • Quantization-aware training allows for the training of networks that can be quantized with minimal accuracy drop (only available for specific CNNs). Since this is a relatively experimental technique, we are not going to discuss it in this chapter but the interested reader can find more information in [1].

TensorFlow Lite supports reducing the precision of values from full floats to half-precision floats (float16) or 8-bit integers. TensorFlow reports multiple trade-offs in terms of accuracy, latency, and space for selected CNN models (see Figure 1, source: https://www.tensorflow.org/lite/performance/model_optimization):

Figure 1: Trade-offs for various quantized CNN models


FlatBuffers (https://google.github.io/flatbuffers/) is an open source format optimized to serialize data on mobile and embedded devices. The format was originally created at Google for game development and other performance-critical applications. FlatBuffers supports access to serialized data without parsing/unpacking for fast processing. The format is designed for memory efficiency and speed by avoiding unnecessary multiple copies in memory. FlatBuffers works across multiple platforms and languages such as C++, C#, C, Go, Java, JavaScript, Lobster, Lua, TypeScript, PHP, Python, and Rust.

Mobile converter

A model generated with TensorFlow needs to be converted into a TensorFlow Lite model. The converter can introduce optimizations for improving the binary size and performance. For instance, the converter can trim away all the nodes in a computational graph that are not directly related to inference, but instead were needed for training.

Mobile optimized interpreter

TensorFlow Lite runs on a highly optimized interpreter that is used to optimize the underlying computational graphs (see Chapter 2, TensorFlow 1.x and 2.x), which in turn are used to describe the machine learning models. Internally, the interpreter uses multiple techniques to optimize the computational graph by inducing a static graph order and by ensuring better memory allocation. The Interpreter Core takes ~100 kb alone or ~300 kb with all supported kernels.

Supported platforms

On Android, TensorFlow Lite inference can be performed using either Java or C++. On iOS, TensorFlow Lite inference can run in Swift and Objective-C. On Linux platforms (such as Raspberry Pi), inferences run in C++ and Python. TensorFlow Lite for microcontrollers is an experimental port of TensorFlow Lite designed to run machine learning models on microcontrollers based on Arm Cortex-M (https://developer.arm.com/ip-products/processors/cortex-m) Series processors including Arduino Nano 33 BLE Sense (https://store.arduino.cc/usa/nano-33-ble-sense-with-headers), SparkFun Edge (https://www.sparkfun.com/products/15170), and the STM32F746 Discovery kit (https://www.st.com/en/evaluation-tools/32f746gdiscovery.html). These microcontrollers are frequently used for IoT applications.


The architecture of TensorFlow Lite is described in Figure 2 (from https://www.tensorflow.org/lite/convert/index). As you can see, both tf.keras (for example, TensorFlow 2.x) and Low-level APIs are supported. A standard TensorFlow 2.x model can be converted by using TFLite Converter and then saved in a TFLite FlatBuffer format (named .tflite), which is then executed by the TFLite interpreter on available devices (GPUs, CPUs) and on native device APIs. The concrete function in Figure 2 defines a graph that can be converted to a TensorFlow Lite model or be exported to a SavedModel.

Using TensorFlow Lite

Using TensorFlow Lite involves the following steps:

  1. Model selection: A standard TensorFlow 2.x model is selected for solving a specific task. This can be either a custom-built model or a pretrained model.
  2. Model conversion: The selected model is converted with the TensorFlow Lite converter, generally invoked with a few lines of Python code.
  3. Model deployment: The converted model is deployed on the chosen device, either a phone or an IoT device and then run by using the TensorFlow Lite interpreter. As discussed, APIs are available for multiple languages.
  4. Model optimization: The model can be optionally optimized by using the TensorFlow Lite optimization framework:

    Figure 2: TensorFlow Lite internal architecture

A generic example of application

In this section we are going to see how to convert a model to TensorFlow Lite and then run it. Note that training can still be performed by TensorFlow in the environment that best fits your needs. However, inference runs on the mobile device. Let's see how with the following code fragment in Python:

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

The code is self-explanatory. A standard TensorFlow 2.x model is opened and converted by using tf.lite.TFLiteConverter.from_saved_model(saved_model_dir). Pretty simple! Note that no specific installation is required. We simply use the tf.lite API (https://www.tensorflow.org/api_docs/python/tf/lite). It is also possible to apply a number of optimizations. For instance, post-training quantization can be applied by default:

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_quantized_model)

Once the model is converted it can be copied onto the specific device. Of course, this step is different for each different device. Then the model can run by using the language you prefer. For instance, in Java the invocation happens with the following code snippet:

try (Interpreter interpreter = new Interpreter(tensorflow_lite_model_file)) {
  interpreter.run(input, output);

Again, pretty simple! What is very useful is that the same steps can be followed for a heterogeneous collection of Mobile and IoT devices.

Using GPUs and accelerators

Modern phones frequently have accelerators on board that allow floating-point matrix operations to be performed faster. In this case, the interpreter can use the concept of Delegate, and specifically GpuDelegate(), to use GPUs. Let's look at an example in Java:

GpuDelegate delegate = new GpuDelegate();
Interpreter.Options options = (new Interpreter.Options()).addDelegate(delegate);
Interpreter interpreter = new Interpreter(tensorflow_lite_model_file, options);
try {
  interpreter.run(input, output);

Again, the code is self-commenting. A new GpuDelegate() is created and then it is used by the Interpreter to run the model on a GPU.

An example of application

In this section, we are going to use TensorFlow Lite for building an example application that is later deployed on Android. We will use Android Studio (https://developer.android.com/studio/) to compile the code. The first step is to clone the repo with:

git clone https://github.com/tensorflow/examples

Then we open an existing project (see Figure 3) with the path examples/lite/examples/image_classification/android.

Then you need to install Android Studio from https://developer.android.com/studio/install and an appropriate distribution of Java. In my case, I selected the Android Studio macOS distribution, and installed Java via brew with the following command:

brew tap adoptopenjdk/openjdk
brew cask install  homebrew/cask-versions/adoptopenjdk8

After that you can launch the sdkmanager and install the required packages. In my case, I decided to use the internal emulator and deploy the application on a virtual device emulating a Google Pixel 3 XL. The required packages are reported in Figure 3:

Figure 3: Required packages to use a Google Pixel 3 XL emulator

Then start Android Studio and select Open an existing Android Studio project as shown in Figure 4:

Figure 4: Opening a new Android project

Open the Adv Manager (under the Tool menu) and follow the instructions for how to create a virtual device, as the one shown in Figure 5:

Figure 5: Creating a virtual device

Pretrained models in TensorFlow Lite

In many interesting use cases, it is possible to use a pretrained model that is already suitable for mobile computation. This is a field of active research with new proposals coming pretty much every month. TensorFlow Lite comes with a set of prebuilt models that are ready to use (https://www.tensorflow.org/lite/models/). As of October 2019, these include:

  • Image classification: Used to identify multiple classes of objects such as places, plants, animals, activities, and people.
  • Object detection: Used to detect multiple objects with bounding boxes.
  • Pose estimation: Used to estimate poses with single or multiple people.
  • Smart reply: Used to create reply suggestions for conversational chat messages.
  • Segmentations: Identifies the shape of objects together with semantic labels for people, places, animals, and many additional classes.
  • Style transfers: Used to apply artistic styles to any given image.
  • Text classification: Used to assign different categories to textual content.
  • Question and answer: Used to provide answers to questions provided by users.

In this section, we will discuss all the optimized pretrained models available in TensorFlow Lite out-of-the-box as of November 2019. These models can be used for a large number of mobile and edge computing use cases. Compiling the example code is pretty simple.

You just import a new project from each example directory and Android Studio will use Gradle (https://gradle.org/) for synching the code with the latest version in the repo and for compiling. If you compile all the examples, you should be able to see them in the emulator (see Figure 6). Remember to select Build | Make Project, and Android Studio will do the rest.

Edge computing is a distributed computing model that brings computation and data closer to the location where it is needed.

Figure 6: Emulated Google Pixel 3 XL with TensorFlow Lite example applications

Image classification

As of November 2019, the list of available models for pretrained classification is rather large, and it offers the opportunity to trade space, accuracy, and performance as shown in Figure 7 (source: https://www.tensorflow.org/lite/guide/hosted_models):

Figure 7: Space, accuracy, and performance trade-offs for various mobile models

MobileNet v1 is a quantized CNN model described in Benoit Jacob [2]. MobileNet V2 is an advanced model proposed by Google [3]. Online, you can also find floating-point models, which offer the best balance between model size and performance. Note that GPU acceleration requires the use of floating-point models. Note that recently AutoML models for mobile have been proposed based an automated mobile neural architecture search (MNAS) approach [4], beating the models handcrafted by humans.

We will discuss AutoML in Chapter 14, An Introduction to AutoML, and the interested reader can refer to MNAS documentation in the references [4] for applications to mobile.

Object detection

TensorFlow Lite comes with a pretrained model that can detect multiple objects within an image, with bounding boxes. 80 different classes of objects are recognized. The network is based on a pretrained quantized COCO SSD MobileNet v1 model. For each object, the model provides the class, the confidence of detection, and the vertices of the bounding boxes (https://www.tensorflow.org/lite/models/object_detection/overview).

Pose estimation

TensorFlow Lite includes a pretrained model for detecting parts of human bodies in an image or a video. For instance, it is possible to detect noses, left/right eyes, hips, ankles, and many other parts. Each detection comes with an associated confidence score (https://www.tensorflow.org/lite/models/pose_estimation/overview).

Smart reply

TensorFlow Lite has also a pretrained model for generating replies to chat messages. These replies are contextualized and similar to what is available on Gmail (https://www.tensorflow.org/lite/models/smart_reply/overview).


TensorFlow Lite has also a pretrained model (https://www.tensorflow.org/lite/models/segmentation/overview) for image segmentation, where the goal is to decide what the semantic labels (for example, person, dog, cat) assigned to every pixel in the input image are. Segmentation is based on the DeepLab algorithm [5].

Style transfer

TensorFlow Lite supports artistic style transfer (see Chapter 5, Advanced Convolutional Neural Networks) via a combination of a MobileNetV2-based neural network, which reduces the input style image to a 100-dimension style vector, and a style transform model, which applies the style vector to a content image to create the stylized image (https://www.tensorflow.org/lite/models/style_transfer/overview).

Text classification

TensorFlow Lite comes with a model for text classification and sentiment analysis (https://www.tensorflow.org/lite/models/text_classification/overview) trained on the Large Movie Review Dataset v1.0 (http://ai.stanford.edu/~amaas/data/sentiment/) with IMDb movie reviews that are positive or negative. An example of text classification is given in Figure 8:

Figure 8: An example of Text classification on Android with TensorFlow Lite

Question and answering

TensorFlow Lite also includes (https://www.tensorflow.org/lite/models/bert_qa/overview) a pretrained model for answering questions based on text fragments. The model is based on a compressed variant of BERT [6] (see Chapter 7, Word Embeddings) called MobileBERT [7], which runs 4x faster and has 4x smaller size. An example of Q&A is given in Figure 9:

Figure 9: An example of Q&A on Android with TensorFlow Lite and Bert

A note about using mobile GPUs

This section concludes the overview on pretrained models for mobile devices and IoT. Note that modern phones are equipped with internal GPUs. For instance, on Pixel 3, TensorFlow Lite GPU inference accelerates inference to 2–7x faster than CPUs for many models (see Figure 10, source: https://medium.com/tensorflow/tensorflow-lite-now-faster-with-mobile-gpus-developer-preview-e15797e6dee7):

Figure 10: GPU speed-up over CPU for various learning models running on various phones

An overview of federated learning at the edge

As discussed, edge computing is a distributed computing model that brings computation and data closer to the location where it is needed.

Now, let's introduce Federated Learning (FL) [8] at the edge, starting with two use cases.

Suppose you built an app for playing music on mobile devices and then you want to add recommendation features aimed at helping users to discover new songs they might like. Is there a way to build a distributed model that leverages each user's experience without disclosing any private data?

Suppose you are a car manufacturer producing millions of cars connected via 5G networks, and then you want to build a distributed model for optimizing each car's fuel consumption. Is there a way to build such a model without disclosing the driving behavior of each user?

Traditional machine learning requires you to have a centralized repository for training data either on your desktop, or in your datacenter, or in the cloud. Federated learning pushes the training phase at the edge by distributing the computation among millions of mobile devices. These devices are ephemeral in that they are not always available for the learning process and they can disappear silently (for instance, a mobile phone can be switched off all of a sudden). The key idea is to leverage the CPUs and the GPU of each mobile phone that is made available for an FL computation. Each mobile device forming a part of a distributed FL training downloads a (pretrained) model from a central server and it performs local optimization based on the local training data collected on each specific mobile device. This process is similar to the transfer learning process (see Chapter 5, Advanced Convolutional Neural Networks), but it is distributed at the edge. Each locally updated model is then sent back by millions of edge devices to a central server to build an averaged shared model.

Of course, there are many issues to be considered. Let's review them:

  1. Battery usage: Each mobile device that is part of an FL computation should save as much as possible on local battery usage.
  2. Encrypted communication: Each mobile device belonging to an FL computation has to use encrypted communication with the central server to update the locally built model.
  3. Efficient communication: Typically, deep learning models are optimized with optimization algorithms such as SGD (see Chapter 1, Neural Network Foundations with TensorFlow 2.0, and Chapter 15, The Math Behind Deep Learning). However, FL works with millions of devices and there is therefore a strong need to minimize the communication patterns. Google introduced a Federated Averaging algorithm [8], which is reported to reduce the amount of communication 10x-100x when compared with vanilla SGD. Plus, compression techniques [9] reduce the communication costs by an additional 100x with random rotations and quantization.
  4. Ensure user privacy: This is probably the most important point. All local training data acquired at the edge must stay at the edge. This means that the training data acquired on a mobile device cannot be sent to a central server. Equally important, any user behavior learned in locally trained models must be anonymized so that it is not possible to understand any specific action performed by specific individuals.

Figure 11 shows a typical FL architecture (source [10]). An FL Server sends a model and a training plan to millions of devices. The training plan includes information on how frequently updates are expected and other metadata.

Each device runs the local training and sends a model update back to the global services. Note that each device has an FL runtime providing federated learning services to an app process that stores data in a local example store. The FL runtime fetches the training examples from the example store:

Figure 11: An example of federated learning architecture

TensorFlow FL APIs

The TensorFlow Federated (TTF) platform has two layers:

  • Federated learning (FL), a high-level interface that works well with tf.keras and non tf.keras models. In the majority of situations you will use this API for distributed training that is privacy preserving.
  • Federated core (FC), a low-level interface that is highly customizable and allows you to interact with low level communications and with federated algorithms. You will need this API only if you intend to implement new and sophisticated distributed learning algorithms. This topic is rather advanced, and we are not going to cover it in this book. If you wish to learn more, you can find more information online (https://www.tensorflow.org/federated/federated_core).

The FL API has three key parts:

  1. Models: Used to wrap existing models for enabling federating learning. This can be achieved via the tff.learning.from_keras_model(), or via subclassing of tff.learning.Model(). For instance, you can have the following code fragment:
    keras_model = …
    keras_federated_model = tff.learning.from_compiled_keras_model(keras_model, ..)
  2. Builders: This is the layer where the federated computation happens. There are two phases: compilation, where the learning algorithm is serialized into an abstract representation of the computation, and execution, where the represented computation is run.
  3. Datasets: This is a large collection of data that can be used to simulate federated learning locally – a step useful for the initial fine tuning.

We conclude this overview by mentioning that you can find a detailed description (https://www.tensorflow.org/federated/federated_learning) of APIs online, and also a number of coding examples. The suggestion is to start by using the Colab notebook made available by Google (https://colab.research.google.com/github/tensorflow/federated/blob/v0.10.1/docs/tutorials/federated_learning_for_image_classification.ipynb). The framework allows us to simulate the distributed training before running it on a real environment. The library in charge of FL learning is tensorflow_federated. Figure 12 discussed all the steps used in federated learning with multiple nodes, and it might be useful to better understand what has been discussed in this section. The next section will introduce TensorFlow.js, a variant of TensorFlow that can be used natively in JavaScript:

Figure 12: An example of federated learning with multiple nodes (source: https://upload.wikimedia.org/wikipedia/commons/e/e2/Federated_learning_process_central_case.png)


TensorFlow.js is a JavaScript library for machine learning models that can work either in vanilla mode or via Node.js. In this section we are going to review both of them.

Vanilla TensorFlow.js

TensorFlow.js is a JavaScript library for training and using Machine Learning (ML) models in a browser. It is derived from deeplearn.js, an open source, hardware-accelerated library for doing Deep Learning (DL) in JavaScript, and is now a companion library to TensorFlow.

The most common use of TensorFlow.js is to make pretrained ML/DL models available on the browser. This can help in situations where it may not be feasible to send client data back to the server due to network bandwidth or security concerns. However, TensorFlow.js is a full stack ML platform, and it is possible to build and train an ML/DL model from scratch, as well as fine-tune an existing pretrained model with new client data.

An example of a TensorFlow.js application is the TensorFlow Projector (https://projector.tensorflow.org), which allows a client to visualize their own data (as word vectors) in 3-dimensional space, using one of several dimensionality reduction algorithms provided. There are a few other examples of TensorFlow.js applications listed on the TensorFlow.js demo page (https://www.tensorflow.org/js/demos).

Similarly to TensorFlow, TensorFlow.js also provides two main APIs – the Ops API, which exposes low-level tensor operations such as matrix multiplication, and the Layers API, which exposes Keras-style high-level building blocks for neural networks.

At the time of writing, TensorFlow.js runs on three different backends. The fastest (and also the most complex) is the WebGL backend, which provides access to WebGL's low-level 3D graphics APIs and can take advantage of GPU hardware acceleration. The other popular backend is the Node.js backend, which allows the use of TensorFlow.js in server-side applications. Finally, as a fallback, there is the CPU-based implementation in plain JavaScript that will run in any browser.

In order to gain a better understanding of how to write a TensorFlow.js application, we will walk through an example of classifying MNIST digits using a Convolutional Neural Network (CNN) provided by the TensorFlow.js team (https://storage.googleapis.com/tfjs-examples/mnist/dist/index.html).

The steps here are similar to a normal supervised model development pipeline – load the data, define, train, and evaluate the model.

JavaScript works inside a browser environment, within an HTML page. The HTML file (named index.html) below represents this HTML page. Notice the two imports for TensorFlow.js (tf.min.js) and the TensorFlow.js visualization library (tfjs-vis.umd.min.js) – these provide library functions that we will use in our application. JavaScript code for our application comes from data.js and script.js files, located in the same directory as our index.html file:

<!DOCTYPE html>
  <meta charset="utf-8">
  <meta http-equiv="X-UA-Compatible" content="IE=edge">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <!-- Import TensorFlow.js -->
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
  <!-- Import tfjs-vis -->
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script>
  <!-- Import the data file -->
  <script src="data.js" type="module"></script>
  <!-- Import the main script file -->
  <script src="script.js" type="module"></script>

For deployment, we will deploy these three files (index.html, data.js, and script.js) on a web server, but for development we can start a web server up by calling a simple one bundled with the Python distribution. This will start up a web server on port 8000 on localhost, and the index.html file can be rendered on the browser at http://localhost:8000:

python -m http.server

The next step is to load the data. Fortunately, Google provides a JavaScript script that we have called directly from our index.html file. It downloads the images and labels from GCP storage and returns shuffled and normalized batches of image and label pairs for training and testing. We can download this to the same folder as the index.html file using the following command:

wget https://raw.githubusercontent.com/tensorflow/tfjs-examples/master/mnist-core/data.js

Model definition, training, and evaluation code is all specified inside the script.js file. The function to define and build the network is shown in the following code block. As you can see, it is very similar to the way you would build a sequential model with tf.keras. The only difference is the way you specify the arguments, as a dictionary of name-value pairs instead of a list of parameters. The model is a sequential model, that is, a list of layers. Finally, the model is compiled with the Adam optimizer:

function getModel() {
  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  const IMAGE_CHANNELS = 1;  
  const NUM_OUTPUT_CLASSES = 10;
  const model = tf.sequential();
    kernelSize: 5,
    filters: 8,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
    poolSize: [2, 2], strides: [2, 2]
    kernelSize: 5,
    filters: 16,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
    poolSize: [2, 2], strides: [2, 2]
    kernelInitializer: 'varianceScaling',
    activation: 'softmax'
  const optimizer = tf.train.adam();
    optimizer: optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  return model;

The model is then trained for 10 epochs with batches from the training dataset and validated inline using batches from the test dataset. Best practice is to create a separate validation dataset from the training set. However, in order to keep our focus on the more important aspect of showing how to use TensorFlow.js to design an end-to-end DL pipeline, we are using the external data.js file provided by Google, which provides functions to return only a training and a test batch. In our example, we will use the test dataset for validation as well as evaluation later. This is likely to give us better accuracies compared to what we would have achieved with an unseen (during training) test set, but that is unimportant for an illustrative example such as this one:

async function train(model, data) {
  const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
  const container = {
    name: 'Model Training', styles: { height: '1000px' }
  const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
  const BATCH_SIZE = 512;
  const TRAIN_DATA_SIZE = 5500;
  const TEST_DATA_SIZE = 1000;
  const [trainXs, trainYs] = tf.tidy(() => {
    const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
    return [
      d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
  const [testXs, testYs] = tf.tidy(() => {
    const d = data.nextTestBatch(TEST_DATA_SIZE);
    return [
      d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
  return model.fit(trainXs, trainYs, {
    batchSize: BATCH_SIZE,
    validationData: [testXs, testYs],
    epochs: 10,
    shuffle: true,
    callbacks: fitCallbacks

Once the model finishes training, we want to make predictions and evaluate the model on its predictions. The following functions will do the predictions and compute the overall accuracy for each of the classes over all the test set examples, as well as produce a confusion matrix across all the test set samples:

const classNames = [
  'Zero', 'One', 'Two', 'Three', 'Four', 
  'Five', 'Six', 'Seven', 'Eight', 'Nine'];
function doPrediction(model, data, testDataSize = 500) {
  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  const testData = data.nextTestBatch(testDataSize);
  const testxs = testData.xs.reshape(
    [testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]);
  const labels = testData.labels.argMax([-1]);
  const preds = model.predict(testxs).argMax([-1]);
  return [preds, labels];
async function showAccuracy(model, data) {
  const [preds, labels] = doPrediction(model, data);
  const classAccuracy = await tfvis.metrics.perClassAccuracy(
    labels, preds);
  const container = {name: 'Accuracy', tab: 'Evaluation'};
  tfvis.show.perClassAccuracy(container, classAccuracy, classNames);
async function showConfusion(model, data) {
  const [preds, labels] = doPrediction(model, data);
  const confusionMatrix = await tfvis.metrics.confusionMatrix(
    labels, preds);
  const container = {name: 'Confusion Matrix', tab: 'Evaluation'};
      container, {values: confusionMatrix}, classNames);

Finally, the run() function will call all these functions in sequence to build an end-to-end ML pipeline:

import {MnistData} from './data.js';
async function run() { 
  const data = new MnistData();
  await data.load();
  await showExamples(data);
  const model = getModel();
  tfvis.show.modelSummary({name: 'Model Architecture'}, model);
  await train(model, data);
  await showAccuracy(model, data);
  await showConfusion(model, data);

document.addEventListener('DOMContentLoaded', run);

Refreshing the browser location http://localhost:8000/index.html will invoke the run() method above. The table below shows the model architecture, and the plots below that show the progress of the training.

On the left are the loss and accuracy values on the validation dataset observed at the end of each batch, and on the right are the same loss and accuracy values observed on the training dataset (blue) and validation dataset (red) at the end of each epoch:

In addition, the following figure shows the accuracies across different classes for predictions from our trained model on the test dataset, as well as the confusion matrix of predicted versus actual classes for test dataset samples:

We have seen how to use TensorFlow.js within the browser. The next section will explain how to convert a model from Keras into TensorFlow.js.

Converting models

Sometimes it is convenient to convert a model that has already been created with tf.keras. This is very easy and can be done offline with the following command, which takes a Keras model from /tmp/model.h5 and outputs a JavaScript model into /tmp/tfjs_model:

tensorflowjs_converter --input_format=keras /tmp/model.h5 /tmp/tfjs_model

The next section will explain how to use pretrained models in TensorFlow.js.

Pretrained models

TensorFlow.js comes with a significant number of pretrained models for deep learning with image, video, and text. The models are hosted on NPM, so it's very simple to use them if you are familiar with Node development.

The following table summarizes what is available as of November 2019 (source: https://github.com/tensorflow/tfjs-models):

Type Model Details Install


MobileNet (https://github.com/tensorflow/tfjs-models/tree/master/mobilenet)

Classify images with labels from the ImageNet database.

npm i @tensorflow-models/mobilenet

PoseNet (https://github.com/tensorflow/tfjs-models/tree/master/posenet)

A machine learning model that allows for real-time human pose estimation in the browser; see a detailed description here: https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-tensorflow-js-7dd0bc881cd5.

npm i @tensorflow-models/posenet

Coco SSD (https://github.com/tensorflow/tfjs-models/tree/master/coco-ssd)

Object detection model that aims to localize and identify multiple objects in a single image; based on the TensorFlow object detection API (https://github.com/tensorflow/models/blob/master/research/object_detection/README.md).

npm i @tensorflow-models/coco-ssd

BodyPix (https://github.com/tensorflow/tfjs-models/tree/master/body-pix)

Real-time person and body-part segmentation in the browser using TensorFlow.js.

npm i @tensorflow-models/body-pix

DeepLab v3(https://github.com/tensorflow/tfjs-models/tree/master/deeplab)

Semantic segmentation

npm i @tensorflow-models/deeplab


Speech Commands (https://github.com/tensorflow/tfjs-models/tree/master/speech-commands)

Classify 1 second audio snippets from the speech commands dataset (https://github.com/tensorflow/docs/blob/master/site/en/r1/tutorials/sequences/audio_recognition.md).

npm i @tensorflow-models/speech-commands


Universal Sentence Encoder (https://github.com/tensorflow/tfjs-models/tree/master/universal-sentence-encoder)

Encode text into a 512-dimensional embedding to be used as inputs to natural language processing tasks such as sentiment classification and textual similarity.

npm i @tensorflow-models/universal-sentence-encoder

Text Toxicity

Score the perceived impact a comment might have on a conversation, from "Very toxic" to "Very healthy".

npm i @tensorflow-models/toxicity

General Utilities

KNN Classifier (https://github.com/tensorflow/tfjs-models/tree/master/knn-classifier)

This package provides a utility for creating a classifier using the K-nearest neighbors algorithm; it can be used for transfer learning.

npm i @tensorflow-models/knn-classifier

Each pretrained model can be directly used from HTML. For instance, this is an example with the KNN Classifier:

    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <!-- Load MobileNet -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
    <!-- Load KNN Classifier -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>

The next section will explain how to use pretrained models in Node.js.


In this section we will give an overview of how to use TensorFlow with Node.js. Let's start:

The CPU package is imported with the following line of code, which will work for all Mac, Linux, and Windows platforms:

import * as tf from '@tensorflow/tfjs-node'

The GPU package is imported with the following line of code (as of November 2019 this will work only on a GPU in a CUDA environment):

import * as tf from '@tensorflow/tfjs-node-gpu'

An example of Node.js code for defining and compiling a simple dense model is reported below. The code is self-explanatory:

const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [400] }));
  loss: 'meanSquaredError',
  optimizer: 'sgd',
  metrics: ['MAE']

Training can then start with the typical Node.js asynchronous invocation:

const xs = tf.randomUniform([10000, 400]);
const ys = tf.randomUniform([10000, 1]);
const valXs = tf.randomUniform([1000, 400]);
const valYs = tf.randomUniform([1000, 1]);
async function train() {
  await model.fit(xs, ys, {
    epochs: 100,
    validationData: [valXs, valYs],

In this section, we have discussed how to use TensorFlow.js with both vanilla JavaScript and with Node.js with sample applications for both the browser and for backend computation.


In this chapter we have discussed how to use TensorFlow Lite for mobile devices and IoT and deployed real applications on Android devices. Then, we also talked about Federated Learning for distributed learning across thousands (millions) of mobile devices, taking into account privacy concerns. The last section of the chapter was devoted to TensorFlow.js for using TensorFlow with vanilla JavaScript or with Node.js.

The next chapter is about AutoML, a set of techniques used to enable domain experts who are unfamiliar with machine learning technologies to use ML techniques easily.


  1. Quantization-aware training https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize
  2. Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference, Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko (Submitted on 15 Dec 2017); https://arxiv.org/abs/1712.05877
  3. MobileNetV2: Inverted Residuals and Linear Bottlenecks, Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen (Submitted on 13 Jan 2018 (v1), last revised 21 Mar 2019 (v4)) https://arxiv.org/abs/1806.08342
  4. MnasNet: Platform-Aware Neural Architecture Search for Mobile, Mingxing Tan, Bo Chen, Ruoming Pang, Vijay Vasudevan, Mark Sandler, Andrew Howard, Quoc V. Le https://arxiv.org/abs/1807.11626
  5. DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs, Liang-Chieh Chen, George Papandreou, Iasonas Kokkinos, Kevin Murphy, and Alan L. Yuille, May 2017, https://arxiv.org/pdf/1606.00915.pdf
  6. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova (Submitted on 11 Oct 2018 (v1), last revised 24 May 2019 v2) https://arxiv.org/abs/1810.04805
  7. MOBILEBERT: TASK-AGNOSTIC COMPRESSION OF BERT BY PROGRESSIVE KNOWLEDGE TRANSFER, Anonymous authors, Paper under double-blind review, https://openreview.net/pdf?id=SJxjVaNKwB, 25 Sep 2019 (modified: 25 Sep 2019)ICLR 2020 Conference Blind Submission Readers: Everyone
  8. Communication-Efficient Learning of Deep Networks from Decentralized Data, H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Agüera y Arcas (Submitted on 17 Feb 2016 (v1), last revised 28 Feb 2017 (this version, v3)) https://arxiv.org/abs/1602.05629
  9. Federated Learning: Strategies for Improving Communication Efficiency, Jakub Konečný, H. Brendan McMahan, Felix X. Yu, Peter Richtárik, Ananda Theertha Suresh, Dave Bacon (Submitted on 18 Oct 2016 (v1), last revised 30 Oct 2017 (this version, v2)) https://arxiv.org/abs/1610.05492
  10. TOWARDS FEDERATED LEARNING AT SCALE: SYSTEM DESIGN, Keith Bonawitz et al. 22 March 2019 https://arxiv.org/pdf/1902.01046.pdf