Building an AI Application in Java with Deeplearning4j

Bayram EKER
4 min readMay 18, 2024

--

In this tutorial, we’ll explore how to build a simple AI application in Java using the Deeplearning4j library. We’ll walk through setting up the project structure, writing the necessary code for training and evaluating a machine learning model, and running the application. This guide is ideal for developers looking to delve into AI with Java and modern libraries.

Introduction

Artificial Intelligence (AI) is transforming industries by automating tasks, enhancing decision-making, and creating new opportunities. While Python is the most popular language for AI, Java offers robust libraries like Deeplearning4j that bring powerful AI capabilities to the Java ecosystem. In this tutorial, we’ll build a simple AI application that demonstrates basic machine learning concepts using Java.

Project Structure

java-ai

We’ll start by defining a clean and modular project structure. This helps in maintaining and scaling the project efficiently.

ai-project
├── src
│ ├── main
│ │ ├── java
│ │ │ └── com
│ │ │ └── example
│ │ │ └── ai
│ │ │ ├── App.java
│ │ │ ├── model
│ │ │ │ └── DataModel.java
│ │ │ └── service
│ │ │ └── AIService.java
│ │ └── resources
│ └── test
│ ├── java
│ │ └── com
│ │ └── example
│ │ └── ai
│ │ └── service
│ │ └── AIServiceTest.java

Dependencies

We’ll use Maven for dependency management. Here’s the pom.xml file:

<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>com.example</groupId>
<artifactId>ai-project</artifactId>
<version>1.0-SNAPSHOT</version>

<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.12</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.30</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.7.30</version>
</dependency>
</dependencies>

<build>
<sourceDirectory>src/main/java</sourceDirectory>
<testSourceDirectory>src/test/java</testSourceDirectory>
<resources>
<resource>
<directory>src/main/resources</directory>
</resource>
</resources>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<goals>
<goal>java</goal>
</goals>
</execution>
</executions>
<configuration>
<mainClass>com.example.ai.App</mainClass>
</configuration>
</plugin>
</plugins>
</build>
</project>

Setting Up the Project

  1. Clone the Repository:
git clone https://github.com/yourusername/ai-project.git
cd ai-project
  1. Install Dependencies and Compile the Project:
mvn clean compile

Writing the Code

1. Data Model

The DataModel class represents the structure of our input data.

package com.example.ai.model;

public class DataModel {
private double feature1;
private double feature2;
private double label;

public DataModel(double feature1, double feature2, double label) {
this.feature1 = feature1;
this.feature2 = feature2;
this.label = label;
}

// Getters and setters
}

2. AI Service

The AIService class handles the creation, training, and evaluation of the machine learning model.

package com.example.ai.service;

import com.example.ai.model.DataModel;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.ListDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.util.ArrayList;
import java.util.List;

public class AIService {
private MultiLayerNetwork model;

public AIService() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Sgd(0.1))
.list()
.layer(0, new DenseLayer.Builder()
.nIn(2)
.nOut(3)
.activation(Activation.RELU)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.nIn(3)
.nOut(1)
.activation(Activation.IDENTITY)
.build())
.build();

model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
}

public void trainModel(List<DataModel> dataList) {
// Preparing training data
int dataSize = dataList.size();
INDArray input = Nd4j.create(dataSize, 2); // 2D matrix
INDArray labels = Nd4j.create(dataSize, 1); // 2D matrix

for (int i = 0; i < dataSize; i++) {
DataModel data = dataList.get(i);
input.putRow(i, Nd4j.create(new double[]{data.getFeature1(), data.getFeature2()}));
labels.putRow(i, Nd4j.create(new double[]{data.getLabel()}));
}

DataSet dataSet = new DataSet(input, labels);
DataSetIterator iterator = new ListDataSetIterator<>(dataSet.asList(), 10);

// Training the model
int nEpochs = 1000;
for (int i = 0; i < nEpochs; i++) {
model.fit(iterator);
if (i % 100 == 0) {
System.out.println("Score at iteration " + i + " is " + model.score());
}
}
}

public void evaluateModel(DataModel testData) {
INDArray input = Nd4j.create(new double[][]{{testData.getFeature1(), testData.getFeature2()}}); // 2D matrix
INDArray output = model.output(input);
System.out.println("Model Output: " + output);
}
}

3. Main Application

The App class serves as the entry point to the application, orchestrating the training and evaluation processes.

package com.example.ai;

import com.example.ai.model.DataModel;
import com.example.ai.service.AIService;

import java.util.ArrayList;
import java.util.List;

public class App {
public static void main(String[] args) {
AIService aiService = new AIService();

// Creating training data
List<DataModel> trainingData = new ArrayList<>();
trainingData.add(new DataModel(0.1, 0.2, 0.3));
trainingData.add(new DataModel(0.2, 0.3, 0.5));
trainingData.add(new DataModel(0.3, 0.4, 0.7));
trainingData.add(new DataModel(0.4, 0.5, 0.9));

aiService.trainModel(trainingData);

// Evaluating the model with test data
DataModel testData = new DataModel(0.5, 0.5, 1.0);
aiService.evaluateModel(testData);
}
}

4. Testing

The AIServiceTest class contains unit tests to ensure the correctness of the AI service.

package com.example.ai.service;

import com.example.ai.model.DataModel;
import org.junit.Before;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;

import static org.junit.Assert.*;

public class AIServiceTest {
private AIService aiService;

@Before
public void setUp() {
aiService = new AIService();
}

@Test
public void testTrainModel() {
List<DataModel> trainingData = new ArrayList<>();
trainingData.add(new DataModel(0.1, 0.2, 0.3));
trainingData.add(new DataModel(0.2, 0.3, 0.5));
trainingData.add(new DataModel(0.3, 0.4, 0.7));
trainingData.add(new DataModel(0.4, 0.5, 0.9));

aiService.trainModel(trainingData);
// Verify that training data is not null
assertNotNull(trainingData);
}

@Test
public void testEvaluateModel() {
DataModel testData = new DataModel(0.5, 0.5, 1.0);
aiService.evaluateModel(testData);
// Verify that test data is not null
assertNotNull(testData);
}
}

Running the Project

To run the project, use the following Maven command:

mvn exec:java -Dexec.mainClass="com.example.ai.App"

This command will execute the main application class App.java.

Testing the Project

To run the tests, use the following Maven command:

mvn test

This command will run all the tests defined in the src/test directory.

Conclusion

In this tutorial, we’ve built a simple AI application in Java using Deeplearning4j. We’ve covered setting up the project, writing the necessary code for training and evaluating a machine learning model, and running the application. This project serves as a foundation for more complex AI applications and demonstrates the power of using Java for AI development.

Improvements

  • Data Normalization: Normalize the input data to improve model performance.
  • Cross-Validation: Implement cross-validation to better assess model performance.
  • Hyperparameter Tuning: Experiment with different learning rates, batch sizes, and model architectures.
  • Logging: Implement more detailed logging for better monitoring of the training process.

Feel free to clone the repository and experiment with different model configurations and datasets. Happy coding!

--

--

No responses yet