變換

變換#

eval(*args)

array 或由 array 組成的樹狀結構進行求值。

async_eval(*args)

非同步地對 array 或由 array 組成的樹狀結構求值。

compile(fun[, inputs, outputs, shapeless])

回傳一個編譯後的函式,其輸出與 fun 相同。

checkpoint(fun)

Transform the passed callable to one that performs gradient checkpointing with respect to the inputs of the callable.

custom_function(*args, **kwargs)

建立可自訂梯度與 vmap 定義的函式。

disable_compile()

全域停用編譯。

enable_compile()

全域啟用編譯。

grad(fun[, argnums, argnames])

回傳一個計算 fun 梯度的函式。

value_and_grad(fun[, argnums, argnames])

回傳一個計算 fun 值與梯度的函式。

jvp(fun, primals, tangents)

計算雅可比-向量乘積。

vjp(fun, primals, cotangents)

計算向量-雅可比乘積。

vmap(fun[, in_axes, out_axes])

回傳 fun 的向量化版本。