Leveraging AI in Front-End Development: A New Era of Interaction

With the rapid advancement of deep learning technology, the magic of AI is gradually permeating every corner of our daily lives, unleashing infinite possibilities. Particularly in the field of front-end development, AI is like a breath of fresh air, bringing a whole new experience to traditional user interactions. Imagine browsing a webpage through gestures, playing your favorite music with just a smile, or controlling an entire application with a single voice command. These scenarios are no longer scenes from a science fiction movie but will soon become a part of our everyday reality.

Moreover, AI's extraordinary capabilities have become a fountain of inspiration for front-end product innovation. Whether it's text-to-speech (TTS), image classification, or object detection, what once seemed unattainable is now within reach. And as the boundaries of technology continue to expand, high-performance models emerging from communities like Kaggle and Hugging Face are simply breathtaking. For front-end developers to ignore these advancements is undoubtedly a considerable waste of potential.

Additionally, for companies that have already ventured into or plan to introduce AI technology, server costs pose a significant burden. Integrating these capabilities into the front end can drastically reduce costs while enhancing the user experience.

As such, mastering the application of machine learning technology within the browser has become a mandatory skill for front-end developers. In this article, I will guide you through three methods to apply AI in the front end using a series of code examples: by leveraging pre-existing libraries, using TensorFlow.js, and OnnxRuntime-web. These approaches will illuminate your front-end products and make AI your ultimate ally.

Please refer to the accompanying code examples at https://github.com/ymrdf/web-ai-examples for more details.

1 Utilizing Pre-existing Libraries

By using this approach, developers can embed powerful AI capabilities into their applications simply by calling API interfaces. Here are some useful libraries:

  • @tensorflow-models/face-detection: Focuses on facial recognition functionality, capable of identifying faces and key features in images.

  • @tensorflow-models/mobilenet: A lightweight model suited for image classification and other visual tasks, optimized for mobile and low-power devices.

  • @tensorflow-models/pose-detection: Detects human poses and key points in images or videos, supporting various pose detection algorithms.

As a specific example, we can use the @tensorflow-models/face-detection library to implement facial detection. Below is a simple code example for reference. More detailed implementation and code can be found in this GitHub project:

async createDetector() {
  try {
    this.detector = await faceDetection.createDetector(
      faceDetection.SupportedModels.MediaPipeFaceDetector,
      {
        runtime: "mediapipe",
        modelType: "short",
        maxFaces: 1,
        solutionPath: this.options.solutionPath 
          ? this.options.solutionPath 
          : `https://cdn.jsdelivr.net/npm/@mediapipe/face_detection@${mpFaceDetection.VERSION}`,
      }
    );

    return;
  } catch (e) {
    console.warn(e);
    this.setStatus(EUserDetectorStatus.faceDetectorCreateError);
  }
}

The advantage of using this method is ease of use. For developers who are not familiar with the principles of deep learning and model training, this approach significantly lowers the technical barrier, allowing for the quick integration of AI capabilities with just a few lines of code. However, there are currently very few ready-to-use libraries available, and it is challenging to find libraries that meet specific needs. Most of the libraries found on npm and GitHub are designed to run on the node side, and their quality is often not high. The three listed libraries above are primarily Google-wrapped pretrained TensorFlow models for the front end, ensuring quality. Other available models can be referenced at https://github.com/tensorflow/tfjs-models/tree/master. Additionally, Transformers.js is also good, but since these models use transformers, they tend to run slowly on the front end: https://www.npmjs.com/package/@xenova/transformers.

2 Using TensorFlow.js to Run Models

TensorFlow.js is a library developed by Google for performing machine learning in the browser or in Node.js. It allows developers to train and deploy machine learning models directly on the client side without relying on server-side computing resources, resulting in efficient data processing and real-time interaction.

Moreover, TensorFlow.js supports the use of existing TensorFlow models by converting them into a web-friendly format. Developers can leverage pre-trained models or create and train new models from scratch, making machine learning more convenient and scalable.

TensorFlow.js supports various model formats, including Keras HDF5, tf.keras SavedModel, and TensorFlow Hub modules from Kaggle.

2.1 Running Pre-trained tf.keras Models with TensorFlow.js

This method supports models saved in formats like Keras HDF5, tf.keras SavedModel, and TensorFlow Hub modules from Kaggle. You can find usable models here: https://keras.io/api/applications/

The main steps are:

  1. Save the Model: Save the existing model.

  2. Convert the Model: Use the TensorFlow.js converter to convert the model files into a TensorFlow.js-compatible format.

  3. Load the Model: Use TensorFlow.js to load and run the model in a web application.

Let's look at how to run a pre-trained model from tf.keras. I have chosen one of my favorite image classification models, InceptionV3:

2.1.1 Save the Model

First, save the pre-trained model in tf.keras as a Keras HDF5 format:

import tensorflow as tf
from tensorflow.keras.layers import Input

input_tensor = Input(shape=(224, 224, 3))
model = tf.keras.applications.InceptionV3(input_tensor=input_tensor, weights='imagenet')

# Save the model as an HDF5 file
model.save('inceptionv3.h5')

2.1.2 Convert the Model

2.1.2.1 Install and Use TensorFlow.js Converter

Reference: https://github.com/tensorflow/tfjs/tree/master/tfjs-converter

Make sure to create a new Python environment and install Python 3.6.8, then execute:

pip install tensorflowjs[wizard]

If you encounter the following error, please upgrade pip:

Could not find a version that satisfies the requirement tensorflow<3,>=2.13.0 (from tensorflowjs[wizard]) (from versions: ...)
No matching distribution found for tensorflow<3,>=2.13.0 (from tensorflowjs[wizard])

If you encounter certificate issues, append --trusted-host parameters:

pip install --trusted-host pypi.python.org --trusted-host files.pythonhosted.org --trusted-host pypi.org tensorflowjs[wizard]

2.1.2.2 Convert the Model

Run tensorflowjs_wizard and follow the prompts to convert the .h5 file into a format supported by TensorFlow.js.

2.1.3 Load the Model

In your front-end project, use the following code to load and use the converted model:

async function loadModel() {
  // Use tf.loadLayersModel or tf.loadGraphModel to load the converted model
  const model = await tf.loadGraphModel('/inceptionv3/model.json');
  const result = model.predict(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3])) as tf.Tensor;
  result.dispose();
  return model;
}

2.1.4 Use the Model

Use the following code to make predictions with the model. First, asynchronously load the pre-trained model by calling loadModel() and convert an image at a specified path to a tensor format, imageTensor, for prediction with the model. Then, use the loaded model to predict the image tensor. After obtaining predictions, squeeze the tensor dimensions with tf.squeeze(), apply tf.softmax() to get the probability distribution, and extract the top 5 classes using tf.topk(). Finally, use imagenetClassesTopK() to retrieve the class names corresponding to the indices.

(Reference code: https://github.com/ymrdf/web-ai-examples/blob/main/src/tensorflow/imageRecog.ts)

const model = await loadModel();

const imageTensor = await getImageTfTensorFromPath(path);

const predictions = model.predict(imageTensor) as tf.Tensor;
const squeezed_tensor = tf.squeeze(predictions);
const outputSoftmax = tf.softmax(squeezed_tensor);

const top5 = tf.topk(outputSoftmax, 5);
const top5Indices = top5.indices.dataSync();
const top5Values = top5.values.dataSync();
const results = imagenetClassesTopK(top5Indices, top5Values);
return [results, 0.5];

Below is an example output:
image.png

2.2 Running Custom TensorFlow Models with TensorFlow.js

2.2.1 Training a Handwritten Digit Recognition Model

Let's start by training a simple handwritten digit recognition model. Given the straightforward nature of this image recognition task, we can modify a basic convolutional neural network (CNN) model. I've chosen LeNet (https://ieeexplore.ieee.org/document/726791), with a structure close to this:

image

After making several adjustments, including replacing average pooling layers with max pooling layers, changing the activation functions of fully connected layers to ReLU, and adding batch normalization layers after the convolutional and fully connected layers, the final model looks like this:

from tensorflow.keras import layers, models, Input

model = models.Sequential([
    Input(shape=(28, 28, 1)),  # Specify input shape with Input layer
    layers.Conv2D(6, kernel_size=(5, 5), padding="same", activation="sigmoid"),
    layers.BatchNormalization(),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Conv2D(16, kernel_size=(5, 5), activation="sigmoid"),
    layers.BatchNormalization(),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dense(120, activation="relu"),
    layers.BatchNormalization(),
    layers.Dense(84, activation="relu"),
    layers.BatchNormalization(),
    layers.Dense(10, activation="softmax")
])

We then train the model with the following code:


    import tensorflow as tf
    from tensorflow.keras import datasets, models, layers, Input
    from tensorflow.keras.utils import to_categorical

    # 1. Load the dataset
    (training_images, training_labels), (test_images, test_labels) = datasets.mnist.load_data()

    # Data preprocessing
    training_images = training_images.reshape((60000, 28, 28, 1)).astype("float32") / 255
    test_images = test_images.reshape((10000, 28, 28, 1)).astype("float32") / 255

    # Convert labels to one-hot encoding
    training_labels = to_categorical(training_labels)
    test_labels = to_categorical(test_labels)

    # 2. Define the model
    model = models.Sequential([
        Input(shape=(28, 28, 1)), 
        layers.Conv2D(6, kernel_size=(5, 5), padding="same", activation="sigmoid"),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(16, kernel_size=(5, 5), activation="sigmoid"),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dense(120, activation="relu"),
        layers.BatchNormalization(),
        layers.Dense(84, activation="relu"),
        layers.BatchNormalization(),
        layers.Dense(10, activation="softmax")
    ])

    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.05),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    # 3. Train the model
    model.fit(training_images, training_labels, epochs=5, batch_size=64, verbose=2)

    # 4. Evaluate the model
    test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
    print(f"Test accuracy: {test_acc * 100:.2f}%, Test loss: {test_loss:.4f}")

By adjusting the learning rate to 0.05, the model achieves an accuracy of 98% on the test set. To save this model as a Keras file, add the following code:


    model.save("./numberRecog.h5")

2.2.2 Converting the Model

Convert the model by following the same process as described earlier. Run tensorflowjs_wizard and follow the prompts to generate the model files.

2.2.3 Loading the Model

Use tf.loadLayersModel or tf.loadGraphModel to load the converted model, depending on the format chosen during the conversion. Here's an example code (reference: https://github.com/ymrdf/web-ai-examples/blob/main/src/tensorflow/numberRecog.ts):


    async function loadModel() {
      const model = await tf.loadGraphModel('/numberRecogV1/model.json');
      const result = model.predict(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 1])) as tf.Tensor;
      result.dispose();
      return model;
    }

2.2.4 Using the Model

    export async function inference(path: string) {
      const model = await loadModel();
      const imageTensor = await getImageTfTensorFromPath(path);
      const predictions = model.predict(imageTensor) as tf.Tensor;
      
      const squeezed_tensor = tf.squeeze(predictions);
      const outputSoftmax = tf.softmax(squeezed_tensor);
      const top5 = tf.topk(outputSoftmax, 5);
      const top5Indices = top5.indices.dataSync();
      return [top5Indices, 0.5];
    }

When using your trained model, the results may vary, and here are some examples:

image

image

image

image

The methods described above can only load models trained with TensorFlow. They cannot run models trained with other frameworks like PyTorch. However, many resources, such as books and research papers, use PyTorch for demonstration, and pre-trained models available online are predominantly in PyTorch. Is there a way to run models trained with any framework?

(One could transfer the parameters from other models to TensorFlow, then use TensorFlow.js to build and load the model. However, this approach is quite challenging, error-prone, and not easily accessible for most people.)

3 Running Models with ONNX Runtime Web

ONNX (Open Neural Network Exchange) is an open-source deep learning model exchange format developed by Microsoft and Facebook. Its goal is to facilitate interoperability between different deep learning frameworks, allowing models to be seamlessly converted and run across various platforms and tools. ONNX supports various mainstream deep learning frameworks like PyTorch, TensorFlow, and Caffe2, simplifying the model deployment process and improving development efficiency. With ONNX, developers can easily share and reuse deep learning models in different environments, enhancing the flexibility and portability of AI projects.

ONNX Runtime Web is a tool that enables running ONNX models in the browser. It allows deep learning inference right on the front end, without relying on backend servers. By leveraging WebAssembly and WebGL, it achieves efficient performance across different browsers. ONNX Runtime Web is a universal method for running AI models on the client side, and mastering this tool can solve many challenges.

The primary steps for running models on ONNX Runtime Web are:

  • Save the model: Export the existing model to an ONNX file.

  • Load the model: Use onnxruntime-web to load and run the model in a web application.

  • Run the model.

3.1 Running Pre-trained PyTorch Models

3.1.1 Exporting the Model

PyTorch provides several pre-trained models. For demonstration, I chose the ResNet18 model. Other models can be found here: https://pytorch.org/vision/stable/models.html.

Convert the PyTorch model to ONNX format using PyTorch's torch.onnx.export function. The primary process involves loading the pre-trained ResNet18 model:

import torch
from torchvision import models, transforms
from PIL import Image

# Load the pre-trained ResNet18 model and set it to evaluation mode
resnet18 = models.resnet18(pretrained=True).eval()

transform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

image = Image.open("./1.jpeg")

# Generate input tensor, here using an image to test the model operation
image = transform(image).unsqueeze(0)

with torch.no_grad():
    outputs = resnet18(image)

input_names = ["actual_input_1"]
output_names = ["output1"]

# Export the model
torch.onnx.export(resnet18, image, "resnet18.onnx", verbose=False, input_names=input_names, output_names=output_names, opset_version=7)

3.1.2 Loading the Model with ONNX Runtime Web

ONNX Runtime Web's InferenceSession.create() is an asynchronous method used to load an ONNX model file specified by the path in the first parameter.

In the .create() method, a configuration object is passed, which has two fields:

  • executionProviders: ['webgl']: Specifies the backend executor for running the model. Here, it is set to use WebGL to accelerate computations, although CPU or WASM can also be chosen.

  • graphOptimizationLevel: 'all': Specifies the level of graph optimization. Setting it to 'all' enables all available graph optimizations to enhance model execution efficiency and performance.

(Reference code: https://github.com/ymrdf/web-ai-examples/blob/main/src/onnx/imageRecog.ts)

const session = await ort.InferenceSession.create('/resnet18.onnx', {
  executionProviders: ['webgl'],
  graphOptimizationLevel: 'all',
});

3.1.3 Running the Model with ONNX Runtime Web

First, create an empty input data object feeds and add the preprocessed data preprocessedData according to the model's input name. Run the model by asynchronously executing session.run(). Then, obtain the model's prediction results from the output data, apply the softmax function to these results to get the probability distribution, and use the tf.topk() method to find the top five predictions and their indices. Finally, use the imagenetClassesTopK function to convert these indices to specific class names and return this information along with the inference time.

    const start = new Date();
    const feeds: Record<string, ort.Tensor> = {};
    feeds[session.inputNames[0]] = preprocessedData;

    const outputData = await session.run(feeds);
    const end = new Date();
    const inferenceTime = (end.getTime() - start.getTime()) / 1000;

    const output = outputData[session.outputNames[0]];
    const outputSoftmax = tf.softmax(tf.tensor(Array.prototype.slice.call(output.data)));

    const top5 = tf.topk(outputSoftmax, 5);
    const top5Indices = top5.indices.dataSync();
    const top5Values = top5.values.dataSync();

    const results = imagenetClassesTopK(top5Indices, top5Values);
    return [results, inferenceTime];

Below is the result:

image.png

3.2 Finding Pre-trained Models on Platforms like Hugging Face and Kaggle

Although the aforementioned methods are practical, pre-trained models from frameworks like TensorFlow are relatively scarce, and training high-quality models independently can be quite challenging. Thus, the real-world application scenarios might be limited. However, if we can run any model from platforms like Hugging Face and Kaggle, the application scope will greatly expand.

For example, I found a powerful model on Hugging Face: RMBG-1.4. This model excels at accurately segmenting objects in images.

Let's take a look at its impressive performance:

image

Now, let's implement this functionality in the browser. Excited?

3.2.1 Preparation

First, follow the instructions on RMBG-1.4 to see how to load this model:

from transformers import AutoModelForImageSegmentation

model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)

Next, write code to test the performance of this model:

from transformers import AutoModelForImageSegmentation
from torchvision.transforms.functional import normalize
import torch.nn.functional as F
import numpy as np
import torch
from skimage import io
from PIL import Image

model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)

def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
    if len(im.shape) < 3:
        im = im[:, :, np.newaxis]
    im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), size=model_input_size, mode='bilinear')
    image = torch.divide(im_tensor, 255.0)
    image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    return image

def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
    result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
    im_array = np.squeeze(im_array)
    return im_array

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_input_size = [1024, 1024]

image_path = "./profile.jpg"
orig_im = io.imread(image_path)
orig_im_size = orig_im.shape[0:2]
image = preprocess_image(orig_im, model_input_size).to(device)

model.eval()

result = model(image)

result_image = postprocess_image(result[0][0], orig_im_size)

pil_im = Image.fromarray(result_image)
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
orig_image = Image.open(image_path)
no_bg_image.paste(orig_image, mask=pil_im)

no_bg_image.show()

Running the above code should successfully segment the objects in the image.

3.2.2 Exporting the Model to ONNX Format

Add the following lines to the above code to generate an ONNX file (due to the newer operators used, it may not support ONNX version 7; use the latest ONNX version and run it on the CPU:

input_names = ["actual_input_1"]
output_names = ["output1"]

torch.onnx.export(model, image, "rmbg.onnx", verbose=False, input_names=input_names, output_names=output_names)

3.2.3 Loading the Model with ONNX Runtime Web

(Reference code: https://github.com/ymrdf/web-ai-examples/blob/main/src/rmbg/predict.ts. Please extract the public/rmbg.onnx file before running the code.)

const session = await ort.InferenceSession.create('/rmbg.onnx', {
  executionProviders: ['cpu'],
  graphOptimizationLevel: 'all',
});

3.2.4 Running Inference

async function runInference(session: ort.InferenceSession, preprocessedData: any): Promise<any> {
  const feeds: Record<string, ort.Tensor> = {};
  feeds[session.inputNames[0]] = preprocessedData;
  const outputData = await session.run(feeds);
  const output = outputData[session.outputNames[0]];
  return output;
}

3.2.5 Data Preprocessing

You might be surprised at how simple the code for model inference is, but the tricky part is data preprocessing. For example, this model requires an input tensor with the shape [1, 3, 1024, 1024]. So, you need to get an image and process it into a tensor of this shape. However, the tensor operations provided by ONNX Runtime Web are limited, so I usually use TensorFlow.js for data processing and then convert the TensorFlow tensor into an ONNX Runtime Web tensor.

Here’s an example illustrating the preprocessing steps:

export async function getImageTfTensorFromPath(path: string): Promise<tf.Tensor> {
  return new Promise((resolve) => {
    const src = path;
    const $image = new Image();
    $image.crossOrigin = 'Anonymous';
    $image.onload = function() {
      // 1. Convert the image element to a tensor
      const tensor = tf.browser.fromPixels($image)
        .resizeBilinear([1024, 1024]) // 2. Resize the image
        .toFloat() // 3. Convert to float
        .div(tf.scalar(255.0)) // 4. Normalize
        .transpose([2, 0, 1]) // 5. Change shape from [1024, 1024, 3] to [3, 1024, 1024]
        .expandDims(); // 6. Add a dimension to become [1, 3, 1024, 1024]

      // 7. Standardize the tensor
      const mean = tf.tensor([0.5, 0.5, 0.5]);
      const std = tf.tensor([1.0, 1.0, 1.0]);
      const normalizedTensor = tensor.sub(mean.reshape([1, 3, 1, 1])).div(std.reshape([1, 3, 1, 1]));
      normalizedTensor.print();
      resolve(normalizedTensor);
    };
    $image.src = src;
  });
}

Finally, convert the TensorFlow tensor to an ONNX Runtime Web tensor (conversion function available in the source code).

3.2.6 Data Post-processing

The model generates data in the shape [1, 1, 1024, 1024], with each value being a float between 0 and 1, representing whether the corresponding pixel should be kept.

Here's how to convert this data into an image:

// Convert the image element to a tensor
const tensor = tf.browser.fromPixels($image).resizeBilinear([1024, 1024]);

// Convert output data to tf.Tensor
const alpha4 = convertOnnxTensorToTfTensor(alphaExpanded);

// Remove a dimension
let alpha3 = tf.squeeze(alpha4, [1]);

// [1, 1024, 1024] => [1024, 1024, 1]
alpha3 = tf.reshape(alpha3, [1024, 1024, 1]);

// Scale by 255
alpha3 = tf.mul(alpha3, 255);
combine(tensor, alpha3);

function combine(imageTensor: tf.Tensor, alphaExpanded: tf.Tensor) {
  // Concatenate along the last dimension
  const combinedTensor = tf.concat([imageTensor, alphaExpanded], -1);

  // Convert tensor to Uint8ClampedArray
  combinedTensor.data().then(data => {
    const clampedArray = new Uint8ClampedArray(data);

    // Create ImageData object
    const imageData = new ImageData(clampedArray, 1024, 1024);

    // Draw to canvas
    const canvas = document.querySelector("#test");
    const ctx = canvas.getContext('2d');
    ctx?.putImageData(imageData, 0, 0);
  });
}

Throw the image below into the system:

Below is the result:

3.3 Running Your Own Trained Model

3.3.1 Training a Handwritten Digit Recognition Model with PyTorch

While we've covered how to use TensorFlow.js to run your own trained models, many books and courses use PyTorch as the teaching framework, especially for front-end developers who often start with PyTorch. So let's go through how to load a model trained with PyTorch into a web environment. We'll train a handwritten digit recognition model using PyTorch, mirroring the previous TensorFlow model.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Load training and test data
training_data = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())

batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

# Check the data shape
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

# Set device (using GPU if available)
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

# Define the neural network model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.BatchNorm2d(6), nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120), nn.BatchNorm1d(120), nn.ReLU(),
            nn.Linear(120, 84), nn.BatchNorm1d(84), nn.ReLU(),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

# Initialize weights
def init_weights(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        nn.init.xavier_uniform_(m.weight)

# Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)  # Try different learning rates for optimization

# Training function
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# Test function
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    
# Training the model
epochs = 5
model.linear_relu_stack.apply(init_weights)
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

3.3.2 Exporting the Model to an ONNX File

Once you're satisfied with the model's performance, save the model by running the following code to generate an ONNX file:

input = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
torch.onnx.export(model, input, "numberRecog.onnx", verbose=False, opset_version=7)

3.3.3 Loading and Running the Model

(Reference code: https://github.com/ymrdf/web-ai-examples/blob/main/src/onnx/numberRecog.ts)

export async function inference(path: string): Promise<[Uint8Array, number]> {
  const imageTensor = await getImageTfTensorFromPath(path);
  const preprocessedData = await convertTfTensorToOnnxTensor(imageTensor);
  const session = await ort.InferenceSession.create('/numberRecog.onnx', {
    executionProviders: ['cpu'],
    graphOptimizationLevel: 'all'
  });

  const start = new Date();
  const feeds: Record<string, ort.Tensor> = {};
  feeds[session.inputNames[0]] = preprocessedData;
  const outputData = await session.run(feeds);
  const end = new Date();
  const inferenceTime = (end.getTime() - start.getTime()) / 1000;
  const output = outputData[session.outputNames[0]];

  const predictions = convertOnnxTensorToTfTensor(output);

  const squeezed_tensor = tf.squeeze(predictions);
  const outputSoftmax = tf.softmax(squeezed_tensor);
  const top5 = tf.topk(outputSoftmax, 5);
  const top5Indices = top5.indices.dataSync() as Uint8Array;
  return [top5Indices, inferenceTime];
}

The results should be similar to the model trained with TensorFlow:

image

image

image

image

Limitations of ONNX Runtime Web

While ONNX Runtime Web is useful, it has some limitations. The current version of ONNX Runtime Web using WebGL supports older ONNX versions, and some operators may not be supported. Using WebAssembly can support all operators, but it may be slower. Also, large ONNX files can slow down the front-end loading process.

Summary

In this article, we delve into how to run machine learning models in front-end development, focusing on two major methods: TensorFlow.js and OnnxRuntime-web. Initially, we need to save or convert the model into the appropriate format and then use the corresponding tools to load and execute it.

TensorFlow.js allows models created with the TensorFlow framework to run directly in the browser, offering developers a convenient method for model training and deployment. However, the existing resources for TensorFlow.js are relatively limited, and its conversion tool, tensorflowjs_wizard, can be somewhat challenging to use.

On the other hand, OnnxRuntime-web supports loading ONNX models trained with various frameworks, providing greater flexibility. Yet, due to limitations with WebGL, some operators may not be supported, necessitating the use of Assembly to ensure model compatibility, though this comes at the expense of execution speed.

Regarding data processing, the tensor computation methods in OnnxRuntime-web are relatively fewer. We can employ methods from @tensorflow/tfjs for data processing, and then convert the processed results into tensor objects supported by OnnxRuntime-web for further application.

By learning and mastering these techniques, we can seamlessly integrate AI technologies into front-end development, elevating the user experience of web applications to a new level. Whether you are a front-end developer or an AI enthusiast, continuing to explore and learn about AI-related technologies is undoubtedly the key to unlocking the door to future advancements. Let's embrace AI technologies together, boldly explore unknown frontiers, and build smarter and more user-friendly applications!

I hope this article provides you with valuable insights, and may we grow together on the path of merging front-end development with AI technology to create more possibilities!