@@ -7,8 +7,9 @@ use eyre::{bail, Context, Result};
77use itertools:: Itertools ;
88use minijinja:: { context, Environment } ;
99
10+ use super :: common:: write_pyproject_toml;
1011use super :: kernel_ops_identifier;
11- use crate :: config:: { Backend , Build , Dependencies , Kernel , Torch } ;
12+ use crate :: config:: { Backend , Build , Dependency , Kernel , Torch } ;
1213use crate :: version:: Version ;
1314use crate :: FileSet ;
1415
@@ -60,7 +61,7 @@ pub fn write_torch_ext_cuda(
6061
6162 write_ops_py ( env, & build. general . python_name ( ) , & ops_name, & mut file_set) ?;
6263
63- write_pyproject_toml ( env, & mut file_set) ?;
64+ write_pyproject_toml ( env, & build . general , & mut file_set) ?;
6465
6566 write_torch_registration_macros ( & mut file_set) ?;
6667
@@ -78,17 +79,6 @@ fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> {
7879 Ok ( ( ) )
7980}
8081
81- fn write_pyproject_toml ( env : & Environment , file_set : & mut FileSet ) -> Result < ( ) > {
82- let writer = file_set. entry ( "pyproject.toml" ) ;
83-
84- env. get_template ( "pyproject.toml" )
85- . wrap_err ( "Cannot get pyproject.toml template" ) ?
86- . render_to_write ( context ! { } , writer)
87- . wrap_err ( "Cannot render kernel template" ) ?;
88-
89- Ok ( ( ) )
90- }
91-
9282fn write_setup_py (
9383 env : & Environment ,
9484 torch : & Torch ,
@@ -230,7 +220,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
230220
231221 for dep in deps {
232222 match dep {
233- Dependencies :: Cutlass2_10 => {
223+ Dependency :: Cutlass2_10 => {
234224 env. get_template ( "cuda/dep-cutlass.cmake" )
235225 . wrap_err ( "Cannot get CUTLASS dependency template" ) ?
236226 . render_to_write (
@@ -241,7 +231,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
241231 )
242232 . wrap_err ( "Cannot render CUTLASS dependency template" ) ?;
243233 }
244- Dependencies :: Cutlass3_5 => {
234+ Dependency :: Cutlass3_5 => {
245235 env. get_template ( "cuda/dep-cutlass.cmake" )
246236 . wrap_err ( "Cannot get CUTLASS dependency template" ) ?
247237 . render_to_write (
@@ -252,7 +242,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
252242 )
253243 . wrap_err ( "Cannot render CUTLASS dependency template" ) ?;
254244 }
255- Dependencies :: Cutlass3_6 => {
245+ Dependency :: Cutlass3_6 => {
256246 env. get_template ( "cuda/dep-cutlass.cmake" )
257247 . wrap_err ( "Cannot get CUTLASS dependency template" ) ?
258248 . render_to_write (
@@ -263,7 +253,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
263253 )
264254 . wrap_err ( "Cannot render CUTLASS dependency template" ) ?;
265255 }
266- Dependencies :: Cutlass3_8 => {
256+ Dependency :: Cutlass3_8 => {
267257 env. get_template ( "cuda/dep-cutlass.cmake" )
268258 . wrap_err ( "Cannot get CUTLASS dependency template" ) ?
269259 . render_to_write (
@@ -274,7 +264,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
274264 )
275265 . wrap_err ( "Cannot render CUTLASS dependency template" ) ?;
276266 }
277- Dependencies :: Cutlass3_9 => {
267+ Dependency :: Cutlass3_9 => {
278268 env. get_template ( "cuda/dep-cutlass.cmake" )
279269 . wrap_err ( "Cannot get CUTLASS dependency template" ) ?
280270 . render_to_write (
@@ -285,7 +275,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
285275 )
286276 . wrap_err ( "Cannot render CUTLASS dependency template" ) ?;
287277 }
288- Dependencies :: Cutlass4_0 => {
278+ Dependency :: Cutlass4_0 => {
289279 env. get_template ( "cuda/dep-cutlass.cmake" )
290280 . wrap_err ( "Cannot get CUTLASS dependency template" ) ?
291281 . render_to_write (
@@ -296,7 +286,7 @@ fn render_deps(env: &Environment, build: &Build, write: &mut impl Write) -> Resu
296286 )
297287 . wrap_err ( "Cannot render CUTLASS dependency template" ) ?;
298288 }
299- Dependencies :: Torch => ( ) ,
289+ Dependency :: Torch => ( ) ,
300290 _ => {
301291 eprintln ! ( "Warning: CUDA backend doesn't need/support dependency: {dep:?}" ) ;
302292 }
0 commit comments