创建项目
先捋清楚 java tensorflow 的依赖关系
Maven集成 Tensorflow
以前是引入 tensorflow 里面是 libtensorflow 和 tensorflow_jni
最新的 tensorflow java 独立于 tensorflow仓库进行更新
引入要改为 tensorflow-core-platform 里面其实就是 tensorflow-java 和 各个平台的 javacpp
就是 java的Api 和TensorFlow大版本没有关系了 tensorflow java 0.3.0 -> 0.4.0 都是 2.4.1的Tensoflow
以前的 libtensorflow 版本是跟着Tensoflow大版本更新的
HelloTensorFlow
https://www.tensorflow.org/install/lang_java?hl=zh_cn
官网给的第一个例子
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| import org.tensorflow.ConcreteFunction; import org.tensorflow.Signature; import org.tensorflow.Tensor; import org.tensorflow.TensorFlow; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.math.Add; import org.tensorflow.types.TInt32;
public class HelloTensorFlow {
public static void main(String[] args) throws Exception { System.out.println("Hello TensorFlow " + TensorFlow.version());
try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl); Tensor<TInt32> x = TInt32.scalarOf(10); Tensor<TInt32> dblX = dbl.call(x).expect(TInt32.DTYPE)) { System.out.println(x.data().getInt() + " doubled is " + dblX.data().getInt()); } }
private static Signature dbl(Ops tf) { Placeholder<TInt32> x = tf.placeholder(TInt32.DTYPE); Add<TInt32> dblX = tf.math.add(x, x); return Signature.builder().input("x", x).output("dbl", dblX).build(); } }
|
这个例子就已经出错了 正确的应该是:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| public class HelloTensorFlow { public static void main(String[] args) throws Exception { System.out.println("Hello TensorFlow " + TensorFlow.version());
try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl); TInt32 x = TInt32.scalarOf(10); Tensor dblX = dbl.call(x)) { System.out.println(x.getInt() + " doubled is " + ((TInt32) dblX).getInt()); } }
private static Signature dbl(Ops tf) { Placeholder<TInt32> x = tf.placeholder(TInt32.class); Add<TInt32> dblX = tf.math.add(x, x); return Signature.builder().input("x", x).output("dbl", dblX).build(); } }
|
使用MAVEN时需要添加
-Djavacpp.platform=linux-x86_64
指定目标平台
不然 windows linux macos x86 x64 一堆jar都会被导入
源码
预览: