基于SpringBoot2.7.1的深度学习库DL4J的开发环境测试功能
SpringBoot2.7.1
JDK11
DL4J版本 1.0.0-M2.1
测试结果效果图:
训练数据loan_data.csv
700,50000,20000,1,1
650,45000,15000,1,1
600,30000,25000,0,0
720,60000,22000,1,1
580,29000,18000,0,0
训练模型的代码:LoanApprovalModel.java
package com.cwgis.pg.dl4j;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
//import org.nd4j.linalg.dataset.api.iterator.RecordReaderDataSetIterator;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import java.io.File;
public class LoanApprovalModel {
public void test(String[] args) throws Exception {
// Load dataset
int numLinesToSkip = 0;
char delimiter = ',';
CSVRecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new FileSplit(new ClassPathResource("loan_data.csv").getFile()));
int labelIndex = 4; // Index of the label (approve/reject)
int numClasses = 2; // Approve or Reject
int batchSize = 5;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
// Normalize the data
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(iterator);
iterator.setPreProcessor(normalizer);
// Define the network configuration
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(1000)
//.iterations(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
//.learningRate(0.01)
.l2(0.01)
.list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(2).build())
//.backprop(true)
.backpropType(BackpropType.Standard)
//.pretrain(false)?
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
// 训练模型
for (int i = 0; i < 1000; i++) {
iterator.reset();
model.fit(iterator);
}
// 保存模型
model.save(new File("loan_approval_model.zip"), true);
System.out.println("生成完毕loan_approval_model.zip");
}
}
loanApplication.java
package com.cwgis.pg.dl4j;
import lombok.Data;
@Data
public class LoanApplication {
private double creditScore;
private double income;
private double loanAmount;
private int employmentStatus;
// Getters and setters
}
预测rest api功能
LoanApprovalController.java
package com.cwgis.pg.dl4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.web.bind.annotation.*;
import java.io.File;
import java.io.IOException;
@RestController
@RequestMapping("/ai/loan")
public class LoanApprovalController {
private MultiLayerNetwork model;
public LoanApprovalController() throws IOException {
// Load the trained model
File f=new File("loan_approval_model.zip");
System.out.println(f.getAbsolutePath());
//D:\project\svn\puri\galaxy-ai\galaxy-ai-dl4j\loan_approval_model.zip
model = MultiLayerNetwork.load(f, true);
}
@PostMapping("/approve")
public String approveLoan(@RequestBody LoanApplication loanApplication) {
// Prepare input data
INDArray input = Nd4j.create(new double[]{
loanApplication.getCreditScore(),
loanApplication.getIncome(),
loanApplication.getLoanAmount(),
loanApplication.getEmploymentStatus()
}, 1, 4);
// Make prediction
INDArray output = model.output(input);
int prediction = Nd4j.argMax(output, 1).getInt(0);
return prediction == 1 ? "Approved" : "Rejected";
}
}
环境配置参数
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.cwgis</groupId>
<artifactId>galaxy-ai-dl4j</artifactId>
<version>1.0-SNAPSHOT</version>
<name>galaxy-ai-dl4j</name>
<!-- FIXME change it to the project's website -->
<url>http://www.cwgis.com</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.release>11</maven.compiler.release>
<spring-boot.version>2.7.1</spring-boot.version>
<dl4j.version>1.0.0-M2.1</dl4j.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.junit</groupId>
<artifactId>junit-bom</artifactId>
<version>5.11.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<!-- Optionally: parameterized tests support -->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${spring-boot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<!--Spring Boot Dependencies -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<version>${spring-boot.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId>
<version>${spring-boot.version}</version>
<scope>runtime</scope>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<version>${spring-boot.version}</version>
<scope>test</scope>
</dependency>
<!--AI Deeplearning4j Dependencies -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- DataVec (for CSV reading) -->
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!--DL4J核心包-->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<!--DL4J GPU计算包-->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.6-platform</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<!--DL4J 模型库包-->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>RELEASE</version>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<pluginManagement><!-- lock down plugins versions to avoid using Maven defaults (may be moved to parent pom) -->
<plugins>
<!-- clean lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#clean_Lifecycle -->
<plugin>
<artifactId>maven-clean-plugin</artifactId>
<version>3.4.0</version>
</plugin>
<!-- default lifecycle, jar packaging: see https://maven.apache.org/ref/current/maven-core/default-bindings.html#Plugin_bindings_for_jar_packaging -->
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.3.1</version>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.13.0</version>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.3.0</version>
</plugin>
<plugin>
<artifactId>maven-jar-plugin</artifactId>
<version>3.4.2</version>
</plugin>
<plugin>
<artifactId>maven-install-plugin</artifactId>
<version>3.1.2</version>
</plugin>
<plugin>
<artifactId>maven-deploy-plugin</artifactId>
<version>3.1.2</version>
</plugin>
<!-- site lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#site_Lifecycle -->
<plugin>
<artifactId>maven-site-plugin</artifactId>
<version>3.12.1</version>
</plugin>
<plugin>
<artifactId>maven-project-info-reports-plugin</artifactId>
<version>3.6.1</version>
</plugin>
</plugins>
</pluginManagement>
</build>
<repositories>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>aliyun-central</id>
<name>aliyun-central</name>
<url>https://maven.aliyun.com/repository/central</url>
</repository>
<repository>
<id>aliyun-public</id>
<name>aliyun-public</name>
<url>https://maven.aliyun.com/repository/public</url>
</repository>
<repository>
<id>aliyun-snapshots</id>
<name>aliyun-snapshots</name>
<url>https://maven.aliyun.com/repository/apache-snapshots</url>
</repository>
<repository>
<id>aliyun-plugin</id>
<name>aliyun-plugin</name>
<url>https://maven.aliyun.com/repository/gradle-plugin</url>
</repository>
<repository>
<id>osgeo</id>
<name>OSGeo Release Repository</name>
<url>https://repo.osgeo.org/repository/release/</url>
</repository>
</repositories>
</project>
添加阿里maven仓库地址
<repositories>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>aliyun-central</id>
<name>aliyun-central</name>
<url>https://maven.aliyun.com/repository/central</url>
</repository>
<repository>
<id>aliyun-public</id>
<name>aliyun-public</name>
<url>https://maven.aliyun.com/repository/public</url>
</repository>
<repository>
<id>aliyun-snapshots</id>
<name>aliyun-snapshots</name>
<url>https://maven.aliyun.com/repository/apache-snapshots</url>
</repository>
<repository>
<id>aliyun-plugin</id>
<name>aliyun-plugin</name>
<url>https://maven.aliyun.com/repository/gradle-plugin</url>
</repository>
<repository>
<id>osgeo</id>
<name>OSGeo Release Repository</name>
<url>https://repo.osgeo.org/repository/release/</url>
</repository>
</repositories>
本blog地址:https://blog.csdn.net/hsg77