mlx.core.conv3d

目錄

mlx.core.conv3d#

conv3d(input: array, weight: array, /, stride: int | tuple[int, int, int] = 1, padding: int | tuple[int, int, int] = 0, dilation: int | tuple[int, int, int] = 1, groups: int = 1, *, stream: None | Stream | Device = None) array#

3D convolution over an input with several channels

Note: Only the default groups=1 is currently supported.

參數:
  • input (array) -- Input array of shape (N, D, H, W, C_in).

  • weight (array) -- Weight array of shape (C_out, KD, KH, KW, C_in).

  • stride (int or tuple(int), optional) -- tuple of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: 1.

  • padding (int or tuple(int), optional) -- tuple of size 3 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: 0.

  • dilation (int or tuple(int), optional) -- tuple of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: 1

  • groups (int, optional) -- input feature groups. Default: 1.

回傳:

The convolved array.

回傳型別:

array