/*
* Encog(tm) Core v3.3 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2014 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.neural.rbf;
import org.encog.EncogError;
import org.encog.mathutil.randomize.ConsistentRandomizer;
import org.encog.mathutil.randomize.RangeRandomizer;
import org.encog.mathutil.rbf.GaussianFunction;
import org.encog.mathutil.rbf.InverseMultiquadricFunction;
import org.encog.mathutil.rbf.MultiquadricFunction;
import org.encog.mathutil.rbf.RBFEnum;
import org.encog.mathutil.rbf.RadialBasisFunction;
import org.encog.ml.BasicML;
import org.encog.ml.MLEncodable;
import org.encog.ml.MLError;
import org.encog.ml.MLRegression;
import org.encog.ml.MLResettable;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.flat.FlatNetworkRBF;
import org.encog.neural.networks.ContainsFlat;
import org.encog.util.EngineArray;
import org.encog.util.Format;
import org.encog.util.simple.EncogUtility;
/**
* RBF neural network.
*
*/
public class RBFNetwork extends BasicML implements MLError, MLRegression,
ContainsFlat, MLResettable, MLEncodable {
/**
* Serial id.
*/
private static final long serialVersionUID = 1L;
/**
* The underlying flat network.
*/
private final FlatNetworkRBF flat;
/**
* Construct RBF network.
*/
public RBFNetwork() {
this.flat = new FlatNetworkRBF();
}
/**
* Construct RBF network.
*
* @param inputCount
* The input count.
* @param hiddenCount
* The hidden count.
* @param outputCount
* The output count.
* @param t
* The RBF type.
*/
public RBFNetwork(final int inputCount, final int hiddenCount,
final int outputCount, final RBFEnum t) {
if (hiddenCount == 0) {
throw new NeuralNetworkError(
"RBF network cannot have zero hidden neurons.");
}
final RadialBasisFunction[] rbf = new RadialBasisFunction[hiddenCount];
// Set the standard RBF neuron width.
// Literature seems to suggest this is a good default value.
final double volumeNeuronWidth = 2.0 / hiddenCount;
this.flat = new FlatNetworkRBF(inputCount, rbf.length, outputCount, rbf);
try {
// try this
setRBFCentersAndWidthsEqualSpacing(-1, 1, t, volumeNeuronWidth,
false);
} catch (final EncogError ex) {
// if we have the wrong number of hidden neurons, try this
randomizeRBFCentersAndWidths(-1, 1, t);
}
}
/**
* Construct RBF network.
*
* @param inputCount
* The input count.
* @param outputCount
* The output count.
* @param rbf
* The RBF type.
*/
public RBFNetwork(final int inputCount, final int outputCount,
final RadialBasisFunction[] rbf) {
this.flat = new FlatNetworkRBF(inputCount, rbf.length, outputCount, rbf);
this.flat.setRBF(rbf);
}
/**
* Calculate the error for this neural network.
*
* @param data
* The training set.
* @return The error percentage.
*/
@Override
public double calculateError(final MLDataSet data) {
return EncogUtility.calculateRegressionError(this, data);
}
/**
* {@inheritDoc}
*/
@Override
public MLData compute(final MLData input) {
final MLData output = new BasicMLData(getOutputCount());
this.flat.compute(input.getData(), output.getData());
return output;
}
/**
* {@inheritDoc}
*/
@Override
public FlatNetwork getFlat() {
return this.flat;
}
/**
* {@inheritDoc}
*/
@Override
public int getInputCount() {
return this.flat.getInputCount();
}
/**
* {@inheritDoc}
*/
@Override
public int getOutputCount() {
return this.flat.getOutputCount();
}
/**
* Get the RBF's.
*
* @return The RBF's.
*/
public RadialBasisFunction[] getRBF() {
return this.flat.getRBF();
}
/**
* Set the RBF components to random values.
*
* @param min
* Minimum random value.
* @param max
* Max random value.
* @param t
* The type of RBF to use.
*/
public void randomizeRBFCentersAndWidths(final double min,
final double max, final RBFEnum t) {
final int dimensions = getInputCount();
final double[] centers = new double[dimensions];
for (int i = 0; i < dimensions; i++) {
centers[i] = RangeRandomizer.randomize(min, max);
}
for (int i = 0; i < this.flat.getRBF().length; i++) {
setRBFFunction(i, t, centers, RangeRandomizer.randomize(min, max));
}
}
/**
* Set the RBF's.
*
* @param rbf
* The RBF's.
*/
public void setRBF(final RadialBasisFunction[] rbf) {
this.flat.setRBF(rbf);
}
/**
* Array containing center position. Row n contains centers for neuron n.
* Row n contains x elements for x number of dimensions.
*
* @param centers
* The centers.
* @param widths
* Array containing widths. Row n contains widths for neuron n.
* Row n contains x elements for x number of dimensions.
* @param t
* The RBF Function to use for this layer.
*/
public void setRBFCentersAndWidths(final double[][] centers,
final double[] widths, final RBFEnum t) {
for (int i = 0; i < this.flat.getRBF().length; i++) {
setRBFFunction(i, t, centers[i], widths[i]);
}
}
/**
* Equally spaces all hidden neurons within the n dimensional variable
* space.
*
* @param minPosition
* The minimum position neurons should be centered. Typically 0.
* @param maxPosition
* The maximum position neurons should be centered. Typically 1
* @param volumeNeuronRBFWidth
* The neuron width of neurons within the mesh.
* @param useWideEdgeRBFs
* Enables wider RBF's around the boundary of the neuron mesh.
*/
public void setRBFCentersAndWidthsEqualSpacing(final double minPosition,
final double maxPosition, final RBFEnum t,
final double volumeNeuronRBFWidth, final boolean useWideEdgeRBFs) {
final int totalNumHiddenNeurons = this.flat.getRBF().length;
final int dimensions = getInputCount();
final double disMinMaxPosition = Math.abs(maxPosition - minPosition);
// Check to make sure we have the correct number of neurons for the
// provided dimensions
final int expectedSideLength = (int) Math.pow(totalNumHiddenNeurons,
1.0 / dimensions);
final double cmp = Math.pow(totalNumHiddenNeurons, 1.0 / dimensions);
if (expectedSideLength != cmp) {
throw new NeuralNetworkError(
"Total number of RBF neurons must be some integer to the power of 'dimensions'.\n"
+ Format.formatDouble(expectedSideLength, 5)
+ " <> " + Format.formatDouble(cmp, 5));
}
final double edgeNeuronRBFWidth = 2.5 * volumeNeuronRBFWidth;
final double[][] centers = new double[totalNumHiddenNeurons][];
final double[] widths = new double[totalNumHiddenNeurons];
for (int i = 0; i < totalNumHiddenNeurons; i++) {
cen
rbf.rar_RBF
版权申诉
105 浏览量
2022-09-14
22:16:38
上传
评论
收藏 3KB RAR 举报
御道御小黑
- 粉丝: 61
- 资源: 1万+
最新资源
- VIVADO中UART IP核使用
- 【深度学习实际案例解析】深度学习实际案例解析
- 封装swagger组件,提供全新UI以及无状态登录接口调用解决方案
- 小龙坎支局2024年4月渠道积分核对数据.xlam
- onlyoffice搭建及与alist使用的view.html
- Quadcopter-UAV-attitude-estimation-linux常用命令大全demo
- Quadcopter-UAV-attitude-estimation-based-on-数据库课程设计
- pbdlib-python-master.zip
- 43904245495352013_base.apk
- 基于springboot+vue + redis的工作流审批系统
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
评论0