Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.

Commit 574f6ea

Browse files
committed
WIP
1 parent da0367c commit 574f6ea

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

overlay.nix

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ in
144144
torchVersion = "2.9";
145145
xpuPackages = final.xpuPackages_2025_2;
146146
};
147+
148+
triton-xpu_2_10 = callPackage ./pkgs/python-modules/triton-xpu {
149+
torchVersion = "2.10";
150+
xpuPackages = final.xpuPackages_2025_3;
151+
};
147152
}
148153
)
149154
(import ./pkgs/python-modules/hooks)

pkgs/python-modules/torch/binary/generic.nix

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,8 @@
4848
effectiveStdenv ? if cudaSupport then cudaPackages.backendStdenv else stdenv,
4949
}:
5050
let
51-
effectiveTriton =
52-
if cudaSupport then
53-
triton-cuda
54-
else if xpuSupport then
55-
python.pkgs.triton-xpu_2_8
56-
else
57-
triton;
58-
5951
archs = (import ../archs.nix).${lib.versions.majorMinor version};
52+
torchMajorMinor = lib.versions.majorMinor version;
6053

6154
supportedTorchCudaCapabilities =
6255
let
@@ -68,16 +61,27 @@ let
6861
supportedCudaCapabilities = lib.intersectLists cudaPackages.flags.cudaCapabilities supportedTorchCudaCapabilities;
6962
inherit (archs) supportedTorchRocmArchs;
7063

64+
xpuTritonVersions = {
65+
"2.8" = python.pkgs.triton-xpu_2_8;
66+
"2.9" = python.pkgs.triton-xpu_2_9;
67+
"2.10" = python.pkgs.triton-xpu_2_10;
68+
};
69+
70+
effectiveTriton =
71+
if cudaSupport then
72+
triton-cuda
73+
else if xpuSupport then
74+
xpuTritonVersions.${torchMajorMinor}
75+
else
76+
triton;
77+
7178
aotritonVersions = with rocmPackages; {
7279
"2.8" = aotriton_0_10;
7380
"2.9" = aotriton_0_11;
7481
"2.10" = aotriton_0_11_1;
7582
};
7683

7784
aotriton =
78-
let
79-
torchMajorMinor = lib.versions.majorMinor version;
80-
in
8185
aotritonVersions.${torchMajorMinor}
8286
or (throw "aotriton version is not specified Torch ${torchMajorMinor}");
8387

pkgs/python-modules/triton-xpu/default.nix

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ let
5656
hash = "sha256-N8NBAkkpOcbgap4loPJJW6E5bjG+TixCh/HN259RyjI=";
5757
};
5858
};
59+
"2.10" = {
60+
# https://github.com/intel/intel-xpu-backend-for-triton/blob/225cdbde3ea155d5ed4c0aad1f2aa4bd2b3c4a3d/cmake/llvm-hash.txt
61+
llvm = {
62+
rev = "f6ded0be897e2878612dd903f7e8bb85448269e5";
63+
hash = "sha256-T76zHZZ2bp3Ye9GTV+MgbKqMbtmMGElMFsWuCkiWqrM=";
64+
};
65+
# https://github.com/pytorch/pytorch/tree/v2.10.0-rc7/.ci/docker/ci_commit_pins
66+
triton = {
67+
rev = "225cdbde3ea155d5ed4c0aad1f2aa4bd2b3c4a3d";
68+
hash = "sha256-AuNk4FMBwi7y1zWGhN/P0JsYwPuKV79JBLDDw6IVouA=";
69+
};
70+
# https://github.com/intel/intel-xpu-backend-for-triton/blob/225cdbde3ea155d5ed4c0aad1f2aa4bd2b3c4a3d/third_party/intel/lib/Target/SPIRV/spirv-llvm-translator.conf
71+
spirv_llm = {
72+
rev = "daba8b217bc266806ac00095262d1af0ba2ee610";
73+
hash = "sha256-X/Pk1GpA1Se6UFp1UIbNAW1JLTj3vgFtg9b7Niv3/ro=";
74+
};
75+
# https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/daba8b217bc266806ac00095262d1af0ba2ee610/spirv-headers-tag.conf
76+
spirv_headers = {
77+
rev = "9e3836d7d6023843a72ecd3fbf3f09b1b6747a9e";
78+
hash = "sha256-N8NBAkkpOcbgap4loPJJW6E5bjG+TixCh/HN259RyjI=";
79+
};
80+
};
5981
};
6082
tritonVersions =
6183
torchTritonVersions.${torchVersion} or (throw "Unsupported Torch version: ${torchVersion}");
@@ -67,7 +89,7 @@ let
6789
"SPIRV"
6890
];
6991
}
70-
// lib.optionalAttrs (torchVersion == "2.9") {
92+
// lib.optionalAttrs (lib.versionAtLeast torchVersion "2.9") {
7193
llvmProjectsToBuild = [
7294
"mlir"
7395
"llvm"
@@ -119,6 +141,10 @@ buildPythonPackage rec {
119141
sed -i 's/-Werror//g' $NIX_BUILD_TOP/source/CMakeLists.txt
120142
sed -i 's/ninja==1.11.1.4/ninja>=1.11.1/' $NIX_BUILD_TOP/source/pyproject.toml
121143
''}
144+
${lib.optionalString (torchVersion == "2.10") ''
145+
sed -i 's/-Werror//g' $NIX_BUILD_TOP/source/CMakeLists.txt
146+
sed -i 's/ninja<1.13.0/ninja/' $NIX_BUILD_TOP/source/pyproject.toml
147+
''}
122148
sed -i '/if (NOT SPIRVToLLVMTranslator_FOUND)/,/endif (NOT SPIRVToLLVMTranslator_FOUND)/c\
123149
set(SPIRVToLLVMTranslator_SOURCE_DIR "${spirvLlvmTranslatorSrc}")\n\
124150
set(SPIRVToLLVMTranslator_BINARY_DIR \''${CMAKE_CURRENT_BINARY_DIR}/SPIRVToLLVMTranslator-build)\n\

0 commit comments

Comments
 (0)