1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
{ jax , jaxlib , pkgs }: pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } '' import jax from jax import random assert jax.devices()[0].platform == "gpu" rng = random.PRNGKey(0) x = random.normal(rng, (100, 100)) x @ x print("success!") ''