diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index cad0d42036..ca703cb742 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -53,6 +53,7 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/MapVector.h" +#include #include #include #include @@ -13045,6 +13046,12 @@ struct BroadcastInDimOpCanon final // Eliminate redundant nested BroadcastInDim. if (auto definingOp = operand.getDefiningOp()) { + DenseElementsAttr denseAttr; + if (matchPattern(definingOp.getOperand(), m_Constant(&denseAttr)) && + !denseAttr.isSplat()) { + // TODO: investigate why this leads to incorrect results + return failure(); + } auto newIndices = llvm::to_vector( llvm::map_range(definingOp.getBroadcastDimensions(), [&dims](int64_t dim) { return dims[dim]; })); diff --git a/test/neuralgcm_test.py b/test/neuralgcm_test.py index edd8d74513..0ed060cd91 100644 --- a/test/neuralgcm_test.py +++ b/test/neuralgcm_test.py @@ -111,9 +111,6 @@ def forward(initial_state, all_forcings): self.atol = 5e-2 self.rtol = 1e-2 - # TODO: we should fix this at some point - self.skip_test_assert = True - if __name__ == "__main__": from test_utils import fix_paths diff --git a/test/test_utils.py b/test/test_utils.py index 5bcb1e568a..5580afb5e0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -51,7 +51,7 @@ def fix_paths(): # https://github.com/jax-ml/jax/blob/af36ae2cd783aea9eaa7979170df760a52542fcd/jax/_src/lib/__init__.py#L185 os.environ["PYTHON_RUNFILES"] = runfiles # https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html - os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95" + # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9" cuda_version = 12 cuda_postfix = "_cu12" @@ -419,11 +419,11 @@ def pipelines(): setup_backends() return [ - get_pipeline("JaxPipe"), get_pipeline("Jax"), - get_pipeline("HLOOpt"), + get_pipeline("JaxPipe"), get_pipeline("PartOpt"), get_pipeline("IPartOpt"), + get_pipeline("HLOOpt"), get_pipeline("DefOpt"), get_pipeline("IDefOpt"), ]