神经网络在连续函数逼近与对象分类中的应用
立即解锁
发布时间: 2025-08-31 01:30:10 阅读量: 3 订阅数: 5 AIGC 

### 神经网络在连续函数逼近与对象分类中的应用
#### 1. 连续函数逼近测试结果
在连续函数逼近的测试中,采用微批量方法得到了如下结果:
- 最大误差百分比(maxErrorPerc)为 9.002677165459051E - 6,即最大误差小于 0.00000900%。
- 平均误差百分比(averErrorPerc)为 4.567068981414947E - 6,即平均误差小于 0.00000457%。
测试结果的图表显示,实际值(黑色)和预测值(白色)的图表几乎重叠,这表明微批量方法能够高精度地逼近具有复杂拓扑结构的连续函数。
#### 2. 对象分类概述
对象分类是识别各种对象并确定其所属类别的任务。在许多人工智能领域,人类能够轻松完成分类任务,但计算机实现起来却颇具难度。
#### 3. 分类示例
以书籍分类为例,有五本不同领域的书籍,每本书给出三个最常用的词汇:
| 书籍领域 | 最常用的三个词汇 |
| ---- | ---- |
| 医学 | surgery, blood, prescription |
| 编程 | file, java, debugging |
| 工程 | combustion, screw, machine |
| 电气 | volt, solenoid, diode |
| 音乐 | adagio, hymn, opera |
此外,还有一些额外的词汇用于测试数据集,如 customer、wind 等。为了简化处理,给所有词汇分配了编号,如下表所示:
| 词汇 | 分配编号 |
| ---- | ---- |
| Surgery | 1 |
| Blood | 2 |
| Prescription | 3 |
| File | 4 |
| Java | 5 |
| Debugging | 6 |
| Combustion | 7 |
| screw | 8 |
| machine | 9 |
| Volt | 10 |
| solenoid | 11 |
| diode | 12 |
| adagio | 13 |
| hymn | 14 |
| opera | 15 |
| customer | 16 |
| wind | 17 |
| grass | 18 |
| paper | 19 |
| calculator | 20 |
| flower | 21 |
| printer | 22 |
| desk | 23 |
| photo | 24 |
| map | 25 |
| pen | 26 |
| floor | 27 |
#### 4. 训练数据集
训练数据集的每条记录包含三个词汇编号字段和五个目标字段,目标字段用于指示记录所属的书籍。对于每本书,需要构建六条记录,包含所有可能的词汇排列组合。以下是训练数据集的部分示例:
| Word1 | Word2 | Word3 | Target1 | Target2 | Target3 | Target4 | Target5 |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| 1 | 2 | 3 | 1 | 0 | 0 | 0 | 0 |
| 1 | 3 | 2 | 1 | 0 | 0 | 0 | 0 |
| 2 | 1 | 3 | 1 | 0 | 0 | 0 | 0 |
|... |... |... |... |... |... |... |... |
#### 5. 网络架构
网络架构如下:
- 输入层:包含三个输入神经元。
- 隐藏层:六个隐藏层,每层有七个神经元。
- 输出层:包含五个神经元。
mermaid 格式流程图如下:
```mermaid
graph LR
classDef startend fill:#F5EBFF,stroke:#BE8FED,stroke-width:2px;
classDef process fill:#E5F6FF,stroke:#73A6FF,stroke-width:2px;
A([输入层: 3 个神经元]):::startend --> B(隐藏层 1: 7 个神经元):::process
B --> C(隐藏层 2: 7 个神经元):::process
C --> D(隐藏层 3: 7 个神经元):::process
D --> E(隐藏层 4: 7 个神经元):::process
E --> F(隐藏层 5: 7 个神经元):::process
F --> G(隐藏层 6: 7 个神经元):::process
G --> H([输出层: 5 个神经元]):::startend
```
#### 6. 测试数据集
测试数据集由随机包含词汇/编号的记录组成,这些记录不属于任何单一书籍,尽管有些记录包含了最常用词汇列表中的一两个词汇。以下是测试数据集的部分示例:
| Word1 | Word2 | Word3 | Target 1 | Target 2 | Target 3 | Target 4 | Target 5 |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| 1 | 2 | 16 | 0 | 0 | 0 | 0 | 0 |
| 4 | 17 | 5 | 0 | 0 | 0 | 0 | 0 |
| 8 | 9 | 18 | 0 | 0 | 0 | 0 | 0 |
|... |... |... |... |... |... |... |... |
#### 7. 数据归一化
为了便于处理,需要对训练和测试数据集进行归一化处理,将数据缩放到区间 [-1, 1]。以下是数据归一化的代码:
```java
package sample5_norm;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.PrintWriter;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.*;
public class Sample5_Norm
{
// Interval to normalize
static double Nh = 1;
static double Nl = -1;
// First column
static double minXPointDl = 1.00;
static double maxXPointDh = 1000.00;
// Second column - target data
static double minTargetValueDl = 60.00;
static double maxTargetValueDh = 1600.00;
public static double normalize(double value, double Dh, double Dl)
{
double normalizedValue = (value - Dl)*(Nh - Nl)/(Dh - Dl) + Nl;
return normalizedValue;
}
public static void main(String[] args)
{
// Normalize train file
String inputFileName = "C:/Book_Examples/Sample5_Train_Real.csv";
String outputNormFileName = "C:/Book_Examples/Sample5_Train_Norm.csv";
// Normalize test file
//String inputFileName = "C:/Book_Examples/Sample5_Test_Real.csv";
//String outputNormFileName = "C:/Book_Examples/Sample5_Test_Norm.csv";
BufferedReader br = null;
PrintWriter out = null;
String line = "";
String cvsSplitBy = ",";
double inputXPointValue;
double targetXPointValue;
double normInputXPointValue;
double normTargetXPointValue;
String strNormInputXPointValue;
String strNormTargetXPointValue;
String fullLine;
int i = -1;
try
{
Files.deleteIfExists(Paths.get(outputNormFileName));
br = new BufferedReader(new FileReader(inputFileName));
out = new
PrintWriter(new BufferedWriter(new FileWriter(outputNormFileName)));
while ((line = br.readLine()) != null)
{
i++;
if(i == 0)
{
// Write the label line
out.println(line);
}
else
{
// Brake the line using comma as separator
String[] workFields = line.split(cvsSplitBy);
inputXPointValue = Double.parseDouble(workFields[0]);
targetXPointValue = Double.parseDouble( workFields[1]);
// Normalize these fields
normInputXPointValue =
normalize(inputXPointValue, maxXPointDh, minXPointDl);
normTargetXPointValue =
normalize(targetXPointValue, maxTargetValueDh, minTargetValueDl);
// Convert normalized fields to string, so they can be inserted
//into the output CSV file
strNormInputXPointValue = Double.toString(normInputXPointValue);
strNormTargetXPointValue = Double.toString(normTargetXPointValue);
// Concatenate these fields into a string line with
//coma separator
fullLine =
strNormInputXPointValue + "," + strNormTargetXPointValue;
// Put fullLine into the output file
out.println(fullLine);
} // End of IF Else
} // end of while
} // end of TRY
catch (FileNotFoundException e)
{
e.printStackTrace();
System.exit(1);
}
catch (IOException io)
{
io.printStackTrace();
}
finally
{
if (br != null)
{
try
{
br.close();
out.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
}
}
}
```
归一化后的训练数据集和测试数据集示例如下:
**归一化后的训练数据集**
| Word 1 | Word 2 | Word 3 | Target 1 | Target 2 | Target 3 | Target 4 | Target 5 |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| -1 | -0.966101695 | -0.93220339 | 1 | -1 | -1 | -1 | -1 |
| -1 | -0.93220339 | -0.966101695 | 1 | -1 | -1 | -1 | -1 |
| -0.966101695 | -1 | -0.93220339 | 1 | -1 | -1 | -1 | -1 |
|... |... |... |... |... |... |... |... |
**归一化后的测试数据集**
| Word1 | Word2 | Word3 | Target 1 | Target 2 | Target 3 | Target 4 | Target 5 |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| -1 | -0.966101695 | -0.491525424 | -1 | -1 | -1 | -1 | -1 |
| -0.898305085 | -0.457627119 | -0.86440678 | -1 | -1 | -1 | -1 | -1 |
| -0.762711864 | -0.728813559 | -0.423728814 | -1 | -1 | -1 | -1 | -1 |
|... |... |... |... |... |... |... |... |
#### 8. 分类程序代码
以下是使用神经网络进行对象分类的程序代码:
```java
package sample6;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.PrintWriter;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.*;
import java.util.Properties;
import java.time.YearMonth;
import java.awt.Color;
import java.awt.Font;
import java.io.BufferedReader;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.LocalDate;
import java.time.Month;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Date;
import java.util.List;
import java.util.Locale;
import java.util.Properties;
import org.encog.Encog;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.engine.network.activation.ActivationReLU;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.buffer.MemoryDataLoader;
import org.encog.ml.data.buffer.codec.CSVDataCODEC;
import org.encog.ml.data.buffer.codec.DataSetCODEC;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.csv.CSVFormat;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;
import org.knowm.xchart.XYSeries;
import org.knowm.xchart.demo.charts.ExampleChart;
import org.knowm.xchart.style.Styler.LegendPosition;
import org.knowm.xchart.style.colors.ChartColor;
import org.knowm.xchart.style.colors.XChartSeriesColors;
import org.knowm.xchart.style.lines.SeriesLines;
import org.knowm.xchart.style.markers.SeriesMarkers;
import org.knowm.xchart.BitmapEncoder;
import org.knowm.xchart.BitmapEncoder.BitmapFormat;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.SwingWrapper;
public class Sample6 implements ExampleChart<XYChart>
{
// Interval to normalize data
static double Nh;
static double Nl;
// Normalization parameters for workBook number
static double minWordNumberDl;
static double maxWordNumberDh;
// Normalization parameters for target values
static double minTargetValueDl;
static double maxTargetValueDh;
static double doublePointNumber = 0.00;
static int intPointNumber = 0;
static InputStream input = null;
static double[] arrPrices = new double[2500];
static double normInputWordNumber_01 = 0.00;
static double normInputWordNumber_02 = 0.00;
static double normInputWordNumber_03 = 0.00;
static double denormInputWordNumber_01 = 0.00;
static double denormInputWordNumber_02 = 0.00;
static double denormInputWordNumber_03 = 0.00;
static double normTargetBookNumber_01 = 0.00;
static double normTargetBookNumber_02 = 0.00;
static double normTargetBookNumber_03 = 0.00;
static double normTargetBookNumber_04 = 0.00;
static double normTargetBookNumber_05 = 0.00;
static double normPredictBookNumber_01 = 0.00;
static double normPredictBookNumber_02 = 0.00;
static double normPredictBookNumber_03 = 0.00;
static double normPredictBookNumber_04 = 0.00;
static double normPredictBookNumber_05 = 0.00;
static double denormTargetBookNumber_01 = 0.00;
static double denormTargetBookNumber_02 = 0.00;
static double denormTargetBookNumber_03 = 0.00;
static double denormTargetBookNumber_04 = 0.00;
static double denormTargetBookNumber_05 = 0.00;
static double denormPredictBookNumber_01 = 0.00;
static double denormPredictBookNumber_02 = 0.00;
static double denormPredictBookNumber_03 = 0.00;
static double denormPredictBookNumber_04 = 0.00;
static double denormPredictBookNumber_05 = 0.00;
static double normDifferencePerc = 0.00;
static double denormPredictXPointValue_01 = 0.00;
static double denormPredictXPointValue_02 = 0.00;
static double denormPredictXPointValue_03 = 0.00;
static double denormPredictXPointValue_04 = 0.00;
static double denormPredictXPointValue_05 = 0.00;
static double valueDifference = 0.00;
static int numberOfInputNeurons;
static int numberOfOutputNeurons;
static int intNumberOfRecordsInTestFile;
static String trainFileName;
static String priceFileName;
static String testFileName;
static String chartTrainFileName;
static String chartTestFileName;
static String networkFileName;
static int workingMode;
static String cvsSplitBy = ",";
static int returnCode;
static List<Double> xData = new ArrayList<Double>();
static List<Double> yData1 = new ArrayList<Double>();
static List<Double> yData2 = new ArrayList<Double>();
static XYChart Chart;
@Override
public XYChart getChart()
{
// Create Chart
Chart = new XYChartBuilder().width(900).height(500).title(getClass().
getSimpleName()).xAxisTitle("x").yAxisTitle("y= f(x)").build();
// Customize Chart
Chart.getStyler().setPlotBackgroundColor(ChartColor.
getAWTColor(ChartColor.GREY));
Chart.getStyler().setPlotGridLinesColor(new Color(255, 255, 255));
Chart.getStyler().setChartBackgroundColor(Color.WHITE);
Chart.getStyler().setLegendBackgroundColor(Color.PINK);
Chart.getStyler().setChartFontColor(Color.MAGENTA);
Chart.getStyler().setChartTitleBoxBackgroundColor(new Color(0, 222, 0));
Chart.getStyler().setChartTitleBoxVisible(true);
Chart.getStyler().setChartTitleBoxBorderColor(Color.BLACK);
Chart.getStyler().setPlotGridLinesVisible(true);
Chart.getStyler().setAxisTickPadding(20);
Chart.getStyler().setAxisTickMarkLength(15);
Chart.getStyler().setPlotMargin(20);
Chart.getStyler().setChartTitleVisible(false);
Chart.getStyler().setChartTitleFont(new Font(Font.MONOSPACED, Font.
BOLD, 24));
Chart.getStyler().setLegendFont(new Font(Font.SERIF, Font.PLAIN, 18));
Chart.getStyler().setLegendPosition(LegendPosition.InsideSE);
Chart.getStyler().setLegendSeriesLineLength(12);
Chart.getStyler().setAxisTitleFont(new Font(Font.SANS_SERIF, Font.ITALIC, 18));
Chart.getStyler().setAxisTickLabelsFont(new Font(Font.SERIF, Font.PLAIN, 11));
Chart.getStyler().setDatePattern("yyyy-MM");
Chart.getStyler().setDecimalPattern("#0.00");
// Interval to normalize data
Nh = 1;
Nl = -1;
// Normalization parameters for workBook number
double minWordNumberDl = 1.00;
double maxWordNumberDh = 60.00;
// Normalization parameters for target values
minTargetValueDl = 0.00;
maxTargetValueDh = 1.00;
// Configuration (comment and uncomment the appropriate configuration)
// For training the network
workingMode = 1;
intNumberOfRecordsInTestFile = 31;
trainFileName = "C:/My_Neural_Network_Book/Book_Examples/Sample6_Norm_
Train_File.csv";
// For testing the trained network at non-trained points
//workingMode = 2;
//intNumberOfRecordsInTestFile = 16;
//testFileName = "C:/My_Neural_Network_Book/Book_Examples/Sample6_Norm_
Test_File.csv";
networkFileName =
"C:/My_Neural_Network_Book/Book_Examples/Sample6_Saved_Network_File.csv";
numberOfInputNeurons = 3;
numberOfOutputNeurons = 5;
// Check the wo
```
0
0
复制全文
相关推荐










