Enterprise Java
Building Predictive APIs with TensorFlow and Spring Boot
1. Why Combine AI/ML with Spring Boot?
Modern applications increasingly need smart capabilities – from recommendation engines to fraud detection. While Python dominates ML development, Java teams can leverage:
- TensorFlow Java for model inference
- Spring Boot for scalable API delivery
- DJL (Deep Java Library) as an alternative framework
This guide walks through serving a trained ML model via REST API with zero Python dependencies.
2. Architecture Overview
[Python Environment] -- Trains Model --> SavedModel.pb ? [Java Service] <-- Loads Model --> [Spring Boot REST API] ? [Client Apps] <-- Gets Predictions
Key components:
- TensorFlow SavedModel (exported from Python)
- Spring Boot web layer
- TensorFlow Java API for inference
Step 1: Train and Export Model (Python)
# train.py import tensorflow as tf # Sample neural network model = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(1) ]) model.compile(optimizer='adam', loss='mse') model.fit(X_train, y_train, epochs=10) # Export for Java tf.saved_model.save(model, "saved_model")
This creates a /saved_model
directory with:
saved_model.pb
(architecture)variables/
(trained weights)
Step 2: Spring Boot Integration
Dependencies (pom.xml
)
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow-core-platform</artifactId> <version>0.4.1</version> </dependency>
Load Model in Java
import org.tensorflow.*; import org.tensorflow.types.TFloat32; public class Predictor { private SavedModelBundle model; @PostConstruct public void init() { this.model = SavedModelBundle.load( "src/main/resources/saved_model", "serve" ); } public float predict(float[] input) { try (Tensor<TFloat32> inputTensor = TFloat32.tensorOf(input); TFloat32 result = (TFloat32)model.session() .runner() .feed("dense_input", inputTensor) .fetch("dense_1") .run() .get(0)) { return result.getFloat(); } } }
Step 3: Expose as REST API
@RestController @RequestMapping("/api/predict") public class PredictionController { @Autowired private Predictor predictor; @PostMapping public PredictionResponse predict(@RequestBody PredictionRequest request) { float result = predictor.predict(request.getFeatures()); return new PredictionResponse(result); } }
Sample request:
curl -X POST http://localhost:8080/api/predict \ -H "Content-Type: application/json" \ -d '{"features": [0.1, 0.5, 0.3]}'
3. Performance Optimization Tips
- Batching Predictions
Process multiple inputs in one session run:
float[][] batchInputs = ...; Tensor<TFloat32> batchTensor = TFloat32.tensorOf(batchInputs);
2. GPU Acceleration
Add CUDA dependencies for NVIDIA GPUs:
<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow-core-platform-gpu</artifactId> <version>0.4.1</version> </dependency>
3. Model Warmup
Initialize model at startup to avoid first-call latency:
@Bean public CommandLineRunner warmup(Predictor predictor) { return args -> predictor.predict(new float[inputSize]); }
4. Alternative: DJL (Deep Java Library)
For more Java-native ML workflows:
// Build model directly in Java Model model = Model.newInstance("linear"); model.load(new Path("model.pt")); try(NDManager manager = NDManager.newBaseManager()) { NDArray input = manager.create(new float[]{...}); Predictor predictor = model.newPredictor(); NDArray result = predictor.predict(input); }
Advantages:
- Unified API for TensorFlow/PyTorch/MXNet
- No SWIG/JNI overhead
- Built-in image preprocessing
5. Conclusion
Key takeaways:
✅ Serve TensorFlow models without Python in production
✅ Achieve <10ms latency per prediction
✅ Scale horizontally like any Spring Boot service
Next Steps:
- Try the TensorFlow Java examples
- Explore DJL’s Spring Boot starter
- Monitor performance with Micrometer metrics