Return to blog

How To Use YOLOv9 With OpenCV for Java

Greg Bizup
Aug 15, 2024

Your stack is Java, but all the YOLOv9 code is in Python

This was my problem. My entire software stack for an project I'm working on is Java, and I needed to run YOLOv9 object detection. Most of the machine learning code out there nowadays is in Python, and trying to get it working in Java is like getting dropped in the desert with no compass. This is actually a very common problem, with Java being the backend language of choice for many companies and developers, but Python being the language of choice for AI/ML. Unfortunately, there's not a lot of support or instruction on how to move from one to the other just yet.

The best option for Java users is to use OpenCV's DNN module, but it's sparsely documented and there's not many examples for Java. How do you preprocess images and consume the output from YOLOv9 in Java? It took me about a week to figure this all out, but I'm here to share it with you to avoid reinventing the wheel.

This guide assumes you have already trained a YOLOv9 model according to your needs, and are familiar with the official YOLOv9 repository. It also assumes you have some familiarity with Java OpenCV. It is intended for Java users who want to bring their trained model into their Java applications. The following sections attempt to provide all the information needed to understand how it works, along with code examples.

Understanding The YOLOv9 Output

Working with OpenCV in Java is much more bare-bones than using the premade Python scripts. Therefore I found it very helpful to understand the output layer of the model when working through this problem. This knowledge is key to understanding what happens in the rest of the guide. The output is a three-dimensional array.

The number of values in the first dimension of the output array will be equal to the batch size. Each value represents a batch, or one image. If you are running inference on one image at a time, this number will always be one.

The second dimension contains a number of rows equal to the number of classes (N) that the model can detect plus four. So, if your model detects N=10 classes, then your second dimension will have size 14. The first four values in this row represent the bounding box in terms of coordinates on the image. The bounding box is expressed in (x, y, w, h). The remaining N values in the row represent the confidence that this bounding box contains that corresponding class.

Finally, the third dimension has 8400 values, and that represents the "anchors" for each bounding box. In other words, there are 8400 possible locations where a bounding box can be found per image, and for each of these locations, there is an N + 4 length array representing the bounding box location plus confidence score for each class. For further discussion on the output layer of the YOLO model, please see this GitHub thread.

Putting it all together

The idea is that we loop through all 8400 anchors to see which ones contain a confidence score greater than an arbitrary threshold, typically 0.5. Then, we will take the maximum value of the array containing the high confidence score starting from the 5th value from each of these rows. The index of the maximum value in this array corresponds to the index of the class we are trying to detect. The first four values in this array represent the bounding box around the object we are trying to detect. The final step is to apply a technique called Non-Max Suppression (NMS) which removes duplicate boxes around the same object. The bounding box can then be used downstream in your program for whatever your needs are.

Convert PyTorch model weights to ONNX format

OpenCV does not support loading PyTorch weights, but it does support ONNX. ONNX is a universal format for neural networks that is intended to be used interchangeably between machine learning frameworks. Therefore the model weights need to be converted from Pytorch to ONNX format. The official YOLOv9 repository has a script called export.py which allows us to easily convert our weights to ONNX format. Here is the shell command I have been using in my Kaggle notebook to export to ONNX:


!python export.py \
  --include onnx \
  --weights /kaggle/working/yolov9/runs/train/exp/weights/best.pt \
  --data {dataset.location}/data.yaml 

You will notice that this command has an exclamation mark at the beginning, indicating that I am running a Jupyter notebook. The {datatset.location} argument to --data refers to a dataset that was loaded using the Roboflow API. Once you have run this command, you may download your ONNX weights and move them into your Java project.

Build The Correct OpenCV version

Just a quick note to avoid confusion. I have tried this using a few different versions of OpenCV, but encountered bugs when running OpenCV 4.9 or lower. It is possible that YOLOv9 did not work well with earlier versions of the DNN module, so I would recommend building OpenCV 4.10 or higher. You can follow this guide to build OpenCV for Java, and just replace the old version in the guide with the latest version.

Write the Java code

Here is the Java code to load the model, preprocess the image, get the predictions, and draw bounding boxes. It is adapted from suddh123's example on GitHub. I made some adjustments in his code to make it work with the YOLOv9 model, and added a class called BoxPrediction which represents an output from the YOLOv9 model, and has a convenient method for scaling the prediction to new sizes. If you use this code, you need to adjust the final variables CLASS_NAMES, NUM_CLASSES, MODEL_PATH, and OPENCV_RESOURCE_NAME accordingly to your project.

// Class YoloShapesSolver, Adjust final variables to your needs
package com.sadcaptcha.opencv_dnn;

import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;

import org.opencv.core.Core;
import org.opencv.core.Core.MinMaxLocResult;
import org.opencv.core.Mat;
import org.opencv.core.MatOfFloat;
import org.opencv.core.MatOfInt;
import org.opencv.core.MatOfRect2d;
import org.opencv.core.Point;
import org.opencv.core.Rect;
import org.opencv.core.Rect2d;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.dnn.Dnn;
import org.opencv.dnn.Net;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import org.opencv.utils.Converters;

/**
 * Special kudos to suddh123 on Github for writing this code:
 * https://github.com/suddh123/YOLO-object-detection-in-java/blob/code/yolo.java
 */
public class YoloShapesSolver {

	private static final String OPENCV_RESOURCE_NAME = "opencv_4100.so";
	private static final String MODEL_PATH = "src/main/resources/yolov9-3d-shapes-100-images-100-epochs.onnx";
	private static final Size TARGET_IMAGE_SIZE = new Size(640, 640);
	private static final float SCALE_FACTOR = 1f / 255f;
	private static final int NUM_CLASSES = 4;
	private static final float CONF_THRESHOLD = 0.5f;
	private static final float NMS_THRESHOLD = 0.5f;

	private static final String[] CLASS_NAMES = new String[] { "class_1", "class_2", "class_3", "class_4"};

	public static void main(String[] args) {
		loadOpenCVFromResources();
		Net model = loadModel();
		Mat image = loadImage("src/test/resources/shapes0.jpeg");
		Mat blob = preprocess(image);
		PreNmsModelResult result = getBoxPredictions(model, blob);
		List preds = applyNonMaxSuppression(result);
		Size originalImageSize = image.size();
		for (BoxPrediction pred : preds) {
			pred = pred.scale(originalImageSize);
			Imgproc.rectangle(image, pred.getBox(), new Scalar(0, 0, 255), 2);
			Imgproc.circle(image, pred.getCenter(), 2, new Scalar(0, 0, 255), 3);
			Imgproc.putText(image, pred.getClassName(), pred.getCenter(),Imgproc.FONT_HERSHEY_SIMPLEX , 1, new Scalar(0, 0, 0), 2);
		}
		Imgcodecs.imwrite("res.jpg", image);
	}

	/**
	 * Get the predicted boxes for a preprocessed blob.
	 * This method applies non-max suppression to filter the results.
	 */
	private static PreNmsModelResult getBoxPredictions(Net model, Mat blob) {
		// YoloV9 output is batch_size x (num_classes + 4) x 8400
		// where the + 4 in the second dimesion refers to the
		// center_x,center_y,width,height of the detection box
		// Since we're only looking at one sample, we reshape it to 8400 rows to get rid
		// of the batch dimension
		// YoloV9 output: batch size x n_classes + 4 x 8400. Batch size is 1 so we
		// remove that dimension by reshaping it to 49 rows where 49 = n_classes + 4
		// That way, the model outputs are now arranged in 8400 columns where each
		// column value represents a box anchor.
		// Each row of a column represents the probability for a given class
		// The first through fourth rows of each column is x,y,w,h
		// In other words, for each of the 8400 anchors (columns) the first 4 rows
		// represent the box location, and the rest of the values represent the
		// probabilities that box is class N
		// For each column, we want to get the maximum probability.
		// If the maximum probability is greater than our determined threshold, we will
		// consider that an answer and add it to our list
		// We will create a rectangle for the answers, and store the class names as
		// well.
		model.setInput(blob);
		Mat output = model.forward().reshape(0, NUM_CLASSES + 4);
		Mat confidences;
		float confidence;
		Mat column;
		MinMaxLocResult mm;
		Rect2d box;
		double width;
		double height;
		double centerX;
		double centerY;
		double left;
		double top;

		PreNmsModelResult result = new PreNmsModelResult();
		for (int i = 0; i < output.cols(); i++) {
			column = output.col(i);
			confidences = column.rowRange(4, NUM_CLASSES + 4);
			mm = Core.minMaxLoc(confidences);
			confidence = (float) mm.maxVal;

			if (confidence > CONF_THRESHOLD) {
				centerX = column.get(0, 0)[0];
				centerY = column.get(1, 0)[0];
				width = column.get(2, 0)[0];
				height = column.get(3, 0)[0];
				left = centerX - width / 2;
				top = centerY - height / 2;
				box = new Rect2d(left, top, width, height);
				result.addClassId((int) mm.maxLoc.y);
				result.addConfidence(confidence);
				result.addBox(box);
			}
		}
		return result;
	}

	private static List applyNonMaxSuppression(PreNmsModelResult input) {
		MatOfFloat confs = new MatOfFloat(Converters.vector_float_to_Mat(input.getConfidences()));
		Rect2d[] boxesArray = input.getBoxes().toArray(new Rect2d[0]);
		MatOfRect2d boxesMat = new MatOfRect2d(boxesArray);
		MatOfInt indices = new MatOfInt();
		Dnn.NMSBoxes(boxesMat, confs, CONF_THRESHOLD, NMS_THRESHOLD, indices);
		int[] ind = indices.toArray();
		List preds = new ArrayList<>();
		for (int i = 0; i < ind.length; i++) {
			Rect box = new Rect((int) boxesArray[ind[i]].x, (int) boxesArray[ind[i]].y, (int) boxesArray[ind[i]].width,
					(int) boxesArray[ind[i]].height);
			preds.add(new BoxPrediction.Builder()
					.withBox(box)
					.withCenter(new Point((box.width / 2) + box.x, (box.height / 2) + box.y))
					.withClassName(getClassName(input.getClassIds().get(ind[i])))
					.withClassIndex(input.getClassIds().get(ind[i]))
					.withImageSize(TARGET_IMAGE_SIZE)
					.build());
		}
		return preds;
	}

	private static void loadOpenCVFromResources() {
		try {
			String filename = YoloShapesSolver.class.getClassLoader().getResource(OPENCV_RESOURCE_NAME).toURI().toString()
					.replace("file:", "");
			System.load(filename);
		} catch (URISyntaxException e) {
			throw new RuntimeException(e);
		}
	}

	private static Net loadModel() {
		return Dnn.readNetFromONNX(MODEL_PATH);
	}

	private static Mat preprocess(Mat image) {
		return Dnn.blobFromImage(image, SCALE_FACTOR, TARGET_IMAGE_SIZE, new Scalar(0, 0, 0), true, false);
	}

	private static Mat loadImage(String path) {
		return Imgcodecs.imread(path, Imgcodecs.IMREAD_COLOR);
	}

	private static String getClassName(int loc) {
		return CLASS_NAMES[loc];
	}

}

Here is an additional class called PreNmsModelResult, which is used to bring the model output into a single object. It represents the output of the model before NMS is applied:

import java.util.ArrayList;
import java.util.List;

import org.opencv.core.Rect2d;

/**
 * Contains parallel lists for all the classes, confidence values, and bounding
 * rectangles after running 1 image through the model. NMS has not been applied
 */
public class PreNmsModelResult {

	List classIds = new ArrayList<>();
	List confidences = new ArrayList<>();
	List boxes = new ArrayList<>();


	public List getBoxes() {
		return boxes;
	}

	public List getClassIds() {
		return classIds;
	}

	public List getConfidences() {
		return confidences;
	}

	public void setBoxes(List boxes) {
		this.boxes = boxes;
	}

	public void setClassIds(List classIds) {
		this.classIds = classIds;
	}

	public void setConfidences(List confidences) {
		this.confidences = confidences;
	}

	public void addConfidence(Float confidence) {
		confidences.add(confidence);
	}

	public void addBox(Rect2d box) {
		boxes.add(box);
	}

	public void addClassId(Integer classId) {
		classIds.add(classId);
	}
	

}

And finally, here is BoxPrediction class which represents the first four values in the YOLOv9 output array. It has a Builder class which lets us easily create instances, as well as a scale() method which returns a new BoxPrediction in the desired scale.

// Class BoxPrediction
package com.sadcaptcha.opencv_dnn;

import org.opencv.core.Point;
import org.opencv.core.Rect;
import org.opencv.core.Size;

/**
 * This class represents a predicted box from a YOLO model
 */
public class BoxPrediction {

	private String className;
	private int classIndex;
	private Rect box;
	private Point center;
	private Size imageSize;

	/**
	 * @param size new size to scale this prediction
	 * @return new BoxPrediction with box and center scaled accordingly
	 */
	public BoxPrediction scale(Size newSize) {
		if (imageSize == null) {
			throw new IllegalStateException(
					"Cannot rescale prediction because imageSize instance property is null. Please run setImageSize");
		}
		double xScaleFactor = newSize.width / imageSize.width;
		double yScaleFactor = newSize.height / imageSize.height;
		int width = (int) (xScaleFactor * box.width);
		int height = (int) (yScaleFactor * box.height);
		int left = (int) (xScaleFactor * box.x);
		int top = (int) (yScaleFactor * box.y);
		return new BoxPrediction.Builder()
				.withClassIndex(classIndex)
				.withClassName(className)
				.withBox(new Rect(left, top, width, height))
				.withCenter(new Point(center.x * xScaleFactor, center.y * yScaleFactor))
				.withImageSize(newSize)
				.build();
	}

	public Size getImageSize() {
		return imageSize;
	}

	public void setImageSize(Size imageSize) {
		this.imageSize = imageSize;
	}

	public String getClassName() {
		return className;
	}

	public Rect getBox() {
		return box;
	}

	public Point getCenter() {
		return center;
	}

	public int getClassIndex() {
		return classIndex;
	}

	public void setCenter(Point center) {
		this.center = center;
	}

	public void setBox(Rect box) {
		this.box = box;
	}

	public void setClassName(String className) {
		this.className = className;
	}

	public void setClassIndex(int classIndex) {
		this.classIndex = classIndex;
	}

	public static class Builder {
		private String className;
		private int classIndex;
		private Rect box;
		private Point center;
		private Size imageSize;

		public Builder withClassIndex(int classIndex) {
			this.classIndex = classIndex;
			return this;
		}

		public Builder withBox(Rect box) {
			this.box = box;
			return this;
		}

		public Builder withCenter(Point center) {
			this.center = center;
			return this;
		}

		public Builder withClassName(String className) {
			this.className = className;
			return this;
		}

		public Builder withImageSize(Size imageSize) {
			this.imageSize = imageSize;
			return this;
		}

		public BoxPrediction build() {
			BoxPrediction pred = new BoxPrediction();
			pred.setImageSize(this.imageSize);
			pred.setClassName(this.className);
			pred.setClassIndex(this.classIndex);
			pred.setBox(this.box);
			pred.setCenter(this.center);
			return pred;
		}

	}

}