JAX(機械学習フレームワーク)とは?特徴・使い方・インストール方法を徹底解説

プログラミング・IT

機械学習の世界では、TensorFlowやPyTorchが主流のフレームワークとして広く使われてきましたが、近年、Googleが開発した「JAX」が研究者やエンジニアの間で急速に注目を集めています。
JAXは、NumPy互換のAPIを持ちながら、GPU/TPUでの高速実行や自動微分機能を提供する、次世代の数値計算ライブラリです。
Google DeepMindやGoogle ResearchをはじめとするAI研究の最前線で活用され、Vision TransformerやPathways Language Model (PaLM)などの最先端モデルの開発にも使用されています。
この記事では、JAXの基本概念から、TensorFlow/PyTorchとの違い、実際のインストール方法、使い方まで、中学3年生でも理解できるレベルで詳しく解説します。

スポンサーリンク
  1. JAXとは何か
    1. JAXの基本概要
    2. JAXの位置づけ
  2. JAXの主な特徴
    1. 1. NumPy互換のAPI
    2. 2. 自動微分(Automatic Differentiation)
    3. 3. JITコンパイル(Just-In-Time Compilation)
    4. 4. ベクトル化(Vectorization)
    5. 5. 並列化(Parallelization)
    6. 6. 関数型プログラミングの思想
    7. 7. XLA(Accelerated Linear Algebra)コンパイラ
    8. 8. 非同期ディスパッチ(Asynchronous Dispatch)
  3. TensorFlow/PyTorchとの違い
    1. 設計思想の違い
    2. 機能の比較
    3. 使い分けの指針
  4. JAXのエコシステム
    1. Flax – ニューラルネットワーク構築
    2. Optax – 最適化ライブラリ
    3. Haiku – DeepMind製ニューラルネットワークライブラリ
    4. RLax – 強化学習コンポーネント
    5. その他の関連ライブラリ
  5. JAXのインストール方法
    1. システム要件
    2. パッケージの構成
    3. CPU版のインストール
    4. GPU版のインストール(NVIDIA)
    5. TPU版のインストール(Google Cloud)
    6. AMD GPU版のインストール
    7. インストールの確認
    8. 仮想環境の使用(推奨)
    9. トラブルシューティング
  6. JAXの基本的な使い方
    1. NumPy互換APIの使用
    2. 自動微分の使用
    3. JITコンパイルの使用
    4. ベクトル化の使用
    5. 並列化の使用(複数GPU)
  7. JAXの制約と注意点
    1. 純粋関数の要求
    2. 制御フローの制約
    3. イミュータブル(不変)配列
    4. 乱数生成の違い
  8. 実践例: 簡単な線形回帰
  9. パフォーマンスベンチマーク
    1. CPU環境での比較
    2. GPU環境での比較
  10. よくある質問(FAQ)
    1. Q1: JAXは初心者でも使えますか?
    2. Q2: JAXは本番環境で使えますか?
    3. Q3: Windows で JAX は使えますか?
    4. Q4: JAXでプレトレーニング済みモデルは使えますか?
    5. Q5: JAXとNumPyのコードは100%互換ですか?
    6. Q6: PyTorchとJAXはどちらが速いですか?
    7. Q7: JAXでディープラーニング以外の用途は?
    8. Q8: JAXの学習リソースはどこにありますか?
    9. Q9: JAXでモデルを保存・読み込みするには?
    10. Q10: JAXは今後主流になりますか?
  11. まとめ
  12. 参考情報

JAXとは何か

JAXの基本概要

JAX(ジャックス)は、Googleが開発した高性能数値計算と大規模機械学習のために設計されたPythonライブラリです。
2018年に初めて学術論文で紹介され、現在ではGoogle ResearchとGoogle DeepMindの研究チームが中心となって開発を進めています。
NvidiaをはじめとするコミュニティからGitHubに送信された投稿内容・コミットと統合して開発されています。

JAXは単なる機械学習フレームワークではなく、「NumPyの強化版」とも言える数値計算ライブラリであり、その上に機械学習モデルを構築するための基盤となるツールです。

名称の由来:
JAXという名前の正式な意味は公開されていませんが、一部では「Just After eXecution」の略とも言われています。
ただし、これは非公式な解釈であり、Google公式からの明確な説明はありません。

開発目的:
既存の機械学習フレームワーク(TensorFlow/PyTorch)では、定義されたモデルを高速に実行することはできますが、研究開発で必要となる独自の数値計算処理をGPU/TPUで実行するのは困難でした。
JAXは、一般的なPythonコードで書かれた数値計算関数を、そのままの形でXLAコンパイラによってGPU/TPUに最適化されたバイナリーコードにコンパイルできるように設計されています。

JAXの位置づけ

JAXは、2025年2月時点において、Googleのあくまでもリサーチプロジェクトであり、TensorFlowのような公式製品ではありません。
しかし、Google内部では「最も重要な機械学習フレームワークの一つ」として位置づけられ、多くの最先端研究で使用されています。

主な使用例:

  • 言語理解: Pathways Language Model (PaLM)
  • コンピュータビジョン: Vision Transformer (ViT)
  • 物理シミュレーション: Brax、JAX MD
  • 交通予測: Google Mapsの交通予測システム
  • 科学計算: NeurIPS 2024最優秀論文賞を受賞した高次偏微分方程式の効率的解法(STDE)

JAXの主な特徴

JAXには、他の機械学習フレームワークとは異なる独自の特徴があります。

1. NumPy互換のAPI

JAXの最大の特徴は、NumPyとほぼ同じ感覚で使えることです。
NumPyは、Pythonで科学計算を行う際に最も広く使われているライブラリであり、多くのデータサイエンティストや研究者に親しまれています。

使用例:

# NumPyの場合
import numpy as np
x = np.array([1, 2, 3], dtype=np.float32)
y = np.sum(x)

# JAXの場合(ほぼ同じ)
import jax.numpy as jnp
x = jnp.array([1, 2, 3], dtype=jnp.float32)
y = jnp.sum(x)

この互換性により、既存のNumPyコードを簡単にJAXへ移行でき、GPU/TPUでの高速化の恩恵を受けることができます。

2. 自動微分(Automatic Differentiation)

ディープラーニングでは、モデルの学習に「勾配」の計算が不可欠です。
JAXは、Pythonで記述した任意の関数に対して、自動的に導関数(勾配)を計算する機能を持っています。

主な関数:

  • jax.grad: 関数の勾配を計算
  • jax.hessian: ヘッセ行列(二階微分)を計算
  • jax.jacfwd: ヤコビ行列(前方モード)を計算
  • jax.jacrev: ヤコビ行列(後方モード)を計算

TensorFlow/PyTorchとの対応:

  • TensorFlow: GradientTape
  • PyTorch: autograd
  • JAX: grad

JAXの優位性:
JAXは、ループ、分岐、再帰、クロージャを含む複雑な関数でも自動微分が可能です。
また、導関数の導関数の導関数…と、任意の階数の高次導関数を計算できます。

使用例:

from jax import grad
import jax.numpy as jnp

# 二次関数を定義
def quadratic(x):
    return x ** 2 + 2 * x + 1

# 導関数を自動生成
grad_quadratic = grad(quadratic)

# x=3.0での勾配を計算
result = grad_quadratic(3.0)
# 結果: 8.0 (2*3 + 2 = 8)

3. JITコンパイル(Just-In-Time Compilation)

JITコンパイルとは、プログラムの実行時にコードを機械語へコンパイルする技術です。
JAXでは、@jitデコレーターを関数に付けるだけで、XLA(Accelerated Linear Algebra)コンパイラによって最適化されたコードが生成され、実行速度が劇的に向上します。

使用例:

from jax import jit
import jax.numpy as jnp

# JITコンパイルあり
@jit
def fast_function(x):
    return jnp.sum(x ** 2)

# JITコンパイルなし
def slow_function(x):
    return jnp.sum(x ** 2)

速度向上の実例:

  • 実際の研究例では、P100 GPUで500エポックのモデル訓練が、JITコンパイルにより6.25分から1.8分へ短縮されました(約3.5倍の高速化)。
  • GPU環境でのベンチマークテストでは、NumPyと比較して最大658倍の速度を記録した事例もあります。

4. ベクトル化(Vectorization)

vmap関数を使用すると、配列の各要素に対する処理を自動的にベクトル化できます。
これにより、ループを明示的に書かなくても、効率的な並列処理が実現できます。

使用例:

from jax import vmap
import jax.numpy as jnp

# 単一の値に対する関数
def single_value_func(x):
    return x * 2

# vmapで自動ベクトル化
vectorized_func = vmap(single_value_func)

# 配列全体に適用
x = jnp.array([1, 2, 3, 4, 5])
result = vectorized_func(x)
# 結果: [2, 4, 6, 8, 10]

通常、このような処理をループで書くと非効率ですが、vmapはループを関数の基本演算レベルまで押し下げ、行列-ベクトル積を行列-行列積に変換するなどの最適化を行います。

5. 並列化(Parallelization)

pmap関数を使用すると、複数のGPUやTPUに処理を自動的に分散させることができます。

使用例:

from jax import random, pmap
import jax.numpy as jnp

# 8つのGPUそれぞれに5000×6000の乱数行列を作成
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# 各デバイスで並列に行列積を計算
result = pmap(lambda x: jnp.dot(x, x.T))(mats)
# result.shape は (8, 5000, 5000)

この機能により、大規模な計算を簡単にスケールアウトできます。

6. 関数型プログラミングの思想

JAXは、関数型プログラミングの思想に強く影響を受けています。
これは、PyTorchやTensorFlowで使われているオブジェクト指向プログラミングとは異なるアプローチです。

関数型プログラミングの特徴:

  • 純粋関数(Pure Function): 同じ入力に対して常に同じ出力を返す関数
  • 副作用なし(No Side Effects): 関数の外部状態を変更しない
  • イミュータブル(Immutable): データを変更せず、新しいデータを作成する

具体例:

# 手続き型(従来の方法)
x = [1, 2, 3]
x[0] = 10  # xを直接変更

# 関数型(JAXの方法)
x = jnp.array([1, 2, 3])
y = x.at[0].set(10)  # 新しい配列yを作成、xは変更されない

この特性により、JAXはJITコンパイルや並列化が容易になり、バグの少ない安全なコードを書くことができます。

7. XLA(Accelerated Linear Algebra)コンパイラ

XLAは、Googleが開発した線形代数演算に特化したコンパイラです。
JAXは、PythonコードをXLAを通じてGPU/TPU向けに最適化されたバイナリコードへコンパイルします。

コンパイルの流れ:

  1. Python/NumPyコード
  2. JAXによるトレース(関数の構造を解析)
  3. XLAへコンパイル
  4. GPU/TPU向け最適化バイナリコード生成
  5. LLVM経由で実行(多くのCPU/GPU)

この仕組みにより、同じコードがCPU、GPU、TPUのいずれでも、それぞれに最適化された形で実行されます。

8. 非同期ディスパッチ(Asynchronous Dispatch)

JAXは非同期ディスパッチを採用しています。
これは、演算の完了を待たずに、即座にPythonプログラムへ制御を返す仕組みです。

仕組み:

  • JAXは演算結果としてDeviceArrayを返します
  • これは「将来の値(Future)」であり、実際の計算はGPU/TPU上で非同期に実行されます
  • Pythonコードは計算完了を待たずに次の処理を進められます
  • GPU/TPUは待ち時間なく連続して演算を実行できます

この仕組みにより、ハードウェアアクセラレータを効率的に活用できます。

TensorFlow/PyTorchとの違い

JAXと既存の主要フレームワークとの違いを理解することは、適切なツール選択に重要です。

設計思想の違い

TensorFlow/Keras:

  • オールインワンのフレームワーク
  • 機械学習の「定型作業」を簡単に実行できる
  • 高レベルAPI(Keras)で初心者に優しい
  • 応用的な作業にはTensorFlow固有の複雑なコードが必要
  • 本番環境への統合が容易
  • 強力なエコシステム

PyTorch:

  • オブジェクト指向プログラミング
  • define-by-run(動的計算グラフ)
  • NumPy-esqueなAPI
  • 研究と本番の両方に適している
  • 広範なライブラリとユーティリティ
  • プレトレーニング済みモデルが豊富

JAX:

  • 関数型プログラミング
  • NumPy互換の低レベルAPI
  • 機械学習に特化していない(数値計算ライブラリ)
  • 研究開発に適している
  • 定型作業にも一定のコーディングが必要
  • 応用的な作業を通常のPythonプログラミングの感覚で実行可能

機能の比較

機能TensorFlowPyTorchJAX
NumPy互換部分的高い非常に高い
自動微分○(高次導関数も容易)
GPU/TPU対応○(TPU限定的)
JITコンパイル○(XLA)○(TorchScript)○(XLA)
並列化○(pmap)
ベクトル化○(vmap)
関数型××
エコシステム成熟成熟発展途上
学習曲線低-中中-高

使い分けの指針

TensorFlow/Kerasを選ぶべき場合:

  • 機械学習の初心者
  • 定型的なディープラーニングタスク
  • 本番環境への迅速なデプロイが必要
  • モバイルやエッジデバイスでの実行が必要
  • 強力なエコシステムが必要

PyTorchを選ぶべき場合:

  • 一般的な機械学習研究
  • 柔軟なモデル構築が必要
  • 研究と本番の両方に対応したい
  • 豊富なプレトレーニング済みモデルを使いたい
  • オブジェクト指向プログラミングに慣れている

JAXを選ぶべき場合:

  • 最先端の機械学習研究
  • 独自の数値計算アルゴリズムの開発
  • GPU/TPUでの高速な数値計算が必要
  • 関数型プログラミングの利点を活かしたい
  • 高次導関数の計算が必要
  • 大規模な並列計算が必要
  • NumPyコードの高速化が必要

JAXのエコシステム

JAX自体は低レベルの数値計算ライブラリですが、その上に構築された高レベルライブラリのエコシステムが存在します。

Flax – ニューラルネットワーク構築

Flaxは、JAXベースで最も人気のあるニューラルネットワーク構築ライブラリです。
Google ResearchのBrainチームが開発を開始し、現在はオープンソースコミュニティと共同開発されています。

Flax NNX API(新しいAPI):

  • より標準的なPythonに近い
  • 状態管理の複雑さが隠されている
  • 初心者におすすめ

Flax Linen API(従来のAPI):

  • 純粋関数型プログラミングに忠実
  • MaxText、MaxDiffusionなどの強力なフレームワークで使用
  • より高度な制御が可能

使用例(Flax NNX):

import jax
import jax.numpy as jnp
from flax import nnx

# モデル定義(PyTorchに似ている)
class MLP(nnx.Module):
    def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.linear1(x)
        x = nnx.relu(x)
        x = self.linear2(x)
        return x

# モデルのインスタンス化
model = MLP(10, 20, 5, rngs=nnx.Rngs(0))

Optax – 最適化ライブラリ

Optaxは、JAX用の最適化アルゴリズムライブラリです。
SGD、Adam、AdamWなど、主要なオプティマイザーが実装されています。

使用例:

import optax

# Adamオプティマイザーを作成
optimizer = optax.adam(learning_rate=0.001)

# 最適化状態を初期化
opt_state = optimizer.init(params)

# 勾配を使ってパラメータを更新
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

Haiku – DeepMind製ニューラルネットワークライブラリ

Haikuは、DeepMindが開発したJAXベースのニューラルネットワークライブラリです。
Optaxと組み合わせて使用することで、強力な機械学習システムを構築できます。

RLax – 強化学習コンポーネント

RLaxは、強化学習(Reinforcement Learning)のアルゴリズムを構築するためのビルディングブロックを提供します。

主な機能:

  • TD学習
  • 方策勾配法
  • Actor-Critic
  • 近接方策最適化(PPO)
  • 非線形価値変換
  • 探索手法

その他の関連ライブラリ

Acme:
完全な強化学習エージェントフレームワーク(RLaxを基盤として構築)

Chex:
JAX対応のテストユーティリティ(単体テスト、アサーション、モック)

Jraph:
グラフニューラルネットワーク(GNN)ライブラリ

Elegy、Stax:
その他のニューラルネットワーク構築ライブラリ

JAXのインストール方法

JAXのインストールは、使用するハードウェア(CPU、GPU、TPU)によって方法が異なります。

システム要件

必須環境:

  • Python 3.7以降(推奨: Python 3.8以降)
  • pip(Pythonパッケージインストーラー)

対応プラットフォーム:

  • Linux (x86_64, aarch64) – 完全サポート
  • macOS (Apple ARM) – 完全サポート
  • Windows (x86_64) – 実験的サポート(WSL推奨)

パッケージの構成

JAXのインストールには、2つのパッケージが必要です。

1. jax:

  • Pure Pythonパッケージ
  • クロスプラットフォーム
  • 主要なAPIを提供

2. jaxlib:

  • コンパイル済みバイナリを含む
  • OS/アクセラレータごとに異なるビルドが必要
  • 実際の計算処理を担当

CPU版のインストール

ノートパソコンなどでローカル開発を行う場合は、CPU版をインストールします。

Linux/macOS/Windows:

pip install --upgrade pip
pip install --upgrade jax

Windows追加要件:

  • Microsoft Visual Studio 2019 Redistributableが必要な場合があります

GPU版のインストール(NVIDIA)

GPU版のインストールには、事前にCUDAとcuDNNのセットアップが必要です。

システム要件:

  • NVIDIA GPU: CUDA 12ならSM version 5.2 (Maxwell)以降、CUDA 13ならSM version 7.5以降
  • NVIDIAドライバ: CUDA 12なら>= 525、CUDA 13なら>= 580 (Linux)
  • libdevice10.bc(cuda-nvvmパッケージに含まれる)

CUDA 13対応:

pip install --upgrade pip
pip install --upgrade "jax[cuda13]"

CUDA 12対応:

pip install --upgrade pip
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

注意点:

  • Kepler世代のGPUは、NVIDIAがサポートを終了したため、JAXでもサポートされていません。

TPU版のインストール(Google Cloud)

Google Cloud TPU VMでJAXを使用する場合の手順です。

pip install --upgrade pip
pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Google Colabでの注意:

  • TPU v2を使用してください(古いTPUランタイムは非推奨)

AMD GPU版のインストール

AMD GPUのサポートは、ROCm JAX pluginを通じて提供されます。
詳細は公式ドキュメントを参照してください。

インストールの確認

インストールが正しく完了したか確認する方法です。

import jax
import jax.numpy as jnp

# 利用可能なデバイスを表示
print("JAX detected the following devices:")
for i, device in enumerate(jax.devices()):
    print(f"{i}: {device.platform.upper()} ({device.device_kind})")

# 簡単な計算を実行
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (10,))
y = jnp.dot(x, x)
print(f"\nSuccessfully executed a simple JAX operation. Result: {y}")

# デフォルトのバックエンドを確認
print(f"\nDefault device: {jax.default_backend()}")

期待される出力(GPU環境):

JAX detected the following devices:
0: GPU (NVIDIA GeForce RTX 3090)
1: CPU (cpu)

Successfully executed a simple JAX operation. Result: 10.811179161071777

Default device: gpu

GPUが認識されているか確認:

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
# "gpu"と表示されればGPUが使用されている

仮想環境の使用(推奨)

プロジェクトごとに独立した環境を作成することをお勧めします。

venvを使用:

# 仮想環境を作成
python -m venv jax_env

# 仮想環境を有効化
# Linux/macOS:
source jax_env/bin/activate
# Windows (Command Prompt):
jax_env\Scripts\activate.bat
# Windows (PowerShell):
jax_env\Scripts\Activate.ps1

# JAXをインストール
pip install --upgrade pip
pip install --upgrade jax

condaを使用:

# 環境を作成
conda create -n jax_env python=3.10
conda activate jax_env

# JAXをインストール
pip install jax

トラブルシューティング

問題: GPUが認識されない

対処法:

  1. NVIDIAドライバが正しくインストールされているか確認: nvidia-smi
  2. CUDAバージョンとjaxlibのバージョンが一致しているか確認
  3. jax[cuda13]またはjax[cuda12_pip]が正しくインストールされているか確認

問題: “libdevice10.bc not found”エラー

対処法:

  • cuda-nvvmパッケージがCUDAインストールに含まれているか確認
  • 環境変数でCUDAのパスが正しく設定されているか確認

問題: Windows で動作しない

対処法:

  • WSL(Windows Subsystem for Linux)の使用を推奨
  • ネイティブWindowsサポートは実験的段階のため、不具合が発生する可能性があります

JAXの基本的な使い方

ここでは、JAXの基本的な機能を実際のコード例とともに紹介します。

NumPy互換APIの使用

JAXのjax.numpyモジュールは、NumPyとほぼ同じAPIを提供します。

import jax.numpy as jnp
from jax import random

# 乱数キーの生成(JAXの特徴)
key = random.PRNGKey(0)

# 配列の作成
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
y = jnp.linspace(0, 10, 100)

# 数学演算
z = jnp.sin(x)
w = jnp.exp(y)

# 行列演算
A = random.normal(key, (100, 100))
b = random.normal(key, (100,))
x_result = jnp.linalg.solve(A, b)

# 統計関数
mean = jnp.mean(x)
std = jnp.std(x)

自動微分の使用

JAXの自動微分機能を使った例です。

from jax import grad
import jax.numpy as jnp

# 関数の定義
def loss_function(params, x, y):
    predictions = params[0] * x + params[1]
    return jnp.mean((predictions - y) ** 2)

# 勾配関数の自動生成
grad_loss = grad(loss_function)

# データの準備
x_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_data = jnp.array([2.1, 4.0, 5.9, 8.1, 10.0])

# 初期パラメータ
params = jnp.array([1.0, 0.0])

# 勾配の計算
gradients = grad_loss(params, x_data, y_data)
print(f"Gradients: {gradients}")

JITコンパイルの使用

関数を高速化する例です。

from jax import jit
import jax.numpy as jnp
import time

# JITコンパイルなし
def slow_function(x):
    for i in range(1000):
        x = jnp.sin(x) + jnp.cos(x)
    return x

# JITコンパイルあり
@jit
def fast_function(x):
    for i in range(1000):
        x = jnp.sin(x) + jnp.cos(x)
    return x

# ベンチマーク
x = jnp.ones((1000, 1000))

# 最初の実行(コンパイル時間を含む)
_ = fast_function(x).block_until_ready()

# 速度比較
start = time.time()
result_slow = slow_function(x).block_until_ready()
time_slow = time.time() - start

start = time.time()
result_fast = fast_function(x).block_until_ready()
time_fast = time.time() - start

print(f"Slow function: {time_slow:.4f}秒")
print(f"Fast function: {time_fast:.4f}秒")
print(f"Speed up: {time_slow/time_fast:.2f}倍")

ベクトル化の使用

配列演算を効率化する例です。

from jax import vmap
import jax.numpy as jnp

# 単一サンプルに対する関数
def compute_loss(params, x, y):
    prediction = jnp.dot(params, x)
    return (prediction - y) ** 2

# バッチ全体に対してベクトル化
batch_loss = vmap(compute_loss, in_axes=(None, 0, 0))

# データの準備
params = jnp.array([1.0, 2.0, 3.0])
X = jnp.array([[1.0, 2.0, 3.0],
               [4.0, 5.0, 6.0],
               [7.0, 8.0, 9.0]])
y = jnp.array([14.0, 32.0, 50.0])

# バッチ全体の損失を一度に計算
losses = batch_loss(params, X, y)
print(f"Individual losses: {losses}")
print(f"Mean loss: {jnp.mean(losses)}")

並列化の使用(複数GPU)

複数のGPUで並列処理を行う例です。

from jax import pmap, random
import jax.numpy as jnp

# 並列実行する関数
def parallel_computation(x):
    return jnp.sum(x ** 2)

# 複数デバイスに分散
parallel_func = pmap(parallel_computation)

# デバイス数に応じてデータを分割
n_devices = jax.local_device_count()
data = random.normal(random.PRNGKey(0), (n_devices, 1000, 1000))

# 並列実行
results = parallel_func(data)
print(f"Results from {n_devices} devices: {results}")

JAXの制約と注意点

関数型プログラミングに基づくJAXには、いくつかの制約があります。

純粋関数の要求

JAXのJITコンパイルや変換を使用するには、関数が「純粋」である必要があります。

禁止事項:

  • グローバル変数の読み書き
  • print()などの副作用のある操作
  • ファイル入出力
  • ランダムな状態の暗黙的な使用

悪い例:

global_counter = 0

@jit
def bad_function(x):
    global global_counter
    global_counter += 1  # グローバル変数の変更(副作用)
    print(f"Counter: {global_counter}")  # print(副作用)
    return x * 2

良い例:

@jit
def good_function(x, counter):
    # 副作用なし、新しい値を返すのみ
    new_counter = counter + 1
    return x * 2, new_counter

制御フローの制約

通常のPythonのif文やfor文は、JITコンパイルされた関数内では使用できません。

代替関数:

if文の代替:

  • jax.lax.cond: 条件分岐
from jax import lax

def conditional_func(x):
    return lax.cond(
        x == 0,
        lambda: 10,  # Trueの場合
        lambda: 20   # Falseの場合
    )

switch文の代替:

  • jax.lax.switch: 多分岐
def switch_func(index):
    return lax.switch(
        index,
        (lambda: 10, lambda: 20, lambda: 30)
    )

for文の代替:

  • jax.lax.fori_loop: 固定回数のループ
  • jax.lax.scan: 畳み込み演算
  • jax.vmap: マッピング
from jax import lax

def loop_func(init_val, num_iterations):
    def body_fun(i, val):
        return val + i

    return lax.fori_loop(0, num_iterations, body_fun, init_val)

while文の代替:

  • jax.lax.while_loop: 条件付きループ
def while_func(init_val):
    def cond_fun(val):
        return val < 100

    def body_fun(val):
        return val * 2

    return lax.while_loop(cond_fun, body_fun, init_val)

イミュータブル(不変)配列

JAXの配列は変更できません。
配列を「変更」したい場合は、新しい配列を作成する必要があります。

悪い例(NumPyスタイル):

import numpy as np
x = np.array([1, 2, 3, 4, 5])
x[0] = 10  # NumPyでは可能

良い例(JAXスタイル):

import jax.numpy as jnp

x = jnp.array([1, 2, 3, 4, 5])
y = x.at[0].set(10)  # 新しい配列を作成
# x は変更されない
# y は [10, 2, 3, 4, 5]

最適化:
もし元の配列xを以降使用しない場合、JAXは最適化により破壊的更新を行い、メモリ効率を保ちます。

乱数生成の違い

JAXでは、再現性と並列化のため、乱数生成に明示的なキーを使用します。

NumPyの場合(暗黙的なグローバル状態):

import numpy as np

x = np.random.normal(size=(10,))
y = np.random.normal(size=(10,))
# 異なる乱数が生成される

JAXの場合(明示的なキー):

from jax import random

key = random.PRNGKey(0)  # シードキーの生成

# キーを分割して使用
key1, key2 = random.split(key)
x = random.normal(key1, shape=(10,))
y = random.normal(key2, shape=(10,))

この方式により、並列処理での再現性が保証され、デバッグが容易になります。

実践例: 簡単な線形回帰

JAXを使った実践的な例として、線形回帰を実装してみます。

import jax
import jax.numpy as jnp
from jax import grad, jit

# データの生成
key = jax.random.PRNGKey(42)
x_data = jnp.linspace(0, 10, 100)
y_data = 2.5 * x_data + 1.0 + jax.random.normal(key, (100,)) * 0.5

# 損失関数
def loss_fn(params, x, y):
    w, b = params
    predictions = w * x + b
    return jnp.mean((predictions - y) ** 2)

# 勾配関数
grad_fn = jit(grad(loss_fn))

# 初期化
params = jnp.array([0.0, 0.0])
learning_rate = 0.01

# 訓練ループ
for i in range(1000):
    grads = grad_fn(params, x_data, y_data)
    params = params - learning_rate * grads

    if i % 100 == 0:
        loss = loss_fn(params, x_data, y_data)
        print(f"Step {i}, Loss: {loss:.4f}, w: {params[0]:.4f}, b: {params[1]:.4f}")

print(f"\nFinal parameters: w={params[0]:.4f}, b={params[1]:.4f}")
print("True parameters: w=2.5000, b=1.0000")

期待される出力:

Step 0, Loss: 43.7500, w: 0.0000, b: 0.0000
Step 100, Loss: 0.2156, w: 2.4823, b: 1.0234
Step 200, Loss: 0.2155, w: 2.4825, b: 1.0232
...
Step 900, Loss: 0.2155, w: 2.4825, b: 1.0232

Final parameters: w=2.4825, b=1.0232
True parameters: w=2.5000, b=1.0000

パフォーマンスベンチマーク

JAXの実際の性能を、NumPyと比較してみます。

CPU環境での比較

import numpy as np
import jax.numpy as jnp
from jax import jit
import time

# データサイズ
size = 10000

# NumPy版
np_array = np.random.randn(size, size)
start = time.time()
np_result = np.dot(np_array, np_array.T)
np_time = time.time() - start

# JAX版(JITなし)
jax_array = jnp.array(np_array)
start = time.time()
jax_result = jnp.dot(jax_array, jax_array.T).block_until_ready()
jax_time = time.time() - start

# JAX版(JITあり)
@jit
def matrix_multiply(x):
    return jnp.dot(x, x.T)

# ウォームアップ
_ = matrix_multiply(jax_array).block_until_ready()

start = time.time()
jax_jit_result = matrix_multiply(jax_array).block_until_ready()
jax_jit_time = time.time() - start

print(f"NumPy: {np_time:.4f}秒")
print(f"JAX (JITなし): {jax_time:.4f}秒")
print(f"JAX (JITあり): {jax_jit_time:.4f}秒")
print(f"JAX速度向上率: {np_time/jax_jit_time:.2f}倍")

GPU環境での比較

GPU環境では、さらに劇的な速度向上が期待できます。
実際の研究報告では、以下のような結果が得られています。

ベンチマーク結果(実測値):

  • サイズ10,000の整数配列での行列積: NumPyの658倍
  • サイズ10,000のfloat32配列での行列積: NumPyの183倍
  • 500エポックの訓練(P100 GPU): 6.25分→1.8分(3.5倍)

よくある質問(FAQ)

Q1: JAXは初心者でも使えますか?

A: JAXは、NumPyの基本的な使い方を理解していれば、始めることは可能です。
ただし、関数型プログラミングの概念や、純粋関数、イミュータブルなデータ構造などの理解が必要になるため、機械学習の完全な初心者には少しハードルが高いかもしれません。
TensorFlow/KerasやPyTorchで基本を学んでから、JAXに移行することをお勧めします。

Q2: JAXは本番環境で使えますか?

A: JAXは現時点では主に研究用途に適しています。
本番環境への統合、モデルのデプロイ、モバイル対応などのエコシステムは、TensorFlowやPyTorchと比較するとまだ成熟していません。
研究開発やプロトタイピングには非常に優れていますが、本番環境での大規模運用にはTensorFlow/PyTorchの方が適している場合が多いです。

Q3: Windows で JAX は使えますか?

A: Windowsのネイティブサポートは実験的段階です。
現時点では、WSL(Windows Subsystem for Linux)を使用することを強く推奨します。
WSL2上でLinux版のJAXをインストールすれば、GPU含めて問題なく動作します。

Q4: JAXでプレトレーニング済みモデルは使えますか?

A: はい、HuggingFaceはJAX/Flax対応のプレトレーニング済みモデルを多数提供しています。
自然言語処理、画像処理、音声処理など、様々な分野のモデルが利用可能です。
ただし、PyTorchやTensorFlowと比較すると、モデルの選択肢は限られています。

Q5: JAXとNumPyのコードは100%互換ですか?

A: ほぼ互換性がありますが、完全に同じではありません。
主な違いは:

  • 配列がイミュータブル(変更不可)
  • 乱数生成に明示的なキーが必要
  • 一部のNumPy機能がサポートされていない
    ほとんどのNumPyコードは、小さな修正でJAXに移行できます。

Q6: PyTorchとJAXはどちらが速いですか?

A: タスクや実装によって異なりますが、一般的にJAXの方が若干高速な傾向があります。
特にJITコンパイルを効果的に使用した場合、JAXは非常に高いパフォーマンスを発揮します。
ただし、実際の差は使用ケースによって大きく変わるため、両方を試してみることをお勧めします。

Q7: JAXでディープラーニング以外の用途は?

A: はい、JAXは数値計算全般に使用できます。
実際、物理シミュレーション、偏微分方程式の求解、最適化問題、統計分析、科学計算など、幅広い分野で活用されています。
NeurIPS 2024で最優秀論文賞を受賞した研究も、JAXを使った科学計算でした。

Q8: JAXの学習リソースはどこにありますか?

A: 以下のリソースが役立ちます:

  • 公式ドキュメント: https://docs.jax.dev/
  • 公式チュートリアル(Google Colab)
  • 公式GitHubリポジトリ: https://github.com/jax-ml/jax
  • HuggingFaceのFlaxチュートリアル
  • Google Cloud公式ブログのチュートリアル
    日本語では、Qiitaやzenn.devに有志による解説記事があります。

Q9: JAXでモデルを保存・読み込みするには?

A: JAX自体にはモデル保存機能がありませんが、Flaxなどの高レベルライブラリがこの機能を提供しています。
また、picklenumpy.saveを使ってパラメータを保存し、後で読み込むこともできます。
TensorFlowやPyTorchほど統一されたシステムではありませんが、柔軟に対応できます。

Q10: JAXは今後主流になりますか?

A: 現時点では予測困難ですが、Google内部での採用が拡大していること、最先端研究での活用が増えていることから、少なくとも研究分野では重要な選択肢として定着していくと考えられます。
一方、本番環境での大規模運用については、TensorFlow/PyTorchの優位性が続く可能性が高いです。

まとめ

JAXは、NumPy互換のAPIを持ちながら、GPU/TPUでの高速実行、自動微分、JITコンパイル、関数型プログラミングの利点を提供する、次世代の数値計算ライブラリです。

JAXの主な特徴:

  • NumPyとほぼ同じAPI
  • 強力な自動微分(高次導関数も容易)
  • XLAによるJITコンパイル
  • vmap/pmapによるベクトル化・並列化
  • 関数型プログラミングの思想
  • GPU/TPU対応

適している用途:

  • 最先端の機械学習研究
  • 独自の数値計算アルゴリズム開発
  • 高次導関数が必要な科学計算
  • 大規模並列計算
  • NumPyコードの高速化

現時点での制約:

  • エコシステムが発展途上
  • 本番環境への統合が未成熟
  • Windows対応が実験的
  • 関数型プログラミングの学習曲線

今後の展望:
Googleの最先端研究で積極的に採用され、NeurIPS 2024最優秀論文賞などの成果も生まれています。
研究分野ではますます重要な選択肢となる一方、本番環境での運用については、TensorFlow/PyTorchとの使い分けが続くと考えられます。

機械学習研究や科学計算に取り組むエンジニアや研究者にとって、JAXは非常に強力なツールです。
NumPyの知識があれば比較的スムーズに始められるため、ぜひ一度試してみることをお勧めします。

参考情報

この記事は、以下の公式情報およびソースに基づいて作成されています。

公式ドキュメント:

  1. JAX: High performance array computing — JAX documentation
  2. GitHub – jax-ml/jax: Composable transformations of Python+NumPy programs
  3. Installation — JAX documentation
  4. Getting started with JAX for ML — JAX AI Stack

Google公式ブログ:

  1. PyTorch デベロッパー向け JAX 基礎ガイド – Google Cloud 公式ブログ
  2. A guide to JAX for PyTorch developers – Google Cloud Blog
  3. バックプロパゲーションを超えて: JAX の記号的処理能力が科学コンピューティングの新しいフロンティアを切り開く – Google Developers Blog
  4. Using JAX to accelerate our research – Google DeepMind

その他:

  1. JAXによるスケーラブルな機械学習 – ZOZO TECH BLOG
  2. JAX (ライブラリ) – Wikipedia
  3. JAX (software) – Wikipedia

コメント

タイトルとURLをコピーしました