/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import lombok.var;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.commons.math3.util.FastMath;
import org.junit.*;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.imports.TFGraphs.NodeReader;
import org.nd4j.linalg.api.blas.params.GemmParams;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.api.iter.INDArrayIterator;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.custom.Flatten;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.broadcast.*;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
import org.nd4j.linalg.api.ops.impl.reduce3.*;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.MathUtils;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import static org.junit.Assert.*;
import static org.junit.Assert.assertArrayEquals;
/**
* NDArrayTests
*
* @author Adam Gibson
*/
@Slf4j
@RunWith(Parameterized.class)
public class Nd4jTestsC extends BaseNd4jTest {
DataType initialType;
public Nd4jTestsC(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@Before
public void before() throws Exception {
super.before();
Nd4j.setDataType(DataType.DOUBLE);
Nd4j.getRandom().setSeed(123);
Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false);
}
@After
public void after() throws Exception {
super.after();
Nd4j.setDataType(initialType);
}
@Test
public void testArangeNegative() {
INDArray arr = Nd4j.arange(-2,2);
INDArray assertion = Nd4j.create(new double[]{-2, -1, 0, 1});
assertEquals(assertion,arr);
}
@Test
public void testTri() {
INDArray assertion = Nd4j.create(new double[][]{
{1,1,1,0,0},
{1,1,1,1,0},
{1,1,1,1,1}
});
INDArray tri = Nd4j.tri(3,5,2);
assertEquals(assertion,tri);
}
@Test
public void testTriu() {
INDArray input = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(4,3);
int k = -1;
INDArray test = Nd4j.triu(input,k);
INDArray create = Nd4j.create(new double[][]{
{1,2,3},
{4,5,6},
{0,8,9},
{0,0,12}
});
assertEquals(test,create);
}
@Test
public void testDiag() {
INDArray diag = Nd4j.diag(Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(4,1));
assertArrayEquals(new long[] {4,4},diag.shape());
}
@Test
public void testGetRowEdgeCase() {
INDArray orig = Nd4j.linspace(1,300,300, DataType.DOUBLE).reshape('c', 100, 3);
INDArray col = orig.getColumn(0).reshape(100, 1);
for( int i = 0; i < 100; i++) {
INDArray row = col.getRow(i);
INDArray rowDup = row.dup();
double d = orig.getDouble(i, 0);
double d2 = col.getDouble(i);
double dRowDup = rowDup.getDouble(0);
double dRow = row.getDouble(0);
String s = String.valueOf(i);
assertEquals(s, d, d2, 0.0);
assertEquals(s, d, dRowDup, 0.0); //Fails
assertEquals(s, d, dRow, 0.0); //Fails
}
}
@Test
public void testNd4jEnvironment() {
System.out.println(Nd4j.getExecutioner().getEnvironmentInformation());
int manualNumCores = Integer.parseInt(Nd4j.getExecutioner().getEnvironmentInformation()
.get(Nd4jEnvironment.CPU_CORES_KEY).toString());
assertEquals(Runtime.getRuntime().availableProcessors(), manualNumCores);
assertEquals(Runtime.getR