about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/libraries/science/math/libtorch/test
diff options
context:
space:
mode:
Diffstat (limited to 'nixpkgs/pkgs/development/libraries/science/math/libtorch/test')
-rw-r--r--nixpkgs/pkgs/development/libraries/science/math/libtorch/test/CMakeLists.txt5
-rw-r--r--nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix22
-rw-r--r--nixpkgs/pkgs/development/libraries/science/math/libtorch/test/test.cpp20
3 files changed, 47 insertions, 0 deletions
diff --git a/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/CMakeLists.txt b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/CMakeLists.txt
new file mode 100644
index 000000000000..4e96704a4c17
--- /dev/null
+++ b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/CMakeLists.txt
@@ -0,0 +1,5 @@
+cmake_minimum_required(VERSION 3.0)
+find_package(Torch REQUIRED)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
+add_executable(test test.cpp)
+target_link_libraries(test "${TORCH_LIBRARIES}")
diff --git a/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix
new file mode 100644
index 000000000000..e69807871f46
--- /dev/null
+++ b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/default.nix
@@ -0,0 +1,22 @@
+{ stdenv, cmake, libtorch-bin, symlinkJoin }:
+
+stdenv.mkDerivation {
+  pname = "libtorch-test";
+  version = libtorch-bin.version;
+
+  src = ./.;
+
+  nativeBuildInputs = [ cmake ];
+
+  buildInputs = [ libtorch-bin ];
+
+  doCheck = true;
+
+  installPhase = ''
+    touch $out
+  '';
+
+  checkPhase = ''
+    ./test
+  '';
+}
diff --git a/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/test.cpp b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/test.cpp
new file mode 100644
index 000000000000..ca238fba521d
--- /dev/null
+++ b/nixpkgs/pkgs/development/libraries/science/math/libtorch/test/test.cpp
@@ -0,0 +1,20 @@
+#undef NDEBUG
+#include <cassert>
+
+#include <iostream>
+
+#include <torch/torch.h>
+
+int main() {
+  torch::Tensor tensor = torch::eye(3);
+
+  float checkData[] = {
+    1, 0, 0,
+    0, 1, 0,
+    0, 0, 1
+  };
+
+  torch::Tensor check = torch::from_blob(checkData, {3, 3});
+
+  assert(tensor.allclose(check));
+}