Enterprise Java

Building Predictive APIs with TensorFlow and Spring Boot

1. Why Combine AI/ML with Spring Boot?

Modern applications increasingly need smart capabilities – from recommendation engines to fraud detection. While Python dominates ML development, Java teams can leverage:

  • TensorFlow Java for model inference
  • Spring Boot for scalable API delivery
  • DJL (Deep Java Library) as an alternative framework

This guide walks through serving a trained ML model via REST API with zero Python dependencies.

2. Architecture Overview

[Python Environment] -- Trains Model --> SavedModel.pb
                      ?
[Java Service] <-- Loads Model --> [Spring Boot REST API]
                      ? 
[Client Apps] <-- Gets Predictions

Key components:

  1. TensorFlow SavedModel (exported from Python)
  2. Spring Boot web layer
  3. TensorFlow Java API for inference

Step 1: Train and Export Model (Python)

# train.py
import tensorflow as tf

# Sample neural network
model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1)
])

model.compile(optimizer='adam', loss='mse')
model.fit(X_train, y_train, epochs=10)

# Export for Java
tf.saved_model.save(model, "saved_model")

This creates a /saved_model directory with:

  • saved_model.pb (architecture)
  • variables/ (trained weights)

Step 2: Spring Boot Integration

Dependencies (pom.xml)

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow-core-platform</artifactId>
    <version>0.4.1</version>
</dependency>

Load Model in Java

import org.tensorflow.*;
import org.tensorflow.types.TFloat32;

public class Predictor {
    private SavedModelBundle model;
    
    @PostConstruct
    public void init() {
        this.model = SavedModelBundle.load(
            "src/main/resources/saved_model", 
            "serve"
        );
    }
    
    public float predict(float[] input) {
        try (Tensor<TFloat32> inputTensor = TFloat32.tensorOf(input);
             TFloat32 result = (TFloat32)model.session()
                .runner()
                .feed("dense_input", inputTensor)
                .fetch("dense_1")
                .run()
                .get(0)) {
            
            return result.getFloat();
        }
    }
}

Step 3: Expose as REST API

@RestController
@RequestMapping("/api/predict")
public class PredictionController {
    
    @Autowired
    private Predictor predictor;
    
    @PostMapping
    public PredictionResponse predict(@RequestBody PredictionRequest request) {
        float result = predictor.predict(request.getFeatures());
        return new PredictionResponse(result);
    }
}

Sample request:

curl -X POST http://localhost:8080/api/predict \
  -H "Content-Type: application/json" \
  -d '{"features": [0.1, 0.5, 0.3]}'

3. Performance Optimization Tips

  1. Batching Predictions
    Process multiple inputs in one session run:
float[][] batchInputs = ...;
Tensor<TFloat32> batchTensor = TFloat32.tensorOf(batchInputs);

2. GPU Acceleration
Add CUDA dependencies for NVIDIA GPUs:

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow-core-platform-gpu</artifactId>
    <version>0.4.1</version>
</dependency>

3. Model Warmup
Initialize model at startup to avoid first-call latency:

@Bean
public CommandLineRunner warmup(Predictor predictor) {
    return args -> predictor.predict(new float[inputSize]);
}

4. Alternative: DJL (Deep Java Library)

For more Java-native ML workflows:

// Build model directly in Java
Model model = Model.newInstance("linear");
model.load(new Path("model.pt"));

try(NDManager manager = NDManager.newBaseManager()) {
    NDArray input = manager.create(new float[]{...});
    Predictor predictor = model.newPredictor();
    NDArray result = predictor.predict(input);
}

Advantages:

  • Unified API for TensorFlow/PyTorch/MXNet
  • No SWIG/JNI overhead
  • Built-in image preprocessing

5. Conclusion

Key takeaways:
✅ Serve TensorFlow models without Python in production
✅ Achieve <10ms latency per prediction
✅ Scale horizontally like any Spring Boot service

Next Steps:

  1. Try the TensorFlow Java examples
  2. Explore DJL’s Spring Boot starter
  3. Monitor performance with Micrometer metrics

Eleftheria Drosopoulou

Eleftheria is an Experienced Business Analyst with a robust background in the computer software industry. Proficient in Computer Software Training, Digital Marketing, HTML Scripting, and Microsoft Office, they bring a wealth of technical skills to the table. Additionally, she has a love for writing articles on various tech subjects, showcasing a talent for translating complex concepts into accessible content.
Subscribe
Notify of
guest

This site uses Akismet to reduce spam. Learn how your comment data is processed.

0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
Back to top button