TensorFlow Java

创建项目

先捋清楚 java tensorflow 的依赖关系

Maven集成 Tensorflow

以前是引入 tensorflow 里面是 libtensorflow 和 tensorflow_jni

最新的 tensorflow java 独立于 tensorflow仓库进行更新

引入要改为 tensorflow-core-platform 里面其实就是 tensorflow-java 和 各个平台的 javacpp

就是 java的Api 和TensorFlow大版本没有关系了 tensorflow java 0.3.0 -> 0.4.0 都是 2.4.1的Tensoflow

以前的 libtensorflow 版本是跟着Tensoflow大版本更新的

HelloTensorFlow

https://www.tensorflow.org/install/lang_java?hl=zh_cn

官网给的第一个例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;

public class HelloTensorFlow {

public static void main(String[] args) throws Exception {
System.out.println("Hello TensorFlow " + TensorFlow.version());

try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
Tensor<TInt32> x = TInt32.scalarOf(10);
Tensor<TInt32> dblX = dbl.call(x).expect(TInt32.DTYPE)) {
System.out.println(x.data().getInt() + " doubled is " + dblX.data().getInt());
}
}

private static Signature dbl(Ops tf) {
Placeholder<TInt32> x = tf.placeholder(TInt32.DTYPE);
Add<TInt32> dblX = tf.math.add(x, x);
return Signature.builder().input("x", x).output("dbl", dblX).build();
}
}

这个例子就已经出错了 正确的应该是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

public class HelloTensorFlow {
public static void main(String[] args) throws Exception {
System.out.println("Hello TensorFlow " + TensorFlow.version());

try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
TInt32 x = TInt32.scalarOf(10);
Tensor dblX = dbl.call(x)) {
System.out.println(x.getInt() + " doubled is " + ((TInt32) dblX).getInt());
}
}

private static Signature dbl(Ops tf) {
Placeholder<TInt32> x = tf.placeholder(TInt32.class);
Add<TInt32> dblX = tf.math.add(x, x);
return Signature.builder().input("x", x).output("dbl", dblX).build();
}
}

使用MAVEN时需要添加

-Djavacpp.platform=linux-x86_64

指定目标平台

不然 windows linux macos x86 x64 一堆jar都会被导入

源码