about summary refs log tree commit diff
path: root/nixpkgs/pkgs/development/python-modules/jax/test-cuda.nix
blob: d156061f38495cf3003fa8a491516b1b1ca4ef88 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{ jax
, jaxlib
, pkgs
}:

pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } ''
  import jax
  from jax import random

  assert jax.devices()[0].platform == "gpu"

  rng = random.PRNGKey(0)
  x = random.normal(rng, (100, 100))
  x @ x

  print("success!")
''