Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.

Commit 7a24559

Browse files
authored
Add support for (limited) Python dependencies: nvidia-cutlass-dsl and einops (#302)
This change adds support for kernel dependencies. Dependencies are specified in `build.toml`, e.g.: ```toml [general] python-depends = [ "nvidia-cutlass-dsl" ] ``` The set of dependencies is limited to avoid bringing back the problem of resolving Python dependencies. Currently only `nvidia-cutlass-dsl` and `einops` are supported. This PR also adds the derivations for the `nvidia-cutlass-dsl` package.
1 parent 1d12b13 commit 7a24559

File tree

26 files changed

+485
-127
lines changed

26 files changed

+485
-127
lines changed

build2cmake/src/config/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pub mod v1;
55

66
mod v2;
77
use serde_value::Value;
8-
pub use v2::{Backend, Build, Dependencies, Kernel, Torch};
8+
pub use v2::{Backend, Build, Dependency, General, Kernel, Torch};
99

1010
#[derive(Debug)]
1111
pub enum BuildCompat {

build2cmake/src/config/v1.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{collections::HashMap, fmt::Display, path::PathBuf};
22

33
use serde::Deserialize;
44

5-
use super::v2::Dependencies;
5+
use super::v2::Dependency;
66

77
#[derive(Debug, Deserialize)]
88
#[serde(deny_unknown_fields)]
@@ -40,7 +40,7 @@ pub struct Kernel {
4040
pub rocm_archs: Option<Vec<String>>,
4141
#[serde(default)]
4242
pub language: Language,
43-
pub depends: Vec<Dependencies>,
43+
pub depends: Vec<Dependency>,
4444
pub include: Option<Vec<String>>,
4545
pub src: Vec<String>,
4646
}

build2cmake/src/config/v2.rs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ pub struct General {
5454
pub cuda_minver: Option<Version>,
5555

5656
pub hub: Option<Hub>,
57+
58+
pub python_depends: Option<Vec<PythonDependency>>,
5759
}
5860

5961
impl General {
@@ -70,6 +72,22 @@ pub struct Hub {
7072
pub branch: Option<String>,
7173
}
7274

75+
#[derive(Clone, Debug, Deserialize, Serialize)]
76+
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
77+
pub enum PythonDependency {
78+
Einops,
79+
NvidiaCutlassDsl,
80+
}
81+
82+
impl Display for PythonDependency {
83+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84+
match self {
85+
PythonDependency::Einops => write!(f, "einops"),
86+
PythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"),
87+
}
88+
}
89+
}
90+
7391
#[derive(Debug, Deserialize, Clone, Serialize)]
7492
#[serde(deny_unknown_fields)]
7593
pub struct Torch {
@@ -107,7 +125,7 @@ pub enum Kernel {
107125
#[serde(rename_all = "kebab-case")]
108126
Cpu {
109127
cxx_flags: Option<Vec<String>>,
110-
depends: Vec<Dependencies>,
128+
depends: Vec<Dependency>,
111129
include: Option<Vec<String>>,
112130
src: Vec<String>,
113131
},
@@ -117,21 +135,21 @@ pub enum Kernel {
117135
cuda_flags: Option<Vec<String>>,
118136
cuda_minver: Option<Version>,
119137
cxx_flags: Option<Vec<String>>,
120-
depends: Vec<Dependencies>,
138+
depends: Vec<Dependency>,
121139
include: Option<Vec<String>>,
122140
src: Vec<String>,
123141
},
124142
#[serde(rename_all = "kebab-case")]
125143
Metal {
126144
cxx_flags: Option<Vec<String>>,
127-
depends: Vec<Dependencies>,
145+
depends: Vec<Dependency>,
128146
include: Option<Vec<String>>,
129147
src: Vec<String>,
130148
},
131149
#[serde(rename_all = "kebab-case")]
132150
Rocm {
133151
cxx_flags: Option<Vec<String>>,
134-
depends: Vec<Dependencies>,
152+
depends: Vec<Dependency>,
135153
rocm_archs: Option<Vec<String>>,
136154
hip_flags: Option<Vec<String>>,
137155
include: Option<Vec<String>>,
@@ -140,7 +158,7 @@ pub enum Kernel {
140158
#[serde(rename_all = "kebab-case")]
141159
Xpu {
142160
cxx_flags: Option<Vec<String>>,
143-
depends: Vec<Dependencies>,
161+
depends: Vec<Dependency>,
144162
sycl_flags: Option<Vec<String>>,
145163
include: Option<Vec<String>>,
146164
src: Vec<String>,
@@ -178,7 +196,7 @@ impl Kernel {
178196
}
179197
}
180198

181-
pub fn depends(&self) -> &[Dependencies] {
199+
pub fn depends(&self) -> &[Dependency] {
182200
match self {
183201
Kernel::Cpu { depends, .. }
184202
| Kernel::Cuda { depends, .. }
@@ -239,7 +257,7 @@ impl FromStr for Backend {
239257
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
240258
#[non_exhaustive]
241259
#[serde(rename_all = "lowercase")]
242-
pub enum Dependencies {
260+
pub enum Dependency {
243261
#[serde(rename = "cutlass_2_10")]
244262
Cutlass2_10,
245263
#[serde(rename = "cutlass_3_5")]
@@ -284,6 +302,7 @@ impl General {
284302
cuda_maxver: None,
285303
cuda_minver: None,
286304
hub: None,
305+
python_depends: None,
287306
}
288307
}
289308
}

build2cmake/src/templates/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ requires = [
66
"setuptools>=61",
77
"torch",
88
"wheel",
9+
{{python_dependencies}}
910
]
1011
build-backend = "setuptools.build_meta"

build2cmake/src/templates/universal/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
name = "{{ name }}"
33
version = "0.0.1"
44
requires-python = ">= 3.9"
5-
dependencies = ["torch>=2.4"]
5+
dependencies = [
6+
"torch>=2.8",
7+
{{python_dependencies}}
8+
]
69

710
[tool.setuptools]
811
package-dir = { "" = "torch-ext" }

build2cmake/src/torch/common.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use eyre::{Context, Result};
2+
use itertools::Itertools;
3+
use minijinja::{context, Environment};
4+
5+
use crate::{config::General, FileSet};
6+
7+
pub fn write_pyproject_toml(
8+
env: &Environment,
9+
general: &General,
10+
file_set: &mut FileSet,
11+
) -> Result<()> {
12+
let writer = file_set.entry("pyproject.toml");
13+
14+
let python_dependencies = general
15+
.python_depends
16+
.as_ref()
17+
.unwrap_or(&vec![])
18+
.iter()
19+
.map(|d| format!("\"{d}\""))
20+
.join(", ");
21+
22+
env.get_template("pyproject.toml")
23+
.wrap_err("Cannot get pyproject.toml template")?
24+
.render_to_write(
25+
context! {
26+
python_dependencies => python_dependencies,
27+
},
28+
writer,
29+
)
30+
.wrap_err("Cannot render kernel template")?;
31+
32+
Ok(())
33+
}

build2cmake/src/torch/cpu.rs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use eyre::{bail, Context, Result};
44
use itertools::Itertools;
55
use minijinja::{context, Environment};
66

7-
use super::kernel_ops_identifier;
7+
use super::{common::write_pyproject_toml, kernel_ops_identifier};
88
use crate::{
99
config::{Build, Kernel, Torch},
1010
fileset::FileSet,
@@ -47,7 +47,7 @@ pub fn write_torch_ext_cpu(
4747

4848
write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?;
4949

50-
write_pyproject_toml(env, &mut file_set)?;
50+
write_pyproject_toml(env, &build.general, &mut file_set)?;
5151

5252
write_torch_registration_macros(&mut file_set)?;
5353

@@ -209,17 +209,6 @@ fn write_ops_py(
209209
Ok(())
210210
}
211211

212-
fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> {
213-
let writer = file_set.entry("pyproject.toml");
214-
215-
env.get_template("pyproject.toml")
216-
.wrap_err("Cannot get pyproject.toml template")?
217-
.render_to_write(context! {}, writer)
218-
.wrap_err("Cannot render kernel template")?;
219-
220-
Ok(())
221-
}
222-
223212
fn write_setup_py(
224213
env: &Environment,
225214
torch: &Torch,

build2cmake/src/torch/cuda.rs

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ use eyre::{bail, Context, Result};
77
use itertools::Itertools;
88
use minijinja::{context, Environment};
99

10+
use super::common::write_pyproject_toml;
1011
use super::kernel_ops_identifier;
11-
use crate::config::{Backend, Build, Dependencies, Kernel, Torch};
12+
use crate::config::{Backend, Build, Dependency, Kernel, Torch};
1213
use crate::version::Version;
1314
use crate::FileSet;
1415

@@ -60,7 +61,7 @@ pub fn write_torch_ext_cuda(
6061

6162
write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?;
6263

63-
write_pyproject_toml(env, &mut file_set)?;
64+
write_pyproject_toml(env, &build.general, &mut file_set)?;
6465

6566
write_torch_registration_macros(&mut file_set)?;
6667

@@ -78,17 +79,6 @@ fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> {
7879
Ok(())
7980
}
8081

81-
fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> {
82-
let writer = file_set.entry("pyproject.toml");
83-
84-
env.get_template("pyproject.toml")
85-
.wrap_err("Cannot get pyproject.toml template")?
86-
.render_to_write(context! {}, writer)
87-
.wrap_err("Cannot render kernel template")?;
88-
89-
Ok(())
90-
}
91-
9282
fn write_setup_py(
9383
env: &Environment,
9484
torch: &Torch,
@@ -230,7 +220,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
230220

231221
for dep in deps {
232222
match dep {
233-
Dependencies::Cutlass2_10 => {
223+
Dependency::Cutlass2_10 => {
234224
env.get_template("cuda/dep-cutlass.cmake")
235225
.wrap_err("Cannot get CUTLASS dependency template")?
236226
.render_to_write(
@@ -241,7 +231,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
241231
)
242232
.wrap_err("Cannot render CUTLASS dependency template")?;
243233
}
244-
Dependencies::Cutlass3_5 => {
234+
Dependency::Cutlass3_5 => {
245235
env.get_template("cuda/dep-cutlass.cmake")
246236
.wrap_err("Cannot get CUTLASS dependency template")?
247237
.render_to_write(
@@ -252,7 +242,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
252242
)
253243
.wrap_err("Cannot render CUTLASS dependency template")?;
254244
}
255-
Dependencies::Cutlass3_6 => {
245+
Dependency::Cutlass3_6 => {
256246
env.get_template("cuda/dep-cutlass.cmake")
257247
.wrap_err("Cannot get CUTLASS dependency template")?
258248
.render_to_write(
@@ -263,7 +253,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
263253
)
264254
.wrap_err("Cannot render CUTLASS dependency template")?;
265255
}
266-
Dependencies::Cutlass3_8 => {
256+
Dependency::Cutlass3_8 => {
267257
env.get_template("cuda/dep-cutlass.cmake")
268258
.wrap_err("Cannot get CUTLASS dependency template")?
269259
.render_to_write(
@@ -274,7 +264,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
274264
)
275265
.wrap_err("Cannot render CUTLASS dependency template")?;
276266
}
277-
Dependencies::Cutlass3_9 => {
267+
Dependency::Cutlass3_9 => {
278268
env.get_template("cuda/dep-cutlass.cmake")
279269
.wrap_err("Cannot get CUTLASS dependency template")?
280270
.render_to_write(
@@ -285,7 +275,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
285275
)
286276
.wrap_err("Cannot render CUTLASS dependency template")?;
287277
}
288-
Dependencies::Cutlass4_0 => {
278+
Dependency::Cutlass4_0 => {
289279
env.get_template("cuda/dep-cutlass.cmake")
290280
.wrap_err("Cannot get CUTLASS dependency template")?
291281
.render_to_write(
@@ -296,7 +286,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
296286
)
297287
.wrap_err("Cannot render CUTLASS dependency template")?;
298288
}
299-
Dependencies::Torch => (),
289+
Dependency::Torch => (),
300290
_ => {
301291
eprintln!("Warning: CUDA backend doesn't need/support dependency: {dep:?}");
302292
}

build2cmake/src/torch/metal.rs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use eyre::{bail, Context, Result};
44
use itertools::Itertools;
55
use minijinja::{context, Environment};
66

7-
use super::kernel_ops_identifier;
7+
use super::{common::write_pyproject_toml, kernel_ops_identifier};
88
use crate::{
99
config::{Build, Kernel, Torch},
1010
fileset::FileSet,
@@ -49,7 +49,7 @@ pub fn write_torch_ext_metal(
4949

5050
write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?;
5151

52-
write_pyproject_toml(env, &mut file_set)?;
52+
write_pyproject_toml(env, &build.general, &mut file_set)?;
5353

5454
write_torch_registration_macros(&mut file_set)?;
5555

@@ -225,17 +225,6 @@ fn write_ops_py(
225225
Ok(())
226226
}
227227

228-
fn write_pyproject_toml(env: &Environment, file_set: &mut FileSet) -> Result<()> {
229-
let writer = file_set.entry("pyproject.toml");
230-
231-
env.get_template("pyproject.toml")
232-
.wrap_err("Cannot get pyproject.toml template")?
233-
.render_to_write(context! {}, writer)
234-
.wrap_err("Cannot render kernel template")?;
235-
236-
Ok(())
237-
}
238-
239228
fn write_setup_py(
240229
env: &Environment,
241230
torch: &Torch,

build2cmake/src/torch/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ pub use cpu::write_torch_ext_cpu;
44
mod cuda;
55
pub use cuda::write_torch_ext_cuda;
66

7+
pub mod common;
8+
79
mod metal;
810
pub use metal::write_torch_ext_metal;
911

0 commit comments

Comments
 (0)