在 Python 中使用帶有 trait bounds 的 Rust 函式
PyO3 讓某些函式與類別可輕鬆在 Rust 與 Python 間轉換(見轉換表)。不過,當 Rust 程式碼需要以特定 trait 實作作為參數時,轉換就未必那麼直接。
本教學說明如何將以 trait 作為參數的 Rust 函式轉換成可在 Python 中使用,並讓具備相同方法的 Python 類別能夠配合使用。
為什麼這很有用?
優點
- 讓你的 Rust 程式碼可供 Python 使用者使用
- 在 Rust 中編寫複雜演算法並受惠於借用檢查器
缺點
- 不如原生 Rust 快(需要進行型別轉換,且部分程式碼在 Python 端執行)
- 你需要調整程式碼以便對外暴露
範例
讓我們以一個基本範例開始:一個在給定模型上運作的最佳化求解器實作。
假設我們有一個 solve 函式,會在模型上運作並改變其狀態。該函式的參數可以是任何實作 Model trait 的模型:
#![allow(dead_code)]
pub trait Model {
fn set_variables(&mut self, inputs: &Vec<f64>);
fn compute(&mut self);
fn get_results(&self) -> Vec<f64>;
}
pub fn solve<T: Model>(model: &mut T) {
println!("Magic solver that mutates the model into a resolved state");
}
假設我們有以下限制:
- 我們不能修改該程式碼,因為它已用在許多 Rust 模型上。
- 我們也有許多 Python 模型無法被求解,因為該求解器在那個語言中不可用。
由於一切都已在 Rust 中可用,把它改寫成 Python 既繁瑣又容易出錯。
那我們該如何透過 PyO3 將此求解器提供給 Python 使用?
為 Python 類別實作 trait bounds
如果某個 Python 類別實作了與 Model trait 相同的三個方法,直覺上它應該能配合此求解器。但實際上無法傳入 Py<PyAny>,因為它沒有實作 Rust 的 trait(即便 Python 模型具備所需方法也一樣)。
為了實作該 trait,我們必須在 Rust 端撰寫一個包裝器來呼叫 Python 模型。方法簽章必須與 trait 相同,因為 Rust trait 不能為了讓程式碼可在 Python 使用而被修改。
以下是我們想要對外暴露的 Python 模型,已包含所有必要方法:
class Model:
def set_variables(self, inputs):
self.inputs = inputs
def compute(self):
self.results = [elt**2 - 3 for elt in self.inputs]
def get_results(self):
return self.results
以下包裝器會從 Rust 呼叫 Python 模型,使用結構來持有該模型的 PyAny 物件:
#![allow(dead_code)]
use pyo3::prelude::*;
use pyo3::types::PyList;
pub trait Model {
fn set_variables(&mut self, inputs: &Vec<f64>);
fn compute(&mut self);
fn get_results(&self) -> Vec<f64>;
}
struct UserModel {
model: Py<PyAny>,
}
impl Model for UserModel {
fn set_variables(&mut self, var: &Vec<f64>) {
println!("Rust calling Python to set the variables");
Python::attach(|py| {
self.model
.bind(py)
.call_method("set_variables", (PyList::new(py, var).unwrap(),), None)
.unwrap();
})
}
fn get_results(&self) -> Vec<f64> {
println!("Rust calling Python to get the results");
Python::attach(|py| {
self.model
.bind(py)
.call_method("get_results", (), None)
.unwrap()
.extract()
.unwrap()
})
}
fn compute(&mut self) {
println!("Rust calling Python to perform the computation");
Python::attach(|py| {
self.model
.bind(py)
.call_method("compute", (), None)
.unwrap();
})
}
}
上述部分完成後,讓我們將模型包裝器對外暴露給 Python。加上 PyO3 標註並新增建構子:
#![allow(dead_code)]
fn main() {}
pub trait Model {
fn set_variables(&mut self, inputs: &Vec<f64>);
fn compute(&mut self);
fn get_results(&self) -> Vec<f64>;
}
use pyo3::prelude::*;
#[pyclass]
struct UserModel {
model: Py<PyAny>,
}
#[pymethods]
impl UserModel {
#[new]
pub fn new(model: Py<PyAny>) -> Self {
UserModel { model }
}
}
#[pymodule]
mod trait_exposure {
#[pymodule_export]
use super::UserModel;
}
接著在 trait 實作上加入 PyO3 標註:
#[pymethods]
impl Model for UserModel {
// 前述的 trait 實作
}
然而,前述程式碼無法編譯。錯誤訊息如下:error: #[pymethods] cannot be used on trait impl blocks
這很可惜!不過,我們可以再寫一個包裝器來直接呼叫這些函式。該包裝器也會負責在 Python 與 Rust 之間進行型別轉換。
#![allow(dead_code)]
use pyo3::prelude::*;
use pyo3::types::PyList;
pub trait Model {
fn set_variables(&mut self, inputs: &Vec<f64>);
fn compute(&mut self);
fn get_results(&self) -> Vec<f64>;
}
#[pyclass]
struct UserModel {
model: Py<PyAny>,
}
impl Model for UserModel {
fn set_variables(&mut self, var: &Vec<f64>) {
println!("Rust calling Python to set the variables");
Python::attach(|py| {
self.model.bind(py)
.call_method("set_variables", (PyList::new(py, var).unwrap(),), None)
.unwrap();
})
}
fn get_results(&self) -> Vec<f64> {
println!("Rust calling Python to get the results");
Python::attach(|py| {
self.model
.bind(py)
.call_method("get_results", (), None)
.unwrap()
.extract()
.unwrap()
})
}
fn compute(&mut self) {
println!("Rust calling Python to perform the computation");
Python::attach(|py| {
self.model
.bind(py)
.call_method("compute", (), None)
.unwrap();
})
}
}
#[pymethods]
impl UserModel {
pub fn set_variables(&mut self, var: Vec<f64>) {
println!("從 Python 呼叫 Rust 來設定變數");
Model::set_variables(self, &var)
}
pub fn get_results(&mut self) -> Vec<f64> {
println!("從 Python 呼叫 Rust 取得結果");
Model::get_results(self)
}
pub fn compute(&mut self) {
println!("從 Python 呼叫 Rust 執行計算");
Model::compute(self)
}
}
這個包裝器會處理 PyO3 要求與 trait 之間的型別轉換。為了符合 PyO3 的要求,此包裝器必須:
- 回傳
PyResult型別的物件 - 在方法簽章中僅使用值,而非參照
讓我們執行這個 Python 檔案:
class Model:
def set_variables(self, inputs):
self.inputs = inputs
def compute(self):
self.results = [elt**2 - 3 for elt in self.inputs]
def get_results(self):
return self.results
if __name__=="__main__":
import trait_exposure
myModel = Model()
my_rust_model = trait_exposure.UserModel(myModel)
my_rust_model.set_variables([2.0])
print("Print value from Python: ", myModel.inputs)
my_rust_model.compute()
print("Print value from Python through Rust: ", my_rust_model.get_results())
print("Print value directly from Python: ", myModel.get_results())
This outputs:
Set variables from Python calling Rust
Set variables from Rust calling Python
Print value from Python: [2.0]
Compute from Python calling Rust
Compute from Rust calling Python
Get results from Python calling Rust
Get results from Rust calling Python
Print value from Python through Rust: [1.0]
Print value directly from Python: [1.0]
We have now successfully exposed a Rust model that implements the Model trait to Python!
We will now expose the solve function, but before, let’s talk about types errors.
Type errors in Python
What happens if you have type errors when using Python and how can you improve the error messages?
Wrong types in Python function arguments
Let’s assume in the first case that you will use in your Python file my_rust_model.set_variables(2.0) instead of my_rust_model.set_variables([2.0]).
The Rust signature expects a vector, which corresponds to a list in Python. What happens if instead of a vector, we pass a single value ?
At the execution of Python, we get :
File "main.py", line 15, in <module>
my_rust_model.set_variables(2)
TypeError
It is a type error and Python points to it, so it’s easy to identify and solve.
Wrong types in Python method signatures
Let’s assume now that the return type of one of the methods of our Model class is wrong, for example the get_results method that is expected to return a Vec<f64> in Rust, a list in Python.
class Model:
def set_variables(self, inputs):
self.inputs = inputs
def compute(self):
self.results = [elt**2 -3 for elt in self.inputs]
def get_results(self):
return self.results[0]
#return self.results <-- this is the expected output
This call results in the following panic:
pyo3_runtime.PanicException: called `Result::unwrap()` on an `Err` value: PyErr { type: Py(0x10dcf79f0, PhantomData) }
This error code is not helpful for a Python user that does not know anything about Rust, or someone that does not know PyO3 was used to interface the Rust code.
However, as we are responsible for making the Rust code available to Python, we can do something about it.
The issue is that we called unwrap anywhere we could, and therefore any panic from PyO3 will be directly forwarded to the end user.
Let’s modify the code performing the type conversion to give a helpful error message to the Python user:
We used in our get_results method the following call that performs the type conversion:
#![allow(dead_code)]
use pyo3::prelude::*;
use pyo3::types::PyList;
pub trait Model {
fn set_variables(&mut self, inputs: &Vec<f64>);
fn compute(&mut self);
fn get_results(&self) -> Vec<f64>;
}
#[pyclass]
struct UserModel {
model: Py<PyAny>,
}
impl Model for UserModel {
fn get_results(&self) -> Vec<f64> {
println!("Rust calling Python to get the results");
Python::attach(|py| {
self.model
.bind(py)
.call_method("get_results", (), None)
.unwrap()
.extract()
.unwrap()
})
}
fn set_variables(&mut self, var: &Vec<f64>) {
println!("Rust calling Python to set the variables");
Python::attach(|py| {
self.model.bind(py)
.call_method("set_variables", (PyList::new(py, var).unwrap(),), None)
.unwrap();
})
}
fn compute(&mut self) {
println!("Rust calling Python to perform the computation");
Python::attach(|py| {
self.model
.bind(py)
.call_method("compute", (), None)
.unwrap();
})
}
}
Let’s break it down in order to perform better error handling:
#![allow(dead_code)]
use pyo3::prelude::*;
use pyo3::types::PyList;
pub trait Model {
fn set_variables(&mut self, inputs: &Vec<f64>);
fn compute(&mut self);
fn get_results(&self) -> Vec<f64>;
}
#[pyclass]
struct UserModel {
model: Py<PyAny>,
}
impl Model for UserModel {
fn get_results(&self) -> Vec<f64> {
println!("Get results from Rust calling Python");
Python::attach(|py| {
let py_result: Bound<'_, PyAny> = self
.model
.bind(py)
.call_method("get_results", (), None)
.unwrap();
if py_result.get_type().name().unwrap() != "list" {
panic!(
"Expected a list for the get_results() method signature, got {}",
py_result.get_type().name().unwrap()
);
}
py_result.extract()
})
.unwrap()
}
fn set_variables(&mut self, var: &Vec<f64>) {
println!("Rust calling Python to set the variables");
Python::attach(|py| {
let py_model = self.model.bind(py)
.call_method("set_variables", (PyList::new(py, var).unwrap(),), None)
.unwrap();
})
}
fn compute(&mut self) {
println!("Rust calling Python to perform the computation");
Python::attach(|py| {
self.model
.bind(py)
.call_method("compute", (), None)
.unwrap();
})
}
}
By doing so, you catch the result of the Python computation and check its type in order to be able to deliver a better error message before performing the unwrapping.
Of course, it does not cover all the possible wrong outputs: the user could return a list of strings instead of a list of floats. In this case, a runtime panic would still occur due to PyO3, but with an error message much more difficult to decipher for non-rust user.
It is up to the developer exposing the rust code to decide how much effort to invest into Python type error handling and improved error messages.
The final code
Now let’s expose the solve() function to make it available from Python.
It is not possible to directly expose the solve function to Python, as the type conversion cannot be performed. It requires an object implementing the Model trait as input.
However, the UserModel already implements this trait. Because of this, we can write a function wrapper that takes the UserModel--which has already been exposed to Python–as an argument in order to call the core function solve.
It is also required to make the struct public.
#![allow(dead_code)]
fn main() {}
use pyo3::prelude::*;
use pyo3::types::PyList;
pub trait Model {
fn set_variables(&mut self, var: &Vec<f64>);
fn get_results(&self) -> Vec<f64>;
fn compute(&mut self);
}
pub fn solve<T: Model>(model: &mut T) {
println!("Magic solver that mutates the model into a resolved state");
}
#[pyfunction]
#[pyo3(name = "solve")]
pub fn solve_wrapper(model: &mut UserModel) {
solve(model);
}
#[pyclass]
pub struct UserModel {
model: Py<PyAny>,
}
#[pymethods]
impl UserModel {
#[new]
pub fn new(model: Py<PyAny>) -> Self {
UserModel { model }
}
pub fn set_variables(&mut self, var: Vec<f64>) {
println!("從 Python 呼叫 Rust 來設定變數");
Model::set_variables(self, &var)
}
pub fn get_results(&mut self) -> Vec<f64> {
println!("從 Python 呼叫 Rust 取得結果");
Model::get_results(self)
}
pub fn compute(&mut self) {
Model::compute(self)
}
}
#[pymodule]
mod trait_exposure {
#[pymodule_export]
use super::{UserModel, solve_wrapper};
}
impl Model for UserModel {
fn set_variables(&mut self, var: &Vec<f64>) {
println!("Rust calling Python to set the variables");
Python::attach(|py| {
self.model
.bind(py)
.call_method("set_variables", (PyList::new(py, var).unwrap(),), None)
.unwrap();
})
}
fn get_results(&self) -> Vec<f64> {
println!("Get results from Rust calling Python");
Python::attach(|py| {
let py_result: Bound<'_, PyAny> = self
.model
.bind(py)
.call_method("get_results", (), None)
.unwrap();
if py_result.get_type().name().unwrap() != "list" {
panic!(
"Expected a list for the get_results() method signature, got {}",
py_result.get_type().name().unwrap()
);
}
py_result.extract()
})
.unwrap()
}
fn compute(&mut self) {
println!("Rust calling Python to perform the computation");
Python::attach(|py| {
self.model
.bind(py)
.call_method("compute", (), None)
.unwrap();
})
}
}