基于SpringBoot2.7.1的深度学习库DL4J的开发环境测试功能

基于SpringBoot2.7.1的深度学习库DL4J的开发环境测试功能
SpringBoot2.7.1
JDK11
DL4J版本 1.0.0-M2.1
测试结果效果图:
在这里插入图片描述
训练数据loan_data.csv
700,50000,20000,1,1
650,45000,15000,1,1
600,30000,25000,0,0
720,60000,22000,1,1
580,29000,18000,0,0

训练模型的代码:LoanApprovalModel.java

package com.cwgis.pg.dl4j;

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
//import org.nd4j.linalg.dataset.api.iterator.RecordReaderDataSetIterator;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;

import java.io.File;

public class LoanApprovalModel {
    public  void test(String[] args) throws Exception {
        // Load dataset
        int numLinesToSkip = 0;
        char delimiter = ',';
        CSVRecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
        recordReader.initialize(new FileSplit(new ClassPathResource("loan_data.csv").getFile()));

        int labelIndex = 4; // Index of the label (approve/reject)
        int numClasses = 2; // Approve or Reject
        int batchSize = 5;
        DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);

        // Normalize the data
        DataNormalization normalizer = new NormalizerStandardize();
        normalizer.fit(iterator);
        iterator.setPreProcessor(normalizer);

        // Define the network configuration
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(1000)
                //.iterations(1000)
                .activation(Activation.RELU)
                .weightInit(WeightInit.XAVIER)
                //.learningRate(0.01)
                .l2(0.01)
                .list()
                .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX)
                        .nIn(3).nOut(2).build())
                //.backprop(true)
                .backpropType(BackpropType.Standard)
                //.pretrain(false)?
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(100));

        // 训练模型
        for (int i = 0; i < 1000; i++) {
            iterator.reset();
            model.fit(iterator);
        }

        // 保存模型
        model.save(new File("loan_approval_model.zip"), true);
        System.out.println("生成完毕loan_approval_model.zip");
    }
}


loanApplication.java

package com.cwgis.pg.dl4j;

import lombok.Data;

@Data
public class LoanApplication {
    private double creditScore;
    private double income;
    private double loanAmount;
    private int employmentStatus;

    // Getters and setters
}

预测rest api功能
LoanApprovalController.java

package com.cwgis.pg.dl4j;


import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.web.bind.annotation.*;

import java.io.File;
import java.io.IOException;

@RestController
@RequestMapping("/ai/loan")
public class LoanApprovalController {

    private MultiLayerNetwork model;

    public LoanApprovalController() throws IOException {
        // Load the trained model
        File f=new File("loan_approval_model.zip");
        System.out.println(f.getAbsolutePath());
        //D:\project\svn\puri\galaxy-ai\galaxy-ai-dl4j\loan_approval_model.zip
        model = MultiLayerNetwork.load(f, true);
    }

    @PostMapping("/approve")
    public String approveLoan(@RequestBody LoanApplication loanApplication) {
        // Prepare input data
        INDArray input = Nd4j.create(new double[]{
                loanApplication.getCreditScore(),
                loanApplication.getIncome(),
                loanApplication.getLoanAmount(),
                loanApplication.getEmploymentStatus()
        }, 1, 4);

        // Make prediction
        INDArray output = model.output(input);
        int prediction = Nd4j.argMax(output, 1).getInt(0);

        return prediction == 1 ? "Approved" : "Rejected";
    }


}

环境配置参数
pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>

  <groupId>com.cwgis</groupId>
  <artifactId>galaxy-ai-dl4j</artifactId>
  <version>1.0-SNAPSHOT</version>

  <name>galaxy-ai-dl4j</name>
  <!-- FIXME change it to the project's website -->
  <url>http://www.cwgis.com</url>

  <properties>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
    <maven.compiler.source>11</maven.compiler.source>
    <maven.compiler.release>11</maven.compiler.release>

    <spring-boot.version>2.7.1</spring-boot.version>

    <dl4j.version>1.0.0-M2.1</dl4j.version>

  </properties>

  <dependencyManagement>
    <dependencies>
      <dependency>
        <groupId>org.junit</groupId>
        <artifactId>junit-bom</artifactId>
        <version>5.11.0</version>
        <type>pom</type>
        <scope>import</scope>
      </dependency>


    </dependencies>
  </dependencyManagement>

  <dependencies>
    <dependency>
      <groupId>org.junit.jupiter</groupId>
      <artifactId>junit-jupiter-api</artifactId>
      <scope>test</scope>
    </dependency>
    <!-- Optionally: parameterized tests support -->
    <dependency>
      <groupId>org.junit.jupiter</groupId>
      <artifactId>junit-jupiter-params</artifactId>
      <scope>test</scope>
    </dependency>

    <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-dependencies</artifactId>
      <version>${spring-boot.version}</version>
      <type>pom</type>
      <scope>import</scope>
    </dependency>

    <!--Spring Boot Dependencies -->
    <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-starter-web</artifactId>
      <version>${spring-boot.version}</version>
    </dependency>
    <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-devtools</artifactId>
      <version>${spring-boot.version}</version>
      <scope>runtime</scope>
      <optional>true</optional>
    </dependency>
    <dependency>
      <groupId>org.springframework.boot</groupId>
      <artifactId>spring-boot-starter-test</artifactId>
      <version>${spring-boot.version}</version>
      <scope>test</scope>
    </dependency>

    <!--AI Deeplearning4j Dependencies -->
    <dependency>
      <groupId>org.deeplearning4j</groupId>
      <artifactId>deeplearning4j-core</artifactId>
      <version>${dl4j.version}</version>
    </dependency>
    <dependency>
      <groupId>org.deeplearning4j</groupId>
      <artifactId>deeplearning4j-nn</artifactId>
      <version>${dl4j.version}</version>
    </dependency>
    <dependency>
      <groupId>org.nd4j</groupId>
      <artifactId>nd4j-native-platform</artifactId>
      <version>${dl4j.version}</version>
    </dependency>

    <!-- DataVec (for CSV reading) -->
    <dependency>
      <groupId>org.datavec</groupId>
      <artifactId>datavec-api</artifactId>
      <version>${dl4j.version}</version>
    </dependency>
    <dependency>
      <groupId>org.datavec</groupId>
      <artifactId>datavec-local</artifactId>
      <version>${dl4j.version}</version>
    </dependency>

    <!--DL4J核心包-->
    <dependency>
      <groupId>org.deeplearning4j</groupId>
      <artifactId>deeplearning4j-core</artifactId>
      <version>1.0.0-M2.1</version>
    </dependency>

    <!--DL4J GPU计算包-->
    <dependency>
      <groupId>org.nd4j</groupId>
      <artifactId>nd4j-cuda-11.6-platform</artifactId>
      <version>1.0.0-M2.1</version>
    </dependency>

    <!--DL4J 模型库包-->
    <dependency>
      <groupId>org.deeplearning4j</groupId>
      <artifactId>deeplearning4j-zoo</artifactId>
      <version>1.0.0-M2.1</version>
    </dependency>

    <dependency>
      <groupId>org.projectlombok</groupId>
      <artifactId>lombok</artifactId>
      <version>RELEASE</version>
      <scope>compile</scope>
    </dependency>



  </dependencies>

  <build>
    <pluginManagement><!-- lock down plugins versions to avoid using Maven defaults (may be moved to parent pom) -->
      <plugins>
        <!-- clean lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#clean_Lifecycle -->
        <plugin>
          <artifactId>maven-clean-plugin</artifactId>
          <version>3.4.0</version>
        </plugin>
        <!-- default lifecycle, jar packaging: see https://maven.apache.org/ref/current/maven-core/default-bindings.html#Plugin_bindings_for_jar_packaging -->
        <plugin>
          <artifactId>maven-resources-plugin</artifactId>
          <version>3.3.1</version>
        </plugin>
        <plugin>
          <artifactId>maven-compiler-plugin</artifactId>
          <version>3.13.0</version>
        </plugin>
        <plugin>
          <artifactId>maven-surefire-plugin</artifactId>
          <version>3.3.0</version>
        </plugin>
        <plugin>
          <artifactId>maven-jar-plugin</artifactId>
          <version>3.4.2</version>
        </plugin>
        <plugin>
          <artifactId>maven-install-plugin</artifactId>
          <version>3.1.2</version>
        </plugin>
        <plugin>
          <artifactId>maven-deploy-plugin</artifactId>
          <version>3.1.2</version>
        </plugin>
        <!-- site lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#site_Lifecycle -->
        <plugin>
          <artifactId>maven-site-plugin</artifactId>
          <version>3.12.1</version>
        </plugin>
        <plugin>
          <artifactId>maven-project-info-reports-plugin</artifactId>
          <version>3.6.1</version>
        </plugin>
      </plugins>
    </pluginManagement>
  </build>

  <repositories>
    <repository>
      <id>spring-milestones</id>
      <name>Spring Milestones</name>
      <url>https://repo.spring.io/milestone</url>
      <snapshots>
        <enabled>false</enabled>
      </snapshots>
    </repository>
    <repository>
      <id>aliyun-central</id>
      <name>aliyun-central</name>
      <url>https://maven.aliyun.com/repository/central</url>
    </repository>
    <repository>
      <id>aliyun-public</id>
      <name>aliyun-public</name>
      <url>https://maven.aliyun.com/repository/public</url>
    </repository>
    <repository>
      <id>aliyun-snapshots</id>
      <name>aliyun-snapshots</name>
      <url>https://maven.aliyun.com/repository/apache-snapshots</url>
    </repository>
    <repository>
      <id>aliyun-plugin</id>
      <name>aliyun-plugin</name>
      <url>https://maven.aliyun.com/repository/gradle-plugin</url>
    </repository>
    <repository>
      <id>osgeo</id>
      <name>OSGeo Release Repository</name>
      <url>https://repo.osgeo.org/repository/release/</url>
    </repository>
  </repositories>
</project>

添加阿里maven仓库地址

<repositories>
    <repository>
      <id>spring-milestones</id>
      <name>Spring Milestones</name>
      <url>https://repo.spring.io/milestone</url>
      <snapshots>
        <enabled>false</enabled>
      </snapshots>
    </repository>
    <repository>
      <id>aliyun-central</id>
      <name>aliyun-central</name>
      <url>https://maven.aliyun.com/repository/central</url>
    </repository>
    <repository>
      <id>aliyun-public</id>
      <name>aliyun-public</name>
      <url>https://maven.aliyun.com/repository/public</url>
    </repository>
    <repository>
      <id>aliyun-snapshots</id>
      <name>aliyun-snapshots</name>
      <url>https://maven.aliyun.com/repository/apache-snapshots</url>
    </repository>
    <repository>
      <id>aliyun-plugin</id>
      <name>aliyun-plugin</name>
      <url>https://maven.aliyun.com/repository/gradle-plugin</url>
    </repository>
    <repository>
      <id>osgeo</id>
      <name>OSGeo Release Repository</name>
      <url>https://repo.osgeo.org/repository/release/</url>
    </repository>
  </repositories>

本blog地址:https://blog.csdn.net/hsg77

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值