<div align="center">
<img src="https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png" alt="logo"></img>
</div>
# JAX: Autograd and XLA
JAX is [Autograd](https://github.com/hips/autograd) and
[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md),
brought together for high-performance machine learning research.
With its updated version of [Autograd](https://github.com/hips/autograd),
JAX can automatically differentiate native
Python and NumPy functions. It can differentiate through loops, branches,
recursion, and closures, and it can take derivatives of derivatives of
derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
and the two can be composed arbitrarily to any order.
What’s new is that JAX uses
[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md)
to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
under the hood by default, with library calls getting just-in-time compiled and
executed. But JAX also lets you just-in-time compile your own Python functions
into XLA-optimized kernels using a one-function API,
[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
composed arbitrarily, so you can express sophisticated algorithms and get
maximal performance without leaving Python.
Dig a little deeper, and you'll see that JAX is really an extensible system for
[composable transformations of functions](#transformations). Both
[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
are instances of such transformations. Another is [`vmap`](#auto-vectorization-with-vmap)
for automatic vectorization, with more to come.
This is a research project, not an official Google product. Expect bugs and
sharp edges. Please help by trying it out, [reporting
bugs](https://github.com/google/jax/issues), and letting us know what you
think!
```python
import jax.numpy as np
from jax import grad, jit, vmap
from functools import partial
def predict(params, inputs):
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = np.tanh(outputs)
return outputs
def logprob_fun(params, inputs, targets):
preds = predict(params, inputs)
return np.sum((preds - targets)**2)
grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grads
```
JAX started as a research project by [Matt Johnson](https://github.com/mattjj),
[Roy Frostig](https://github.com/froystig), [Dougal
Maclaurin](https://github.com/dougalm), and [Chris
Leary](https://github.com/learyg), and is now developed [in the
open](https://github.com/google/jax) by a growing number of
[contributors](#contributors).
### Contents
* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
* [Installation](#installation)
* [A brief tour](#a-brief-tour)
* [What's supported](#whats-supported)
* [Transformations](#transformations)
* [Random numbers are different](#random-numbers-are-different)
* [Mini-libraries](#mini-libraries)
* [How it works](#how-it-works)
* [What we're working on](#what-were-working-on)
* [Current gotchas](#current-gotchas)
## Quickstart: Colab in the Cloud
Jump right in using a notebook in your browser, connected to a Google Cloud GPU:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://colab.research.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb)
- [Training a Simple Neural Network, with PyTorch Data Loading](https://colab.research.google.com/github/google/jax/blob/master/notebooks/neural_network_and_data_loading.ipynb)
## Installation
JAX is written in pure Python, but it depends on XLA, which needs to be
compiled and installed as the `jaxlib` package. Use the following instructions
to build JAX from source or install a binary package with pip.
### Building JAX from source
First, obtain the JAX source code:
```bash
git clone https://github.com/google/jax
cd jax
```
To build XLA with CUDA support, you can run
```bash
python build/build.py --enable_cuda
pip install -e build # install jaxlib (includes XLA)
pip install -e . # install jax (pure Python)
```
See `python build/build.py --help` for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. The build
also depends on NumPy, and a compiler toolchain corresponding to that of
Ubuntu 16.04 or newer.
To build XLA without CUDA GPU support (CPU only), drop the `--enable_cuda`:
```bash
python build/build.py
pip install -e build # install jaxlib (includes XLA)
pip install -e . # install jax
```
To upgrade to the latest version from GitHub, just run `git pull` from the JAX
repository root, and rebuild by running `build.py` if necessary. You shouldn't have
to reinstall because `pip install -e` sets up symbolic links from site-packages
into the repository.
### pip installation
Installing XLA with prebuilt binaries via `pip` is still experimental,
especially with GPU support. Let us know on [the issue
tracker](https://github.com/google/jax/issues) if you run into any errors.
To install a CPU-only version, which might be useful for doing local
development on a laptop, you can run
```bash
pip install jax jaxlib # CPU-only version
```
If you want to install JAX with both CPU and GPU support, using existing CUDA
and CUDNN7 installations on your machine (for example, preinstalled on your
cloud VM), you can run
```bash
# install jaxlib
PYTHON_VERSION=py2 # alternatives: py2, py3
CUDA_VERSION=cuda92 # alternatives: cuda90, cuda92, cuda100
PLATFORM=linux_x86_64 # alternatives: linux_x86_64
pip install https://storage.googleapis.com/jax-wheels/$CUDA_VERSION/jaxlib-0.1-$PYTHON_VERSION-none-$PLATFORM.whl
pip install jax # install jax
```
The library package name must correspond to the version of the existing CUDA
installation you want to use, with `cuda100` for CUDA 10.0, `cuda92` for CUDA
9.2, and `cuda90` for CUDA 9.0. To find your CUDA and CUDNN versions, you can
run command like these, depending on your CUDNN install path:
```bash
nvcc --version
grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different path
```
## A brief tour
```python
In [1]: import jax.numpy as np
In [2]: from jax import random
In [3]: key = random.PRNGKey(0)
In [4]: x = random.normal(key, (5000, 5000))
In [5]: print(np.dot(x, x.T) / 2) # fast!
[[ 2.52727051e+03 8.15895557e+00 -8.53276134e-01 ..., # ...
In [6]: print(np.dot(x, x.T) / 2) # even faster!
[[ 2.52727051e+03 8.15895557e+00 -8.53276134e-01 ..., # ...
```
What’s happening behind-the-scenes is that JAX is using XLA to just-in-time
(JIT) compile and execute these individual operations on the GPU. First the
`random.normal` call is compiled and the array referred to by `x` is generated
on the GPU. Next, each function called on `x` (namely `transpose`, `dot`, and
`divide`) is individually JIT-compiled and executed, each keeping its results on
the device.
It’s only when a value needs to be printed, plotted, saved, or passed into a raw
NumPy function that a read-only copy of the value is brought back to the host as
an ndarray and cached. The second call to `dot` is faster because the
JIT-compiled code is cached and reused, saving the compilation time.
The fun really starts when you use `grad` for automatic differentiation and
`jit` to compile your own functions end-to-end. Here’s a more complete toy
example:
```python
from jax import grad, jit
import jax.numpy as np
def sigmoid(x):
return 0.5 * (np.tanh(x / 2.) + 1)
# Outputs probability of a label being true according to logistic model.
def logistic_predictions(weights, inputs):
return sigmoid(np.dot(inputs, weights))
# Tr
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
共60个文件
py:52个
txt:4个
pkg-info:2个
资源分类:Python库 所属语言:Python 资源全名:jax-0.1.9.tar.gz 资源来源:官方 安装方法:https://lanzao.blog.csdn.net/article/details/101784059
资源推荐
资源详情
资源评论
收起资源包目录
jax-0.1.9.tar.gz (60个子文件)
jax-0.1.9
PKG-INFO 256B
jax
scipy
special.py 1012B
stats
beta.py 2KB
uniform.py 1KB
__init__.py 737B
laplace.py 1KB
expon.py 1KB
gamma.py 1KB
norm.py 1KB
misc.py 1KB
__init__.py 676B
lax.py 98KB
lax_reference.py 14KB
abstract_arrays.py 5KB
pprint_util.py 2KB
core.py 17KB
experimental
stax.py 11KB
__init__.py 575B
lapax.py 8KB
minmax.py 6KB
numpy
fft.py 1KB
lax_numpy.py 40KB
linalg.py 1KB
__init__.py 679B
tree_util.py 5KB
util.py 4KB
linear_util.py 4KB
ad_util.py 2KB
__init__.py 734B
interpreters
ad.py 13KB
xla.py 14KB
__init__.py 575B
partial_eval.py 12KB
batching.py 9KB
test_util.py 13KB
api_util.py 1KB
random.py 13KB
lib
xla_bridge.py 13KB
xla_data_pb2.py 69KB
__init__.py 575B
xla_client.py 49KB
pywrap_xla.py 26KB
config.py 3KB
api.py 10KB
setup.cfg 38B
examples
onnx2xla.py 5KB
datasets.py 3KB
mnist_vae.py 5KB
kernel_lsq.py 2KB
mnist_classifier.py 3KB
__init__.py 575B
mnist_classifier_fromscratch.py 3KB
resnet50.py 4KB
setup.py 1006B
README.md 27KB
jax.egg-info
PKG-INFO 256B
requires.txt 51B
SOURCES.txt 1KB
top_level.txt 13B
dependency_links.txt 1B
共 60 条
- 1
资源评论
- 拉格良日月2023-08-25终于找到了超赞的宝藏资源,果断冲冲冲,支持!
挣扎的蓝藻
- 粉丝: 12w+
- 资源: 15万+
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功